From a8eeb2024fa9076ffaf4849c17ec4fa4e2dc817c Mon Sep 17 00:00:00 2001 From: Richard Hudson Date: Fri, 22 Nov 2024 15:45:00 +0100 Subject: [PATCH] feat: Allow unverified OpenAPI calls (#8562) * Feed through ssl_verify value to OpenAPI * Add release note * Update serialization methods * Applied black formatting --- .../components/connectors/openapi_service.py | 31 +++++++++++++++++-- ...rified-openapi-calls-46842af37464bb6d.yaml | 4 +++ .../connectors/test_openapi_service.py | 10 +++++- 3 files changed, 41 insertions(+), 4 deletions(-) create mode 100644 releasenotes/notes/allow-unverified-openapi-calls-46842af37464bb6d.yaml diff --git a/haystack/components/connectors/openapi_service.py b/haystack/components/connectors/openapi_service.py index 5238722f7..716a50124 100644 --- a/haystack/components/connectors/openapi_service.py +++ b/haystack/components/connectors/openapi_service.py @@ -7,7 +7,7 @@ from collections import defaultdict from copy import copy from typing import Any, Dict, List, Optional, Union -from haystack import component, logging +from haystack import component, default_from_dict, default_to_dict, logging from haystack.dataclasses import ChatMessage, ChatRole from haystack.lazy_imports import LazyImport @@ -69,11 +69,15 @@ class OpenAPIServiceConnector: """ - def __init__(self): + def __init__(self, ssl_verify: Optional[Union[bool, str]] = None): """ Initializes the OpenAPIServiceConnector instance + + :param ssl_verify: Decide if to use SSL verification to the requests or not, + in case a string is passed, will be used as the CA. """ openapi_imports.check() + self.ssl_verify = ssl_verify @component.output_types(service_response=Dict[str, Any]) def run( @@ -112,7 +116,7 @@ class OpenAPIServiceConnector: function_invocation_payloads = self._parse_message(last_message) # instantiate the OpenAPI service for the given specification - openapi_service = OpenAPI(service_openapi_spec) + openapi_service = OpenAPI(service_openapi_spec, ssl_verify=self.ssl_verify) self._authenticate_service(openapi_service, service_credentials) response_messages = [] @@ -127,6 +131,27 @@ class OpenAPIServiceConnector: return {"service_response": response_messages} + def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ + return default_to_dict(self, ssl_verify=self.ssl_verify) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "OpenAPIServiceConnector": + """ + Deserializes the component from a dictionary. + + :param data: + The dictionary to deserialize from. + :returns: + The deserialized component. + """ + return default_from_dict(cls, data) + def _parse_message(self, message: ChatMessage) -> List[Dict[str, Any]]: """ Parses the message to extract the method invocation descriptor. diff --git a/releasenotes/notes/allow-unverified-openapi-calls-46842af37464bb6d.yaml b/releasenotes/notes/allow-unverified-openapi-calls-46842af37464bb6d.yaml new file mode 100644 index 000000000..4a1041236 --- /dev/null +++ b/releasenotes/notes/allow-unverified-openapi-calls-46842af37464bb6d.yaml @@ -0,0 +1,4 @@ +--- +features: + - | + When making function calls via OpenAPI, allow both switching SSL verification off and specifying a certificate authority to use for it. diff --git a/test/components/connectors/test_openapi_service.py b/test/components/connectors/test_openapi_service.py index 82340bdd7..b5681f321 100644 --- a/test/components/connectors/test_openapi_service.py +++ b/test/components/connectors/test_openapi_service.py @@ -140,7 +140,7 @@ class TestOpenAPIServiceConnector: connector.run(messages=messages, service_openapi_spec=spec, service_credentials="fake_key") - openapi_mock.assert_called_once_with(spec) + openapi_mock.assert_called_once_with(spec, ssl_verify=None) mock_service.authenticate.assert_called_once_with("apikey", "fake_key") # verify call went through on the wire with the correct parameters @@ -331,3 +331,11 @@ class TestOpenAPIServiceConnector: ValueError, match="Missing requestBody parameter: 'message' required for the 'greet' operation." ): connector.run(messages=messages, service_openapi_spec=spec) + + def test_serialization(self): + for test_val in ("myvalue", True, None): + openapi_service_connector = OpenAPIServiceConnector(test_val) + serialized = openapi_service_connector.to_dict() + assert serialized["init_parameters"]["ssl_verify"] == test_val + deserialized = OpenAPIServiceConnector.from_dict(serialized) + assert deserialized.ssl_verify == test_val