From 37d9de3c4e011fb4aac3644a6b7ca5ee541fe088 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Fri, 9 Feb 2024 16:03:27 +0100 Subject: [PATCH] feat: Add service_credentials to OpenAPIServiceConnector run (#6962) * Add service_credentials to OpenAPIServiceConnector run * PR feedback Silvano --- .github/workflows/tests.yml | 1 + .../components/connectors/openapi_service.py | 73 ++- ...tor-auth-enhancement-a78e0666d3cf6353.yaml | 4 + .../connectors/test_openapi_service.py | 85 ++- .../github_compare_branch_openapi_spec.json | 562 ++++++++++++++++++ 5 files changed, 704 insertions(+), 21 deletions(-) create mode 100644 releasenotes/notes/openapi-connector-auth-enhancement-a78e0666d3cf6353.yaml create mode 100644 test/test_files/json/github_compare_branch_openapi_spec.json diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index e0b2caf86..88b5832a6 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -25,6 +25,7 @@ env: CORE_AZURE_CS_API_KEY: ${{ secrets.CORE_AZURE_CS_API_KEY }} AZURE_OPENAI_API_KEY: ${{ secrets.AZURE_OPENAI_API_KEY }} AZURE_OPENAI_ENDPOINT: ${{ secrets.AZURE_OPENAI_ENDPOINT }} + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} PYTHON_VERSION: "3.8" jobs: diff --git a/haystack/components/connectors/openapi_service.py b/haystack/components/connectors/openapi_service.py index d4b070bc1..db41076e7 100644 --- a/haystack/components/connectors/openapi_service.py +++ b/haystack/components/connectors/openapi_service.py @@ -1,6 +1,6 @@ import json import logging -from typing import List, Dict, Any, Optional +from typing import List, Dict, Any, Optional, Union from haystack import component from haystack.dataclasses import ChatMessage, ChatRole @@ -26,16 +26,19 @@ class OpenAPIServiceConnector: This can be done using the OpenAPIServiceToFunctions component. """ - def __init__(self, service_auths: Optional[Dict[str, Any]] = None): + def __init__(self): """ Initializes the OpenAPIServiceConnector instance - :param service_auths: A dictionary containing the service name and token to be used for authentication. """ openapi_imports.check() - self.service_authentications = service_auths or {} @component.output_types(service_response=Dict[str, Any]) - def run(self, messages: List[ChatMessage], service_openapi_spec: Dict[str, Any]) -> Dict[str, List[ChatMessage]]: + def run( + self, + messages: List[ChatMessage], + service_openapi_spec: Dict[str, Any], + service_credentials: Optional[Union[dict, str]] = None, + ) -> Dict[str, List[ChatMessage]]: """ Processes a list of chat messages to invoke a method on an OpenAPI service. It parses the last message in the list, expecting it to contain an OpenAI function calling descriptor (name & parameters) in JSON format. @@ -46,6 +49,9 @@ class OpenAPIServiceConnector: :type service_openapi_spec: JSON object :return: A dictionary with a key `"service_response"`, containing the response from the OpenAPI service. :rtype: Dict[str, List[ChatMessage]] + :param service_credentials: The credentials to be used for authentication with the service. + Currently, only the http and apiKey schemes are supported. See _authenticate_service method for more details. + :type service_credentials: Optional[Union[dict, str]] :raises ValueError: If the last message is not from the assistant or if it does not contain the correct payload to invoke a method on the service. """ @@ -58,7 +64,7 @@ class OpenAPIServiceConnector: # instantiate the OpenAPI service for the given specification openapi_service = OpenAPI(service_openapi_spec) - self._authenticate_service(openapi_service) + self._authenticate_service(openapi_service, service_credentials) response_messages = [] for method_invocation_descriptor in function_invocation_payloads: @@ -101,20 +107,59 @@ class OpenAPIServiceConnector: ) return function_payloads - def _authenticate_service(self, openapi_service: OpenAPI): + def _authenticate_service(self, openapi_service: OpenAPI, credentials: Optional[Union[dict, str]] = None): """ - Authenticates with the OpenAPI service if required. + Authenticates with the OpenAPI service if required, supporting both single (str) and multiple + authentication methods (dict). + + OpenAPI spec v3 supports the following security schemes: + http – for Basic, Bearer and other HTTP authentications schemes + apiKey – for API keys and cookie authentication + oauth2 – for OAuth 2 + openIdConnect – for OpenID Connect Discovery + + Currently, only the http and apiKey schemes are supported. Multiple security schemes can be defined in the + OpenAPI spec, and the credentials should be provided as a dictionary with keys matching the security scheme + names. If only one security scheme is defined, the credentials can be provided as a simple string. :param openapi_service: The OpenAPI service instance. :type openapi_service: OpenAPI - :raises ValueError: If authentication fails or is not found. + :param credentials: Credentials for authentication, which can be either a string (e.g. token) or a dictionary + with keys matching the authentication method names. + :type credentials: dict | str, optional + :raises ValueError: If authentication fails, is not found, or if appropriate credentials are missing. """ + service_name = openapi_service.info.title if openapi_service.components.securitySchemes: - auth_method = list(openapi_service.components.securitySchemes.keys())[0] - service_title = openapi_service.info.title - if service_title not in self.service_authentications: - raise ValueError(f"Service {service_title} not found in service_authentications.") - openapi_service.authenticate(auth_method, self.service_authentications[service_title]) + if not credentials: + raise ValueError(f"Service {service_name} requires authentication but no credentials were provided.") + + # a dictionary of security schemes defined in the OpenAPI spec + # each key is the name of the security scheme, and the value is the scheme definition + security_schemes = openapi_service.components.securitySchemes.raw_element + supported_schemes = ["http", "apiKey"] # todo: add support for oauth2 and openIdConnect + + authenticated = False + for scheme_name, scheme in security_schemes.items(): + if scheme["type"] in supported_schemes: + auth_credentials = None + if isinstance(credentials, str): + auth_credentials = credentials + elif isinstance(credentials, dict) and scheme_name in credentials: + auth_credentials = credentials[scheme_name] + if auth_credentials: + openapi_service.authenticate(scheme_name, auth_credentials) + authenticated = True + else: + raise ValueError( + f"Service {service_name} requires {scheme_name} security scheme but no " + f"credentials were provided for it. Check the service configuration and credentials." + ) + if not authenticated: + raise ValueError( + f"Service {service_name} requires authentication but no credentials were provided " + f"for it. Check the service configuration and credentials." + ) def _invoke_method(self, openapi_service: OpenAPI, method_invocation_descriptor: Dict[str, Any]) -> Any: """ diff --git a/releasenotes/notes/openapi-connector-auth-enhancement-a78e0666d3cf6353.yaml b/releasenotes/notes/openapi-connector-auth-enhancement-a78e0666d3cf6353.yaml new file mode 100644 index 000000000..8575bc14f --- /dev/null +++ b/releasenotes/notes/openapi-connector-auth-enhancement-a78e0666d3cf6353.yaml @@ -0,0 +1,4 @@ +--- +enhancements: + - | + Enhanced the OpenAPIServiceConnector to support dynamic authentication handling. With this update, service credentials are now dynamically provided at each run invocation, eliminating the need for pre-configuring a known set of service authentications. This flexibility allows for the introduction of new services on-the-fly, each with its unique authentication, streamlining the integration process. This modification not only simplifies the initial setup of the OpenAPIServiceConnector but also ensures a more transparent and straightforward authentication process for each interaction with different OpenAPI services. diff --git a/test/components/connectors/test_openapi_service.py b/test/components/connectors/test_openapi_service.py index 2ac94b5bf..29441d1f7 100644 --- a/test/components/connectors/test_openapi_service.py +++ b/test/components/connectors/test_openapi_service.py @@ -1,6 +1,10 @@ import json +import os + import pytest from unittest.mock import MagicMock, Mock + +import requests from openapi3 import OpenAPI from openapi3.schemas import Model from haystack.components.connectors import OpenAPIServiceConnector @@ -13,14 +17,40 @@ def openapi_service_mock(): @pytest.fixture -def service_auths(): - return {"TestService": "auth_token"} +def random_open_pull_request_head_branch() -> str: + token = os.getenv("GITHUB_TOKEN") + headers = {"Accept": "application/vnd.github.v3+json", "Authorization": f"token {token}"} + response = requests.get("https://api.github.com/repos/deepset-ai/haystack/pulls?state=open", headers=headers) + + if response.status_code == 200: + pull_requests = response.json() + for pr in pull_requests: + if pr["base"]["ref"] == "main": + return pr["head"]["ref"] + else: + raise Exception(f"Failed to fetch pull requests. Status code: {response.status_code}") + + +@pytest.fixture +def genuine_fc_message(random_open_pull_request_head_branch): + basehead = "main..." + random_open_pull_request_head_branch + # arguments, see below, are always passed as a string representation of a JSON object + params = '{"parameters": {"basehead": "' + basehead + '", "owner": "deepset-ai", "repo": "haystack"}}' + payload_json = [ + { + "id": "call_NJr1NBz2Th7iUWJpRIJZoJIA", + "function": {"arguments": params, "name": "compare_branches"}, + "type": "function", + } + ] + + return json.dumps(payload_json) class TestOpenAPIServiceConnector: @pytest.fixture - def connector(self, service_auths): - return OpenAPIServiceConnector(service_auths) + def connector(self): + return OpenAPIServiceConnector() def test_parse_message_invalid_json(self, connector): # Test invalid JSON content @@ -62,12 +92,33 @@ class TestOpenAPIServiceConnector: with pytest.raises(ValueError): connector._parse_message(ChatMessage.from_assistant('[{"function": {"name": "test_method"}}]')) - def test_authenticate_service_invalid(self, connector, openapi_service_mock): - # Test invalid or missing authentication - openapi_service_mock.components.securitySchemes = {"apiKey": {}} + def test_authenticate_service_missing_authentication_token(self, connector, openapi_service_mock): + securitySchemes_mock = MagicMock() + securitySchemes_mock.raw_element = {"apiKey": {"in": "header", "name": "x-api-key", "type": "apiKey"}} with pytest.raises(ValueError): connector._authenticate_service(openapi_service_mock) + def test_authenticate_service_having_authentication_token(self, connector, openapi_service_mock): + securitySchemes_mock = MagicMock() + securitySchemes_mock.raw_element = {"apiKey": {"in": "header", "name": "x-api-key", "type": "apiKey"}} + openapi_service_mock.components.securitySchemes = securitySchemes_mock + connector._authenticate_service(openapi_service_mock, "some_fake_token") + + def test_authenticate_service_having_authentication_dict(self, connector, openapi_service_mock): + securitySchemes_mock = MagicMock() + securitySchemes_mock.raw_element = {"apiKey": {"in": "header", "name": "x-api-key", "type": "apiKey"}} + openapi_service_mock.components.securitySchemes = securitySchemes_mock + connector._authenticate_service(openapi_service_mock, {"apiKey": "some_fake_token"}) + + def test_authenticate_service_having_authentication_dict_but_unsupported_auth( + self, connector, openapi_service_mock + ): + securitySchemes_mock = MagicMock() + securitySchemes_mock.raw_element = {"oauth2": {"type": "oauth2"}} + openapi_service_mock.components.securitySchemes = securitySchemes_mock + with pytest.raises(ValueError): + connector._authenticate_service(openapi_service_mock, {"apiKey": "some_fake_token"}) + def test_invoke_method_valid(self, connector, openapi_service_mock): # Test valid method invocation method_invocation_descriptor = {"name": "test_method", "arguments": {}} @@ -88,3 +139,23 @@ class TestOpenAPIServiceConnector: "openapi3 changed. Model should have a _raw_data field, we rely on it in OpenAPIServiceConnector" " to get the raw data from the service response" ) + + @pytest.mark.integration + @pytest.mark.skipif(not os.getenv("GITHUB_TOKEN"), reason="GITHUB_TOKEN is not set") + def test_run(self, genuine_fc_message, test_files_path): + openapi_service = OpenAPIServiceConnector() + + open_api_spec_path = test_files_path / "json" / "github_compare_branch_openapi_spec.json" + with open(open_api_spec_path, "r") as file: + github_compare_schema = json.load(file) + messages = [ChatMessage.from_assistant(genuine_fc_message)] + + # genuine call to the GitHub OpenAPI service + result = openapi_service.run(messages, github_compare_schema, os.getenv("GITHUB_TOKEN")) + assert result + + # load json from the service response + service_payload = json.loads(result["service_response"][0].content) + + # verify that the service response contains the expected fields + assert "url" in service_payload and "files" in service_payload diff --git a/test/test_files/json/github_compare_branch_openapi_spec.json b/test/test_files/json/github_compare_branch_openapi_spec.json new file mode 100644 index 000000000..8bd7c7ed8 --- /dev/null +++ b/test/test_files/json/github_compare_branch_openapi_spec.json @@ -0,0 +1,562 @@ +{ + "openapi": "3.1.0", + "info": { + "title": "Github API", + "description": "Enables interaction with OpenAPI", + "version": "v1.0.0" + }, + "servers": [ + { + "url": "https://api.github.com" + } + ], + "paths": { + "/repos/{owner}/{repo}/compare/{basehead}": { + "get": { + "summary": "Compare two branches", + "description": "Compares two branches against one another.", + "tags": [ + "repos" + ], + "operationId": "compare_branches", + "externalDocs": { + "description": "API method documentation", + "url": "https://docs.github.com/enterprise-server@3.9/rest/commits/commits#compare-two-commits" + }, + "parameters": [ + { + "name": "basehead", + "description": "The base branch and head branch to compare. This parameter expects the format `BASE...HEAD`", + "in": "path", + "required": true, + "x-multi-segment": true, + "schema": { + "type": "string" + } + }, + { + "name": "owner", + "description": "The repository owner, usually a company or orgnization", + "in": "path", + "required": true, + "x-multi-segment": true, + "schema": { + "type": "string" + } + }, + { + "name": "repo", + "description": "The repository itself, the project", + "in": "path", + "required": true, + "x-multi-segment": true, + "schema": { + "type": "string" + } + } + ], + "responses": { + "200": { + "description": "Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/commit-comparison" + } + } + } + } + }, + "x-github": { + "githubCloudOnly": false, + "enabledForGitHubApps": true, + "category": "commits", + "subcategory": "commits" + } + } + } + }, + "components": { + "schemas": { + "commit-comparison": { + "title": "Commit Comparison", + "description": "Commit Comparison", + "type": "object", + "properties": { + "url": { + "type": "string", + "format": "uri", + "example": "https://api.github.com/repos/octocat/Hello-World/compare/master...topic" + }, + "html_url": { + "type": "string", + "format": "uri", + "example": "https://github.com/octocat/Hello-World/compare/master...topic" + }, + "permalink_url": { + "type": "string", + "format": "uri", + "example": "https://github.com/octocat/Hello-World/compare/octocat:bbcd538c8e72b8c175046e27cc8f907076331401...octocat:0328041d1152db8ae77652d1618a02e57f745f17" + }, + "diff_url": { + "type": "string", + "format": "uri", + "example": "https://github.com/octocat/Hello-World/compare/master...topic.diff" + }, + "patch_url": { + "type": "string", + "format": "uri", + "example": "https://github.com/octocat/Hello-World/compare/master...topic.patch" + }, + "base_commit": { + "$ref": "#/components/schemas/commit" + }, + "merge_base_commit": { + "$ref": "#/components/schemas/commit" + }, + "status": { + "type": "string", + "enum": [ + "diverged", + "ahead", + "behind", + "identical" + ], + "example": "ahead" + }, + "ahead_by": { + "type": "integer", + "example": 4 + }, + "behind_by": { + "type": "integer", + "example": 5 + }, + "total_commits": { + "type": "integer", + "example": 6 + }, + "commits": { + "type": "array", + "items": { + "$ref": "#/components/schemas/commit" + } + }, + "files": { + "type": "array", + "items": { + "$ref": "#/components/schemas/diff-entry" + } + } + }, + "required": [ + "url", + "html_url", + "permalink_url", + "diff_url", + "patch_url", + "base_commit", + "merge_base_commit", + "status", + "ahead_by", + "behind_by", + "total_commits", + "commits" + ] + }, + "nullable-git-user": { + "title": "Git User", + "description": "Metaproperties for Git author/committer information.", + "type": "object", + "properties": { + "name": { + "type": "string", + "example": "\"Chris Wanstrath\"" + }, + "email": { + "type": "string", + "example": "\"chris@ozmm.org\"" + }, + "date": { + "type": "string", + "example": "\"2007-10-29T02:42:39.000-07:00\"" + } + }, + "nullable": true + }, + "nullable-simple-user": { + "title": "Simple User", + "description": "A GitHub user.", + "type": "object", + "properties": { + "name": { + "nullable": true, + "type": "string" + }, + "email": { + "nullable": true, + "type": "string" + }, + "login": { + "type": "string", + "example": "octocat" + }, + "id": { + "type": "integer", + "example": 1 + }, + "node_id": { + "type": "string", + "example": "MDQ6VXNlcjE=" + }, + "avatar_url": { + "type": "string", + "format": "uri", + "example": "https://github.com/images/error/octocat_happy.gif" + }, + "gravatar_id": { + "type": "string", + "example": "41d064eb2195891e12d0413f63227ea7", + "nullable": true + }, + "url": { + "type": "string", + "format": "uri", + "example": "https://api.github.com/users/octocat" + }, + "html_url": { + "type": "string", + "format": "uri", + "example": "https://github.com/octocat" + }, + "followers_url": { + "type": "string", + "format": "uri", + "example": "https://api.github.com/users/octocat/followers" + }, + "following_url": { + "type": "string", + "example": "https://api.github.com/users/octocat/following{/other_user}" + }, + "gists_url": { + "type": "string", + "example": "https://api.github.com/users/octocat/gists{/gist_id}" + }, + "starred_url": { + "type": "string", + "example": "https://api.github.com/users/octocat/starred{/owner}{/repo}" + }, + "subscriptions_url": { + "type": "string", + "format": "uri", + "example": "https://api.github.com/users/octocat/subscriptions" + }, + "organizations_url": { + "type": "string", + "format": "uri", + "example": "https://api.github.com/users/octocat/orgs" + }, + "repos_url": { + "type": "string", + "format": "uri", + "example": "https://api.github.com/users/octocat/repos" + }, + "events_url": { + "type": "string", + "example": "https://api.github.com/users/octocat/events{/privacy}" + }, + "received_events_url": { + "type": "string", + "format": "uri", + "example": "https://api.github.com/users/octocat/received_events" + }, + "type": { + "type": "string", + "example": "User" + }, + "site_admin": { + "type": "boolean" + }, + "starred_at": { + "type": "string", + "example": "\"2020-07-09T00:17:55Z\"" + } + }, + "required": [ + "avatar_url", + "events_url", + "followers_url", + "following_url", + "gists_url", + "gravatar_id", + "html_url", + "id", + "node_id", + "login", + "organizations_url", + "received_events_url", + "repos_url", + "site_admin", + "starred_url", + "subscriptions_url", + "type", + "url" + ], + "nullable": true + }, + "verification": { + "title": "Verification", + "type": "object", + "properties": { + "verified": { + "type": "boolean" + }, + "reason": { + "type": "string" + }, + "payload": { + "type": "string", + "nullable": true + }, + "signature": { + "type": "string", + "nullable": true + } + }, + "required": [ + "verified", + "reason", + "payload", + "signature" + ] + }, + "diff-entry": { + "title": "Diff Entry", + "description": "Diff Entry", + "type": "object", + "properties": { + "sha": { + "type": "string", + "example": "bbcd538c8e72b8c175046e27cc8f907076331401" + }, + "filename": { + "type": "string", + "example": "file1.txt" + }, + "status": { + "type": "string", + "enum": [ + "added", + "removed", + "modified", + "renamed", + "copied", + "changed", + "unchanged" + ], + "example": "added" + }, + "additions": { + "type": "integer", + "example": 103 + }, + "deletions": { + "type": "integer", + "example": 21 + }, + "changes": { + "type": "integer", + "example": 124 + }, + "blob_url": { + "type": "string", + "format": "uri", + "example": "https://github.com/octocat/Hello-World/blob/6dcb09b5b57875f334f61aebed695e2e4193db5e/file1.txt" + }, + "raw_url": { + "type": "string", + "format": "uri", + "example": "https://github.com/octocat/Hello-World/raw/6dcb09b5b57875f334f61aebed695e2e4193db5e/file1.txt" + }, + "contents_url": { + "type": "string", + "format": "uri", + "example": "https://api.github.com/repos/octocat/Hello-World/contents/file1.txt?ref=6dcb09b5b57875f334f61aebed695e2e4193db5e" + }, + "patch": { + "type": "string", + "example": "@@ -132,7 +132,7 @@ module Test @@ -1000,7 +1000,7 @@ module Test" + }, + "previous_filename": { + "type": "string", + "example": "file.txt" + } + }, + "required": [ + "additions", + "blob_url", + "changes", + "contents_url", + "deletions", + "filename", + "raw_url", + "sha", + "status" + ] + }, + "commit": { + "title": "Commit", + "description": "Commit", + "type": "object", + "properties": { + "url": { + "type": "string", + "format": "uri", + "example": "https://api.github.com/repos/octocat/Hello-World/commits/6dcb09b5b57875f334f61aebed695e2e4193db5e" + }, + "sha": { + "type": "string", + "example": "6dcb09b5b57875f334f61aebed695e2e4193db5e" + }, + "node_id": { + "type": "string", + "example": "MDY6Q29tbWl0NmRjYjA5YjViNTc4NzVmMzM0ZjYxYWViZWQ2OTVlMmU0MTkzZGI1ZQ==" + }, + "html_url": { + "type": "string", + "format": "uri", + "example": "https://github.com/octocat/Hello-World/commit/6dcb09b5b57875f334f61aebed695e2e4193db5e" + }, + "comments_url": { + "type": "string", + "format": "uri", + "example": "https://api.github.com/repos/octocat/Hello-World/commits/6dcb09b5b57875f334f61aebed695e2e4193db5e/comments" + }, + "commit": { + "type": "object", + "properties": { + "url": { + "type": "string", + "format": "uri", + "example": "https://api.github.com/repos/octocat/Hello-World/commits/6dcb09b5b57875f334f61aebed695e2e4193db5e" + }, + "author": { + "$ref": "#/components/schemas/nullable-git-user" + }, + "committer": { + "$ref": "#/components/schemas/nullable-git-user" + }, + "message": { + "type": "string", + "example": "Fix all the bugs" + }, + "comment_count": { + "type": "integer", + "example": 0 + }, + "tree": { + "type": "object", + "properties": { + "sha": { + "type": "string", + "example": "827efc6d56897b048c772eb4087f854f46256132" + }, + "url": { + "type": "string", + "format": "uri", + "example": "https://api.github.com/repos/octocat/Hello-World/tree/827efc6d56897b048c772eb4087f854f46256132" + } + }, + "required": [ + "sha", + "url" + ] + }, + "verification": { + "$ref": "#/components/schemas/verification" + } + }, + "required": [ + "author", + "committer", + "comment_count", + "message", + "tree", + "url" + ] + }, + "author": { + "$ref": "#/components/schemas/nullable-simple-user" + }, + "committer": { + "$ref": "#/components/schemas/nullable-simple-user" + }, + "parents": { + "type": "array", + "items": { + "type": "object", + "properties": { + "sha": { + "type": "string", + "example": "7638417db6d59f3c431d3e1f261cc637155684cd" + }, + "url": { + "type": "string", + "format": "uri", + "example": "https://api.github.com/repos/octocat/Hello-World/commits/7638417db6d59f3c431d3e1f261cc637155684cd" + }, + "html_url": { + "type": "string", + "format": "uri", + "example": "https://github.com/octocat/Hello-World/commit/7638417db6d59f3c431d3e1f261cc637155684cd" + } + }, + "required": [ + "sha", + "url" + ] + } + }, + "stats": { + "type": "object", + "properties": { + "additions": { + "type": "integer" + }, + "deletions": { + "type": "integer" + }, + "total": { + "type": "integer" + } + } + }, + "files": { + "type": "array", + "items": { + "$ref": "#/components/schemas/diff-entry" + } + } + }, + "required": [ + "url", + "sha", + "node_id", + "html_url", + "comments_url", + "commit", + "author", + "committer", + "parents" + ] + } + }, + "securitySchemes": { + "apikey": { + "type": "apiKey", + "name": "x-api-key", + "in": "header" + } + } + } +}