feat: add optional index selection for endpoints (#5444)

* add index selection

* reformatting

* updated test script
This commit is contained in:
Muhammad Bilal 2023-08-01 10:47:46 +02:00 committed by GitHub
parent 62029ba441
commit 8920fd6939
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 27 additions and 22 deletions

View File

@ -1,4 +1,4 @@
from typing import List
from typing import Optional, List
import logging
@ -21,7 +21,7 @@ document_store: BaseDocumentStore = get_pipelines().get("document_store", None)
@router.post("/documents/get_by_filters", response_model=List[Document], response_model_exclude_none=True)
def get_documents(filters: FilterRequest):
def get_documents(filters: FilterRequest, index: Optional[str] = None):
"""
This endpoint allows you to retrieve documents contained in your document store.
You can filter the documents to retrieve by metadata (like the document's name),
@ -33,14 +33,14 @@ def get_documents(filters: FilterRequest):
To get all documents you should provide an empty dict, like:
`'{"filters": {}}'`
"""
docs = document_store.get_all_documents(filters=filters.filters)
docs = document_store.get_all_documents(filters=filters.filters, index=index)
for doc in docs:
doc.embedding = None
return docs
@router.post("/documents/delete_by_filters", response_model=bool)
def delete_documents(filters: FilterRequest):
def delete_documents(filters: FilterRequest, index: Optional[str] = None):
"""
This endpoint allows you to delete documents contained in your document store.
You can filter the documents to delete by metadata (like the document's name),
@ -52,5 +52,5 @@ def delete_documents(filters: FilterRequest):
To get all documents you should provide an empty dict, like:
`'{"filters": {}}'`
"""
document_store.delete_documents(filters=filters.filters)
document_store.delete_documents(filters=filters.filters, index=index)
return True

View File

@ -18,7 +18,7 @@ document_store: BaseDocumentStore = get_pipelines().get("document_store", None)
@router.post("/feedback")
def post_feedback(feedback: CreateLabelSerialized):
def post_feedback(feedback: CreateLabelSerialized, index: Optional[str] = None):
"""
With this endpoint, the API user can submit their feedback on an answer for a particular query. This feedback is then written to the label_index of the DocumentStore.
@ -31,31 +31,31 @@ def post_feedback(feedback: CreateLabelSerialized):
feedback.origin = "user-feedback"
label = Label(**feedback.dict())
document_store.write_labels([label])
document_store.write_labels([label], index=index)
@router.get("/feedback", response_model=List[Label])
def get_feedback():
def get_feedback(index: Optional[str] = None):
"""
This endpoint allows the API user to retrieve all the feedback that has been submitted through the `POST /feedback` endpoint.
"""
labels = document_store.get_all_labels()
labels = document_store.get_all_labels(index=index)
return labels
@router.delete("/feedback")
def delete_feedback():
def delete_feedback(index: Optional[str] = None):
"""
This endpoint allows the API user to delete all the feedback that has been sumbitted through the
`POST /feedback` endpoint.
"""
all_labels = document_store.get_all_labels()
all_labels = document_store.get_all_labels(index=index)
user_label_ids = [label.id for label in all_labels if label.origin == "user-feedback"]
document_store.delete_labels(ids=user_label_ids)
document_store.delete_labels(ids=user_label_ids, index=index)
@router.post("/eval-feedback")
def get_feedback_metrics(filters: Optional[FilterRequest] = None):
def get_feedback_metrics(filters: Optional[FilterRequest] = None, index: Optional[str] = None):
"""
This endpoint returns basic accuracy metrics based on user feedback, for example, the ratio of correct answers or correctly identified documents.
You can filter the output by document or label.
@ -73,7 +73,7 @@ def get_feedback_metrics(filters: Optional[FilterRequest] = None):
else:
filters_content = {"origin": ["user-feedback"]}
labels = document_store.get_all_labels(filters=filters_content)
labels = document_store.get_all_labels(filters=filters_content, index=index)
res: Dict[str, Optional[Union[float, int]]]
if len(labels) > 0:
@ -91,7 +91,10 @@ def get_feedback_metrics(filters: Optional[FilterRequest] = None):
@router.get("/export-feedback")
def export_feedback(
context_size: int = 100_000, full_document_context: bool = True, only_positive_labels: bool = False
context_size: int = 100_000,
full_document_context: bool = True,
only_positive_labels: bool = False,
index: Optional[str] = None,
):
"""
This endpoint returns JSON output in the SQuAD format for question/answer pairs that were marked as "relevant" by user feedback through the `POST /feedback` endpoint.
@ -99,9 +102,11 @@ def export_feedback(
The context_size param can be used to limit response size for large documents.
"""
if only_positive_labels:
labels = document_store.get_all_labels(filters={"is_correct_answer": [True], "origin": ["user-feedback"]})
labels = document_store.get_all_labels(
filters={"is_correct_answer": [True], "origin": ["user-feedback"]}, index=index
)
else:
labels = document_store.get_all_labels(filters={"origin": ["user-feedback"]})
labels = document_store.get_all_labels(filters={"origin": ["user-feedback"]}, index=index)
# Filter out the labels where the passage is correct but answer is wrong (in SQuAD this matches
# neither a "positive example" nor a negative "is_impossible" one)
labels = [l for l in labels if not (l.is_correct_document is True and l.is_correct_answer is False)]

View File

@ -260,7 +260,7 @@ def test_get_all_documents(client):
response = client.post(url="/documents/get_by_filters", data='{"filters": {}}')
assert 200 == response.status_code
# Ensure `get_all_documents` was called with the expected `filters` param
MockDocumentStore.mocker.get_all_documents.assert_called_with(filters={})
MockDocumentStore.mocker.get_all_documents.assert_called_with(filters={}, index=None)
# Ensure results are part of the response body
response_json = response.json()
assert len(response_json) == 2
@ -270,21 +270,21 @@ def test_get_documents_with_filters(client):
response = client.post(url="/documents/get_by_filters", data='{"filters": {"test_index": ["2"]}}')
assert 200 == response.status_code
# Ensure `get_all_documents` was called with the expected `filters` param
MockDocumentStore.mocker.get_all_documents.assert_called_with(filters={"test_index": ["2"]})
MockDocumentStore.mocker.get_all_documents.assert_called_with(filters={"test_index": ["2"]}, index=None)
def test_delete_all_documents(client):
response = client.post(url="/documents/delete_by_filters", data='{"filters": {}}')
assert 200 == response.status_code
# Ensure `delete_documents` was called on the Document Store instance
MockDocumentStore.mocker.delete_documents.assert_called_with(filters={})
MockDocumentStore.mocker.delete_documents.assert_called_with(filters={}, index=None)
def test_delete_documents_with_filters(client):
response = client.post(url="/documents/delete_by_filters", data='{"filters": {"test_index": ["1"]}}')
assert 200 == response.status_code
# Ensure `delete_documents` was called on the Document Store instance with the same params
MockDocumentStore.mocker.delete_documents.assert_called_with(filters={"test_index": ["1"]})
MockDocumentStore.mocker.delete_documents.assert_called_with(filters={"test_index": ["1"]}, index=None)
def test_file_upload(client):
@ -564,7 +564,7 @@ def test_delete_feedback(client, monkeypatch, feedback):
# Call the API and ensure `delete_labels` was called only on the label with id=123
response = client.delete(url="/feedback")
assert 200 == response.status_code
MockDocumentStore.mocker.delete_labels.assert_called_with(ids=["123"])
MockDocumentStore.mocker.delete_labels.assert_called_with(ids=["123"], index=None)
def test_export_feedback(client, monkeypatch, feedback):