feat: Add service_credentials to OpenAPIServiceConnector run (#6962)

* Add service_credentials to OpenAPIServiceConnector run
* PR feedback Silvano
This commit is contained in:
Vladimir Blagojevic 2024-02-09 16:03:27 +01:00 committed by GitHub
parent d2d01f9fe1
commit 37d9de3c4e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 704 additions and 21 deletions

View File

@ -25,6 +25,7 @@ env:
CORE_AZURE_CS_API_KEY: ${{ secrets.CORE_AZURE_CS_API_KEY }}
AZURE_OPENAI_API_KEY: ${{ secrets.AZURE_OPENAI_API_KEY }}
AZURE_OPENAI_ENDPOINT: ${{ secrets.AZURE_OPENAI_ENDPOINT }}
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
PYTHON_VERSION: "3.8"
jobs:

View File

@ -1,6 +1,6 @@
import json
import logging
from typing import List, Dict, Any, Optional
from typing import List, Dict, Any, Optional, Union
from haystack import component
from haystack.dataclasses import ChatMessage, ChatRole
@ -26,16 +26,19 @@ class OpenAPIServiceConnector:
This can be done using the OpenAPIServiceToFunctions component.
"""
def __init__(self, service_auths: Optional[Dict[str, Any]] = None):
def __init__(self):
"""
Initializes the OpenAPIServiceConnector instance
:param service_auths: A dictionary containing the service name and token to be used for authentication.
"""
openapi_imports.check()
self.service_authentications = service_auths or {}
@component.output_types(service_response=Dict[str, Any])
def run(self, messages: List[ChatMessage], service_openapi_spec: Dict[str, Any]) -> Dict[str, List[ChatMessage]]:
def run(
self,
messages: List[ChatMessage],
service_openapi_spec: Dict[str, Any],
service_credentials: Optional[Union[dict, str]] = None,
) -> Dict[str, List[ChatMessage]]:
"""
Processes a list of chat messages to invoke a method on an OpenAPI service. It parses the last message in the
list, expecting it to contain an OpenAI function calling descriptor (name & parameters) in JSON format.
@ -46,6 +49,9 @@ class OpenAPIServiceConnector:
:type service_openapi_spec: JSON object
:return: A dictionary with a key `"service_response"`, containing the response from the OpenAPI service.
:rtype: Dict[str, List[ChatMessage]]
:param service_credentials: The credentials to be used for authentication with the service.
Currently, only the http and apiKey schemes are supported. See _authenticate_service method for more details.
:type service_credentials: Optional[Union[dict, str]]
:raises ValueError: If the last message is not from the assistant or if it does not contain the correct payload
to invoke a method on the service.
"""
@ -58,7 +64,7 @@ class OpenAPIServiceConnector:
# instantiate the OpenAPI service for the given specification
openapi_service = OpenAPI(service_openapi_spec)
self._authenticate_service(openapi_service)
self._authenticate_service(openapi_service, service_credentials)
response_messages = []
for method_invocation_descriptor in function_invocation_payloads:
@ -101,20 +107,59 @@ class OpenAPIServiceConnector:
)
return function_payloads
def _authenticate_service(self, openapi_service: OpenAPI):
def _authenticate_service(self, openapi_service: OpenAPI, credentials: Optional[Union[dict, str]] = None):
"""
Authenticates with the OpenAPI service if required.
Authenticates with the OpenAPI service if required, supporting both single (str) and multiple
authentication methods (dict).
OpenAPI spec v3 supports the following security schemes:
http for Basic, Bearer and other HTTP authentications schemes
apiKey for API keys and cookie authentication
oauth2 for OAuth 2
openIdConnect for OpenID Connect Discovery
Currently, only the http and apiKey schemes are supported. Multiple security schemes can be defined in the
OpenAPI spec, and the credentials should be provided as a dictionary with keys matching the security scheme
names. If only one security scheme is defined, the credentials can be provided as a simple string.
:param openapi_service: The OpenAPI service instance.
:type openapi_service: OpenAPI
:raises ValueError: If authentication fails or is not found.
:param credentials: Credentials for authentication, which can be either a string (e.g. token) or a dictionary
with keys matching the authentication method names.
:type credentials: dict | str, optional
:raises ValueError: If authentication fails, is not found, or if appropriate credentials are missing.
"""
service_name = openapi_service.info.title
if openapi_service.components.securitySchemes:
auth_method = list(openapi_service.components.securitySchemes.keys())[0]
service_title = openapi_service.info.title
if service_title not in self.service_authentications:
raise ValueError(f"Service {service_title} not found in service_authentications.")
openapi_service.authenticate(auth_method, self.service_authentications[service_title])
if not credentials:
raise ValueError(f"Service {service_name} requires authentication but no credentials were provided.")
# a dictionary of security schemes defined in the OpenAPI spec
# each key is the name of the security scheme, and the value is the scheme definition
security_schemes = openapi_service.components.securitySchemes.raw_element
supported_schemes = ["http", "apiKey"] # todo: add support for oauth2 and openIdConnect
authenticated = False
for scheme_name, scheme in security_schemes.items():
if scheme["type"] in supported_schemes:
auth_credentials = None
if isinstance(credentials, str):
auth_credentials = credentials
elif isinstance(credentials, dict) and scheme_name in credentials:
auth_credentials = credentials[scheme_name]
if auth_credentials:
openapi_service.authenticate(scheme_name, auth_credentials)
authenticated = True
else:
raise ValueError(
f"Service {service_name} requires {scheme_name} security scheme but no "
f"credentials were provided for it. Check the service configuration and credentials."
)
if not authenticated:
raise ValueError(
f"Service {service_name} requires authentication but no credentials were provided "
f"for it. Check the service configuration and credentials."
)
def _invoke_method(self, openapi_service: OpenAPI, method_invocation_descriptor: Dict[str, Any]) -> Any:
"""

View File

@ -0,0 +1,4 @@
---
enhancements:
- |
Enhanced the OpenAPIServiceConnector to support dynamic authentication handling. With this update, service credentials are now dynamically provided at each run invocation, eliminating the need for pre-configuring a known set of service authentications. This flexibility allows for the introduction of new services on-the-fly, each with its unique authentication, streamlining the integration process. This modification not only simplifies the initial setup of the OpenAPIServiceConnector but also ensures a more transparent and straightforward authentication process for each interaction with different OpenAPI services.

View File

@ -1,6 +1,10 @@
import json
import os
import pytest
from unittest.mock import MagicMock, Mock
import requests
from openapi3 import OpenAPI
from openapi3.schemas import Model
from haystack.components.connectors import OpenAPIServiceConnector
@ -13,14 +17,40 @@ def openapi_service_mock():
@pytest.fixture
def service_auths():
return {"TestService": "auth_token"}
def random_open_pull_request_head_branch() -> str:
token = os.getenv("GITHUB_TOKEN")
headers = {"Accept": "application/vnd.github.v3+json", "Authorization": f"token {token}"}
response = requests.get("https://api.github.com/repos/deepset-ai/haystack/pulls?state=open", headers=headers)
if response.status_code == 200:
pull_requests = response.json()
for pr in pull_requests:
if pr["base"]["ref"] == "main":
return pr["head"]["ref"]
else:
raise Exception(f"Failed to fetch pull requests. Status code: {response.status_code}")
@pytest.fixture
def genuine_fc_message(random_open_pull_request_head_branch):
basehead = "main..." + random_open_pull_request_head_branch
# arguments, see below, are always passed as a string representation of a JSON object
params = '{"parameters": {"basehead": "' + basehead + '", "owner": "deepset-ai", "repo": "haystack"}}'
payload_json = [
{
"id": "call_NJr1NBz2Th7iUWJpRIJZoJIA",
"function": {"arguments": params, "name": "compare_branches"},
"type": "function",
}
]
return json.dumps(payload_json)
class TestOpenAPIServiceConnector:
@pytest.fixture
def connector(self, service_auths):
return OpenAPIServiceConnector(service_auths)
def connector(self):
return OpenAPIServiceConnector()
def test_parse_message_invalid_json(self, connector):
# Test invalid JSON content
@ -62,12 +92,33 @@ class TestOpenAPIServiceConnector:
with pytest.raises(ValueError):
connector._parse_message(ChatMessage.from_assistant('[{"function": {"name": "test_method"}}]'))
def test_authenticate_service_invalid(self, connector, openapi_service_mock):
# Test invalid or missing authentication
openapi_service_mock.components.securitySchemes = {"apiKey": {}}
def test_authenticate_service_missing_authentication_token(self, connector, openapi_service_mock):
securitySchemes_mock = MagicMock()
securitySchemes_mock.raw_element = {"apiKey": {"in": "header", "name": "x-api-key", "type": "apiKey"}}
with pytest.raises(ValueError):
connector._authenticate_service(openapi_service_mock)
def test_authenticate_service_having_authentication_token(self, connector, openapi_service_mock):
securitySchemes_mock = MagicMock()
securitySchemes_mock.raw_element = {"apiKey": {"in": "header", "name": "x-api-key", "type": "apiKey"}}
openapi_service_mock.components.securitySchemes = securitySchemes_mock
connector._authenticate_service(openapi_service_mock, "some_fake_token")
def test_authenticate_service_having_authentication_dict(self, connector, openapi_service_mock):
securitySchemes_mock = MagicMock()
securitySchemes_mock.raw_element = {"apiKey": {"in": "header", "name": "x-api-key", "type": "apiKey"}}
openapi_service_mock.components.securitySchemes = securitySchemes_mock
connector._authenticate_service(openapi_service_mock, {"apiKey": "some_fake_token"})
def test_authenticate_service_having_authentication_dict_but_unsupported_auth(
self, connector, openapi_service_mock
):
securitySchemes_mock = MagicMock()
securitySchemes_mock.raw_element = {"oauth2": {"type": "oauth2"}}
openapi_service_mock.components.securitySchemes = securitySchemes_mock
with pytest.raises(ValueError):
connector._authenticate_service(openapi_service_mock, {"apiKey": "some_fake_token"})
def test_invoke_method_valid(self, connector, openapi_service_mock):
# Test valid method invocation
method_invocation_descriptor = {"name": "test_method", "arguments": {}}
@ -88,3 +139,23 @@ class TestOpenAPIServiceConnector:
"openapi3 changed. Model should have a _raw_data field, we rely on it in OpenAPIServiceConnector"
" to get the raw data from the service response"
)
@pytest.mark.integration
@pytest.mark.skipif(not os.getenv("GITHUB_TOKEN"), reason="GITHUB_TOKEN is not set")
def test_run(self, genuine_fc_message, test_files_path):
openapi_service = OpenAPIServiceConnector()
open_api_spec_path = test_files_path / "json" / "github_compare_branch_openapi_spec.json"
with open(open_api_spec_path, "r") as file:
github_compare_schema = json.load(file)
messages = [ChatMessage.from_assistant(genuine_fc_message)]
# genuine call to the GitHub OpenAPI service
result = openapi_service.run(messages, github_compare_schema, os.getenv("GITHUB_TOKEN"))
assert result
# load json from the service response
service_payload = json.loads(result["service_response"][0].content)
# verify that the service response contains the expected fields
assert "url" in service_payload and "files" in service_payload

