From 8920fd693965e9011084c87cee9afd565fdcecbf Mon Sep 17 00:00:00 2001 From: Muhammad Bilal <40235535+bilalsattar@users.noreply.github.com> Date: Tue, 1 Aug 2023 10:47:46 +0200 Subject: [PATCH] feat: add optional index selection for endpoints (#5444) * add index selection * reformatting * updated test script --- rest_api/rest_api/controller/document.py | 10 ++++---- rest_api/rest_api/controller/feedback.py | 29 ++++++++++++++---------- rest_api/test/test_rest_api.py | 10 ++++---- 3 files changed, 27 insertions(+), 22 deletions(-) diff --git a/rest_api/rest_api/controller/document.py b/rest_api/rest_api/controller/document.py index 45200c4ce..a1f0bcb21 100644 --- a/rest_api/rest_api/controller/document.py +++ b/rest_api/rest_api/controller/document.py @@ -1,4 +1,4 @@ -from typing import List +from typing import Optional, List import logging @@ -21,7 +21,7 @@ document_store: BaseDocumentStore = get_pipelines().get("document_store", None) @router.post("/documents/get_by_filters", response_model=List[Document], response_model_exclude_none=True) -def get_documents(filters: FilterRequest): +def get_documents(filters: FilterRequest, index: Optional[str] = None): """ This endpoint allows you to retrieve documents contained in your document store. You can filter the documents to retrieve by metadata (like the document's name), @@ -33,14 +33,14 @@ def get_documents(filters: FilterRequest): To get all documents you should provide an empty dict, like: `'{"filters": {}}'` """ - docs = document_store.get_all_documents(filters=filters.filters) + docs = document_store.get_all_documents(filters=filters.filters, index=index) for doc in docs: doc.embedding = None return docs @router.post("/documents/delete_by_filters", response_model=bool) -def delete_documents(filters: FilterRequest): +def delete_documents(filters: FilterRequest, index: Optional[str] = None): """ This endpoint allows you to delete documents contained in your document store. You can filter the documents to delete by metadata (like the document's name), @@ -52,5 +52,5 @@ def delete_documents(filters: FilterRequest): To get all documents you should provide an empty dict, like: `'{"filters": {}}'` """ - document_store.delete_documents(filters=filters.filters) + document_store.delete_documents(filters=filters.filters, index=index) return True diff --git a/rest_api/rest_api/controller/feedback.py b/rest_api/rest_api/controller/feedback.py index af8da58e0..bf38513ea 100644 --- a/rest_api/rest_api/controller/feedback.py +++ b/rest_api/rest_api/controller/feedback.py @@ -18,7 +18,7 @@ document_store: BaseDocumentStore = get_pipelines().get("document_store", None) @router.post("/feedback") -def post_feedback(feedback: CreateLabelSerialized): +def post_feedback(feedback: CreateLabelSerialized, index: Optional[str] = None): """ With this endpoint, the API user can submit their feedback on an answer for a particular query. This feedback is then written to the label_index of the DocumentStore. @@ -31,31 +31,31 @@ def post_feedback(feedback: CreateLabelSerialized): feedback.origin = "user-feedback" label = Label(**feedback.dict()) - document_store.write_labels([label]) + document_store.write_labels([label], index=index) @router.get("/feedback", response_model=List[Label]) -def get_feedback(): +def get_feedback(index: Optional[str] = None): """ This endpoint allows the API user to retrieve all the feedback that has been submitted through the `POST /feedback` endpoint. """ - labels = document_store.get_all_labels() + labels = document_store.get_all_labels(index=index) return labels @router.delete("/feedback") -def delete_feedback(): +def delete_feedback(index: Optional[str] = None): """ This endpoint allows the API user to delete all the feedback that has been sumbitted through the `POST /feedback` endpoint. """ - all_labels = document_store.get_all_labels() + all_labels = document_store.get_all_labels(index=index) user_label_ids = [label.id for label in all_labels if label.origin == "user-feedback"] - document_store.delete_labels(ids=user_label_ids) + document_store.delete_labels(ids=user_label_ids, index=index) @router.post("/eval-feedback") -def get_feedback_metrics(filters: Optional[FilterRequest] = None): +def get_feedback_metrics(filters: Optional[FilterRequest] = None, index: Optional[str] = None): """ This endpoint returns basic accuracy metrics based on user feedback, for example, the ratio of correct answers or correctly identified documents. You can filter the output by document or label. @@ -73,7 +73,7 @@ def get_feedback_metrics(filters: Optional[FilterRequest] = None): else: filters_content = {"origin": ["user-feedback"]} - labels = document_store.get_all_labels(filters=filters_content) + labels = document_store.get_all_labels(filters=filters_content, index=index) res: Dict[str, Optional[Union[float, int]]] if len(labels) > 0: @@ -91,7 +91,10 @@ def get_feedback_metrics(filters: Optional[FilterRequest] = None): @router.get("/export-feedback") def export_feedback( - context_size: int = 100_000, full_document_context: bool = True, only_positive_labels: bool = False + context_size: int = 100_000, + full_document_context: bool = True, + only_positive_labels: bool = False, + index: Optional[str] = None, ): """ This endpoint returns JSON output in the SQuAD format for question/answer pairs that were marked as "relevant" by user feedback through the `POST /feedback` endpoint. @@ -99,9 +102,11 @@ def export_feedback( The context_size param can be used to limit response size for large documents. """ if only_positive_labels: - labels = document_store.get_all_labels(filters={"is_correct_answer": [True], "origin": ["user-feedback"]}) + labels = document_store.get_all_labels( + filters={"is_correct_answer": [True], "origin": ["user-feedback"]}, index=index + ) else: - labels = document_store.get_all_labels(filters={"origin": ["user-feedback"]}) + labels = document_store.get_all_labels(filters={"origin": ["user-feedback"]}, index=index) # Filter out the labels where the passage is correct but answer is wrong (in SQuAD this matches # neither a "positive example" nor a negative "is_impossible" one) labels = [l for l in labels if not (l.is_correct_document is True and l.is_correct_answer is False)] diff --git a/rest_api/test/test_rest_api.py b/rest_api/test/test_rest_api.py index ee212bf44..fd0153b05 100644 --- a/rest_api/test/test_rest_api.py +++ b/rest_api/test/test_rest_api.py @@ -260,7 +260,7 @@ def test_get_all_documents(client): response = client.post(url="/documents/get_by_filters", data='{"filters": {}}') assert 200 == response.status_code # Ensure `get_all_documents` was called with the expected `filters` param - MockDocumentStore.mocker.get_all_documents.assert_called_with(filters={}) + MockDocumentStore.mocker.get_all_documents.assert_called_with(filters={}, index=None) # Ensure results are part of the response body response_json = response.json() assert len(response_json) == 2 @@ -270,21 +270,21 @@ def test_get_documents_with_filters(client): response = client.post(url="/documents/get_by_filters", data='{"filters": {"test_index": ["2"]}}') assert 200 == response.status_code # Ensure `get_all_documents` was called with the expected `filters` param - MockDocumentStore.mocker.get_all_documents.assert_called_with(filters={"test_index": ["2"]}) + MockDocumentStore.mocker.get_all_documents.assert_called_with(filters={"test_index": ["2"]}, index=None) def test_delete_all_documents(client): response = client.post(url="/documents/delete_by_filters", data='{"filters": {}}') assert 200 == response.status_code # Ensure `delete_documents` was called on the Document Store instance - MockDocumentStore.mocker.delete_documents.assert_called_with(filters={}) + MockDocumentStore.mocker.delete_documents.assert_called_with(filters={}, index=None) def test_delete_documents_with_filters(client): response = client.post(url="/documents/delete_by_filters", data='{"filters": {"test_index": ["1"]}}') assert 200 == response.status_code # Ensure `delete_documents` was called on the Document Store instance with the same params - MockDocumentStore.mocker.delete_documents.assert_called_with(filters={"test_index": ["1"]}) + MockDocumentStore.mocker.delete_documents.assert_called_with(filters={"test_index": ["1"]}, index=None) def test_file_upload(client): @@ -564,7 +564,7 @@ def test_delete_feedback(client, monkeypatch, feedback): # Call the API and ensure `delete_labels` was called only on the label with id=123 response = client.delete(url="/feedback") assert 200 == response.status_code - MockDocumentStore.mocker.delete_labels.assert_called_with(ids=["123"]) + MockDocumentStore.mocker.delete_labels.assert_called_with(ids=["123"], index=None) def test_export_feedback(client, monkeypatch, feedback):