fix: extend schema for prompt node results (#3891)

* extend schema for prompt node results

* extend schema

* update openapi

* fix mypy for test module

* added 1.14 specs

* reverted schema for 1.13

---------

Co-authored-by: bogdankostic <bogdankostic@web.de>
Co-authored-by: Mayank Jobanputra <mayankjobanputra@gmail.com>
Co-authored-by: Sebastian <sjrl@users.noreply.github.com>
Co-authored-by: ZanSara <sara.zanzottera@deepset.ai>
This commit is contained in:
tstadel 2023-01-31 16:31:33 +01:00 committed by GitHub
parent c855e18d78
commit 8002cf92d6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 45 additions and 0 deletions

View File

@ -977,6 +977,13 @@
}, },
"default": [] "default": []
}, },
"results": {
"title": "Results",
"type": "array",
"items": {
"type": "string"
}
},
"_debug": { "_debug": {
"title": " Debug", "title": " Debug",
"type": "object" "type": "object"

View File

@ -977,6 +977,13 @@
}, },
"default": [] "default": []
}, },
"results": {
"title": "Results",
"type": "array",
"items": {
"type": "string"
}
},
"_debug": { "_debug": {
"title": " Debug", "title": " Debug",
"type": "object" "type": "object"

View File

@ -61,4 +61,5 @@ class QueryResponse(BaseModel):
query: str query: str
answers: List[Answer] = [] answers: List[Answer] = []
documents: List[Document] = [] documents: List[Document] = []
results: Optional[List[str]] = None
debug: Optional[Dict] = Field(None, alias="_debug") debug: Optional[Dict] = Field(None, alias="_debug")

View File

@ -1,3 +1,5 @@
# mypy: disable_error_code = "empty-body, override, union-attr"
from typing import Dict, List, Optional, Union, Generator from typing import Dict, List, Optional, Union, Generator
import os import os
@ -449,6 +451,34 @@ def test_query_with_dataframe(client):
mocked_pipeline.run.assert_called_with(query=TEST_QUERY, params={}, debug=False) mocked_pipeline.run.assert_called_with(query=TEST_QUERY, params={}, debug=False)
def test_query_with_prompt_node(client):
with mock.patch("rest_api.controller.search.query_pipeline") as mocked_pipeline:
# `run` must return a dictionary containing a `query` key
mocked_pipeline.run.return_value = {
"query": TEST_QUERY,
"documents": [
Document(
content="test",
content_type="text",
score=0.9,
meta={"test_key": "test_value"},
embedding=np.array([0.1, 0.2, 0.3]),
)
],
"results": ["test"],
}
response = client.post(url="/query", json={"query": TEST_QUERY})
assert 200 == response.status_code
assert len(response.json()["documents"]) == 1
assert response.json()["documents"][0]["content"] == "test"
assert response.json()["documents"][0]["content_type"] == "text"
assert response.json()["documents"][0]["embedding"] == [0.1, 0.2, 0.3]
assert len(response.json()["results"]) == 1
assert response.json()["results"][0] == "test"
# Ensure `run` was called with the expected parameters
mocked_pipeline.run.assert_called_with(query=TEST_QUERY, params={}, debug=False)
def test_write_feedback(client, feedback): def test_write_feedback(client, feedback):
response = client.post(url="/feedback", json=feedback) response = client.post(url="/feedback", json=feedback)
assert 200 == response.status_code assert 200 == response.status_code