mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-09-21 22:23:23 +00:00
feat: Improve OpenAPIServiceToFunctions signature (#7257)
* Convert OpenAPIServiceToFunctions run interface --------- Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com>
This commit is contained in:
parent
721691c036
commit
0e7c41be5e
@ -3,11 +3,9 @@ import os
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
import requests
|
|
||||||
import yaml
|
import yaml
|
||||||
from requests import RequestException
|
|
||||||
|
|
||||||
from haystack import Document, component, logging
|
from haystack import component, logging
|
||||||
from haystack.dataclasses.byte_stream import ByteStream
|
from haystack.dataclasses.byte_stream import ByteStream
|
||||||
from haystack.lazy_imports import LazyImport
|
from haystack.lazy_imports import LazyImport
|
||||||
|
|
||||||
@ -38,7 +36,7 @@ class OpenAPIServiceToFunctions:
|
|||||||
|
|
||||||
converter = OpenAPIServiceToFunctions()
|
converter = OpenAPIServiceToFunctions()
|
||||||
result = converter.run(sources=["path/to/openapi_definition.yaml"])
|
result = converter.run(sources=["path/to/openapi_definition.yaml"])
|
||||||
assert result["documents"]
|
assert result["functions"]
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -46,43 +44,50 @@ class OpenAPIServiceToFunctions:
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""
|
"""
|
||||||
Create a OpenAPIServiceToFunctions component.
|
Create an OpenAPIServiceToFunctions component.
|
||||||
"""
|
"""
|
||||||
openapi_imports.check()
|
openapi_imports.check()
|
||||||
|
|
||||||
@component.output_types(documents=List[Document])
|
@component.output_types(functions=List[Dict[str, Any]], openapi_specs=List[Dict[str, Any]])
|
||||||
def run(
|
def run(self, sources: List[Union[str, Path, ByteStream]]) -> Dict[str, Any]:
|
||||||
self, sources: List[Union[str, Path, ByteStream]], system_messages: Optional[List[str]] = None
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""
|
"""
|
||||||
Converts OpenAPI definitions in OpenAI function calling format.
|
Converts OpenAPI definitions in OpenAI function calling format.
|
||||||
|
|
||||||
:param sources:
|
:param sources:
|
||||||
File paths, URLs or ByteStream objects of OpenAPI definitions.
|
File paths or ByteStream objects of OpenAPI definitions (in JSON or YAML format).
|
||||||
:param system_messages:
|
|
||||||
Optional system messages for each source.
|
|
||||||
|
|
||||||
:returns:
|
:returns:
|
||||||
A dictionary with the following keys:
|
A dictionary with the following keys:
|
||||||
- documents: Documents containing a function definition and relevant metadata
|
- functions: Function definitions in JSON object format
|
||||||
|
- openapi_specs: OpenAPI specs in JSON/YAML object format with resolved references
|
||||||
|
|
||||||
:raises RuntimeError:
|
:raises RuntimeError:
|
||||||
If the OpenAPI definitions cannot be downloaded or processed.
|
If the OpenAPI definitions cannot be downloaded or processed.
|
||||||
:raises ValueError:
|
:raises ValueError:
|
||||||
If the source type is not recognized or no functions are found in the OpenAPI definitions.
|
If the source type is not recognized or no functions are found in the OpenAPI definitions.
|
||||||
"""
|
"""
|
||||||
documents: List[Document] = []
|
all_extracted_fc_definitions: List[Dict[str, Any]] = []
|
||||||
system_messages = system_messages or [""] * len(sources)
|
all_openapi_specs = []
|
||||||
for source, system_message in zip(sources, system_messages):
|
for source in sources:
|
||||||
openapi_spec_content = None
|
openapi_spec_content = None
|
||||||
if isinstance(source, (str, Path)):
|
if isinstance(source, (str, Path)):
|
||||||
# check if the source is a file path or a URL
|
|
||||||
if os.path.exists(source):
|
if os.path.exists(source):
|
||||||
openapi_spec_content = self._read_from_file(source)
|
try:
|
||||||
|
with open(source, "r") as f:
|
||||||
|
openapi_spec_content = f.read()
|
||||||
|
except IOError as e:
|
||||||
|
logger.warning(
|
||||||
|
"IO error reading OpenAPI specification file: {source}. Error: {e}", source=source, e=e
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
openapi_spec_content = self._read_from_url(str(source))
|
logger.warning(f"OpenAPI specification file not found: {source}")
|
||||||
elif isinstance(source, ByteStream):
|
elif isinstance(source, ByteStream):
|
||||||
openapi_spec_content = source.data.decode("utf-8")
|
openapi_spec_content = source.data.decode("utf-8")
|
||||||
|
if not openapi_spec_content:
|
||||||
|
logger.warning(
|
||||||
|
"Invalid OpenAPI specification content provided: {openapi_spec_content}",
|
||||||
|
openapi_spec_content=openapi_spec_content,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Invalid source type {source}. Only str, Path, and ByteStream are supported.", source=type(source)
|
"Invalid source type {source}. Only str, Path, and ByteStream are supported.", source=type(source)
|
||||||
@ -93,18 +98,17 @@ class OpenAPIServiceToFunctions:
|
|||||||
try:
|
try:
|
||||||
service_openapi_spec = self._parse_openapi_spec(openapi_spec_content)
|
service_openapi_spec = self._parse_openapi_spec(openapi_spec_content)
|
||||||
functions: List[Dict[str, Any]] = self._openapi_to_functions(service_openapi_spec)
|
functions: List[Dict[str, Any]] = self._openapi_to_functions(service_openapi_spec)
|
||||||
for function in functions:
|
all_extracted_fc_definitions.extend(functions)
|
||||||
meta: Dict[str, Any] = {"spec": service_openapi_spec}
|
all_openapi_specs.append(service_openapi_spec)
|
||||||
if system_message:
|
|
||||||
meta["system_message"] = system_message
|
|
||||||
doc = Document(content=json.dumps(function), meta=meta)
|
|
||||||
documents.append(doc)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
"Error processing OpenAPI specification from source {source}: {error}", source=source, error=e
|
"Error processing OpenAPI specification from source {source}: {error}", source=source, error=e
|
||||||
)
|
)
|
||||||
|
|
||||||
return {"documents": documents}
|
if not all_extracted_fc_definitions:
|
||||||
|
logger.warning("No OpenAI function definitions extracted from the provided OpenAPI specification sources.")
|
||||||
|
|
||||||
|
return {"functions": all_extracted_fc_definitions, "openapi_specs": all_openapi_specs}
|
||||||
|
|
||||||
def _openapi_to_functions(self, service_openapi_spec: Dict[str, Any]) -> List[Dict[str, Any]]:
|
def _openapi_to_functions(self, service_openapi_spec: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
@ -244,32 +248,3 @@ class OpenAPIServiceToFunctions:
|
|||||||
|
|
||||||
# Replace references in the object with their resolved values, if any
|
# Replace references in the object with their resolved values, if any
|
||||||
return jsonref.replace_refs(open_api_spec_content)
|
return jsonref.replace_refs(open_api_spec_content)
|
||||||
|
|
||||||
def _read_from_file(self, path: Union[str, Path]) -> Optional[str]:
|
|
||||||
"""
|
|
||||||
Reads the content of a file, given its path.
|
|
||||||
:param path: The path of the file.
|
|
||||||
:type path: Union[str, Path]
|
|
||||||
:return: The content of the file or None if the file cannot be read.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
with open(path, "r") as f:
|
|
||||||
return f.read()
|
|
||||||
except IOError as e:
|
|
||||||
logger.warning("IO error reading file: {path}. Error: {error}", path=path, error=e)
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _read_from_url(self, url: str) -> Optional[str]:
|
|
||||||
"""
|
|
||||||
Reads the content of a URL.
|
|
||||||
:param url: The URL to read.
|
|
||||||
:type url: str
|
|
||||||
:return: The content of the URL or None if the URL cannot be read.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
response = requests.get(url, timeout=10)
|
|
||||||
response.raise_for_status()
|
|
||||||
return response.text
|
|
||||||
except RequestException as e:
|
|
||||||
logger.warning("Error fetching URL: {url}. Error: {error}", url=url, error=e)
|
|
||||||
return None
|
|
||||||
|
@ -179,20 +179,16 @@ class TestOpenAPIServiceToFunctions:
|
|||||||
def test_run_with_bytestream_source(self, json_serperdev_openapi_spec):
|
def test_run_with_bytestream_source(self, json_serperdev_openapi_spec):
|
||||||
service = OpenAPIServiceToFunctions()
|
service = OpenAPIServiceToFunctions()
|
||||||
spec_stream = ByteStream.from_string(json_serperdev_openapi_spec)
|
spec_stream = ByteStream.from_string(json_serperdev_openapi_spec)
|
||||||
result = service.run(sources=[spec_stream], system_messages=["Some system message we don't care about here"])
|
result = service.run(sources=[spec_stream])
|
||||||
assert len(result["documents"]) == 1
|
assert len(result["functions"]) == 1
|
||||||
doc = result["documents"][0]
|
fc = result["functions"][0]
|
||||||
|
|
||||||
# check that the content is as expected
|
# check that fc definition is as expected
|
||||||
assert (
|
assert fc == {
|
||||||
doc.content
|
"name": "search",
|
||||||
== '{"name": "search", "description": "Search the web with Google", "parameters": {"type": "object", '
|
"description": "Search the web with Google",
|
||||||
'"properties": {"q": {"type": "string"}}}}'
|
"parameters": {"type": "object", "properties": {"q": {"type": "string"}}},
|
||||||
)
|
}
|
||||||
|
|
||||||
# check that the metadata is as expected
|
|
||||||
assert doc.meta["system_message"] == "Some system message we don't care about here"
|
|
||||||
assert doc.meta["spec"] == json.loads(json_serperdev_openapi_spec)
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
sys.platform in ["win32", "cygwin"],
|
sys.platform in ["win32", "cygwin"],
|
||||||
@ -205,47 +201,56 @@ class TestOpenAPIServiceToFunctions:
|
|||||||
with tempfile.NamedTemporaryFile() as tmp:
|
with tempfile.NamedTemporaryFile() as tmp:
|
||||||
tmp.write(json_serperdev_openapi_spec.encode("utf-8"))
|
tmp.write(json_serperdev_openapi_spec.encode("utf-8"))
|
||||||
tmp.seek(0)
|
tmp.seek(0)
|
||||||
result = service.run(sources=[tmp.name], system_messages=["Some system message we don't care about here"])
|
result = service.run(sources=[tmp.name])
|
||||||
assert len(result["documents"]) == 1
|
assert len(result["functions"]) == 1
|
||||||
doc = result["documents"][0]
|
fc = result["functions"][0]
|
||||||
|
|
||||||
# check that the content is as expected
|
# check that fc definition is as expected
|
||||||
assert (
|
assert fc == {
|
||||||
doc.content
|
"name": "search",
|
||||||
== '{"name": "search", "description": "Search the web with Google", "parameters": {"type": "object", '
|
"description": "Search the web with Google",
|
||||||
'"properties": {"q": {"type": "string"}}}}'
|
"parameters": {"type": "object", "properties": {"q": {"type": "string"}}},
|
||||||
)
|
}
|
||||||
|
|
||||||
# check that the metadata is as expected
|
def test_run_with_invalid_file_source(self, caplog):
|
||||||
assert doc.meta["system_message"] == "Some system message we don't care about here"
|
# test invalid source
|
||||||
assert doc.meta["spec"] == json.loads(json_serperdev_openapi_spec)
|
|
||||||
|
|
||||||
def test_run_with_file_source_and_none_system_messages(self, json_serperdev_openapi_spec):
|
|
||||||
service = OpenAPIServiceToFunctions()
|
service = OpenAPIServiceToFunctions()
|
||||||
spec_stream = ByteStream.from_string(json_serperdev_openapi_spec)
|
result = service.run(sources=["invalid_source"])
|
||||||
|
assert result["functions"] == []
|
||||||
|
assert "not found" in caplog.text
|
||||||
|
|
||||||
# we now omit the system_messages argument
|
def test_run_with_invalid_bytestream_source(self, caplog):
|
||||||
result = service.run(sources=[spec_stream])
|
# test invalid source
|
||||||
assert len(result["documents"]) == 1
|
service = OpenAPIServiceToFunctions()
|
||||||
doc = result["documents"][0]
|
result = service.run(sources=[ByteStream.from_string("")])
|
||||||
|
assert result["functions"] == []
|
||||||
# check that the content is as expected
|
assert "Invalid OpenAPI specification" in caplog.text
|
||||||
assert (
|
|
||||||
doc.content
|
|
||||||
== '{"name": "search", "description": "Search the web with Google", "parameters": {"type": "object", '
|
|
||||||
'"properties": {"q": {"type": "string"}}}}'
|
|
||||||
)
|
|
||||||
|
|
||||||
# check that the metadata is as expected, system_message should not be present
|
|
||||||
assert "system_message" not in doc.meta
|
|
||||||
assert doc.meta["spec"] == json.loads(json_serperdev_openapi_spec)
|
|
||||||
|
|
||||||
def test_complex_types_conversion(self, test_files_path):
|
def test_complex_types_conversion(self, test_files_path):
|
||||||
# ensure that complex types from OpenAPI spec are converted to the expected format in OpenAI function calling
|
# ensure that complex types from OpenAPI spec are converted to the expected format in OpenAI function calling
|
||||||
service = OpenAPIServiceToFunctions()
|
service = OpenAPIServiceToFunctions()
|
||||||
result = service.run(sources=[test_files_path / "json" / "complex_types_openapi_service.json"])
|
result = service.run(sources=[test_files_path / "json" / "complex_types_openapi_service.json"])
|
||||||
assert len(result["documents"]) == 1
|
assert len(result["functions"]) == 1
|
||||||
|
|
||||||
with open(test_files_path / "json" / "complex_types_openai_spec.json") as openai_spec_file:
|
with open(test_files_path / "json" / "complex_types_openai_spec.json") as openai_spec_file:
|
||||||
desired_output = json.load(openai_spec_file)
|
desired_output = json.load(openai_spec_file)
|
||||||
assert result["documents"][0].content == json.dumps(desired_output)
|
assert result["functions"][0] == desired_output
|
||||||
|
|
||||||
|
def test_simple_and_complex_at_once(self, test_files_path, json_serperdev_openapi_spec):
|
||||||
|
# ensure multiple functions are extracted from multiple paths in OpenAPI spec
|
||||||
|
service = OpenAPIServiceToFunctions()
|
||||||
|
sources = [
|
||||||
|
ByteStream.from_string(json_serperdev_openapi_spec),
|
||||||
|
test_files_path / "json" / "complex_types_openapi_service.json",
|
||||||
|
]
|
||||||
|
result = service.run(sources=sources)
|
||||||
|
assert len(result["functions"]) == 2
|
||||||
|
|
||||||
|
with open(test_files_path / "json" / "complex_types_openai_spec.json") as openai_spec_file:
|
||||||
|
desired_output = json.load(openai_spec_file)
|
||||||
|
assert result["functions"][0] == {
|
||||||
|
"name": "search",
|
||||||
|
"description": "Search the web with Google",
|
||||||
|
"parameters": {"type": "object", "properties": {"q": {"type": "string"}}},
|
||||||
|
}
|
||||||
|
assert result["functions"][1] == desired_output
|
||||||
|
Loading…
x
Reference in New Issue
Block a user