mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-09-21 14:13:26 +00:00
feat: Improve OpenAPIServiceToFunctions signature (#7257)
* Convert OpenAPIServiceToFunctions run interface --------- Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com>
This commit is contained in:
parent
721691c036
commit
0e7c41be5e
@ -3,11 +3,9 @@ import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import requests
|
||||
import yaml
|
||||
from requests import RequestException
|
||||
|
||||
from haystack import Document, component, logging
|
||||
from haystack import component, logging
|
||||
from haystack.dataclasses.byte_stream import ByteStream
|
||||
from haystack.lazy_imports import LazyImport
|
||||
|
||||
@ -38,7 +36,7 @@ class OpenAPIServiceToFunctions:
|
||||
|
||||
converter = OpenAPIServiceToFunctions()
|
||||
result = converter.run(sources=["path/to/openapi_definition.yaml"])
|
||||
assert result["documents"]
|
||||
assert result["functions"]
|
||||
```
|
||||
"""
|
||||
|
||||
@ -46,43 +44,50 @@ class OpenAPIServiceToFunctions:
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Create a OpenAPIServiceToFunctions component.
|
||||
Create an OpenAPIServiceToFunctions component.
|
||||
"""
|
||||
openapi_imports.check()
|
||||
|
||||
@component.output_types(documents=List[Document])
|
||||
def run(
|
||||
self, sources: List[Union[str, Path, ByteStream]], system_messages: Optional[List[str]] = None
|
||||
) -> Dict[str, Any]:
|
||||
@component.output_types(functions=List[Dict[str, Any]], openapi_specs=List[Dict[str, Any]])
|
||||
def run(self, sources: List[Union[str, Path, ByteStream]]) -> Dict[str, Any]:
|
||||
"""
|
||||
Converts OpenAPI definitions in OpenAI function calling format.
|
||||
|
||||
:param sources:
|
||||
File paths, URLs or ByteStream objects of OpenAPI definitions.
|
||||
:param system_messages:
|
||||
Optional system messages for each source.
|
||||
File paths or ByteStream objects of OpenAPI definitions (in JSON or YAML format).
|
||||
|
||||
:returns:
|
||||
A dictionary with the following keys:
|
||||
- documents: Documents containing a function definition and relevant metadata
|
||||
- functions: Function definitions in JSON object format
|
||||
- openapi_specs: OpenAPI specs in JSON/YAML object format with resolved references
|
||||
|
||||
:raises RuntimeError:
|
||||
If the OpenAPI definitions cannot be downloaded or processed.
|
||||
:raises ValueError:
|
||||
If the source type is not recognized or no functions are found in the OpenAPI definitions.
|
||||
"""
|
||||
documents: List[Document] = []
|
||||
system_messages = system_messages or [""] * len(sources)
|
||||
for source, system_message in zip(sources, system_messages):
|
||||
all_extracted_fc_definitions: List[Dict[str, Any]] = []
|
||||
all_openapi_specs = []
|
||||
for source in sources:
|
||||
openapi_spec_content = None
|
||||
if isinstance(source, (str, Path)):
|
||||
# check if the source is a file path or a URL
|
||||
if os.path.exists(source):
|
||||
openapi_spec_content = self._read_from_file(source)
|
||||
try:
|
||||
with open(source, "r") as f:
|
||||
openapi_spec_content = f.read()
|
||||
except IOError as e:
|
||||
logger.warning(
|
||||
"IO error reading OpenAPI specification file: {source}. Error: {e}", source=source, e=e
|
||||
)
|
||||
else:
|
||||
openapi_spec_content = self._read_from_url(str(source))
|
||||
logger.warning(f"OpenAPI specification file not found: {source}")
|
||||
elif isinstance(source, ByteStream):
|
||||
openapi_spec_content = source.data.decode("utf-8")
|
||||
if not openapi_spec_content:
|
||||
logger.warning(
|
||||
"Invalid OpenAPI specification content provided: {openapi_spec_content}",
|
||||
openapi_spec_content=openapi_spec_content,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Invalid source type {source}. Only str, Path, and ByteStream are supported.", source=type(source)
|
||||
@ -93,18 +98,17 @@ class OpenAPIServiceToFunctions:
|
||||
try:
|
||||
service_openapi_spec = self._parse_openapi_spec(openapi_spec_content)
|
||||
functions: List[Dict[str, Any]] = self._openapi_to_functions(service_openapi_spec)
|
||||
for function in functions:
|
||||
meta: Dict[str, Any] = {"spec": service_openapi_spec}
|
||||
if system_message:
|
||||
meta["system_message"] = system_message
|
||||
doc = Document(content=json.dumps(function), meta=meta)
|
||||
documents.append(doc)
|
||||
all_extracted_fc_definitions.extend(functions)
|
||||
all_openapi_specs.append(service_openapi_spec)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Error processing OpenAPI specification from source {source}: {error}", source=source, error=e
|
||||
)
|
||||
|
||||
return {"documents": documents}
|
||||
if not all_extracted_fc_definitions:
|
||||
logger.warning("No OpenAI function definitions extracted from the provided OpenAPI specification sources.")
|
||||
|
||||
return {"functions": all_extracted_fc_definitions, "openapi_specs": all_openapi_specs}
|
||||
|
||||
def _openapi_to_functions(self, service_openapi_spec: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
@ -244,32 +248,3 @@ class OpenAPIServiceToFunctions:
|
||||
|
||||
# Replace references in the object with their resolved values, if any
|
||||
return jsonref.replace_refs(open_api_spec_content)
|
||||
|
||||
def _read_from_file(self, path: Union[str, Path]) -> Optional[str]:
|
||||
"""
|
||||
Reads the content of a file, given its path.
|
||||
:param path: The path of the file.
|
||||
:type path: Union[str, Path]
|
||||
:return: The content of the file or None if the file cannot be read.
|
||||
"""
|
||||
try:
|
||||
with open(path, "r") as f:
|
||||
return f.read()
|
||||
except IOError as e:
|
||||
logger.warning("IO error reading file: {path}. Error: {error}", path=path, error=e)
|
||||
return None
|
||||
|
||||
def _read_from_url(self, url: str) -> Optional[str]:
|
||||
"""
|
||||
Reads the content of a URL.
|
||||
:param url: The URL to read.
|
||||
:type url: str
|
||||
:return: The content of the URL or None if the URL cannot be read.
|
||||
"""
|
||||
try:
|
||||
response = requests.get(url, timeout=10)
|
||||
response.raise_for_status()
|
||||
return response.text
|
||||
except RequestException as e:
|
||||
logger.warning("Error fetching URL: {url}. Error: {error}", url=url, error=e)
|
||||
return None
|
||||
|
@ -179,20 +179,16 @@ class TestOpenAPIServiceToFunctions:
|
||||
def test_run_with_bytestream_source(self, json_serperdev_openapi_spec):
|
||||
service = OpenAPIServiceToFunctions()
|
||||
spec_stream = ByteStream.from_string(json_serperdev_openapi_spec)
|
||||
result = service.run(sources=[spec_stream], system_messages=["Some system message we don't care about here"])
|
||||
assert len(result["documents"]) == 1
|
||||
doc = result["documents"][0]
|
||||
result = service.run(sources=[spec_stream])
|
||||
assert len(result["functions"]) == 1
|
||||
fc = result["functions"][0]
|
||||
|
||||
# check that the content is as expected
|
||||
assert (
|
||||
doc.content
|
||||
== '{"name": "search", "description": "Search the web with Google", "parameters": {"type": "object", '
|
||||
'"properties": {"q": {"type": "string"}}}}'
|
||||
)
|
||||
|
||||
# check that the metadata is as expected
|
||||
assert doc.meta["system_message"] == "Some system message we don't care about here"
|
||||
assert doc.meta["spec"] == json.loads(json_serperdev_openapi_spec)
|
||||
# check that fc definition is as expected
|
||||
assert fc == {
|
||||
"name": "search",
|
||||
"description": "Search the web with Google",
|
||||
"parameters": {"type": "object", "properties": {"q": {"type": "string"}}},
|
||||
}
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform in ["win32", "cygwin"],
|
||||
@ -205,47 +201,56 @@ class TestOpenAPIServiceToFunctions:
|
||||
with tempfile.NamedTemporaryFile() as tmp:
|
||||
tmp.write(json_serperdev_openapi_spec.encode("utf-8"))
|
||||
tmp.seek(0)
|
||||
result = service.run(sources=[tmp.name], system_messages=["Some system message we don't care about here"])
|
||||
assert len(result["documents"]) == 1
|
||||
doc = result["documents"][0]
|
||||
result = service.run(sources=[tmp.name])
|
||||
assert len(result["functions"]) == 1
|
||||
fc = result["functions"][0]
|
||||
|
||||
# check that the content is as expected
|
||||
assert (
|
||||
doc.content
|
||||
== '{"name": "search", "description": "Search the web with Google", "parameters": {"type": "object", '
|
||||
'"properties": {"q": {"type": "string"}}}}'
|
||||
)
|
||||
# check that fc definition is as expected
|
||||
assert fc == {
|
||||
"name": "search",
|
||||
"description": "Search the web with Google",
|
||||
"parameters": {"type": "object", "properties": {"q": {"type": "string"}}},
|
||||
}
|
||||
|
||||
# check that the metadata is as expected
|
||||
assert doc.meta["system_message"] == "Some system message we don't care about here"
|
||||
assert doc.meta["spec"] == json.loads(json_serperdev_openapi_spec)
|
||||
|
||||
def test_run_with_file_source_and_none_system_messages(self, json_serperdev_openapi_spec):
|
||||
def test_run_with_invalid_file_source(self, caplog):
|
||||
# test invalid source
|
||||
service = OpenAPIServiceToFunctions()
|
||||
spec_stream = ByteStream.from_string(json_serperdev_openapi_spec)
|
||||
result = service.run(sources=["invalid_source"])
|
||||
assert result["functions"] == []
|
||||
assert "not found" in caplog.text
|
||||
|
||||
# we now omit the system_messages argument
|
||||
result = service.run(sources=[spec_stream])
|
||||
assert len(result["documents"]) == 1
|
||||
doc = result["documents"][0]
|
||||
|
||||
# check that the content is as expected
|
||||
assert (
|
||||
doc.content
|
||||
== '{"name": "search", "description": "Search the web with Google", "parameters": {"type": "object", '
|
||||
'"properties": {"q": {"type": "string"}}}}'
|
||||
)
|
||||
|
||||
# check that the metadata is as expected, system_message should not be present
|
||||
assert "system_message" not in doc.meta
|
||||
assert doc.meta["spec"] == json.loads(json_serperdev_openapi_spec)
|
||||
def test_run_with_invalid_bytestream_source(self, caplog):
|
||||
# test invalid source
|
||||
service = OpenAPIServiceToFunctions()
|
||||
result = service.run(sources=[ByteStream.from_string("")])
|
||||
assert result["functions"] == []
|
||||
assert "Invalid OpenAPI specification" in caplog.text
|
||||
|
||||
def test_complex_types_conversion(self, test_files_path):
|
||||
# ensure that complex types from OpenAPI spec are converted to the expected format in OpenAI function calling
|
||||
service = OpenAPIServiceToFunctions()
|
||||
result = service.run(sources=[test_files_path / "json" / "complex_types_openapi_service.json"])
|
||||
assert len(result["documents"]) == 1
|
||||
assert len(result["functions"]) == 1
|
||||
|
||||
with open(test_files_path / "json" / "complex_types_openai_spec.json") as openai_spec_file:
|
||||
desired_output = json.load(openai_spec_file)
|
||||
assert result["documents"][0].content == json.dumps(desired_output)
|
||||
assert result["functions"][0] == desired_output
|
||||
|
||||
def test_simple_and_complex_at_once(self, test_files_path, json_serperdev_openapi_spec):
|
||||
# ensure multiple functions are extracted from multiple paths in OpenAPI spec
|
||||
service = OpenAPIServiceToFunctions()
|
||||
sources = [
|
||||
ByteStream.from_string(json_serperdev_openapi_spec),
|
||||
test_files_path / "json" / "complex_types_openapi_service.json",
|
||||
]
|
||||
result = service.run(sources=sources)
|
||||
assert len(result["functions"]) == 2
|
||||
|
||||
with open(test_files_path / "json" / "complex_types_openai_spec.json") as openai_spec_file:
|
||||
desired_output = json.load(openai_spec_file)
|
||||
assert result["functions"][0] == {
|
||||
"name": "search",
|
||||
"description": "Search the web with Google",
|
||||
"parameters": {"type": "object", "properties": {"q": {"type": "string"}}},
|
||||
}
|
||||
assert result["functions"][1] == desired_output
|
||||
|
Loading…
x
Reference in New Issue
Block a user