feat: Allow unverified OpenAPI calls (#8562)

* Feed through ssl_verify value to OpenAPI

* Add release note

* Update serialization methods

* Applied black formatting
This commit is contained in:
Richard Hudson 2024-11-22 15:45:00 +01:00 committed by GitHub
parent 4e6c7967d9
commit a8eeb2024f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 41 additions and 4 deletions

View File

@ -7,7 +7,7 @@ from collections import defaultdict
from copy import copy from copy import copy
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
from haystack import component, logging from haystack import component, default_from_dict, default_to_dict, logging
from haystack.dataclasses import ChatMessage, ChatRole from haystack.dataclasses import ChatMessage, ChatRole
from haystack.lazy_imports import LazyImport from haystack.lazy_imports import LazyImport
@ -69,11 +69,15 @@ class OpenAPIServiceConnector:
""" """
def __init__(self): def __init__(self, ssl_verify: Optional[Union[bool, str]] = None):
""" """
Initializes the OpenAPIServiceConnector instance Initializes the OpenAPIServiceConnector instance
:param ssl_verify: Decide if to use SSL verification to the requests or not,
in case a string is passed, will be used as the CA.
""" """
openapi_imports.check() openapi_imports.check()
self.ssl_verify = ssl_verify
@component.output_types(service_response=Dict[str, Any]) @component.output_types(service_response=Dict[str, Any])
def run( def run(
@ -112,7 +116,7 @@ class OpenAPIServiceConnector:
function_invocation_payloads = self._parse_message(last_message) function_invocation_payloads = self._parse_message(last_message)
# instantiate the OpenAPI service for the given specification # instantiate the OpenAPI service for the given specification
openapi_service = OpenAPI(service_openapi_spec) openapi_service = OpenAPI(service_openapi_spec, ssl_verify=self.ssl_verify)
self._authenticate_service(openapi_service, service_credentials) self._authenticate_service(openapi_service, service_credentials)
response_messages = [] response_messages = []
@ -127,6 +131,27 @@ class OpenAPIServiceConnector:
return {"service_response": response_messages} return {"service_response": response_messages}
def to_dict(self) -> Dict[str, Any]:
"""
Serializes the component to a dictionary.
:returns:
Dictionary with serialized data.
"""
return default_to_dict(self, ssl_verify=self.ssl_verify)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "OpenAPIServiceConnector":
"""
Deserializes the component from a dictionary.
:param data:
The dictionary to deserialize from.
:returns:
The deserialized component.
"""
return default_from_dict(cls, data)
def _parse_message(self, message: ChatMessage) -> List[Dict[str, Any]]: def _parse_message(self, message: ChatMessage) -> List[Dict[str, Any]]:
""" """
Parses the message to extract the method invocation descriptor. Parses the message to extract the method invocation descriptor.

View File

@ -0,0 +1,4 @@
---
features:
- |
When making function calls via OpenAPI, allow both switching SSL verification off and specifying a certificate authority to use for it.

View File

@ -140,7 +140,7 @@ class TestOpenAPIServiceConnector:
connector.run(messages=messages, service_openapi_spec=spec, service_credentials="fake_key") connector.run(messages=messages, service_openapi_spec=spec, service_credentials="fake_key")
openapi_mock.assert_called_once_with(spec) openapi_mock.assert_called_once_with(spec, ssl_verify=None)
mock_service.authenticate.assert_called_once_with("apikey", "fake_key") mock_service.authenticate.assert_called_once_with("apikey", "fake_key")
# verify call went through on the wire with the correct parameters # verify call went through on the wire with the correct parameters
@ -331,3 +331,11 @@ class TestOpenAPIServiceConnector:
ValueError, match="Missing requestBody parameter: 'message' required for the 'greet' operation." ValueError, match="Missing requestBody parameter: 'message' required for the 'greet' operation."
): ):
connector.run(messages=messages, service_openapi_spec=spec) connector.run(messages=messages, service_openapi_spec=spec)
def test_serialization(self):
for test_val in ("myvalue", True, None):
openapi_service_connector = OpenAPIServiceConnector(test_val)
serialized = openapi_service_connector.to_dict()
assert serialized["init_parameters"]["ssl_verify"] == test_val
deserialized = OpenAPIServiceConnector.from_dict(serialized)
assert deserialized.ssl_verify == test_val