mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-27 15:08:43 +00:00
feat: add optional index selection for endpoints (#5444)
* add index selection * reformatting * updated test script
This commit is contained in:
parent
62029ba441
commit
8920fd6939
@ -1,4 +1,4 @@
|
||||
from typing import List
|
||||
from typing import Optional, List
|
||||
|
||||
import logging
|
||||
|
||||
@ -21,7 +21,7 @@ document_store: BaseDocumentStore = get_pipelines().get("document_store", None)
|
||||
|
||||
|
||||
@router.post("/documents/get_by_filters", response_model=List[Document], response_model_exclude_none=True)
|
||||
def get_documents(filters: FilterRequest):
|
||||
def get_documents(filters: FilterRequest, index: Optional[str] = None):
|
||||
"""
|
||||
This endpoint allows you to retrieve documents contained in your document store.
|
||||
You can filter the documents to retrieve by metadata (like the document's name),
|
||||
@ -33,14 +33,14 @@ def get_documents(filters: FilterRequest):
|
||||
To get all documents you should provide an empty dict, like:
|
||||
`'{"filters": {}}'`
|
||||
"""
|
||||
docs = document_store.get_all_documents(filters=filters.filters)
|
||||
docs = document_store.get_all_documents(filters=filters.filters, index=index)
|
||||
for doc in docs:
|
||||
doc.embedding = None
|
||||
return docs
|
||||
|
||||
|
||||
@router.post("/documents/delete_by_filters", response_model=bool)
|
||||
def delete_documents(filters: FilterRequest):
|
||||
def delete_documents(filters: FilterRequest, index: Optional[str] = None):
|
||||
"""
|
||||
This endpoint allows you to delete documents contained in your document store.
|
||||
You can filter the documents to delete by metadata (like the document's name),
|
||||
@ -52,5 +52,5 @@ def delete_documents(filters: FilterRequest):
|
||||
To get all documents you should provide an empty dict, like:
|
||||
`'{"filters": {}}'`
|
||||
"""
|
||||
document_store.delete_documents(filters=filters.filters)
|
||||
document_store.delete_documents(filters=filters.filters, index=index)
|
||||
return True
|
||||
|
||||
@ -18,7 +18,7 @@ document_store: BaseDocumentStore = get_pipelines().get("document_store", None)
|
||||
|
||||
|
||||
@router.post("/feedback")
|
||||
def post_feedback(feedback: CreateLabelSerialized):
|
||||
def post_feedback(feedback: CreateLabelSerialized, index: Optional[str] = None):
|
||||
"""
|
||||
With this endpoint, the API user can submit their feedback on an answer for a particular query. This feedback is then written to the label_index of the DocumentStore.
|
||||
|
||||
@ -31,31 +31,31 @@ def post_feedback(feedback: CreateLabelSerialized):
|
||||
feedback.origin = "user-feedback"
|
||||
|
||||
label = Label(**feedback.dict())
|
||||
document_store.write_labels([label])
|
||||
document_store.write_labels([label], index=index)
|
||||
|
||||
|
||||
@router.get("/feedback", response_model=List[Label])
|
||||
def get_feedback():
|
||||
def get_feedback(index: Optional[str] = None):
|
||||
"""
|
||||
This endpoint allows the API user to retrieve all the feedback that has been submitted through the `POST /feedback` endpoint.
|
||||
"""
|
||||
labels = document_store.get_all_labels()
|
||||
labels = document_store.get_all_labels(index=index)
|
||||
return labels
|
||||
|
||||
|
||||
@router.delete("/feedback")
|
||||
def delete_feedback():
|
||||
def delete_feedback(index: Optional[str] = None):
|
||||
"""
|
||||
This endpoint allows the API user to delete all the feedback that has been sumbitted through the
|
||||
`POST /feedback` endpoint.
|
||||
"""
|
||||
all_labels = document_store.get_all_labels()
|
||||
all_labels = document_store.get_all_labels(index=index)
|
||||
user_label_ids = [label.id for label in all_labels if label.origin == "user-feedback"]
|
||||
document_store.delete_labels(ids=user_label_ids)
|
||||
document_store.delete_labels(ids=user_label_ids, index=index)
|
||||
|
||||
|
||||
@router.post("/eval-feedback")
|
||||
def get_feedback_metrics(filters: Optional[FilterRequest] = None):
|
||||
def get_feedback_metrics(filters: Optional[FilterRequest] = None, index: Optional[str] = None):
|
||||
"""
|
||||
This endpoint returns basic accuracy metrics based on user feedback, for example, the ratio of correct answers or correctly identified documents.
|
||||
You can filter the output by document or label.
|
||||
@ -73,7 +73,7 @@ def get_feedback_metrics(filters: Optional[FilterRequest] = None):
|
||||
else:
|
||||
filters_content = {"origin": ["user-feedback"]}
|
||||
|
||||
labels = document_store.get_all_labels(filters=filters_content)
|
||||
labels = document_store.get_all_labels(filters=filters_content, index=index)
|
||||
|
||||
res: Dict[str, Optional[Union[float, int]]]
|
||||
if len(labels) > 0:
|
||||
@ -91,7 +91,10 @@ def get_feedback_metrics(filters: Optional[FilterRequest] = None):
|
||||
|
||||
@router.get("/export-feedback")
|
||||
def export_feedback(
|
||||
context_size: int = 100_000, full_document_context: bool = True, only_positive_labels: bool = False
|
||||
context_size: int = 100_000,
|
||||
full_document_context: bool = True,
|
||||
only_positive_labels: bool = False,
|
||||
index: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
This endpoint returns JSON output in the SQuAD format for question/answer pairs that were marked as "relevant" by user feedback through the `POST /feedback` endpoint.
|
||||
@ -99,9 +102,11 @@ def export_feedback(
|
||||
The context_size param can be used to limit response size for large documents.
|
||||
"""
|
||||
if only_positive_labels:
|
||||
labels = document_store.get_all_labels(filters={"is_correct_answer": [True], "origin": ["user-feedback"]})
|
||||
labels = document_store.get_all_labels(
|
||||
filters={"is_correct_answer": [True], "origin": ["user-feedback"]}, index=index
|
||||
)
|
||||
else:
|
||||
labels = document_store.get_all_labels(filters={"origin": ["user-feedback"]})
|
||||
labels = document_store.get_all_labels(filters={"origin": ["user-feedback"]}, index=index)
|
||||
# Filter out the labels where the passage is correct but answer is wrong (in SQuAD this matches
|
||||
# neither a "positive example" nor a negative "is_impossible" one)
|
||||
labels = [l for l in labels if not (l.is_correct_document is True and l.is_correct_answer is False)]
|
||||
|
||||
@ -260,7 +260,7 @@ def test_get_all_documents(client):
|
||||
response = client.post(url="/documents/get_by_filters", data='{"filters": {}}')
|
||||
assert 200 == response.status_code
|
||||
# Ensure `get_all_documents` was called with the expected `filters` param
|
||||
MockDocumentStore.mocker.get_all_documents.assert_called_with(filters={})
|
||||
MockDocumentStore.mocker.get_all_documents.assert_called_with(filters={}, index=None)
|
||||
# Ensure results are part of the response body
|
||||
response_json = response.json()
|
||||
assert len(response_json) == 2
|
||||
@ -270,21 +270,21 @@ def test_get_documents_with_filters(client):
|
||||
response = client.post(url="/documents/get_by_filters", data='{"filters": {"test_index": ["2"]}}')
|
||||
assert 200 == response.status_code
|
||||
# Ensure `get_all_documents` was called with the expected `filters` param
|
||||
MockDocumentStore.mocker.get_all_documents.assert_called_with(filters={"test_index": ["2"]})
|
||||
MockDocumentStore.mocker.get_all_documents.assert_called_with(filters={"test_index": ["2"]}, index=None)
|
||||
|
||||
|
||||
def test_delete_all_documents(client):
|
||||
response = client.post(url="/documents/delete_by_filters", data='{"filters": {}}')
|
||||
assert 200 == response.status_code
|
||||
# Ensure `delete_documents` was called on the Document Store instance
|
||||
MockDocumentStore.mocker.delete_documents.assert_called_with(filters={})
|
||||
MockDocumentStore.mocker.delete_documents.assert_called_with(filters={}, index=None)
|
||||
|
||||
|
||||
def test_delete_documents_with_filters(client):
|
||||
response = client.post(url="/documents/delete_by_filters", data='{"filters": {"test_index": ["1"]}}')
|
||||
assert 200 == response.status_code
|
||||
# Ensure `delete_documents` was called on the Document Store instance with the same params
|
||||
MockDocumentStore.mocker.delete_documents.assert_called_with(filters={"test_index": ["1"]})
|
||||
MockDocumentStore.mocker.delete_documents.assert_called_with(filters={"test_index": ["1"]}, index=None)
|
||||
|
||||
|
||||
def test_file_upload(client):
|
||||
@ -564,7 +564,7 @@ def test_delete_feedback(client, monkeypatch, feedback):
|
||||
# Call the API and ensure `delete_labels` was called only on the label with id=123
|
||||
response = client.delete(url="/feedback")
|
||||
assert 200 == response.status_code
|
||||
MockDocumentStore.mocker.delete_labels.assert_called_with(ids=["123"])
|
||||
MockDocumentStore.mocker.delete_labels.assert_called_with(ids=["123"], index=None)
|
||||
|
||||
|
||||
def test_export_feedback(client, monkeypatch, feedback):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user