feat: Add RAG based OpenAPI service integration (#6555)

* Add OpenAPIServiceConnector and OpenAPIServiceToFunctions

* Add release note

* Add test deps

* Better docs on OpenAPI spec reqs, improve tests

* Silvano PR feedback
This commit is contained in:
Vladimir Blagojevic 2023-12-19 13:27:41 +01:00 committed by GitHub
parent 94cfe5d9ae
commit 2dd5a94b04
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 658 additions and 0 deletions

View File

@ -0,0 +1,3 @@
from haystack.components.connectors.openapi_service import OpenAPIServiceConnector
__all__ = ["OpenAPIServiceConnector"]

View File

@ -0,0 +1,128 @@
import json
import logging
from typing import List, Dict, Any, Optional
from haystack import component
from haystack.dataclasses import ChatMessage, ChatRole
from haystack.lazy_imports import LazyImport
logger = logging.getLogger(__name__)
with LazyImport("Run 'pip install openapi3'") as openapi_imports:
from openapi3 import OpenAPI
@component
class OpenAPIServiceConnector:
"""
OpenAPIServiceConnector connects to OpenAPI services, allowing for the invocation of methods specified in
an OpenAPI specification of that service. It integrates with ChatMessage interface, where messages are used to
determine the method to be called and the parameters to be passed. The message payload should be a JSON formatted
string consisting of the method name and the parameters to be passed to the method. The method name and parameters
are then used to invoke the method on the OpenAPI service. The response from the service is returned as a
ChatMessage.
Before using this component, one needs to register functions from the OpenAPI specification with LLM.
This can be done using the OpenAPIServiceToFunctions component.
"""
def __init__(self, service_auths: Optional[Dict[str, Any]] = None):
"""
Initializes the OpenAPIServiceConnector instance
:param service_auths: A dictionary containing the service name and token to be used for authentication.
"""
openapi_imports.check()
self.service_authentications = service_auths or {}
@component.output_types(service_response=Dict[str, Any])
def run(self, messages: List[ChatMessage], service_openapi_spec: Dict[str, Any]) -> Dict[str, List[ChatMessage]]:
"""
Processes a list of chat messages to invoke a method on an OpenAPI service. It parses the last message in the
list, expecting it to contain an OpenAI function calling descriptor (name & parameters) in JSON format.
:param messages: A list of `ChatMessage` objects representing the chat history.
:type messages: List[ChatMessage]
:param service_openapi_spec: The OpenAPI JSON specification object of the service.
:type service_openapi_spec: JSON object
:return: A dictionary with a key `"service_response"`, containing the response from the OpenAPI service.
:rtype: Dict[str, List[ChatMessage]]
:raises ValueError: If the last message is not from the assistant or if it does not contain the correct payload
to invoke a method on the service.
"""
last_message = messages[-1]
if not last_message.is_from(ChatRole.ASSISTANT):
raise ValueError(f"{last_message} is not from the assistant.")
method_invocation_descriptor = self._parse_message(last_message.content)
# instantiate the OpenAPI service for the given specification
openapi_service = OpenAPI(service_openapi_spec)
self._authenticate_service(openapi_service)
service_response = self._invoke_method(openapi_service, method_invocation_descriptor)
return {"service_response": [ChatMessage.from_user(str(service_response))]}
def _parse_message(self, content: str) -> Dict[str, Any]:
"""
Parses the message content to extract the method invocation descriptor.
:param content: The JSON string content of the message.
:type content: str
:return: A dictionary with method name and arguments.
:rtype: Dict[str, Any]
:raises ValueError: If the content is not valid JSON or lacks required fields.
"""
try:
method_invocation_descriptor = json.loads(content)
except json.JSONDecodeError:
raise ValueError("Invalid JSON content, cannot parse invocation message.", content)
if "name" not in method_invocation_descriptor or "arguments" not in method_invocation_descriptor:
raise ValueError("Missing required fields in the invocation message content.", content)
method_invocation_descriptor["arguments"] = json.loads(method_invocation_descriptor["arguments"])
return method_invocation_descriptor
def _authenticate_service(self, openapi_service: OpenAPI):
"""
Authenticates with the OpenAPI service if required.
:param openapi_service: The OpenAPI service instance.
:type openapi_service: OpenAPI
:raises ValueError: If authentication fails or is not found.
"""
if openapi_service.components.securitySchemes:
auth_method = list(openapi_service.components.securitySchemes.keys())[0]
service_title = openapi_service.info.title
if service_title not in self.service_authentications:
raise ValueError(f"Service {service_title} not found in service_authentications.")
openapi_service.authenticate(auth_method, self.service_authentications[service_title])
def _invoke_method(self, openapi_service: OpenAPI, method_invocation_descriptor: Dict[str, Any]) -> Any:
"""
Invokes the specified method on the OpenAPI service.
:param openapi_service: The OpenAPI service instance.
:type openapi_service: OpenAPI
:param method_invocation_descriptor: The method name and arguments.
:type method_invocation_descriptor: Dict[str, Any]
:return: A service JSON response.
:rtype: Any
:raises RuntimeError: If the method is not found or invocation fails.
"""
name = method_invocation_descriptor["name"]
# a bit convoluted, but we need to pass parameters, data, or both to the method
# depending on the openapi operation specification, can't use None as a default value
method_call_params = {}
if (parameters := method_invocation_descriptor["arguments"].get("parameters")) is not None:
method_call_params["parameters"] = parameters
if (arguments := method_invocation_descriptor["arguments"].get("requestBody")) is not None:
method_call_params["data"] = arguments
method_to_call = getattr(openapi_service, f"call_{name}", None)
if not callable(method_to_call):
raise RuntimeError(f"Operation {name} not found in OpenAPI specification {openapi_service.info.title}")
# this will call the underlying service REST API
return method_to_call(**method_call_params)

View File

@ -4,6 +4,7 @@ from haystack.components.converters.azure import AzureOCRDocumentConverter
from haystack.components.converters.pypdf import PyPDFToDocument
from haystack.components.converters.html import HTMLToDocument
from haystack.components.converters.markdown import MarkdownToDocument
from haystack.components.converters.openapi_functions import OpenAPIServiceToFunctions
__all__ = [
"TextFileToDocument",
@ -12,4 +13,5 @@ __all__ = [
"PyPDFToDocument",
"HTMLToDocument",
"MarkdownToDocument",
"OpenAPIServiceToFunctions",
]

View File

@ -0,0 +1,216 @@
import json
import logging
import os
from pathlib import Path
from typing import List, Dict, Any, Union, Optional
import requests
import yaml
from requests import RequestException
from haystack import component, Document
from haystack.dataclasses.byte_stream import ByteStream
from haystack.lazy_imports import LazyImport
logger = logging.getLogger(__name__)
with LazyImport("Run 'pip install jsonref'") as openapi_imports:
import jsonref
@component
class OpenAPIServiceToFunctions:
"""
OpenAPIServiceToFunctions is responsible for converting an OpenAPI service specification into a format suitable
for OpenAI function calling, based on the provided OpenAPI specification. Given an OpenAPI specification,
OpenAPIServiceToFunctions processes it, and extracts function definitions that can be invoked via OpenAI's
function calling mechanism. The format of the extracted functions is compatible with OpenAI's function calling
JSON format.
Minimal requirements for OpenAPI specification:
- OpenAPI version 3.0.0 or higher
- Each function must have a unique operationId
- Each function must have a description
- Each function must have a requestBody or parameters or both
- Each function must have a schema for the requestBody and/or parameters
See https://github.com/OAI/OpenAPI-Specification for more details on OpenAPI specification.
See https://platform.openai.com/docs/guides/function-calling for more details on OpenAI function calling.
"""
MIN_REQUIRED_OPENAPI_SPEC_VERSION = 3
def __init__(self):
"""
Initializes the OpenAPIServiceToFunctions instance
"""
openapi_imports.check()
@component.output_types(documents=List[Document])
def run(self, sources: List[Union[str, Path, ByteStream]], system_messages: List[str]) -> Dict[str, Any]:
"""
Processes OpenAPI specification URLs or files to extract functions that can be invoked via OpenAI function
calling mechanism. Each source is paired with a system message in one-to-one correspondence. The system message
is used to assist LLM in the response generation.
:param sources: A list of OpenAPI specification sources, which can be URLs, file paths, or ByteStream objects.
:type sources: List[Union[str, Path, ByteStream]]
:param system_messages: A list of system messages corresponding to each source.
:type system_messages: List[str]
:return: A dictionary with a key 'documents' containing a list of Document objects. Each Document object
encapsulates a function definition and relevant metadata.
:rtype: Dict[str, Any]
:raises RuntimeError: If the OpenAPI specification cannot be downloaded or processed.
:raises ValueError: If the source type is not recognized or no functions are found in the OpenAPI specification.
"""
documents: List[Document] = []
for source, system_message in zip(sources, system_messages):
openapi_spec_content = None
if isinstance(source, (str, Path)):
# check if the source is a file path or a URL
if os.path.exists(source):
openapi_spec_content = self._read_from_file(source)
else:
openapi_spec_content = self._read_from_url(str(source))
elif isinstance(source, ByteStream):
openapi_spec_content = source.data.decode("utf-8")
else:
logger.warning("Invalid source type %s. Only str, Path, and ByteStream are supported.", type(source))
continue
if openapi_spec_content:
try:
service_openapi_spec = self._parse_openapi_spec(openapi_spec_content)
functions: List[Dict[str, Any]] = self._openapi_to_functions(service_openapi_spec)
docs = [
Document(
content=json.dumps(function),
meta={"spec": service_openapi_spec, "system_message": system_message},
)
for function in functions
]
documents.extend(docs)
except Exception as e:
logger.error("Error processing OpenAPI specification from source %s: %s", source, e)
return {"documents": documents}
def _openapi_to_functions(self, service_openapi_spec: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Extracts functions from the OpenAPI specification of the service and converts them into a format
suitable for OpenAI function calling.
:param service_openapi_spec: The OpenAPI specification from which functions are to be extracted.
:type service_openapi_spec: Dict[str, Any]
:return: A list of dictionaries, each representing a function. Each dictionary includes the function's
name, description, and a schema of its parameters.
:rtype: List[Dict[str, Any]]
"""
# Doesn't enforce rigid spec validation because that would require a lot of dependencies
# We check the version and require minimal fields to be present, so we can extract functions
spec_version = service_openapi_spec.get("openapi")
if not spec_version:
raise ValueError(f"Invalid OpenAPI spec provided. Could not extract version from {service_openapi_spec}")
service_openapi_spec_version = int(spec_version.split(".")[0])
# Compare the versions
if service_openapi_spec_version < OpenAPIServiceToFunctions.MIN_REQUIRED_OPENAPI_SPEC_VERSION:
raise ValueError(
f"Invalid OpenAPI spec version {service_openapi_spec_version}. Must be "
f"at least {OpenAPIServiceToFunctions.MIN_REQUIRED_OPENAPI_SPEC_VERSION}."
)
functions: List[Dict[str, Any]] = []
for path_methods in service_openapi_spec["paths"].values():
for method_specification in path_methods.values():
resolved_spec = jsonref.replace_refs(method_specification)
if isinstance(resolved_spec, dict):
function_name = resolved_spec.get("operationId")
desc = resolved_spec.get("description") or resolved_spec.get("summary", "")
schema: Dict[str, Any] = {"type": "object", "properties": {}}
req_body = (
resolved_spec.get("requestBody", {})
.get("content", {})
.get("application/json", {})
.get("schema")
)
if req_body:
schema["properties"]["requestBody"] = req_body
params = resolved_spec.get("parameters", [])
if params:
param_properties = {param["name"]: param["schema"] for param in params if "schema" in param}
schema["properties"]["parameters"] = {"type": "object", "properties": param_properties}
# these three fields are minimal requirement for OpenAI function calling
if function_name and desc and schema:
functions.append({"name": function_name, "description": desc, "parameters": schema})
else:
logger.warning(
"Invalid OpenAPI spec format provided. Could not extract function from %s", resolved_spec
)
else:
logger.warning(
"Invalid OpenAPI spec format provided. Could not extract function from %s", resolved_spec
)
return functions
def _parse_openapi_spec(self, content: str) -> Dict[str, Any]:
"""
Parses OpenAPI specification content, supporting both JSON and YAML formats.
:param content: The content of the OpenAPI specification.
:type content: str
:return: The parsed OpenAPI specification.
:rtype: Dict[str, Any]
"""
try:
return json.loads(content)
except json.JSONDecodeError as json_error:
# heuristic to confirm that the content is likely malformed JSON
if content.strip().startswith(("{", "[")):
raise json_error
try:
return yaml.safe_load(content)
except yaml.YAMLError:
error_message = (
"Failed to parse the OpenAPI specification. "
"The content does not appear to be valid JSON or YAML.\n\n"
)
raise RuntimeError(error_message, content)
def _read_from_file(self, path: Union[str, Path]) -> Optional[str]:
"""
Reads the content of a file, given its path.
:param path: The path of the file.
:type path: Union[str, Path]
:return: The content of the file or None if the file cannot be read.
"""
try:
with open(path, "r") as f:
return f.read()
except IOError as e:
logger.warning("IO error reading file: %s. Error: %s", path, e)
return None
def _read_from_url(self, url: str) -> Optional[str]:
"""
Reads the content of a URL.
:param url: The URL to read.
:type url: str
:return: The content of the URL or None if the URL cannot be read.
"""
try:
response = requests.get(url, timeout=10)
response.raise_for_status()
return response.text
except RequestException as e:
logger.warning("Error fetching URL: %s. Error: %s", url, e)
return None

View File

@ -0,0 +1,4 @@
---
features:
- |
Adds RAG OpenAPI services integration.

View File

@ -0,0 +1,80 @@
import json
import pytest
from unittest.mock import MagicMock, Mock
from openapi3 import OpenAPI
from haystack.components.connectors import OpenAPIServiceConnector
@pytest.fixture
def openapi_service_mock():
return MagicMock(spec=OpenAPI)
@pytest.fixture
def service_auths():
return {"TestService": "auth_token"}
class TestOpenAPIServiceConnector:
@pytest.fixture
def connector(self, service_auths):
return OpenAPIServiceConnector(service_auths)
def test_parse_message_invalid_json(self, connector):
# Test invalid JSON content
with pytest.raises(ValueError):
connector._parse_message("invalid json")
def test_parse_valid_json_message(self):
connector = OpenAPIServiceConnector()
# The content format here is OpenAI function calling descriptor
content = (
'{"name": "compare_branches","arguments": "{\\n \\"parameters\\": {\\n '
' \\"basehead\\": \\"main...openapi_container_v5\\",\\n '
' \\"owner\\": \\"deepset-ai\\",\\n \\"repo\\": \\"haystack\\"\\n }\\n}"}'
)
descriptor = connector._parse_message(content)
# Assert that the descriptor contains the expected method name and arguments
assert descriptor["name"] == "compare_branches"
assert descriptor["arguments"]["parameters"] == {
"basehead": "main...openapi_container_v5",
"owner": "deepset-ai",
"repo": "haystack",
}
# but not the requestBody
assert "requestBody" not in descriptor["arguments"]
# The content format here is OpenAI function calling descriptor
content = '{"name": "search","arguments": "{\\n \\"requestBody\\": {\\n \\"q\\": \\"haystack\\"\\n }\\n}"}'
descriptor = connector._parse_message(content)
assert descriptor["name"] == "search"
assert descriptor["arguments"]["requestBody"] == {"q": "haystack"}
# but not the parameters
assert "parameters" not in descriptor["arguments"]
def test_parse_message_missing_fields(self, connector):
# Test JSON content with missing fields
with pytest.raises(ValueError):
connector._parse_message(json.dumps({"name": "test_method"}))
def test_authenticate_service_invalid(self, connector, openapi_service_mock):
# Test invalid or missing authentication
openapi_service_mock.components.securitySchemes = {"apiKey": {}}
with pytest.raises(ValueError):
connector._authenticate_service(openapi_service_mock)
def test_invoke_method_valid(self, connector, openapi_service_mock):
# Test valid method invocation
method_invocation_descriptor = {"name": "test_method", "arguments": {}}
openapi_service_mock.call_test_method = Mock(return_value="response")
result = connector._invoke_method(openapi_service_mock, method_invocation_descriptor)
assert result == "response"
def test_invoke_method_invalid(self, connector, openapi_service_mock):
# Test invalid method invocation
method_invocation_descriptor = {"name": "invalid_method", "arguments": {}}
with pytest.raises(RuntimeError):
connector._invoke_method(openapi_service_mock, method_invocation_descriptor)

View File

@ -0,0 +1,221 @@
import json
import sys
import tempfile
import pytest
from haystack.components.converters import OpenAPIServiceToFunctions
from haystack.dataclasses import ByteStream
@pytest.fixture
def json_serperdev_openapi_spec():
serper_spec = """
{
"openapi": "3.0.0",
"info": {
"title": "SerperDev",
"version": "1.0.0",
"description": "API for performing search queries"
},
"servers": [
{
"url": "https://google.serper.dev"
}
],
"paths": {
"/search": {
"post": {
"operationId": "search",
"description": "Search the web with Google",
"requestBody": {
"required": true,
"content": {
"application/json": {
"schema": {
"type": "object",
"properties": {
"q": {
"type": "string"
}
}
}
}
}
},
"responses": {
"200": {
"description": "Successful response",
"content": {
"application/json": {
"schema": {
"type": "object",
"properties": {
"searchParameters": {
"type": "undefined"
},
"knowledgeGraph": {
"type": "undefined"
},
"answerBox": {
"type": "undefined"
},
"organic": {
"type": "undefined"
},
"topStories": {
"type": "undefined"
},
"peopleAlsoAsk": {
"type": "undefined"
},
"relatedSearches": {
"type": "undefined"
}
}
}
}
}
}
},
"security": [
{
"apikey": []
}
]
}
}
},
"components": {
"securitySchemes": {
"apikey": {
"type": "apiKey",
"name": "x-api-key",
"in": "header"
}
}
}
}
"""
return serper_spec
@pytest.fixture
def yaml_serperdev_openapi_spec():
serper_spec = """
openapi: 3.0.0
info:
title: SerperDev
version: 1.0.0
description: API for performing search queries
servers:
- url: 'https://google.serper.dev'
paths:
/search:
post:
operationId: search
description: Search the web with Google
requestBody:
required: true
content:
application/json:
schema:
type: object
properties:
q:
type: string
responses:
'200':
description: Successful response
content:
application/json:
schema:
type: object
properties:
searchParameters:
type: undefined
knowledgeGraph:
type: undefined
answerBox:
type: undefined
organic:
type: undefined
topStories:
type: undefined
peopleAlsoAsk:
type: undefined
relatedSearches:
type: undefined
security:
- apikey: []
components:
securitySchemes:
apikey:
type: apiKey
name: x-api-key
in: header
"""
return serper_spec
class TestOpenAPIServiceToFunctions:
# test we can parse openapi spec given in json
def test_openapi_spec_parsing_json(self, json_serperdev_openapi_spec):
service = OpenAPIServiceToFunctions()
serper_spec_json = service._parse_openapi_spec(json_serperdev_openapi_spec)
assert serper_spec_json["openapi"] == "3.0.0"
assert serper_spec_json["info"]["title"] == "SerperDev"
# test we can parse openapi spec given in yaml
def test_openapi_spec_parsing_yaml(self, yaml_serperdev_openapi_spec):
service = OpenAPIServiceToFunctions()
serper_spec_yaml = service._parse_openapi_spec(yaml_serperdev_openapi_spec)
assert serper_spec_yaml["openapi"] == "3.0.0"
assert serper_spec_yaml["info"]["title"] == "SerperDev"
# test we can extract functions from openapi spec given
def test_run_with_bytestream_source(self, json_serperdev_openapi_spec):
service = OpenAPIServiceToFunctions()
spec_stream = ByteStream.from_string(json_serperdev_openapi_spec)
result = service.run(sources=[spec_stream], system_messages=["Some system message we don't care about here"])
assert len(result["documents"]) == 1
doc = result["documents"][0]
# check that the content is as expected
assert (
doc.content
== '{"name": "search", "description": "Search the web with Google", "parameters": {"type": "object", '
'"properties": {"requestBody": {"type": "object", "properties": {"q": {"type": "string"}}}}}}'
)
# check that the metadata is as expected
assert doc.meta["system_message"] == "Some system message we don't care about here"
assert doc.meta["spec"] == json.loads(json_serperdev_openapi_spec)
@pytest.mark.skipif(
sys.platform in ["win32", "cygwin"],
reason="Can't run on Windows Github CI, need access temp file but windows does not allow it",
)
def test_run_with_file_source(self, json_serperdev_openapi_spec):
# test we can extract functions from openapi spec given in file
service = OpenAPIServiceToFunctions()
# write the spec to NamedTemporaryFile and check that it is parsed correctly
with tempfile.NamedTemporaryFile() as tmp:
tmp.write(json_serperdev_openapi_spec.encode("utf-8"))
tmp.seek(0)
result = service.run(sources=[tmp.name], system_messages=["Some system message we don't care about here"])
assert len(result["documents"]) == 1
doc = result["documents"][0]
# check that the content is as expected
assert (
doc.content
== '{"name": "search", "description": "Search the web with Google", "parameters": {"type": "object", '
'"properties": {"requestBody": {"type": "object", "properties": {"q": {"type": "string"}}}}}}'
)
# check that the metadata is as expected
assert doc.meta["system_message"] == "Some system message we don't care about here"
assert doc.meta["spec"] == json.loads(json_serperdev_openapi_spec)

View File

@ -14,3 +14,7 @@ azure-ai-formrecognizer>=3.2.0b2 # AzureOCRDocumentConverter
langdetect # TextLanguageRouter and DocumentLanguageClassifier
sentence-transformers>=2.2.0 # SentenceTransformersTextEmbedder and SentenceTransformersDocumentEmbedder
openai-whisper>=20231106 # LocalWhisperTranscriber
# OpenAPI
jsonref # OpenAPIServiceConnector, OpenAPIServiceToFunctions
openapi3