mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-07 04:27:15 +00:00
feat: Add RAG based OpenAPI service integration (#6555)
* Add OpenAPIServiceConnector and OpenAPIServiceToFunctions * Add release note * Add test deps * Better docs on OpenAPI spec reqs, improve tests * Silvano PR feedback
This commit is contained in:
parent
94cfe5d9ae
commit
2dd5a94b04
3
haystack/components/connectors/__init__.py
Normal file
3
haystack/components/connectors/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from haystack.components.connectors.openapi_service import OpenAPIServiceConnector
|
||||
|
||||
__all__ = ["OpenAPIServiceConnector"]
|
||||
128
haystack/components/connectors/openapi_service.py
Normal file
128
haystack/components/connectors/openapi_service.py
Normal file
@ -0,0 +1,128 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
from haystack import component
|
||||
from haystack.dataclasses import ChatMessage, ChatRole
|
||||
from haystack.lazy_imports import LazyImport
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
with LazyImport("Run 'pip install openapi3'") as openapi_imports:
|
||||
from openapi3 import OpenAPI
|
||||
|
||||
|
||||
@component
|
||||
class OpenAPIServiceConnector:
|
||||
"""
|
||||
OpenAPIServiceConnector connects to OpenAPI services, allowing for the invocation of methods specified in
|
||||
an OpenAPI specification of that service. It integrates with ChatMessage interface, where messages are used to
|
||||
determine the method to be called and the parameters to be passed. The message payload should be a JSON formatted
|
||||
string consisting of the method name and the parameters to be passed to the method. The method name and parameters
|
||||
are then used to invoke the method on the OpenAPI service. The response from the service is returned as a
|
||||
ChatMessage.
|
||||
|
||||
Before using this component, one needs to register functions from the OpenAPI specification with LLM.
|
||||
This can be done using the OpenAPIServiceToFunctions component.
|
||||
"""
|
||||
|
||||
def __init__(self, service_auths: Optional[Dict[str, Any]] = None):
|
||||
"""
|
||||
Initializes the OpenAPIServiceConnector instance
|
||||
:param service_auths: A dictionary containing the service name and token to be used for authentication.
|
||||
"""
|
||||
openapi_imports.check()
|
||||
self.service_authentications = service_auths or {}
|
||||
|
||||
@component.output_types(service_response=Dict[str, Any])
|
||||
def run(self, messages: List[ChatMessage], service_openapi_spec: Dict[str, Any]) -> Dict[str, List[ChatMessage]]:
|
||||
"""
|
||||
Processes a list of chat messages to invoke a method on an OpenAPI service. It parses the last message in the
|
||||
list, expecting it to contain an OpenAI function calling descriptor (name & parameters) in JSON format.
|
||||
|
||||
:param messages: A list of `ChatMessage` objects representing the chat history.
|
||||
:type messages: List[ChatMessage]
|
||||
:param service_openapi_spec: The OpenAPI JSON specification object of the service.
|
||||
:type service_openapi_spec: JSON object
|
||||
:return: A dictionary with a key `"service_response"`, containing the response from the OpenAPI service.
|
||||
:rtype: Dict[str, List[ChatMessage]]
|
||||
:raises ValueError: If the last message is not from the assistant or if it does not contain the correct payload
|
||||
to invoke a method on the service.
|
||||
"""
|
||||
|
||||
last_message = messages[-1]
|
||||
if not last_message.is_from(ChatRole.ASSISTANT):
|
||||
raise ValueError(f"{last_message} is not from the assistant.")
|
||||
|
||||
method_invocation_descriptor = self._parse_message(last_message.content)
|
||||
|
||||
# instantiate the OpenAPI service for the given specification
|
||||
openapi_service = OpenAPI(service_openapi_spec)
|
||||
self._authenticate_service(openapi_service)
|
||||
|
||||
service_response = self._invoke_method(openapi_service, method_invocation_descriptor)
|
||||
return {"service_response": [ChatMessage.from_user(str(service_response))]}
|
||||
|
||||
def _parse_message(self, content: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Parses the message content to extract the method invocation descriptor.
|
||||
|
||||
:param content: The JSON string content of the message.
|
||||
:type content: str
|
||||
:return: A dictionary with method name and arguments.
|
||||
:rtype: Dict[str, Any]
|
||||
:raises ValueError: If the content is not valid JSON or lacks required fields.
|
||||
"""
|
||||
try:
|
||||
method_invocation_descriptor = json.loads(content)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError("Invalid JSON content, cannot parse invocation message.", content)
|
||||
|
||||
if "name" not in method_invocation_descriptor or "arguments" not in method_invocation_descriptor:
|
||||
raise ValueError("Missing required fields in the invocation message content.", content)
|
||||
|
||||
method_invocation_descriptor["arguments"] = json.loads(method_invocation_descriptor["arguments"])
|
||||
return method_invocation_descriptor
|
||||
|
||||
def _authenticate_service(self, openapi_service: OpenAPI):
|
||||
"""
|
||||
Authenticates with the OpenAPI service if required.
|
||||
|
||||
:param openapi_service: The OpenAPI service instance.
|
||||
:type openapi_service: OpenAPI
|
||||
:raises ValueError: If authentication fails or is not found.
|
||||
"""
|
||||
if openapi_service.components.securitySchemes:
|
||||
auth_method = list(openapi_service.components.securitySchemes.keys())[0]
|
||||
service_title = openapi_service.info.title
|
||||
if service_title not in self.service_authentications:
|
||||
raise ValueError(f"Service {service_title} not found in service_authentications.")
|
||||
openapi_service.authenticate(auth_method, self.service_authentications[service_title])
|
||||
|
||||
def _invoke_method(self, openapi_service: OpenAPI, method_invocation_descriptor: Dict[str, Any]) -> Any:
|
||||
"""
|
||||
Invokes the specified method on the OpenAPI service.
|
||||
|
||||
:param openapi_service: The OpenAPI service instance.
|
||||
:type openapi_service: OpenAPI
|
||||
:param method_invocation_descriptor: The method name and arguments.
|
||||
:type method_invocation_descriptor: Dict[str, Any]
|
||||
:return: A service JSON response.
|
||||
:rtype: Any
|
||||
:raises RuntimeError: If the method is not found or invocation fails.
|
||||
"""
|
||||
name = method_invocation_descriptor["name"]
|
||||
# a bit convoluted, but we need to pass parameters, data, or both to the method
|
||||
# depending on the openapi operation specification, can't use None as a default value
|
||||
method_call_params = {}
|
||||
if (parameters := method_invocation_descriptor["arguments"].get("parameters")) is not None:
|
||||
method_call_params["parameters"] = parameters
|
||||
if (arguments := method_invocation_descriptor["arguments"].get("requestBody")) is not None:
|
||||
method_call_params["data"] = arguments
|
||||
|
||||
method_to_call = getattr(openapi_service, f"call_{name}", None)
|
||||
if not callable(method_to_call):
|
||||
raise RuntimeError(f"Operation {name} not found in OpenAPI specification {openapi_service.info.title}")
|
||||
|
||||
# this will call the underlying service REST API
|
||||
return method_to_call(**method_call_params)
|
||||
@ -4,6 +4,7 @@ from haystack.components.converters.azure import AzureOCRDocumentConverter
|
||||
from haystack.components.converters.pypdf import PyPDFToDocument
|
||||
from haystack.components.converters.html import HTMLToDocument
|
||||
from haystack.components.converters.markdown import MarkdownToDocument
|
||||
from haystack.components.converters.openapi_functions import OpenAPIServiceToFunctions
|
||||
|
||||
__all__ = [
|
||||
"TextFileToDocument",
|
||||
@ -12,4 +13,5 @@ __all__ = [
|
||||
"PyPDFToDocument",
|
||||
"HTMLToDocument",
|
||||
"MarkdownToDocument",
|
||||
"OpenAPIServiceToFunctions",
|
||||
]
|
||||
|
||||
216
haystack/components/converters/openapi_functions.py
Normal file
216
haystack/components/converters/openapi_functions.py
Normal file
@ -0,0 +1,216 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Union, Optional
|
||||
|
||||
import requests
|
||||
import yaml
|
||||
from requests import RequestException
|
||||
|
||||
from haystack import component, Document
|
||||
from haystack.dataclasses.byte_stream import ByteStream
|
||||
from haystack.lazy_imports import LazyImport
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
with LazyImport("Run 'pip install jsonref'") as openapi_imports:
|
||||
import jsonref
|
||||
|
||||
|
||||
@component
|
||||
class OpenAPIServiceToFunctions:
|
||||
"""
|
||||
OpenAPIServiceToFunctions is responsible for converting an OpenAPI service specification into a format suitable
|
||||
for OpenAI function calling, based on the provided OpenAPI specification. Given an OpenAPI specification,
|
||||
OpenAPIServiceToFunctions processes it, and extracts function definitions that can be invoked via OpenAI's
|
||||
function calling mechanism. The format of the extracted functions is compatible with OpenAI's function calling
|
||||
JSON format.
|
||||
|
||||
Minimal requirements for OpenAPI specification:
|
||||
- OpenAPI version 3.0.0 or higher
|
||||
- Each function must have a unique operationId
|
||||
- Each function must have a description
|
||||
- Each function must have a requestBody or parameters or both
|
||||
- Each function must have a schema for the requestBody and/or parameters
|
||||
|
||||
|
||||
See https://github.com/OAI/OpenAPI-Specification for more details on OpenAPI specification.
|
||||
See https://platform.openai.com/docs/guides/function-calling for more details on OpenAI function calling.
|
||||
"""
|
||||
|
||||
MIN_REQUIRED_OPENAPI_SPEC_VERSION = 3
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Initializes the OpenAPIServiceToFunctions instance
|
||||
"""
|
||||
openapi_imports.check()
|
||||
|
||||
@component.output_types(documents=List[Document])
|
||||
def run(self, sources: List[Union[str, Path, ByteStream]], system_messages: List[str]) -> Dict[str, Any]:
|
||||
"""
|
||||
Processes OpenAPI specification URLs or files to extract functions that can be invoked via OpenAI function
|
||||
calling mechanism. Each source is paired with a system message in one-to-one correspondence. The system message
|
||||
is used to assist LLM in the response generation.
|
||||
|
||||
:param sources: A list of OpenAPI specification sources, which can be URLs, file paths, or ByteStream objects.
|
||||
:type sources: List[Union[str, Path, ByteStream]]
|
||||
:param system_messages: A list of system messages corresponding to each source.
|
||||
:type system_messages: List[str]
|
||||
:return: A dictionary with a key 'documents' containing a list of Document objects. Each Document object
|
||||
encapsulates a function definition and relevant metadata.
|
||||
:rtype: Dict[str, Any]
|
||||
:raises RuntimeError: If the OpenAPI specification cannot be downloaded or processed.
|
||||
:raises ValueError: If the source type is not recognized or no functions are found in the OpenAPI specification.
|
||||
"""
|
||||
documents: List[Document] = []
|
||||
for source, system_message in zip(sources, system_messages):
|
||||
openapi_spec_content = None
|
||||
if isinstance(source, (str, Path)):
|
||||
# check if the source is a file path or a URL
|
||||
if os.path.exists(source):
|
||||
openapi_spec_content = self._read_from_file(source)
|
||||
else:
|
||||
openapi_spec_content = self._read_from_url(str(source))
|
||||
elif isinstance(source, ByteStream):
|
||||
openapi_spec_content = source.data.decode("utf-8")
|
||||
else:
|
||||
logger.warning("Invalid source type %s. Only str, Path, and ByteStream are supported.", type(source))
|
||||
continue
|
||||
|
||||
if openapi_spec_content:
|
||||
try:
|
||||
service_openapi_spec = self._parse_openapi_spec(openapi_spec_content)
|
||||
functions: List[Dict[str, Any]] = self._openapi_to_functions(service_openapi_spec)
|
||||
docs = [
|
||||
Document(
|
||||
content=json.dumps(function),
|
||||
meta={"spec": service_openapi_spec, "system_message": system_message},
|
||||
)
|
||||
for function in functions
|
||||
]
|
||||
documents.extend(docs)
|
||||
except Exception as e:
|
||||
logger.error("Error processing OpenAPI specification from source %s: %s", source, e)
|
||||
|
||||
return {"documents": documents}
|
||||
|
||||
def _openapi_to_functions(self, service_openapi_spec: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Extracts functions from the OpenAPI specification of the service and converts them into a format
|
||||
suitable for OpenAI function calling.
|
||||
|
||||
:param service_openapi_spec: The OpenAPI specification from which functions are to be extracted.
|
||||
:type service_openapi_spec: Dict[str, Any]
|
||||
:return: A list of dictionaries, each representing a function. Each dictionary includes the function's
|
||||
name, description, and a schema of its parameters.
|
||||
:rtype: List[Dict[str, Any]]
|
||||
"""
|
||||
|
||||
# Doesn't enforce rigid spec validation because that would require a lot of dependencies
|
||||
# We check the version and require minimal fields to be present, so we can extract functions
|
||||
spec_version = service_openapi_spec.get("openapi")
|
||||
if not spec_version:
|
||||
raise ValueError(f"Invalid OpenAPI spec provided. Could not extract version from {service_openapi_spec}")
|
||||
service_openapi_spec_version = int(spec_version.split(".")[0])
|
||||
|
||||
# Compare the versions
|
||||
if service_openapi_spec_version < OpenAPIServiceToFunctions.MIN_REQUIRED_OPENAPI_SPEC_VERSION:
|
||||
raise ValueError(
|
||||
f"Invalid OpenAPI spec version {service_openapi_spec_version}. Must be "
|
||||
f"at least {OpenAPIServiceToFunctions.MIN_REQUIRED_OPENAPI_SPEC_VERSION}."
|
||||
)
|
||||
|
||||
functions: List[Dict[str, Any]] = []
|
||||
for path_methods in service_openapi_spec["paths"].values():
|
||||
for method_specification in path_methods.values():
|
||||
resolved_spec = jsonref.replace_refs(method_specification)
|
||||
if isinstance(resolved_spec, dict):
|
||||
function_name = resolved_spec.get("operationId")
|
||||
desc = resolved_spec.get("description") or resolved_spec.get("summary", "")
|
||||
|
||||
schema: Dict[str, Any] = {"type": "object", "properties": {}}
|
||||
|
||||
req_body = (
|
||||
resolved_spec.get("requestBody", {})
|
||||
.get("content", {})
|
||||
.get("application/json", {})
|
||||
.get("schema")
|
||||
)
|
||||
if req_body:
|
||||
schema["properties"]["requestBody"] = req_body
|
||||
|
||||
params = resolved_spec.get("parameters", [])
|
||||
if params:
|
||||
param_properties = {param["name"]: param["schema"] for param in params if "schema" in param}
|
||||
schema["properties"]["parameters"] = {"type": "object", "properties": param_properties}
|
||||
|
||||
# these three fields are minimal requirement for OpenAI function calling
|
||||
if function_name and desc and schema:
|
||||
functions.append({"name": function_name, "description": desc, "parameters": schema})
|
||||
else:
|
||||
logger.warning(
|
||||
"Invalid OpenAPI spec format provided. Could not extract function from %s", resolved_spec
|
||||
)
|
||||
|
||||
else:
|
||||
logger.warning(
|
||||
"Invalid OpenAPI spec format provided. Could not extract function from %s", resolved_spec
|
||||
)
|
||||
|
||||
return functions
|
||||
|
||||
def _parse_openapi_spec(self, content: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Parses OpenAPI specification content, supporting both JSON and YAML formats.
|
||||
|
||||
:param content: The content of the OpenAPI specification.
|
||||
:type content: str
|
||||
:return: The parsed OpenAPI specification.
|
||||
:rtype: Dict[str, Any]
|
||||
"""
|
||||
try:
|
||||
return json.loads(content)
|
||||
except json.JSONDecodeError as json_error:
|
||||
# heuristic to confirm that the content is likely malformed JSON
|
||||
if content.strip().startswith(("{", "[")):
|
||||
raise json_error
|
||||
|
||||
try:
|
||||
return yaml.safe_load(content)
|
||||
except yaml.YAMLError:
|
||||
error_message = (
|
||||
"Failed to parse the OpenAPI specification. "
|
||||
"The content does not appear to be valid JSON or YAML.\n\n"
|
||||
)
|
||||
raise RuntimeError(error_message, content)
|
||||
|
||||
def _read_from_file(self, path: Union[str, Path]) -> Optional[str]:
|
||||
"""
|
||||
Reads the content of a file, given its path.
|
||||
:param path: The path of the file.
|
||||
:type path: Union[str, Path]
|
||||
:return: The content of the file or None if the file cannot be read.
|
||||
"""
|
||||
try:
|
||||
with open(path, "r") as f:
|
||||
return f.read()
|
||||
except IOError as e:
|
||||
logger.warning("IO error reading file: %s. Error: %s", path, e)
|
||||
return None
|
||||
|
||||
def _read_from_url(self, url: str) -> Optional[str]:
|
||||
"""
|
||||
Reads the content of a URL.
|
||||
:param url: The URL to read.
|
||||
:type url: str
|
||||
:return: The content of the URL or None if the URL cannot be read.
|
||||
"""
|
||||
try:
|
||||
response = requests.get(url, timeout=10)
|
||||
response.raise_for_status()
|
||||
return response.text
|
||||
except RequestException as e:
|
||||
logger.warning("Error fetching URL: %s. Error: %s", url, e)
|
||||
return None
|
||||
@ -0,0 +1,4 @@
|
||||
---
|
||||
features:
|
||||
- |
|
||||
Adds RAG OpenAPI services integration.
|
||||
80
test/components/connectors/test_openapi_service.py
Normal file
80
test/components/connectors/test_openapi_service.py
Normal file
@ -0,0 +1,80 @@
|
||||
import json
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, Mock
|
||||
from openapi3 import OpenAPI
|
||||
from haystack.components.connectors import OpenAPIServiceConnector
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def openapi_service_mock():
|
||||
return MagicMock(spec=OpenAPI)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def service_auths():
|
||||
return {"TestService": "auth_token"}
|
||||
|
||||
|
||||
class TestOpenAPIServiceConnector:
|
||||
@pytest.fixture
|
||||
def connector(self, service_auths):
|
||||
return OpenAPIServiceConnector(service_auths)
|
||||
|
||||
def test_parse_message_invalid_json(self, connector):
|
||||
# Test invalid JSON content
|
||||
with pytest.raises(ValueError):
|
||||
connector._parse_message("invalid json")
|
||||
|
||||
def test_parse_valid_json_message(self):
|
||||
connector = OpenAPIServiceConnector()
|
||||
|
||||
# The content format here is OpenAI function calling descriptor
|
||||
content = (
|
||||
'{"name": "compare_branches","arguments": "{\\n \\"parameters\\": {\\n '
|
||||
' \\"basehead\\": \\"main...openapi_container_v5\\",\\n '
|
||||
' \\"owner\\": \\"deepset-ai\\",\\n \\"repo\\": \\"haystack\\"\\n }\\n}"}'
|
||||
)
|
||||
descriptor = connector._parse_message(content)
|
||||
|
||||
# Assert that the descriptor contains the expected method name and arguments
|
||||
assert descriptor["name"] == "compare_branches"
|
||||
assert descriptor["arguments"]["parameters"] == {
|
||||
"basehead": "main...openapi_container_v5",
|
||||
"owner": "deepset-ai",
|
||||
"repo": "haystack",
|
||||
}
|
||||
# but not the requestBody
|
||||
assert "requestBody" not in descriptor["arguments"]
|
||||
|
||||
# The content format here is OpenAI function calling descriptor
|
||||
content = '{"name": "search","arguments": "{\\n \\"requestBody\\": {\\n \\"q\\": \\"haystack\\"\\n }\\n}"}'
|
||||
descriptor = connector._parse_message(content)
|
||||
assert descriptor["name"] == "search"
|
||||
assert descriptor["arguments"]["requestBody"] == {"q": "haystack"}
|
||||
|
||||
# but not the parameters
|
||||
assert "parameters" not in descriptor["arguments"]
|
||||
|
||||
def test_parse_message_missing_fields(self, connector):
|
||||
# Test JSON content with missing fields
|
||||
with pytest.raises(ValueError):
|
||||
connector._parse_message(json.dumps({"name": "test_method"}))
|
||||
|
||||
def test_authenticate_service_invalid(self, connector, openapi_service_mock):
|
||||
# Test invalid or missing authentication
|
||||
openapi_service_mock.components.securitySchemes = {"apiKey": {}}
|
||||
with pytest.raises(ValueError):
|
||||
connector._authenticate_service(openapi_service_mock)
|
||||
|
||||
def test_invoke_method_valid(self, connector, openapi_service_mock):
|
||||
# Test valid method invocation
|
||||
method_invocation_descriptor = {"name": "test_method", "arguments": {}}
|
||||
openapi_service_mock.call_test_method = Mock(return_value="response")
|
||||
result = connector._invoke_method(openapi_service_mock, method_invocation_descriptor)
|
||||
assert result == "response"
|
||||
|
||||
def test_invoke_method_invalid(self, connector, openapi_service_mock):
|
||||
# Test invalid method invocation
|
||||
method_invocation_descriptor = {"name": "invalid_method", "arguments": {}}
|
||||
with pytest.raises(RuntimeError):
|
||||
connector._invoke_method(openapi_service_mock, method_invocation_descriptor)
|
||||
221
test/components/converters/test_openapi_functions.py
Normal file
221
test/components/converters/test_openapi_functions.py
Normal file
@ -0,0 +1,221 @@
|
||||
import json
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
|
||||
from haystack.components.converters import OpenAPIServiceToFunctions
|
||||
from haystack.dataclasses import ByteStream
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def json_serperdev_openapi_spec():
|
||||
serper_spec = """
|
||||
{
|
||||
"openapi": "3.0.0",
|
||||
"info": {
|
||||
"title": "SerperDev",
|
||||
"version": "1.0.0",
|
||||
"description": "API for performing search queries"
|
||||
},
|
||||
"servers": [
|
||||
{
|
||||
"url": "https://google.serper.dev"
|
||||
}
|
||||
],
|
||||
"paths": {
|
||||
"/search": {
|
||||
"post": {
|
||||
"operationId": "search",
|
||||
"description": "Search the web with Google",
|
||||
"requestBody": {
|
||||
"required": true,
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"q": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"searchParameters": {
|
||||
"type": "undefined"
|
||||
},
|
||||
"knowledgeGraph": {
|
||||
"type": "undefined"
|
||||
},
|
||||
"answerBox": {
|
||||
"type": "undefined"
|
||||
},
|
||||
"organic": {
|
||||
"type": "undefined"
|
||||
},
|
||||
"topStories": {
|
||||
"type": "undefined"
|
||||
},
|
||||
"peopleAlsoAsk": {
|
||||
"type": "undefined"
|
||||
},
|
||||
"relatedSearches": {
|
||||
"type": "undefined"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"apikey": []
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
"components": {
|
||||
"securitySchemes": {
|
||||
"apikey": {
|
||||
"type": "apiKey",
|
||||
"name": "x-api-key",
|
||||
"in": "header"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
return serper_spec
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def yaml_serperdev_openapi_spec():
|
||||
serper_spec = """
|
||||
openapi: 3.0.0
|
||||
info:
|
||||
title: SerperDev
|
||||
version: 1.0.0
|
||||
description: API for performing search queries
|
||||
servers:
|
||||
- url: 'https://google.serper.dev'
|
||||
paths:
|
||||
/search:
|
||||
post:
|
||||
operationId: search
|
||||
description: Search the web with Google
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
q:
|
||||
type: string
|
||||
responses:
|
||||
'200':
|
||||
description: Successful response
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
searchParameters:
|
||||
type: undefined
|
||||
knowledgeGraph:
|
||||
type: undefined
|
||||
answerBox:
|
||||
type: undefined
|
||||
organic:
|
||||
type: undefined
|
||||
topStories:
|
||||
type: undefined
|
||||
peopleAlsoAsk:
|
||||
type: undefined
|
||||
relatedSearches:
|
||||
type: undefined
|
||||
security:
|
||||
- apikey: []
|
||||
components:
|
||||
securitySchemes:
|
||||
apikey:
|
||||
type: apiKey
|
||||
name: x-api-key
|
||||
in: header
|
||||
"""
|
||||
return serper_spec
|
||||
|
||||
|
||||
class TestOpenAPIServiceToFunctions:
|
||||
# test we can parse openapi spec given in json
|
||||
def test_openapi_spec_parsing_json(self, json_serperdev_openapi_spec):
|
||||
service = OpenAPIServiceToFunctions()
|
||||
|
||||
serper_spec_json = service._parse_openapi_spec(json_serperdev_openapi_spec)
|
||||
assert serper_spec_json["openapi"] == "3.0.0"
|
||||
assert serper_spec_json["info"]["title"] == "SerperDev"
|
||||
|
||||
# test we can parse openapi spec given in yaml
|
||||
def test_openapi_spec_parsing_yaml(self, yaml_serperdev_openapi_spec):
|
||||
service = OpenAPIServiceToFunctions()
|
||||
|
||||
serper_spec_yaml = service._parse_openapi_spec(yaml_serperdev_openapi_spec)
|
||||
assert serper_spec_yaml["openapi"] == "3.0.0"
|
||||
assert serper_spec_yaml["info"]["title"] == "SerperDev"
|
||||
|
||||
# test we can extract functions from openapi spec given
|
||||
def test_run_with_bytestream_source(self, json_serperdev_openapi_spec):
|
||||
service = OpenAPIServiceToFunctions()
|
||||
spec_stream = ByteStream.from_string(json_serperdev_openapi_spec)
|
||||
result = service.run(sources=[spec_stream], system_messages=["Some system message we don't care about here"])
|
||||
assert len(result["documents"]) == 1
|
||||
doc = result["documents"][0]
|
||||
|
||||
# check that the content is as expected
|
||||
assert (
|
||||
doc.content
|
||||
== '{"name": "search", "description": "Search the web with Google", "parameters": {"type": "object", '
|
||||
'"properties": {"requestBody": {"type": "object", "properties": {"q": {"type": "string"}}}}}}'
|
||||
)
|
||||
|
||||
# check that the metadata is as expected
|
||||
assert doc.meta["system_message"] == "Some system message we don't care about here"
|
||||
assert doc.meta["spec"] == json.loads(json_serperdev_openapi_spec)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform in ["win32", "cygwin"],
|
||||
reason="Can't run on Windows Github CI, need access temp file but windows does not allow it",
|
||||
)
|
||||
def test_run_with_file_source(self, json_serperdev_openapi_spec):
|
||||
# test we can extract functions from openapi spec given in file
|
||||
service = OpenAPIServiceToFunctions()
|
||||
# write the spec to NamedTemporaryFile and check that it is parsed correctly
|
||||
with tempfile.NamedTemporaryFile() as tmp:
|
||||
tmp.write(json_serperdev_openapi_spec.encode("utf-8"))
|
||||
tmp.seek(0)
|
||||
result = service.run(sources=[tmp.name], system_messages=["Some system message we don't care about here"])
|
||||
assert len(result["documents"]) == 1
|
||||
doc = result["documents"][0]
|
||||
|
||||
# check that the content is as expected
|
||||
assert (
|
||||
doc.content
|
||||
== '{"name": "search", "description": "Search the web with Google", "parameters": {"type": "object", '
|
||||
'"properties": {"requestBody": {"type": "object", "properties": {"q": {"type": "string"}}}}}}'
|
||||
)
|
||||
|
||||
# check that the metadata is as expected
|
||||
assert doc.meta["system_message"] == "Some system message we don't care about here"
|
||||
assert doc.meta["spec"] == json.loads(json_serperdev_openapi_spec)
|
||||
@ -14,3 +14,7 @@ azure-ai-formrecognizer>=3.2.0b2 # AzureOCRDocumentConverter
|
||||
langdetect # TextLanguageRouter and DocumentLanguageClassifier
|
||||
sentence-transformers>=2.2.0 # SentenceTransformersTextEmbedder and SentenceTransformersDocumentEmbedder
|
||||
openai-whisper>=20231106 # LocalWhisperTranscriber
|
||||
|
||||
# OpenAPI
|
||||
jsonref # OpenAPIServiceConnector, OpenAPIServiceToFunctions
|
||||
openapi3
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user