mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-18 02:53:42 +00:00
feat: Allow unverified OpenAPI calls (#8562)
* Feed through ssl_verify value to OpenAPI * Add release note * Update serialization methods * Applied black formatting
This commit is contained in:
parent
4e6c7967d9
commit
a8eeb2024f
@ -7,7 +7,7 @@ from collections import defaultdict
|
|||||||
from copy import copy
|
from copy import copy
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
from haystack import component, logging
|
from haystack import component, default_from_dict, default_to_dict, logging
|
||||||
from haystack.dataclasses import ChatMessage, ChatRole
|
from haystack.dataclasses import ChatMessage, ChatRole
|
||||||
from haystack.lazy_imports import LazyImport
|
from haystack.lazy_imports import LazyImport
|
||||||
|
|
||||||
@ -69,11 +69,15 @@ class OpenAPIServiceConnector:
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, ssl_verify: Optional[Union[bool, str]] = None):
|
||||||
"""
|
"""
|
||||||
Initializes the OpenAPIServiceConnector instance
|
Initializes the OpenAPIServiceConnector instance
|
||||||
|
|
||||||
|
:param ssl_verify: Decide if to use SSL verification to the requests or not,
|
||||||
|
in case a string is passed, will be used as the CA.
|
||||||
"""
|
"""
|
||||||
openapi_imports.check()
|
openapi_imports.check()
|
||||||
|
self.ssl_verify = ssl_verify
|
||||||
|
|
||||||
@component.output_types(service_response=Dict[str, Any])
|
@component.output_types(service_response=Dict[str, Any])
|
||||||
def run(
|
def run(
|
||||||
@ -112,7 +116,7 @@ class OpenAPIServiceConnector:
|
|||||||
function_invocation_payloads = self._parse_message(last_message)
|
function_invocation_payloads = self._parse_message(last_message)
|
||||||
|
|
||||||
# instantiate the OpenAPI service for the given specification
|
# instantiate the OpenAPI service for the given specification
|
||||||
openapi_service = OpenAPI(service_openapi_spec)
|
openapi_service = OpenAPI(service_openapi_spec, ssl_verify=self.ssl_verify)
|
||||||
self._authenticate_service(openapi_service, service_credentials)
|
self._authenticate_service(openapi_service, service_credentials)
|
||||||
|
|
||||||
response_messages = []
|
response_messages = []
|
||||||
@ -127,6 +131,27 @@ class OpenAPIServiceConnector:
|
|||||||
|
|
||||||
return {"service_response": response_messages}
|
return {"service_response": response_messages}
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Serializes the component to a dictionary.
|
||||||
|
|
||||||
|
:returns:
|
||||||
|
Dictionary with serialized data.
|
||||||
|
"""
|
||||||
|
return default_to_dict(self, ssl_verify=self.ssl_verify)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: Dict[str, Any]) -> "OpenAPIServiceConnector":
|
||||||
|
"""
|
||||||
|
Deserializes the component from a dictionary.
|
||||||
|
|
||||||
|
:param data:
|
||||||
|
The dictionary to deserialize from.
|
||||||
|
:returns:
|
||||||
|
The deserialized component.
|
||||||
|
"""
|
||||||
|
return default_from_dict(cls, data)
|
||||||
|
|
||||||
def _parse_message(self, message: ChatMessage) -> List[Dict[str, Any]]:
|
def _parse_message(self, message: ChatMessage) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Parses the message to extract the method invocation descriptor.
|
Parses the message to extract the method invocation descriptor.
|
||||||
|
|||||||
@ -0,0 +1,4 @@
|
|||||||
|
---
|
||||||
|
features:
|
||||||
|
- |
|
||||||
|
When making function calls via OpenAPI, allow both switching SSL verification off and specifying a certificate authority to use for it.
|
||||||
@ -140,7 +140,7 @@ class TestOpenAPIServiceConnector:
|
|||||||
|
|
||||||
connector.run(messages=messages, service_openapi_spec=spec, service_credentials="fake_key")
|
connector.run(messages=messages, service_openapi_spec=spec, service_credentials="fake_key")
|
||||||
|
|
||||||
openapi_mock.assert_called_once_with(spec)
|
openapi_mock.assert_called_once_with(spec, ssl_verify=None)
|
||||||
mock_service.authenticate.assert_called_once_with("apikey", "fake_key")
|
mock_service.authenticate.assert_called_once_with("apikey", "fake_key")
|
||||||
|
|
||||||
# verify call went through on the wire with the correct parameters
|
# verify call went through on the wire with the correct parameters
|
||||||
@ -331,3 +331,11 @@ class TestOpenAPIServiceConnector:
|
|||||||
ValueError, match="Missing requestBody parameter: 'message' required for the 'greet' operation."
|
ValueError, match="Missing requestBody parameter: 'message' required for the 'greet' operation."
|
||||||
):
|
):
|
||||||
connector.run(messages=messages, service_openapi_spec=spec)
|
connector.run(messages=messages, service_openapi_spec=spec)
|
||||||
|
|
||||||
|
def test_serialization(self):
|
||||||
|
for test_val in ("myvalue", True, None):
|
||||||
|
openapi_service_connector = OpenAPIServiceConnector(test_val)
|
||||||
|
serialized = openapi_service_connector.to_dict()
|
||||||
|
assert serialized["init_parameters"]["ssl_verify"] == test_val
|
||||||
|
deserialized = OpenAPIServiceConnector.from_dict(serialized)
|
||||||
|
assert deserialized.ssl_verify == test_val
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user