mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-12 15:27:06 +00:00
fix: Update OpenAPIServiceConnector to new ChatMessage (#8817)
* Update OpenAPIServiceConnector to new ChatMessage, bypass model response validation * Add reno * Lint fixes * Add serde pipeline test * Update haystack/components/connectors/openapi_service.py Co-authored-by: Daria Fokina <daria.fokina@deepset.ai> * Update haystack/components/connectors/openapi_service.py Co-authored-by: Daria Fokina <daria.fokina@deepset.ai> * Update haystack/components/connectors/openapi_service.py Co-authored-by: Daria Fokina <daria.fokina@deepset.ai> * Update haystack/components/connectors/openapi_service.py Co-authored-by: Daria Fokina <daria.fokina@deepset.ai> * Update haystack/components/connectors/openapi_service.py Co-authored-by: Daria Fokina <daria.fokina@deepset.ai> * Add edge case unit test --------- Co-authored-by: Daria Fokina <daria.fokina@deepset.ai>
This commit is contained in:
parent
fd5040108a
commit
b6ebd3cd77
@ -14,7 +14,135 @@ from haystack.lazy_imports import LazyImport
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
with LazyImport("Run 'pip install openapi3'") as openapi_imports:
|
||||
import requests
|
||||
from openapi3 import OpenAPI
|
||||
from openapi3.errors import UnexpectedResponseError
|
||||
from openapi3.paths import Operation
|
||||
|
||||
# Patch the request method to add support for the proper raw_response handling
|
||||
# If you see that https://github.com/Dorthu/openapi3/pull/124/
|
||||
# is merged, we can remove this patch - notify authors of this code
|
||||
def patch_request(
|
||||
self,
|
||||
base_url: str,
|
||||
*,
|
||||
data: Optional[Any] = None,
|
||||
parameters: Optional[Dict[str, Any]] = None,
|
||||
raw_response: bool = False,
|
||||
security: Optional[Dict[str, str]] = None,
|
||||
session: Optional[Any] = None,
|
||||
verify: Union[bool, str] = True,
|
||||
) -> Optional[Any]:
|
||||
"""
|
||||
Sends an HTTP request as described by this path.
|
||||
|
||||
:param base_url: The URL to append this operation's path to when making
|
||||
the call.
|
||||
:param data: The request body to send.
|
||||
:param parameters: The parameters used to create the path.
|
||||
:param raw_response: If true, return the raw response instead of validating
|
||||
and exterpolating it.
|
||||
:param security: The security scheme to use, and the values it needs to
|
||||
process successfully.
|
||||
:param session: A persistent request session.
|
||||
:param verify: If we should do an ssl verification on the request or not.
|
||||
In case str was provided, will use that as the CA.
|
||||
:return: The response data, either raw or processed depending on raw_response flag.
|
||||
"""
|
||||
# Set request method (e.g. 'GET')
|
||||
self._request = requests.Request(self.path[-1])
|
||||
|
||||
# Set self._request.url to base_url w/ path
|
||||
self._request.url = base_url + self.path[-2]
|
||||
|
||||
parameters = parameters or {}
|
||||
security = security or {}
|
||||
|
||||
if security and self.security:
|
||||
security_requirement = None
|
||||
for scheme, value in security.items():
|
||||
security_requirement = None
|
||||
for r in self.security:
|
||||
if r.name == scheme:
|
||||
security_requirement = r
|
||||
self._request_handle_secschemes(r, value)
|
||||
|
||||
if security_requirement is None:
|
||||
err_msg = """No security requirement satisfied (accepts {}) \
|
||||
""".format(", ".join(self.security.keys()))
|
||||
raise ValueError(err_msg)
|
||||
|
||||
if self.requestBody:
|
||||
if self.requestBody.required and data is None:
|
||||
err_msg = "Request Body is required but none was provided."
|
||||
raise ValueError(err_msg)
|
||||
|
||||
self._request_handle_body(data)
|
||||
|
||||
self._request_handle_parameters(parameters)
|
||||
|
||||
if session is None:
|
||||
session = self._session
|
||||
|
||||
# send the prepared request
|
||||
result = session.send(self._request.prepare(), verify=verify)
|
||||
|
||||
# spec enforces these are strings
|
||||
status_code = str(result.status_code)
|
||||
|
||||
# find the response model in spec we received
|
||||
expected_response = None
|
||||
if status_code in self.responses:
|
||||
expected_response = self.responses[status_code]
|
||||
elif "default" in self.responses:
|
||||
expected_response = self.responses["default"]
|
||||
|
||||
if expected_response is None:
|
||||
raise UnexpectedResponseError(result, self)
|
||||
|
||||
# if we got back a valid response code (or there was a default) and no
|
||||
# response content was expected, return None
|
||||
if expected_response.content is None:
|
||||
return None
|
||||
|
||||
content_type = result.headers["Content-Type"]
|
||||
if ";" in content_type:
|
||||
# if the content type that came in included an encoding, we'll ignore
|
||||
# it for now (requests has already parsed it for us) and only look at
|
||||
# the MIME type when determining if an expected content type was returned.
|
||||
content_type = content_type.split(";")[0].strip()
|
||||
|
||||
expected_media = expected_response.content.get(content_type, None)
|
||||
|
||||
# If raw_response is True, return the raw text or json based on content type
|
||||
if raw_response:
|
||||
if "application/json" in content_type:
|
||||
return result.json()
|
||||
return result.text
|
||||
|
||||
if expected_media is None and "/" in content_type:
|
||||
# accept media type ranges in the spec. the most specific matching
|
||||
# type should always be chosen, but if we do not have a match here
|
||||
# a generic range should be accepted if one if provided
|
||||
# https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.1.md#response-object
|
||||
|
||||
generic_type = content_type.split("/")[0] + "/*"
|
||||
expected_media = expected_response.content.get(generic_type, None)
|
||||
|
||||
if expected_media is None:
|
||||
err_msg = """Unexpected Content-Type {} returned for operation {} \
|
||||
(expected one of {})"""
|
||||
err_var = result.headers["Content-Type"], self.operationId, ",".join(expected_response.content.keys())
|
||||
|
||||
raise RuntimeError(err_msg.format(*err_var))
|
||||
|
||||
if content_type.lower() == "application/json":
|
||||
return expected_media.schema.model(result.json())
|
||||
|
||||
raise NotImplementedError("Only application/json content type is supported")
|
||||
|
||||
# Apply the patch
|
||||
Operation.request = patch_request
|
||||
|
||||
|
||||
@component
|
||||
@ -89,12 +217,10 @@ class OpenAPIServiceConnector:
|
||||
"""
|
||||
Processes a list of chat messages to invoke a method on an OpenAPI service.
|
||||
|
||||
It parses the last message in the list, expecting it to contain an OpenAI function calling descriptor
|
||||
(name & parameters) in JSON format.
|
||||
It parses the last message in the list, expecting it to contain tool calls.
|
||||
|
||||
:param messages: A list of `ChatMessage` objects containing the messages to be processed. The last message
|
||||
should contain the function invocation payload in OpenAI function calling format. See the example in the class
|
||||
docstring for the expected format.
|
||||
should contain the tool calls.
|
||||
:param service_openapi_spec: The OpenAPI JSON specification object of the service to be invoked. All the refs
|
||||
should already be resolved.
|
||||
:param service_credentials: The credentials to be used for authentication with the service.
|
||||
@ -105,29 +231,34 @@ class OpenAPIServiceConnector:
|
||||
response is in JSON format, and the `content` attribute of the `ChatMessage` contains
|
||||
the JSON string.
|
||||
|
||||
:raises ValueError: If the last message is not from the assistant or if it does not contain the correct payload
|
||||
to invoke a method on the service.
|
||||
:raises ValueError: If the last message is not from the assistant or if it does not contain tool calls.
|
||||
"""
|
||||
|
||||
last_message = messages[-1]
|
||||
if not last_message.is_from(ChatRole.ASSISTANT):
|
||||
raise ValueError(f"{last_message} is not from the assistant.")
|
||||
|
||||
function_invocation_payloads = self._parse_message(last_message)
|
||||
tool_calls = last_message.tool_calls
|
||||
if not tool_calls:
|
||||
raise ValueError(f"The provided ChatMessage has no tool calls.\nChatMessage: {last_message}")
|
||||
|
||||
function_payloads = []
|
||||
for tool_call in tool_calls:
|
||||
function_payloads.append({"arguments": tool_call.arguments, "name": tool_call.tool_name})
|
||||
|
||||
# instantiate the OpenAPI service for the given specification
|
||||
openapi_service = OpenAPI(service_openapi_spec, ssl_verify=self.ssl_verify)
|
||||
self._authenticate_service(openapi_service, service_credentials)
|
||||
|
||||
response_messages = []
|
||||
for method_invocation_descriptor in function_invocation_payloads:
|
||||
for method_invocation_descriptor in function_payloads:
|
||||
service_response = self._invoke_method(openapi_service, method_invocation_descriptor)
|
||||
# openapi3 parses the JSON service response into a model object, which is not our focus at the moment.
|
||||
# Instead, we require direct access to the raw JSON data of the response, rather than the model objects
|
||||
# provided by the openapi3 library. This approach helps us avoid issues related to (de)serialization.
|
||||
# By accessing the raw JSON response through `service_response._raw_data`, we can serialize this data
|
||||
# into a string. Finally, we use this string to create a ChatMessage object.
|
||||
response_messages.append(ChatMessage.from_user(json.dumps(service_response._raw_data)))
|
||||
response_messages.append(ChatMessage.from_user(json.dumps(service_response)))
|
||||
|
||||
return {"service_response": response_messages}
|
||||
|
||||
@ -152,35 +283,6 @@ class OpenAPIServiceConnector:
|
||||
"""
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
def _parse_message(self, message: ChatMessage) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Parses the message to extract the method invocation descriptor.
|
||||
|
||||
:param message: ChatMessage containing the tools calls
|
||||
:return: A list of function invocation payloads
|
||||
:raises ValueError: If the content is not valid JSON or lacks required fields.
|
||||
"""
|
||||
function_payloads = []
|
||||
if message.text is None:
|
||||
raise ValueError(f"The provided ChatMessage has no text.\nChatMessage: {message}")
|
||||
try:
|
||||
tool_calls = json.loads(message.text)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError("Invalid JSON content, expected OpenAI tools message.", message.text)
|
||||
|
||||
for tool_call in tool_calls:
|
||||
# this should never happen, but just in case do a sanity check
|
||||
if "type" not in tool_call:
|
||||
raise ValueError("Message payload doesn't seem to be a tool invocation descriptor", message.text)
|
||||
|
||||
# In OpenAPIServiceConnector we know how to handle functions tools only
|
||||
if tool_call["type"] == "function":
|
||||
function_call = tool_call["function"]
|
||||
function_payloads.append(
|
||||
{"arguments": json.loads(function_call["arguments"]), "name": function_call["name"]}
|
||||
)
|
||||
return function_payloads
|
||||
|
||||
def _authenticate_service(self, openapi_service: "OpenAPI", credentials: Optional[Union[dict, str]] = None):
|
||||
"""
|
||||
Authentication with an OpenAPI service.
|
||||
@ -294,4 +396,4 @@ class OpenAPIServiceConnector:
|
||||
f"Missing requestBody parameter: '{param_name}' required for the '{name}' operation."
|
||||
)
|
||||
# call the underlying service REST API with the parameters
|
||||
return method_to_call(**method_call_params)
|
||||
return method_to_call(**method_call_params, raw_response=True)
|
||||
|
||||
@ -0,0 +1,4 @@
|
||||
---
|
||||
enhancements:
|
||||
- |
|
||||
Enhanced `OpenAPIServiceConnector` to support and be compatible with the new ChatMessage format.
|
||||
@ -2,14 +2,24 @@
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import json
|
||||
from unittest.mock import MagicMock, Mock, patch, PropertyMock
|
||||
import os
|
||||
from typing import Any, Dict, List
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import requests
|
||||
|
||||
from haystack import Pipeline
|
||||
import pytest
|
||||
from haystack.components.converters.openapi_functions import OpenAPIServiceToFunctions
|
||||
from haystack.components.converters.output_adapter import OutputAdapter
|
||||
from haystack.components.generators.chat.openai import OpenAIChatGenerator
|
||||
from haystack.components.generators.utils import print_streaming_chunk
|
||||
from haystack.dataclasses.byte_stream import ByteStream
|
||||
from openapi3 import OpenAPI
|
||||
from openapi3.schemas import Model
|
||||
|
||||
from haystack.components.connectors import OpenAPIServiceConnector
|
||||
from haystack.dataclasses import ChatMessage
|
||||
from haystack.components.connectors.openapi_service import patch_request
|
||||
from haystack.dataclasses import ChatMessage, ToolCall
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -22,45 +32,15 @@ class TestOpenAPIServiceConnector:
|
||||
def connector(self):
|
||||
return OpenAPIServiceConnector()
|
||||
|
||||
def test_parse_message_invalid_json(self, connector):
|
||||
# Test invalid JSON content
|
||||
with pytest.raises(ValueError):
|
||||
connector._parse_message(ChatMessage.from_assistant("invalid json"))
|
||||
def test_run_without_tool_calls(self, connector):
|
||||
message = ChatMessage.from_assistant(text="Just a regular message")
|
||||
with pytest.raises(ValueError, match="has no tool calls"):
|
||||
connector.run(messages=[message], service_openapi_spec={})
|
||||
|
||||
def test_parse_valid_json_message(self):
|
||||
connector = OpenAPIServiceConnector()
|
||||
|
||||
# The content format here is OpenAI function calling descriptor
|
||||
content = (
|
||||
'[{"function":{"name": "compare_branches","arguments": "{\\n \\"parameters\\": {\\n '
|
||||
' \\"basehead\\": \\"main...openapi_container_v5\\",\\n '
|
||||
' \\"owner\\": \\"deepset-ai\\",\\n \\"repo\\": \\"haystack\\"\\n }\\n}"}, "type": "function"}]'
|
||||
)
|
||||
descriptors = connector._parse_message(ChatMessage.from_assistant(content))
|
||||
|
||||
# Assert that the descriptor contains the expected method name and arguments
|
||||
assert descriptors[0]["name"] == "compare_branches"
|
||||
assert descriptors[0]["arguments"]["parameters"] == {
|
||||
"basehead": "main...openapi_container_v5",
|
||||
"owner": "deepset-ai",
|
||||
"repo": "haystack",
|
||||
}
|
||||
# but not the requestBody
|
||||
assert "requestBody" not in descriptors[0]["arguments"]
|
||||
|
||||
# The content format here is OpenAI function calling descriptor
|
||||
content = '[{"function": {"name": "search","arguments": "{\\n \\"requestBody\\": {\\n \\"q\\": \\"haystack\\"\\n }\\n}"}, "type": "function"}]'
|
||||
descriptors = connector._parse_message(ChatMessage.from_assistant(content))
|
||||
assert descriptors[0]["name"] == "search"
|
||||
assert descriptors[0]["arguments"]["requestBody"] == {"q": "haystack"}
|
||||
|
||||
# but not the parameters
|
||||
assert "parameters" not in descriptors[0]["arguments"]
|
||||
|
||||
def test_parse_message_missing_fields(self, connector):
|
||||
# Test JSON content with missing fields
|
||||
with pytest.raises(ValueError):
|
||||
connector._parse_message(ChatMessage.from_assistant('[{"function": {"name": "test_method"}}]'))
|
||||
def test_run_with_non_assistant_message(self, connector):
|
||||
message = ChatMessage.from_user(text="User message")
|
||||
with pytest.raises(ValueError, match="is not from the assistant"):
|
||||
connector.run(messages=[message], service_openapi_spec={})
|
||||
|
||||
def test_authenticate_service_missing_authentication_token(self, connector, openapi_service_mock):
|
||||
security_schemes_dict = {
|
||||
@ -68,7 +48,7 @@ class TestOpenAPIServiceConnector:
|
||||
}
|
||||
openapi_service_mock.raw_element = security_schemes_dict
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(ValueError, match="requires authentication but no credentials were provided"):
|
||||
connector._authenticate_service(openapi_service_mock)
|
||||
|
||||
def test_authenticate_service_having_authentication_token(self, connector, openapi_service_mock):
|
||||
@ -80,6 +60,7 @@ class TestOpenAPIServiceConnector:
|
||||
"apiKey": {"in": "header", "name": "x-api-key", "type": "apiKey"}
|
||||
}
|
||||
connector._authenticate_service(openapi_service_mock, "some_fake_token")
|
||||
openapi_service_mock.authenticate.assert_called_once_with("apiKey", "some_fake_token")
|
||||
|
||||
def test_authenticate_service_having_authentication_dict(self, connector, openapi_service_mock):
|
||||
security_schemes_dict = {
|
||||
@ -90,80 +71,49 @@ class TestOpenAPIServiceConnector:
|
||||
"apiKey": {"in": "header", "name": "x-api-key", "type": "apiKey"}
|
||||
}
|
||||
connector._authenticate_service(openapi_service_mock, {"apiKey": "some_fake_token"})
|
||||
openapi_service_mock.authenticate.assert_called_once_with("apiKey", "some_fake_token")
|
||||
|
||||
def test_authenticate_service_having_authentication_dict_but_unsupported_auth(
|
||||
self, connector, openapi_service_mock
|
||||
):
|
||||
def test_authenticate_service_having_unsupported_auth(self, connector, openapi_service_mock):
|
||||
security_schemes_dict = {"components": {"securitySchemes": {"oauth2": {"type": "oauth2"}}}}
|
||||
openapi_service_mock.raw_element = security_schemes_dict
|
||||
openapi_service_mock.components.securitySchemes.raw_element = {"oauth2": {"type": "oauth2"}}
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(ValueError, match="Check the service configuration and credentials"):
|
||||
connector._authenticate_service(openapi_service_mock, {"apiKey": "some_fake_token"})
|
||||
|
||||
def test_for_internal_raw_data_field(self):
|
||||
# see https://github.com/deepset-ai/haystack/pull/6772 for details
|
||||
model = Model(data={}, schema={})
|
||||
assert hasattr(model, "_raw_data"), (
|
||||
"openapi3 changed. Model should have a _raw_data field, we rely on it in OpenAPIServiceConnector"
|
||||
" to get the raw data from the service response"
|
||||
)
|
||||
|
||||
@patch("haystack.components.connectors.openapi_service.OpenAPI")
|
||||
def test_run(self, openapi_mock, test_files_path):
|
||||
def test_run_with_parameters(self, openapi_mock):
|
||||
connector = OpenAPIServiceConnector()
|
||||
spec_path = test_files_path / "json" / "github_compare_branch_openapi_spec.json"
|
||||
spec = json.loads((spec_path).read_text())
|
||||
|
||||
mock_message = json.dumps(
|
||||
[
|
||||
{
|
||||
"id": "call_NJr1NBz2Th7iUWJpRIJZoJIA",
|
||||
"function": {
|
||||
"arguments": '{"basehead": "main...some_branch", "owner": "deepset-ai", "repo": "haystack"}',
|
||||
"name": "compare_branches",
|
||||
},
|
||||
"type": "function",
|
||||
}
|
||||
]
|
||||
tool_call = ToolCall(
|
||||
tool_name="compare_branches",
|
||||
arguments={"basehead": "main...some_branch", "owner": "deepset-ai", "repo": "haystack"},
|
||||
)
|
||||
messages = [ChatMessage.from_assistant(mock_message)]
|
||||
call_compare_branches = Mock(return_value=Mock(_raw_data="some_data"))
|
||||
message = ChatMessage.from_assistant(tool_calls=[tool_call])
|
||||
|
||||
# Mock the OpenAPI service
|
||||
call_compare_branches = Mock(return_value={"status": "success"})
|
||||
call_compare_branches.operation.__self__ = Mock()
|
||||
call_compare_branches.operation.__self__.raw_element = {
|
||||
"parameters": [{"name": "basehead"}, {"name": "owner"}, {"name": "repo"}]
|
||||
}
|
||||
mock_service = Mock(
|
||||
call_compare_branches=call_compare_branches,
|
||||
components=Mock(securitySchemes=Mock(raw_element={"apikey": {"type": "apiKey"}})),
|
||||
)
|
||||
mock_service = Mock(call_compare_branches=call_compare_branches, raw_element={})
|
||||
openapi_mock.return_value = mock_service
|
||||
|
||||
connector.run(messages=messages, service_openapi_spec=spec, service_credentials="fake_key")
|
||||
result = connector.run(messages=[message], service_openapi_spec={})
|
||||
|
||||
openapi_mock.assert_called_once_with(spec, ssl_verify=None)
|
||||
mock_service.authenticate.assert_called_once_with("apikey", "fake_key")
|
||||
|
||||
# verify call went through on the wire with the correct parameters
|
||||
# Verify the service call
|
||||
mock_service.call_compare_branches.assert_called_once_with(
|
||||
parameters={"basehead": "main...some_branch", "owner": "deepset-ai", "repo": "haystack"}
|
||||
parameters={"basehead": "main...some_branch", "owner": "deepset-ai", "repo": "haystack"}, raw_response=True
|
||||
)
|
||||
assert json.loads(result["service_response"][0].text) == {"status": "success"}
|
||||
|
||||
@patch("haystack.components.connectors.openapi_service.OpenAPI")
|
||||
def test_run_with_mix_params_request_body(self, openapi_mock, test_files_path):
|
||||
def test_run_with_request_body(self, openapi_mock):
|
||||
connector = OpenAPIServiceConnector()
|
||||
spec_path = test_files_path / "yaml" / "openapi_greeting_service.yml"
|
||||
with open(spec_path, "r") as file:
|
||||
spec = json.loads(file.read())
|
||||
mock_message = json.dumps(
|
||||
[
|
||||
{
|
||||
"id": "call_NJr1NBz2Th7iUWJpRIJZoJIA",
|
||||
"function": {"arguments": '{"name": "John", "message": "Hello"}', "name": "greet"},
|
||||
"type": "function",
|
||||
}
|
||||
]
|
||||
)
|
||||
call_greet = Mock(return_value=Mock(_raw_data="Hello, John"))
|
||||
tool_call = ToolCall(tool_name="greet", arguments={"message": "Hello", "name": "John"})
|
||||
message = ChatMessage.from_assistant(tool_calls=[tool_call])
|
||||
|
||||
# Mock the OpenAPI service
|
||||
call_greet = Mock(return_value="Hello, John")
|
||||
call_greet.operation.__self__ = Mock()
|
||||
call_greet.operation.__self__.raw_element = {
|
||||
"parameters": [{"name": "name"}],
|
||||
@ -171,113 +121,28 @@ class TestOpenAPIServiceConnector:
|
||||
"content": {"application/json": {"schema": {"properties": {"message": {"type": "string"}}}}}
|
||||
},
|
||||
}
|
||||
|
||||
mock_service = Mock(call_greet=call_greet)
|
||||
mock_service.raw_element = {}
|
||||
mock_service = Mock(call_greet=call_greet, raw_element={})
|
||||
openapi_mock.return_value = mock_service
|
||||
|
||||
messages = [ChatMessage.from_assistant(mock_message)]
|
||||
result = connector.run(messages=messages, service_openapi_spec=spec)
|
||||
result = connector.run(messages=[message], service_openapi_spec={})
|
||||
|
||||
# verify call went through on the wire
|
||||
mock_service.call_greet.assert_called_once_with(parameters={"name": "John"}, data={"message": "Hello"})
|
||||
|
||||
response = json.loads(result["service_response"][0].text)
|
||||
assert response == "Hello, John"
|
||||
# Verify the service call
|
||||
mock_service.call_greet.assert_called_once_with(
|
||||
parameters={"name": "John"}, data={"message": "Hello"}, raw_response=True
|
||||
)
|
||||
assert json.loads(result["service_response"][0].text) == "Hello, John"
|
||||
|
||||
@patch("haystack.components.connectors.openapi_service.OpenAPI")
|
||||
def test_run_with_complex_types(self, openapi_mock, test_files_path):
|
||||
def test_run_with_missing_required_parameter(self, openapi_mock):
|
||||
connector = OpenAPIServiceConnector()
|
||||
spec_path = test_files_path / "json" / "complex_types_openapi_service.json"
|
||||
with open(spec_path, "r") as file:
|
||||
spec = json.loads(file.read())
|
||||
mock_message = json.dumps(
|
||||
[
|
||||
{
|
||||
"id": "call_NJr1NBz2Th7iUWJpRIJZoJIA",
|
||||
"function": {
|
||||
"arguments": '{"transaction_amount": 150.75, "description": "Monthly subscription fee", "payment_method_id": "visa_ending_in_1234", "payer": {"name": "Alex Smith", "email": "alex.smith@example.com", "identification": {"type": "Driver\'s License", "number": "D12345678"}}}',
|
||||
"name": "processPayment",
|
||||
},
|
||||
"type": "function",
|
||||
}
|
||||
]
|
||||
tool_call = ToolCall(
|
||||
tool_name="greet",
|
||||
arguments={"message": "Hello"}, # missing required 'name' parameter
|
||||
)
|
||||
message = ChatMessage.from_assistant(tool_calls=[tool_call])
|
||||
|
||||
call_processPayment = Mock(return_value=Mock(_raw_data={"result": "accepted"}))
|
||||
call_processPayment.operation.__self__ = Mock()
|
||||
call_processPayment.operation.__self__.raw_element = {
|
||||
"requestBody": {
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"properties": {
|
||||
"transaction_amount": {"type": "number", "example": 150.75},
|
||||
"description": {"type": "string", "example": "Monthly subscription fee"},
|
||||
"payment_method_id": {"type": "string", "example": "visa_ending_in_1234"},
|
||||
"payer": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string", "example": "Alex Smith"},
|
||||
"email": {"type": "string", "example": "alex.smith@example.com"},
|
||||
"identification": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {"type": "string", "example": "Driver's License"},
|
||||
"number": {"type": "string", "example": "D12345678"},
|
||||
},
|
||||
"required": ["type", "number"],
|
||||
},
|
||||
},
|
||||
"required": ["name", "email", "identification"],
|
||||
},
|
||||
},
|
||||
"required": ["transaction_amount", "description", "payment_method_id", "payer"],
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
mock_service = Mock(call_processPayment=call_processPayment)
|
||||
mock_service.raw_element = {}
|
||||
openapi_mock.return_value = mock_service
|
||||
|
||||
messages = [ChatMessage.from_assistant(mock_message)]
|
||||
result = connector.run(messages=messages, service_openapi_spec=spec)
|
||||
|
||||
# verify call went through on the wire
|
||||
mock_service.call_processPayment.assert_called_once_with(
|
||||
data={
|
||||
"transaction_amount": 150.75,
|
||||
"description": "Monthly subscription fee",
|
||||
"payment_method_id": "visa_ending_in_1234",
|
||||
"payer": {
|
||||
"name": "Alex Smith",
|
||||
"email": "alex.smith@example.com",
|
||||
"identification": {"type": "Driver's License", "number": "D12345678"},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
response = json.loads(result["service_response"][0].text)
|
||||
assert response == {"result": "accepted"}
|
||||
|
||||
@patch("haystack.components.connectors.openapi_service.OpenAPI")
|
||||
def test_run_with_request_params_missing_in_invocation_args(self, openapi_mock, test_files_path):
|
||||
connector = OpenAPIServiceConnector()
|
||||
spec_path = test_files_path / "yaml" / "openapi_greeting_service.yml"
|
||||
with open(spec_path, "r") as file:
|
||||
spec = json.loads(file.read())
|
||||
mock_message = json.dumps(
|
||||
[
|
||||
{
|
||||
"id": "call_NJr1NBz2Th7iUWJpRIJZoJIA",
|
||||
"function": {"arguments": '{"message": "Hello"}', "name": "greet"},
|
||||
"type": "function",
|
||||
}
|
||||
]
|
||||
)
|
||||
call_greet = Mock(return_value=Mock(_raw_data="Hello, John"))
|
||||
# Mock the OpenAPI service
|
||||
call_greet = Mock()
|
||||
call_greet.operation.__self__ = Mock()
|
||||
call_greet.operation.__self__.raw_element = {
|
||||
"parameters": [{"name": "name", "required": True}],
|
||||
@ -285,57 +150,139 @@ class TestOpenAPIServiceConnector:
|
||||
"content": {"application/json": {"schema": {"properties": {"message": {"type": "string"}}}}}
|
||||
},
|
||||
}
|
||||
|
||||
mock_service = Mock(call_greet=call_greet)
|
||||
mock_service.raw_element = {}
|
||||
mock_service = Mock(call_greet=call_greet, raw_element={})
|
||||
openapi_mock.return_value = mock_service
|
||||
|
||||
messages = [ChatMessage.from_assistant(mock_message)]
|
||||
with pytest.raises(ValueError, match="Missing parameter: 'name' required for the 'greet' operation."):
|
||||
connector.run(messages=messages, service_openapi_spec=spec)
|
||||
with pytest.raises(ValueError, match="Missing parameter: 'name' required for the 'greet' operation"):
|
||||
connector.run(messages=[message], service_openapi_spec={})
|
||||
|
||||
@patch("haystack.components.connectors.openapi_service.OpenAPI")
|
||||
def test_run_with_body_properties_missing_in_invocation_args(self, openapi_mock, test_files_path):
|
||||
def test_run_with_missing_required_parameters_in_request_body(self, openapi_mock):
|
||||
"""
|
||||
Test that the connector raises a ValueError when the request body is missing required parameters.
|
||||
"""
|
||||
connector = OpenAPIServiceConnector()
|
||||
spec_path = test_files_path / "yaml" / "openapi_greeting_service.yml"
|
||||
with open(spec_path, "r") as file:
|
||||
spec = json.loads(file.read())
|
||||
mock_message = json.dumps(
|
||||
[
|
||||
{
|
||||
"id": "call_NJr1NBz2Th7iUWJpRIJZoJIA",
|
||||
"function": {"arguments": '{"name": "John"}', "name": "greet"},
|
||||
"type": "function",
|
||||
}
|
||||
]
|
||||
tool_call = ToolCall(
|
||||
tool_name="post_message",
|
||||
arguments={"recipient": "John"}, # only providing URL parameter, no request body data
|
||||
)
|
||||
call_greet = Mock(return_value=Mock(_raw_data="Hello, John"))
|
||||
call_greet.operation.__self__ = Mock()
|
||||
call_greet.operation.__self__.raw_element = {
|
||||
"parameters": [{"name": "name"}],
|
||||
message = ChatMessage.from_assistant(tool_calls=[tool_call])
|
||||
|
||||
# Mock the OpenAPI service
|
||||
call_post_message = Mock()
|
||||
call_post_message.operation.__self__ = Mock()
|
||||
call_post_message.operation.__self__.raw_element = {
|
||||
"parameters": [{"name": "recipient"}],
|
||||
"requestBody": {
|
||||
"required": True,
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {"properties": {"message": {"type": "string"}}, "required": ["message"]}
|
||||
"schema": {
|
||||
"required": ["message"], # Mark message as required in schema
|
||||
"properties": {"message": {"type": "string"}},
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
mock_service = Mock(call_greet=call_greet)
|
||||
mock_service.raw_element = {}
|
||||
mock_service = Mock(call_post_message=call_post_message, raw_element={})
|
||||
openapi_mock.return_value = mock_service
|
||||
|
||||
messages = [ChatMessage.from_assistant(mock_message)]
|
||||
with pytest.raises(
|
||||
ValueError, match="Missing requestBody parameter: 'message' required for the 'greet' operation."
|
||||
ValueError, match="Missing requestBody parameter: 'message' required for the 'post_message' operation"
|
||||
):
|
||||
connector.run(messages=messages, service_openapi_spec=spec)
|
||||
connector.run(messages=[message], service_openapi_spec={})
|
||||
|
||||
# Verify that the service was never called since validation failed
|
||||
call_post_message.assert_not_called()
|
||||
|
||||
def test_serialization(self):
|
||||
for test_val in ("myvalue", True, None):
|
||||
openapi_service_connector = OpenAPIServiceConnector(test_val)
|
||||
serialized = openapi_service_connector.to_dict()
|
||||
connector = OpenAPIServiceConnector(test_val)
|
||||
serialized = connector.to_dict()
|
||||
assert serialized["init_parameters"]["ssl_verify"] == test_val
|
||||
deserialized = OpenAPIServiceConnector.from_dict(serialized)
|
||||
assert deserialized.ssl_verify == test_val
|
||||
|
||||
def test_serde_in_pipeline(self):
|
||||
"""
|
||||
Test serialization/deserialization of OpenAPIServiceConnector in a Pipeline,
|
||||
including YAML conversion and detailed dictionary validation
|
||||
"""
|
||||
connector = OpenAPIServiceConnector(ssl_verify=True)
|
||||
|
||||
pipeline = Pipeline()
|
||||
pipeline.add_component("connector", connector)
|
||||
|
||||
pipeline_dict = pipeline.to_dict()
|
||||
assert pipeline_dict == {
|
||||
"metadata": {},
|
||||
"max_runs_per_component": 100,
|
||||
"components": {
|
||||
"connector": {
|
||||
"type": "haystack.components.connectors.openapi_service.OpenAPIServiceConnector",
|
||||
"init_parameters": {"ssl_verify": True},
|
||||
}
|
||||
},
|
||||
"connections": [],
|
||||
}
|
||||
|
||||
pipeline_yaml = pipeline.dumps()
|
||||
new_pipeline = Pipeline.loads(pipeline_yaml)
|
||||
assert new_pipeline == pipeline
|
||||
|
||||
@pytest.mark.skipif(not os.getenv("SERPERDEV_API_KEY"), reason="SERPERDEV_API_KEY is not set")
|
||||
@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="OPENAI_API_KEY is not set")
|
||||
@pytest.mark.integration
|
||||
def test_run_live(self):
|
||||
# An OutputAdapter filter we'll use to setup function calling
|
||||
def prepare_fc_params(openai_functions_schema: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return {
|
||||
"tools": [{"type": "function", "function": openai_functions_schema}],
|
||||
"tool_choice": {"type": "function", "function": {"name": openai_functions_schema["name"]}},
|
||||
}
|
||||
|
||||
pipe = Pipeline()
|
||||
pipe.add_component("spec_to_functions", OpenAPIServiceToFunctions())
|
||||
pipe.add_component("functions_llm", OpenAIChatGenerator(model="gpt-4o-mini"))
|
||||
|
||||
pipe.add_component("openapi_container", OpenAPIServiceConnector())
|
||||
pipe.add_component(
|
||||
"prepare_fc_adapter",
|
||||
OutputAdapter("{{functions[0] | prepare_fc}}", Dict[str, Any], {"prepare_fc": prepare_fc_params}),
|
||||
)
|
||||
pipe.add_component("openapi_spec_adapter", OutputAdapter("{{specs[0]}}", Dict[str, Any], unsafe=True))
|
||||
pipe.add_component(
|
||||
"final_prompt_adapter",
|
||||
OutputAdapter("{{system_message + service_response}}", List[ChatMessage], unsafe=True),
|
||||
)
|
||||
pipe.add_component("llm", OpenAIChatGenerator(model="gpt-4o-mini", streaming_callback=print_streaming_chunk))
|
||||
|
||||
pipe.connect("spec_to_functions.functions", "prepare_fc_adapter.functions")
|
||||
pipe.connect("spec_to_functions.openapi_specs", "openapi_spec_adapter.specs")
|
||||
pipe.connect("prepare_fc_adapter", "functions_llm.generation_kwargs")
|
||||
pipe.connect("functions_llm.replies", "openapi_container.messages")
|
||||
pipe.connect("openapi_spec_adapter", "openapi_container.service_openapi_spec")
|
||||
pipe.connect("openapi_container.service_response", "final_prompt_adapter.service_response")
|
||||
pipe.connect("final_prompt_adapter", "llm.messages")
|
||||
|
||||
serperdev_spec = requests.get(
|
||||
"https://gist.githubusercontent.com/vblagoje/241a000f2a77c76be6efba71d49e2856/raw/722ccc7fe6170a744afce3e3fb3a30fdd095c184/serper.json"
|
||||
).json()
|
||||
system_prompt = requests.get("https://bit.ly/serper_dev_system").text
|
||||
|
||||
query = "Why did Elon Musk sue OpenAI?"
|
||||
|
||||
result = pipe.run(
|
||||
data={
|
||||
"functions_llm": {
|
||||
"messages": [ChatMessage.from_system("Only do tool/function calling"), ChatMessage.from_user(query)]
|
||||
},
|
||||
"openapi_container": {"service_credentials": os.getenv("SERPERDEV_API_KEY")},
|
||||
"spec_to_functions": {"sources": [ByteStream.from_string(json.dumps(serperdev_spec))]},
|
||||
"final_prompt_adapter": {"system_message": [ChatMessage.from_system(system_prompt)]},
|
||||
}
|
||||
)
|
||||
assert isinstance(result["llm"]["replies"][0], ChatMessage)
|
||||
assert "Elon" in result["llm"]["replies"][0].text
|
||||
assert "OpenAI" in result["llm"]["replies"][0].text
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user