diff --git a/haystack/components/connectors/__init__.py b/haystack/components/connectors/__init__.py new file mode 100644 index 000000000..77789e08c --- /dev/null +++ b/haystack/components/connectors/__init__.py @@ -0,0 +1,3 @@ +from haystack.components.connectors.openapi_service import OpenAPIServiceConnector + +__all__ = ["OpenAPIServiceConnector"] diff --git a/haystack/components/connectors/openapi_service.py b/haystack/components/connectors/openapi_service.py new file mode 100644 index 000000000..f56ffa3bf --- /dev/null +++ b/haystack/components/connectors/openapi_service.py @@ -0,0 +1,128 @@ +import json +import logging +from typing import List, Dict, Any, Optional + +from haystack import component +from haystack.dataclasses import ChatMessage, ChatRole +from haystack.lazy_imports import LazyImport + +logger = logging.getLogger(__name__) + +with LazyImport("Run 'pip install openapi3'") as openapi_imports: + from openapi3 import OpenAPI + + +@component +class OpenAPIServiceConnector: + """ + OpenAPIServiceConnector connects to OpenAPI services, allowing for the invocation of methods specified in + an OpenAPI specification of that service. It integrates with ChatMessage interface, where messages are used to + determine the method to be called and the parameters to be passed. The message payload should be a JSON formatted + string consisting of the method name and the parameters to be passed to the method. The method name and parameters + are then used to invoke the method on the OpenAPI service. The response from the service is returned as a + ChatMessage. + + Before using this component, one needs to register functions from the OpenAPI specification with LLM. + This can be done using the OpenAPIServiceToFunctions component. + """ + + def __init__(self, service_auths: Optional[Dict[str, Any]] = None): + """ + Initializes the OpenAPIServiceConnector instance + :param service_auths: A dictionary containing the service name and token to be used for authentication. + """ + openapi_imports.check() + self.service_authentications = service_auths or {} + + @component.output_types(service_response=Dict[str, Any]) + def run(self, messages: List[ChatMessage], service_openapi_spec: Dict[str, Any]) -> Dict[str, List[ChatMessage]]: + """ + Processes a list of chat messages to invoke a method on an OpenAPI service. It parses the last message in the + list, expecting it to contain an OpenAI function calling descriptor (name & parameters) in JSON format. + + :param messages: A list of `ChatMessage` objects representing the chat history. + :type messages: List[ChatMessage] + :param service_openapi_spec: The OpenAPI JSON specification object of the service. + :type service_openapi_spec: JSON object + :return: A dictionary with a key `"service_response"`, containing the response from the OpenAPI service. + :rtype: Dict[str, List[ChatMessage]] + :raises ValueError: If the last message is not from the assistant or if it does not contain the correct payload + to invoke a method on the service. + """ + + last_message = messages[-1] + if not last_message.is_from(ChatRole.ASSISTANT): + raise ValueError(f"{last_message} is not from the assistant.") + + method_invocation_descriptor = self._parse_message(last_message.content) + + # instantiate the OpenAPI service for the given specification + openapi_service = OpenAPI(service_openapi_spec) + self._authenticate_service(openapi_service) + + service_response = self._invoke_method(openapi_service, method_invocation_descriptor) + return {"service_response": [ChatMessage.from_user(str(service_response))]} + + def _parse_message(self, content: str) -> Dict[str, Any]: + """ + Parses the message content to extract the method invocation descriptor. + + :param content: The JSON string content of the message. + :type content: str + :return: A dictionary with method name and arguments. + :rtype: Dict[str, Any] + :raises ValueError: If the content is not valid JSON or lacks required fields. + """ + try: + method_invocation_descriptor = json.loads(content) + except json.JSONDecodeError: + raise ValueError("Invalid JSON content, cannot parse invocation message.", content) + + if "name" not in method_invocation_descriptor or "arguments" not in method_invocation_descriptor: + raise ValueError("Missing required fields in the invocation message content.", content) + + method_invocation_descriptor["arguments"] = json.loads(method_invocation_descriptor["arguments"]) + return method_invocation_descriptor + + def _authenticate_service(self, openapi_service: OpenAPI): + """ + Authenticates with the OpenAPI service if required. + + :param openapi_service: The OpenAPI service instance. + :type openapi_service: OpenAPI + :raises ValueError: If authentication fails or is not found. + """ + if openapi_service.components.securitySchemes: + auth_method = list(openapi_service.components.securitySchemes.keys())[0] + service_title = openapi_service.info.title + if service_title not in self.service_authentications: + raise ValueError(f"Service {service_title} not found in service_authentications.") + openapi_service.authenticate(auth_method, self.service_authentications[service_title]) + + def _invoke_method(self, openapi_service: OpenAPI, method_invocation_descriptor: Dict[str, Any]) -> Any: + """ + Invokes the specified method on the OpenAPI service. + + :param openapi_service: The OpenAPI service instance. + :type openapi_service: OpenAPI + :param method_invocation_descriptor: The method name and arguments. + :type method_invocation_descriptor: Dict[str, Any] + :return: A service JSON response. + :rtype: Any + :raises RuntimeError: If the method is not found or invocation fails. + """ + name = method_invocation_descriptor["name"] + # a bit convoluted, but we need to pass parameters, data, or both to the method + # depending on the openapi operation specification, can't use None as a default value + method_call_params = {} + if (parameters := method_invocation_descriptor["arguments"].get("parameters")) is not None: + method_call_params["parameters"] = parameters + if (arguments := method_invocation_descriptor["arguments"].get("requestBody")) is not None: + method_call_params["data"] = arguments + + method_to_call = getattr(openapi_service, f"call_{name}", None) + if not callable(method_to_call): + raise RuntimeError(f"Operation {name} not found in OpenAPI specification {openapi_service.info.title}") + + # this will call the underlying service REST API + return method_to_call(**method_call_params) diff --git a/haystack/components/converters/__init__.py b/haystack/components/converters/__init__.py index 0bc40cc05..24bbdf9c3 100644 --- a/haystack/components/converters/__init__.py +++ b/haystack/components/converters/__init__.py @@ -4,6 +4,7 @@ from haystack.components.converters.azure import AzureOCRDocumentConverter from haystack.components.converters.pypdf import PyPDFToDocument from haystack.components.converters.html import HTMLToDocument from haystack.components.converters.markdown import MarkdownToDocument +from haystack.components.converters.openapi_functions import OpenAPIServiceToFunctions __all__ = [ "TextFileToDocument", @@ -12,4 +13,5 @@ __all__ = [ "PyPDFToDocument", "HTMLToDocument", "MarkdownToDocument", + "OpenAPIServiceToFunctions", ] diff --git a/haystack/components/converters/openapi_functions.py b/haystack/components/converters/openapi_functions.py new file mode 100644 index 000000000..b35386003 --- /dev/null +++ b/haystack/components/converters/openapi_functions.py @@ -0,0 +1,216 @@ +import json +import logging +import os +from pathlib import Path +from typing import List, Dict, Any, Union, Optional + +import requests +import yaml +from requests import RequestException + +from haystack import component, Document +from haystack.dataclasses.byte_stream import ByteStream +from haystack.lazy_imports import LazyImport + +logger = logging.getLogger(__name__) + +with LazyImport("Run 'pip install jsonref'") as openapi_imports: + import jsonref + + +@component +class OpenAPIServiceToFunctions: + """ + OpenAPIServiceToFunctions is responsible for converting an OpenAPI service specification into a format suitable + for OpenAI function calling, based on the provided OpenAPI specification. Given an OpenAPI specification, + OpenAPIServiceToFunctions processes it, and extracts function definitions that can be invoked via OpenAI's + function calling mechanism. The format of the extracted functions is compatible with OpenAI's function calling + JSON format. + + Minimal requirements for OpenAPI specification: + - OpenAPI version 3.0.0 or higher + - Each function must have a unique operationId + - Each function must have a description + - Each function must have a requestBody or parameters or both + - Each function must have a schema for the requestBody and/or parameters + + + See https://github.com/OAI/OpenAPI-Specification for more details on OpenAPI specification. + See https://platform.openai.com/docs/guides/function-calling for more details on OpenAI function calling. + """ + + MIN_REQUIRED_OPENAPI_SPEC_VERSION = 3 + + def __init__(self): + """ + Initializes the OpenAPIServiceToFunctions instance + """ + openapi_imports.check() + + @component.output_types(documents=List[Document]) + def run(self, sources: List[Union[str, Path, ByteStream]], system_messages: List[str]) -> Dict[str, Any]: + """ + Processes OpenAPI specification URLs or files to extract functions that can be invoked via OpenAI function + calling mechanism. Each source is paired with a system message in one-to-one correspondence. The system message + is used to assist LLM in the response generation. + + :param sources: A list of OpenAPI specification sources, which can be URLs, file paths, or ByteStream objects. + :type sources: List[Union[str, Path, ByteStream]] + :param system_messages: A list of system messages corresponding to each source. + :type system_messages: List[str] + :return: A dictionary with a key 'documents' containing a list of Document objects. Each Document object + encapsulates a function definition and relevant metadata. + :rtype: Dict[str, Any] + :raises RuntimeError: If the OpenAPI specification cannot be downloaded or processed. + :raises ValueError: If the source type is not recognized or no functions are found in the OpenAPI specification. + """ + documents: List[Document] = [] + for source, system_message in zip(sources, system_messages): + openapi_spec_content = None + if isinstance(source, (str, Path)): + # check if the source is a file path or a URL + if os.path.exists(source): + openapi_spec_content = self._read_from_file(source) + else: + openapi_spec_content = self._read_from_url(str(source)) + elif isinstance(source, ByteStream): + openapi_spec_content = source.data.decode("utf-8") + else: + logger.warning("Invalid source type %s. Only str, Path, and ByteStream are supported.", type(source)) + continue + + if openapi_spec_content: + try: + service_openapi_spec = self._parse_openapi_spec(openapi_spec_content) + functions: List[Dict[str, Any]] = self._openapi_to_functions(service_openapi_spec) + docs = [ + Document( + content=json.dumps(function), + meta={"spec": service_openapi_spec, "system_message": system_message}, + ) + for function in functions + ] + documents.extend(docs) + except Exception as e: + logger.error("Error processing OpenAPI specification from source %s: %s", source, e) + + return {"documents": documents} + + def _openapi_to_functions(self, service_openapi_spec: Dict[str, Any]) -> List[Dict[str, Any]]: + """ + Extracts functions from the OpenAPI specification of the service and converts them into a format + suitable for OpenAI function calling. + + :param service_openapi_spec: The OpenAPI specification from which functions are to be extracted. + :type service_openapi_spec: Dict[str, Any] + :return: A list of dictionaries, each representing a function. Each dictionary includes the function's + name, description, and a schema of its parameters. + :rtype: List[Dict[str, Any]] + """ + + # Doesn't enforce rigid spec validation because that would require a lot of dependencies + # We check the version and require minimal fields to be present, so we can extract functions + spec_version = service_openapi_spec.get("openapi") + if not spec_version: + raise ValueError(f"Invalid OpenAPI spec provided. Could not extract version from {service_openapi_spec}") + service_openapi_spec_version = int(spec_version.split(".")[0]) + + # Compare the versions + if service_openapi_spec_version < OpenAPIServiceToFunctions.MIN_REQUIRED_OPENAPI_SPEC_VERSION: + raise ValueError( + f"Invalid OpenAPI spec version {service_openapi_spec_version}. Must be " + f"at least {OpenAPIServiceToFunctions.MIN_REQUIRED_OPENAPI_SPEC_VERSION}." + ) + + functions: List[Dict[str, Any]] = [] + for path_methods in service_openapi_spec["paths"].values(): + for method_specification in path_methods.values(): + resolved_spec = jsonref.replace_refs(method_specification) + if isinstance(resolved_spec, dict): + function_name = resolved_spec.get("operationId") + desc = resolved_spec.get("description") or resolved_spec.get("summary", "") + + schema: Dict[str, Any] = {"type": "object", "properties": {}} + + req_body = ( + resolved_spec.get("requestBody", {}) + .get("content", {}) + .get("application/json", {}) + .get("schema") + ) + if req_body: + schema["properties"]["requestBody"] = req_body + + params = resolved_spec.get("parameters", []) + if params: + param_properties = {param["name"]: param["schema"] for param in params if "schema" in param} + schema["properties"]["parameters"] = {"type": "object", "properties": param_properties} + + # these three fields are minimal requirement for OpenAI function calling + if function_name and desc and schema: + functions.append({"name": function_name, "description": desc, "parameters": schema}) + else: + logger.warning( + "Invalid OpenAPI spec format provided. Could not extract function from %s", resolved_spec + ) + + else: + logger.warning( + "Invalid OpenAPI spec format provided. Could not extract function from %s", resolved_spec + ) + + return functions + + def _parse_openapi_spec(self, content: str) -> Dict[str, Any]: + """ + Parses OpenAPI specification content, supporting both JSON and YAML formats. + + :param content: The content of the OpenAPI specification. + :type content: str + :return: The parsed OpenAPI specification. + :rtype: Dict[str, Any] + """ + try: + return json.loads(content) + except json.JSONDecodeError as json_error: + # heuristic to confirm that the content is likely malformed JSON + if content.strip().startswith(("{", "[")): + raise json_error + + try: + return yaml.safe_load(content) + except yaml.YAMLError: + error_message = ( + "Failed to parse the OpenAPI specification. " + "The content does not appear to be valid JSON or YAML.\n\n" + ) + raise RuntimeError(error_message, content) + + def _read_from_file(self, path: Union[str, Path]) -> Optional[str]: + """ + Reads the content of a file, given its path. + :param path: The path of the file. + :type path: Union[str, Path] + :return: The content of the file or None if the file cannot be read. + """ + try: + with open(path, "r") as f: + return f.read() + except IOError as e: + logger.warning("IO error reading file: %s. Error: %s", path, e) + return None + + def _read_from_url(self, url: str) -> Optional[str]: + """ + Reads the content of a URL. + :param url: The URL to read. + :type url: str + :return: The content of the URL or None if the URL cannot be read. + """ + try: + response = requests.get(url, timeout=10) + response.raise_for_status() + return response.text + except RequestException as e: + logger.warning("Error fetching URL: %s. Error: %s", url, e) + return None diff --git a/releasenotes/notes/add-rag-openapi-services-f3e377c49ff0f258.yaml b/releasenotes/notes/add-rag-openapi-services-f3e377c49ff0f258.yaml new file mode 100644 index 000000000..dab28853a --- /dev/null +++ b/releasenotes/notes/add-rag-openapi-services-f3e377c49ff0f258.yaml @@ -0,0 +1,4 @@ +--- +features: + - | + Adds RAG OpenAPI services integration. diff --git a/test/components/connectors/test_openapi_service.py b/test/components/connectors/test_openapi_service.py new file mode 100644 index 000000000..ed8b15de7 --- /dev/null +++ b/test/components/connectors/test_openapi_service.py @@ -0,0 +1,80 @@ +import json +import pytest +from unittest.mock import MagicMock, Mock +from openapi3 import OpenAPI +from haystack.components.connectors import OpenAPIServiceConnector + + +@pytest.fixture +def openapi_service_mock(): + return MagicMock(spec=OpenAPI) + + +@pytest.fixture +def service_auths(): + return {"TestService": "auth_token"} + + +class TestOpenAPIServiceConnector: + @pytest.fixture + def connector(self, service_auths): + return OpenAPIServiceConnector(service_auths) + + def test_parse_message_invalid_json(self, connector): + # Test invalid JSON content + with pytest.raises(ValueError): + connector._parse_message("invalid json") + + def test_parse_valid_json_message(self): + connector = OpenAPIServiceConnector() + + # The content format here is OpenAI function calling descriptor + content = ( + '{"name": "compare_branches","arguments": "{\\n \\"parameters\\": {\\n ' + ' \\"basehead\\": \\"main...openapi_container_v5\\",\\n ' + ' \\"owner\\": \\"deepset-ai\\",\\n \\"repo\\": \\"haystack\\"\\n }\\n}"}' + ) + descriptor = connector._parse_message(content) + + # Assert that the descriptor contains the expected method name and arguments + assert descriptor["name"] == "compare_branches" + assert descriptor["arguments"]["parameters"] == { + "basehead": "main...openapi_container_v5", + "owner": "deepset-ai", + "repo": "haystack", + } + # but not the requestBody + assert "requestBody" not in descriptor["arguments"] + + # The content format here is OpenAI function calling descriptor + content = '{"name": "search","arguments": "{\\n \\"requestBody\\": {\\n \\"q\\": \\"haystack\\"\\n }\\n}"}' + descriptor = connector._parse_message(content) + assert descriptor["name"] == "search" + assert descriptor["arguments"]["requestBody"] == {"q": "haystack"} + + # but not the parameters + assert "parameters" not in descriptor["arguments"] + + def test_parse_message_missing_fields(self, connector): + # Test JSON content with missing fields + with pytest.raises(ValueError): + connector._parse_message(json.dumps({"name": "test_method"})) + + def test_authenticate_service_invalid(self, connector, openapi_service_mock): + # Test invalid or missing authentication + openapi_service_mock.components.securitySchemes = {"apiKey": {}} + with pytest.raises(ValueError): + connector._authenticate_service(openapi_service_mock) + + def test_invoke_method_valid(self, connector, openapi_service_mock): + # Test valid method invocation + method_invocation_descriptor = {"name": "test_method", "arguments": {}} + openapi_service_mock.call_test_method = Mock(return_value="response") + result = connector._invoke_method(openapi_service_mock, method_invocation_descriptor) + assert result == "response" + + def test_invoke_method_invalid(self, connector, openapi_service_mock): + # Test invalid method invocation + method_invocation_descriptor = {"name": "invalid_method", "arguments": {}} + with pytest.raises(RuntimeError): + connector._invoke_method(openapi_service_mock, method_invocation_descriptor) diff --git a/test/components/converters/test_openapi_functions.py b/test/components/converters/test_openapi_functions.py new file mode 100644 index 000000000..79e202ee6 --- /dev/null +++ b/test/components/converters/test_openapi_functions.py @@ -0,0 +1,221 @@ +import json +import sys +import tempfile + +import pytest + +from haystack.components.converters import OpenAPIServiceToFunctions +from haystack.dataclasses import ByteStream + + +@pytest.fixture +def json_serperdev_openapi_spec(): + serper_spec = """ + { + "openapi": "3.0.0", + "info": { + "title": "SerperDev", + "version": "1.0.0", + "description": "API for performing search queries" + }, + "servers": [ + { + "url": "https://google.serper.dev" + } + ], + "paths": { + "/search": { + "post": { + "operationId": "search", + "description": "Search the web with Google", + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "q": { + "type": "string" + } + } + } + } + } + }, + "responses": { + "200": { + "description": "Successful response", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "searchParameters": { + "type": "undefined" + }, + "knowledgeGraph": { + "type": "undefined" + }, + "answerBox": { + "type": "undefined" + }, + "organic": { + "type": "undefined" + }, + "topStories": { + "type": "undefined" + }, + "peopleAlsoAsk": { + "type": "undefined" + }, + "relatedSearches": { + "type": "undefined" + } + } + } + } + } + } + }, + "security": [ + { + "apikey": [] + } + ] + } + } + }, + "components": { + "securitySchemes": { + "apikey": { + "type": "apiKey", + "name": "x-api-key", + "in": "header" + } + } + } + } + """ + return serper_spec + + +@pytest.fixture +def yaml_serperdev_openapi_spec(): + serper_spec = """ + openapi: 3.0.0 + info: + title: SerperDev + version: 1.0.0 + description: API for performing search queries + servers: + - url: 'https://google.serper.dev' + paths: + /search: + post: + operationId: search + description: Search the web with Google + requestBody: + required: true + content: + application/json: + schema: + type: object + properties: + q: + type: string + responses: + '200': + description: Successful response + content: + application/json: + schema: + type: object + properties: + searchParameters: + type: undefined + knowledgeGraph: + type: undefined + answerBox: + type: undefined + organic: + type: undefined + topStories: + type: undefined + peopleAlsoAsk: + type: undefined + relatedSearches: + type: undefined + security: + - apikey: [] + components: + securitySchemes: + apikey: + type: apiKey + name: x-api-key + in: header + """ + return serper_spec + + +class TestOpenAPIServiceToFunctions: + # test we can parse openapi spec given in json + def test_openapi_spec_parsing_json(self, json_serperdev_openapi_spec): + service = OpenAPIServiceToFunctions() + + serper_spec_json = service._parse_openapi_spec(json_serperdev_openapi_spec) + assert serper_spec_json["openapi"] == "3.0.0" + assert serper_spec_json["info"]["title"] == "SerperDev" + + # test we can parse openapi spec given in yaml + def test_openapi_spec_parsing_yaml(self, yaml_serperdev_openapi_spec): + service = OpenAPIServiceToFunctions() + + serper_spec_yaml = service._parse_openapi_spec(yaml_serperdev_openapi_spec) + assert serper_spec_yaml["openapi"] == "3.0.0" + assert serper_spec_yaml["info"]["title"] == "SerperDev" + + # test we can extract functions from openapi spec given + def test_run_with_bytestream_source(self, json_serperdev_openapi_spec): + service = OpenAPIServiceToFunctions() + spec_stream = ByteStream.from_string(json_serperdev_openapi_spec) + result = service.run(sources=[spec_stream], system_messages=["Some system message we don't care about here"]) + assert len(result["documents"]) == 1 + doc = result["documents"][0] + + # check that the content is as expected + assert ( + doc.content + == '{"name": "search", "description": "Search the web with Google", "parameters": {"type": "object", ' + '"properties": {"requestBody": {"type": "object", "properties": {"q": {"type": "string"}}}}}}' + ) + + # check that the metadata is as expected + assert doc.meta["system_message"] == "Some system message we don't care about here" + assert doc.meta["spec"] == json.loads(json_serperdev_openapi_spec) + + @pytest.mark.skipif( + sys.platform in ["win32", "cygwin"], + reason="Can't run on Windows Github CI, need access temp file but windows does not allow it", + ) + def test_run_with_file_source(self, json_serperdev_openapi_spec): + # test we can extract functions from openapi spec given in file + service = OpenAPIServiceToFunctions() + # write the spec to NamedTemporaryFile and check that it is parsed correctly + with tempfile.NamedTemporaryFile() as tmp: + tmp.write(json_serperdev_openapi_spec.encode("utf-8")) + tmp.seek(0) + result = service.run(sources=[tmp.name], system_messages=["Some system message we don't care about here"]) + assert len(result["documents"]) == 1 + doc = result["documents"][0] + + # check that the content is as expected + assert ( + doc.content + == '{"name": "search", "description": "Search the web with Google", "parameters": {"type": "object", ' + '"properties": {"requestBody": {"type": "object", "properties": {"q": {"type": "string"}}}}}}' + ) + + # check that the metadata is as expected + assert doc.meta["system_message"] == "Some system message we don't care about here" + assert doc.meta["spec"] == json.loads(json_serperdev_openapi_spec) diff --git a/test/test_requirements.txt b/test/test_requirements.txt index 6ba925805..42fc21b8e 100644 --- a/test/test_requirements.txt +++ b/test/test_requirements.txt @@ -14,3 +14,7 @@ azure-ai-formrecognizer>=3.2.0b2 # AzureOCRDocumentConverter langdetect # TextLanguageRouter and DocumentLanguageClassifier sentence-transformers>=2.2.0 # SentenceTransformersTextEmbedder and SentenceTransformersDocumentEmbedder openai-whisper>=20231106 # LocalWhisperTranscriber + +# OpenAPI +jsonref # OpenAPIServiceConnector, OpenAPIServiceToFunctions +openapi3