feat: Allow unverified OpenAPI calls (#8562)

* Feed through ssl_verify value to OpenAPI

* Add release note

* Update serialization methods

* Applied black formatting
This commit is contained in:
Richard Hudson 2024-11-22 15:45:00 +01:00 committed by GitHub
parent 4e6c7967d9
commit a8eeb2024f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 41 additions and 4 deletions

View File

@ -7,7 +7,7 @@ from collections import defaultdict
from copy import copy
from typing import Any, Dict, List, Optional, Union
from haystack import component, logging
from haystack import component, default_from_dict, default_to_dict, logging
from haystack.dataclasses import ChatMessage, ChatRole
from haystack.lazy_imports import LazyImport
@ -69,11 +69,15 @@ class OpenAPIServiceConnector:
"""
def __init__(self):
def __init__(self, ssl_verify: Optional[Union[bool, str]] = None):
"""
Initializes the OpenAPIServiceConnector instance
:param ssl_verify: Decide if to use SSL verification to the requests or not,
in case a string is passed, will be used as the CA.
"""
openapi_imports.check()
self.ssl_verify = ssl_verify
@component.output_types(service_response=Dict[str, Any])
def run(
@ -112,7 +116,7 @@ class OpenAPIServiceConnector:
function_invocation_payloads = self._parse_message(last_message)
# instantiate the OpenAPI service for the given specification
openapi_service = OpenAPI(service_openapi_spec)
openapi_service = OpenAPI(service_openapi_spec, ssl_verify=self.ssl_verify)
self._authenticate_service(openapi_service, service_credentials)
response_messages = []
@ -127,6 +131,27 @@ class OpenAPIServiceConnector:
return {"service_response": response_messages}
def to_dict(self) -> Dict[str, Any]:
"""
Serializes the component to a dictionary.
:returns:
Dictionary with serialized data.
"""
return default_to_dict(self, ssl_verify=self.ssl_verify)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "OpenAPIServiceConnector":
"""
Deserializes the component from a dictionary.
:param data:
The dictionary to deserialize from.
:returns:
The deserialized component.
"""
return default_from_dict(cls, data)
def _parse_message(self, message: ChatMessage) -> List[Dict[str, Any]]:
"""
Parses the message to extract the method invocation descriptor.

View File

@ -0,0 +1,4 @@
---
features:
- |
When making function calls via OpenAPI, allow both switching SSL verification off and specifying a certificate authority to use for it.

View File

@ -140,7 +140,7 @@ class TestOpenAPIServiceConnector:
connector.run(messages=messages, service_openapi_spec=spec, service_credentials="fake_key")
openapi_mock.assert_called_once_with(spec)
openapi_mock.assert_called_once_with(spec, ssl_verify=None)
mock_service.authenticate.assert_called_once_with("apikey", "fake_key")
# verify call went through on the wire with the correct parameters
@ -331,3 +331,11 @@ class TestOpenAPIServiceConnector:
ValueError, match="Missing requestBody parameter: 'message' required for the 'greet' operation."
):
connector.run(messages=messages, service_openapi_spec=spec)
def test_serialization(self):
for test_val in ("myvalue", True, None):
openapi_service_connector = OpenAPIServiceConnector(test_val)
serialized = openapi_service_connector.to_dict()
assert serialized["init_parameters"]["ssl_verify"] == test_val
deserialized = OpenAPIServiceConnector.from_dict(serialized)
assert deserialized.ssl_verify == test_val