From 8002cf92d6a54c60ac8180f85d1c42b1105f547c Mon Sep 17 00:00:00 2001 From: tstadel <60758086+tstadel@users.noreply.github.com> Date: Tue, 31 Jan 2023 16:31:33 +0100 Subject: [PATCH] fix: extend schema for prompt node results (#3891) * extend schema for prompt node results * extend schema * update openapi * fix mypy for test module * added 1.14 specs * reverted schema for 1.13 --------- Co-authored-by: bogdankostic Co-authored-by: Mayank Jobanputra Co-authored-by: Sebastian Co-authored-by: ZanSara --- docs/_src/api/openapi/openapi-1.14.0rc0.json | 7 +++++ docs/_src/api/openapi/openapi.json | 7 +++++ rest_api/rest_api/schema.py | 1 + rest_api/test/test_rest_api.py | 30 ++++++++++++++++++++ 4 files changed, 45 insertions(+) diff --git a/docs/_src/api/openapi/openapi-1.14.0rc0.json b/docs/_src/api/openapi/openapi-1.14.0rc0.json index c6772e257..90f1ed6a2 100644 --- a/docs/_src/api/openapi/openapi-1.14.0rc0.json +++ b/docs/_src/api/openapi/openapi-1.14.0rc0.json @@ -977,6 +977,13 @@ }, "default": [] }, + "results": { + "title": "Results", + "type": "array", + "items": { + "type": "string" + } + }, "_debug": { "title": " Debug", "type": "object" diff --git a/docs/_src/api/openapi/openapi.json b/docs/_src/api/openapi/openapi.json index c6772e257..90f1ed6a2 100644 --- a/docs/_src/api/openapi/openapi.json +++ b/docs/_src/api/openapi/openapi.json @@ -977,6 +977,13 @@ }, "default": [] }, + "results": { + "title": "Results", + "type": "array", + "items": { + "type": "string" + } + }, "_debug": { "title": " Debug", "type": "object" diff --git a/rest_api/rest_api/schema.py b/rest_api/rest_api/schema.py index be0b97c5c..659408162 100644 --- a/rest_api/rest_api/schema.py +++ b/rest_api/rest_api/schema.py @@ -61,4 +61,5 @@ class QueryResponse(BaseModel): query: str answers: List[Answer] = [] documents: List[Document] = [] + results: Optional[List[str]] = None debug: Optional[Dict] = Field(None, alias="_debug") diff --git a/rest_api/test/test_rest_api.py b/rest_api/test/test_rest_api.py index 2b288639b..124b44228 100644 --- a/rest_api/test/test_rest_api.py +++ b/rest_api/test/test_rest_api.py @@ -1,3 +1,5 @@ +# mypy: disable_error_code = "empty-body, override, union-attr" + from typing import Dict, List, Optional, Union, Generator import os @@ -449,6 +451,34 @@ def test_query_with_dataframe(client): mocked_pipeline.run.assert_called_with(query=TEST_QUERY, params={}, debug=False) +def test_query_with_prompt_node(client): + with mock.patch("rest_api.controller.search.query_pipeline") as mocked_pipeline: + # `run` must return a dictionary containing a `query` key + mocked_pipeline.run.return_value = { + "query": TEST_QUERY, + "documents": [ + Document( + content="test", + content_type="text", + score=0.9, + meta={"test_key": "test_value"}, + embedding=np.array([0.1, 0.2, 0.3]), + ) + ], + "results": ["test"], + } + response = client.post(url="/query", json={"query": TEST_QUERY}) + assert 200 == response.status_code + assert len(response.json()["documents"]) == 1 + assert response.json()["documents"][0]["content"] == "test" + assert response.json()["documents"][0]["content_type"] == "text" + assert response.json()["documents"][0]["embedding"] == [0.1, 0.2, 0.3] + assert len(response.json()["results"]) == 1 + assert response.json()["results"][0] == "test" + # Ensure `run` was called with the expected parameters + mocked_pipeline.run.assert_called_with(query=TEST_QUERY, params={}, debug=False) + + def test_write_feedback(client, feedback): response = client.post(url="/feedback", json=feedback) assert 200 == response.status_code