View File

@ -0,0 +1,562 @@
{
"openapi": "3.1.0",
"info": {
"title": "Github API",
"description": "Enables interaction with OpenAPI",
"version": "v1.0.0"
},
"servers": [
{
"url": "https://api.github.com"
}
],
"paths": {
"/repos/{owner}/{repo}/compare/{basehead}": {
"get": {
"summary": "Compare two branches",
"description": "Compares two branches against one another.",
"tags": [
"repos"
],
"operationId": "compare_branches",
"externalDocs": {
"description": "API method documentation",
"url": "https://docs.github.com/enterprise-server@3.9/rest/commits/commits#compare-two-commits"
},
"parameters": [
{
"name": "basehead",
"description": "The base branch and head branch to compare. This parameter expects the format `BASE...HEAD`",
"in": "path",
"required": true,
"x-multi-segment": true,
"schema": {
"type": "string"
}
},
{
"name": "owner",
"description": "The repository owner, usually a company or orgnization",
"in": "path",
"required": true,
"x-multi-segment": true,
"schema": {
"type": "string"
}
},
{
"name": "repo",
"description": "The repository itself, the project",
"in": "path",
"required": true,
"x-multi-segment": true,
"schema": {
"type": "string"
}
}
],
"responses": {
"200": {
"description": "Response",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/commit-comparison"
}
}
}
}
},
"x-github": {
"githubCloudOnly": false,
"enabledForGitHubApps": true,
"category": "commits",
"subcategory": "commits"
}
}
}
},
"components": {
"schemas": {
"commit-comparison": {
"title": "Commit Comparison",
"description": "Commit Comparison",
"type": "object",
"properties": {
"url": {
"type": "string",
"format": "uri",
"example": "https://api.github.com/repos/octocat/Hello-World/compare/master...topic"
},
"html_url": {
"type": "string",
"format": "uri",
"example": "https://github.com/octocat/Hello-World/compare/master...topic"
},
"permalink_url": {
"type": "string",
"format": "uri",
"example": "https://github.com/octocat/Hello-World/compare/octocat:bbcd538c8e72b8c175046e27cc8f907076331401...octocat:0328041d1152db8ae77652d1618a02e57f745f17"
},
"diff_url": {
"type": "string",
"format": "uri",
"example": "https://github.com/octocat/Hello-World/compare/master...topic.diff"
},
"patch_url": {
"type": "string",
"format": "uri",
"example": "https://github.com/octocat/Hello-World/compare/master...topic.patch"
},
"base_commit": {
"$ref": "#/components/schemas/commit"
},
"merge_base_commit": {
"$ref": "#/components/schemas/commit"
},
"status": {
"type": "string",
"enum": [
"diverged",
"ahead",
"behind",
"identical"
],
"example": "ahead"
},
"ahead_by": {
"type": "integer",
"example": 4
},
"behind_by": {
"type": "integer",
"example": 5
},
"total_commits": {
"type": "integer",
"example": 6
},
"commits": {
"type": "array",
"items": {
"$ref": "#/components/schemas/commit"
}
},
"files": {
"type": "array",
"items": {
"$ref": "#/components/schemas/diff-entry"
}
}
},
"required": [
"url",
"html_url",
"permalink_url",
"diff_url",
"patch_url",
"base_commit",
"merge_base_commit",
"status",
"ahead_by",
"behind_by",
"total_commits",
"commits"
]
},
"nullable-git-user": {
"title": "Git User",
"description": "Metaproperties for Git author/committer information.",
"type": "object",
"properties": {
"name": {
"type": "string",
"example": "\"Chris Wanstrath\""
},
"email": {
"type": "string",
"example": "\"chris@ozmm.org\""
},
"date": {
"type": "string",
"example": "\"2007-10-29T02:42:39.000-07:00\""
}
},
"nullable": true
},
"nullable-simple-user": {
"title": "Simple User",
"description": "A GitHub user.",
"type": "object",
"properties": {
"name": {
"nullable": true,
"type": "string"
},
"email": {
"nullable": true,
"type": "string"
},
"login": {
"type": "string",
"example": "octocat"
},
"id": {
"type": "integer",
"example": 1
},
"node_id": {
"type": "string",
"example": "MDQ6VXNlcjE="
},
"avatar_url": {
"type": "string",
"format": "uri",
"example": "https://github.com/images/error/octocat_happy.gif"
},
"gravatar_id": {
"type": "string",
"example": "41d064eb2195891e12d0413f63227ea7",
"nullable": true
},
"url": {
"type": "string",
"format": "uri",
"example": "https://api.github.com/users/octocat"
},
"html_url": {
"type": "string",
"format": "uri",
"example": "https://github.com/octocat"
},
"followers_url": {
"type": "string",
"format": "uri",
"example": "https://api.github.com/users/octocat/followers"
},
"following_url": {
"type": "string",
"example": "https://api.github.com/users/octocat/following{/other_user}"
},
"gists_url": {
"type": "string",
"example": "https://api.github.com/users/octocat/gists{/gist_id}"
},
"starred_url": {
"type": "string",
"example": "https://api.github.com/users/octocat/starred{/owner}{/repo}"
},
"subscriptions_url": {
"type": "string",
"format": "uri",
"example": "https://api.github.com/users/octocat/subscriptions"
},
"organizations_url": {
"type": "string",
"format": "uri",
"example": "https://api.github.com/users/octocat/orgs"
},
"repos_url": {
"type": "string",
"format": "uri",
"example": "https://api.github.com/users/octocat/repos"
},
"events_url": {
"type": "string",
"example": "https://api.github.com/users/octocat/events{/privacy}"
},
"received_events_url": {
"type": "string",
"format": "uri",
"example": "https://api.github.com/users/octocat/received_events"
},
"type": {
"type": "string",
"example": "User"
},
"site_admin": {
"type": "boolean"
},
"starred_at": {
"type": "string",
"example": "\"2020-07-09T00:17:55Z\""
}
},
"required": [
"avatar_url",
"events_url",
"followers_url",
"following_url",
"gists_url",
"gravatar_id",
"html_url",
"id",
"node_id",
"login",
"organizations_url",
"received_events_url",
"repos_url",
"site_admin",
"starred_url",
"subscriptions_url",
"type",
"url"
],
"nullable": true
},
"verification": {
"title": "Verification",
"type": "object",
"properties": {
"verified": {
"type": "boolean"
},
"reason": {
"type": "string"
},
"payload": {
"type": "string",
"nullable": true
},
"signature": {
"type": "string",
"nullable": true
}
},
"required": [
"verified",
"reason",
"payload",
"signature"
]
},
"diff-entry": {
"title": "Diff Entry",
"description": "Diff Entry",
"type": "object",
"properties": {
"sha": {
"type": "string",
"example": "bbcd538c8e72b8c175046e27cc8f907076331401"
},
"filename": {
"type": "string",
"example": "file1.txt"
},
"status": {
"type": "string",
"enum": [
"added",
"removed",
"modified",
"renamed",
"copied",
"changed",
"unchanged"
],
"example": "added"
},
"additions": {
"type": "integer",
"example": 103
},
"deletions": {
"type": "integer",
"example": 21
},
"changes": {
"type": "integer",
"example": 124
},
"blob_url": {
"type": "string",
"format": "uri",
"example": "https://github.com/octocat/Hello-World/blob/6dcb09b5b57875f334f61aebed695e2e4193db5e/file1.txt"
},
"raw_url": {
"type": "string",
"format": "uri",
"example": "https://github.com/octocat/Hello-World/raw/6dcb09b5b57875f334f61aebed695e2e4193db5e/file1.txt"
},
"contents_url": {
"type": "string",
"format": "uri",
"example": "https://api.github.com/repos/octocat/Hello-World/contents/file1.txt?ref=6dcb09b5b57875f334f61aebed695e2e4193db5e"
},
"patch": {
"type": "string",
"example": "@@ -132,7 +132,7 @@ module Test @@ -1000,7 +1000,7 @@ module Test"
},
"previous_filename": {
"type": "string",
"example": "file.txt"
}
},
"required": [
"additions",
"blob_url",
"changes",
"contents_url",
"deletions",
"filename",
"raw_url",
"sha",
"status"
]
},
"commit": {
"title": "Commit",
"description": "Commit",
"type": "object",
"properties": {
"url": {
"type": "string",
"format": "uri",
"example": "https://api.github.com/repos/octocat/Hello-World/commits/6dcb09b5b57875f334f61aebed695e2e4193db5e"
},
"sha": {
"type": "string",
"example": "6dcb09b5b57875f334f61aebed695e2e4193db5e"
},
"node_id": {
"type": "string",
"example": "MDY6Q29tbWl0NmRjYjA5YjViNTc4NzVmMzM0ZjYxYWViZWQ2OTVlMmU0MTkzZGI1ZQ=="
},
"html_url": {
"type": "string",
"format": "uri",
"example": "https://github.com/octocat/Hello-World/commit/6dcb09b5b57875f334f61aebed695e2e4193db5e"
},
"comments_url": {
"type": "string",
"format": "uri",
"example": "https://api.github.com/repos/octocat/Hello-World/commits/6dcb09b5b57875f334f61aebed695e2e4193db5e/comments"
},
"commit": {
"type": "object",
"properties": {
"url": {
"type": "string",
"format": "uri",
"example": "https://api.github.com/repos/octocat/Hello-World/commits/6dcb09b5b57875f334f61aebed695e2e4193db5e"
},
"author": {
"$ref": "#/components/schemas/nullable-git-user"
},
"committer": {
"$ref": "#/components/schemas/nullable-git-user"
},
"message": {
"type": "string",
"example": "Fix all the bugs"
},
"comment_count": {
"type": "integer",
"example": 0
},
"tree": {
"type": "object",
"properties": {
"sha": {
"type": "string",
"example": "827efc6d56897b048c772eb4087f854f46256132"
},
"url": {
"type": "string",
"format": "uri",
"example": "https://api.github.com/repos/octocat/Hello-World/tree/827efc6d56897b048c772eb4087f854f46256132"
}
},
"required": [
"sha",
"url"
]
},
"verification": {
"$ref": "#/components/schemas/verification"
}
},
"required": [
"author",
"committer",
"comment_count",
"message",
"tree",
"url"
]
},
"author": {
"$ref": "#/components/schemas/nullable-simple-user"
},
"committer": {
"$ref": "#/components/schemas/nullable-simple-user"
},
"parents": {
"type": "array",
"items": {
"type": "object",
"properties": {
"sha": {
"type": "string",
"example": "7638417db6d59f3c431d3e1f261cc637155684cd"
},
"url": {
"type": "string",
"format": "uri",
"example": "https://api.github.com/repos/octocat/Hello-World/commits/7638417db6d59f3c431d3e1f261cc637155684cd"
},
"html_url": {
"type": "string",
"format": "uri",
"example": "https://github.com/octocat/Hello-World/commit/7638417db6d59f3c431d3e1f261cc637155684cd"
}
},
"required": [
"sha",
"url"
]
}
},
"stats": {
"type": "object",
"properties": {
"additions": {
"type": "integer"
},
"deletions": {
"type": "integer"
},
"total": {
"type": "integer"
}
}
},
"files": {
"type": "array",
"items": {
"$ref": "#/components/schemas/diff-entry"
}
}
},
"required": [
"url",
"sha",
"node_id",
"html_url",
"comments_url",
"commit",
"author",
"committer",
"parents"
]
}
},
"securitySchemes": {
"apikey": {
"type": "apiKey",
"name": "x-api-key",
"in": "header"
}
}
}
}