mirror of
				https://github.com/deepset-ai/haystack.git
				synced 2025-11-04 03:39:31 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			171 lines
		
	
	
		
			7.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			171 lines
		
	
	
		
			7.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import json
 | 
						|
import logging
 | 
						|
from typing import Dict, Union, List, Optional
 | 
						|
 | 
						|
from fastapi import APIRouter, HTTPException
 | 
						|
from pydantic import BaseModel, Field
 | 
						|
from rest_api.controller.search import PIPELINE
 | 
						|
 | 
						|
router = APIRouter()
 | 
						|
 | 
						|
logger = logging.getLogger(__name__)
 | 
						|
 | 
						|
# TODO make this generic for other pipelines with different naming
 | 
						|
retriever = PIPELINE.get_node(name="ESRetriever")
 | 
						|
document_store = retriever.document_store if retriever else None
 | 
						|
 | 
						|
 | 
						|
class ExtractiveQAFeedback(BaseModel):
 | 
						|
    question: str = Field(..., description="The question input by the user, i.e., the query.")
 | 
						|
    is_correct_answer: bool = Field(..., description="Whether the answer is correct or not.")
 | 
						|
    document_id: str = Field(..., description="The document in the query result for which feedback is given.")
 | 
						|
    model_id: Optional[int] = Field(None, description="The model used for the query.")
 | 
						|
    is_correct_document: bool = Field(
 | 
						|
        ...,
 | 
						|
        description="In case of negative feedback, there could be two cases; incorrect answer but correct "
 | 
						|
        "document & incorrect document. This flag denotes if the returned document was correct.",
 | 
						|
    )
 | 
						|
    answer: str = Field(..., description="The answer string.")
 | 
						|
    offset_start_in_doc: int = Field(
 | 
						|
        ..., description="The answer start offset in the original doc. Only required for doc-qa feedback."
 | 
						|
    )
 | 
						|
 | 
						|
 | 
						|
class FilterRequest(BaseModel):
 | 
						|
    filters: Optional[Dict[str, Optional[Union[str, List[str]]]]] = None
 | 
						|
 | 
						|
 | 
						|
@router.post("/feedback")
 | 
						|
def user_feedback(feedback: ExtractiveQAFeedback):
 | 
						|
    document_store.write_labels([{"origin": "user-feedback", **feedback.dict()}])
 | 
						|
 | 
						|
 | 
						|
@router.post("/eval-feedback")
 | 
						|
def eval_extractive_qa_feedback(filters: FilterRequest = None):
 | 
						|
    """
 | 
						|
    Return basic accuracy metrics based on the user feedback.
 | 
						|
    Which ratio of answers was correct? Which ratio of documents was correct?
 | 
						|
    You can supply filters in the request to only use a certain subset of labels.
 | 
						|
 | 
						|
    **Example:**
 | 
						|
 | 
						|
        ```
 | 
						|
            | curl --location --request POST 'http://127.0.0.1:8000/eval-doc-qa-feedback' \
 | 
						|
            | --header 'Content-Type: application/json' \
 | 
						|
            | --data-raw '{ "filters": {"document_id": ["XRR3xnEBCYVTkbTystOB"]} }'
 | 
						|
 | 
						|
    """
 | 
						|
 | 
						|
    if filters:
 | 
						|
        filters = filters.filters
 | 
						|
        filters["origin"] = ["user-feedback"]
 | 
						|
    else:
 | 
						|
        filters = {"origin": ["user-feedback"]}
 | 
						|
 | 
						|
    labels = document_store.get_all_labels(filters=filters)
 | 
						|
 | 
						|
    if len(labels) > 0:
 | 
						|
        answer_feedback = [1 if l.is_correct_answer else 0 for l in labels]
 | 
						|
        doc_feedback = [1 if l.is_correct_document else 0 for l in labels]
 | 
						|
 | 
						|
        answer_accuracy = sum(answer_feedback) / len(answer_feedback)
 | 
						|
        doc_accuracy = sum(doc_feedback) / len(doc_feedback)
 | 
						|
 | 
						|
        res = {"answer_accuracy": answer_accuracy, "document_accuracy": doc_accuracy, "n_feedback": len(labels)}
 | 
						|
    else:
 | 
						|
        res = {"answer_accuracy": None, "document_accuracy": None, "n_feedback": 0}
 | 
						|
    return res
 | 
						|
 | 
						|
 | 
						|
@router.get("/export-feedback")
 | 
						|
def export_extractive_qa_feedback(
 | 
						|
    context_size: int = 100_000, full_document_context: bool = True, only_positive_labels: bool = False
 | 
						|
):
 | 
						|
    """
 | 
						|
    SQuAD format JSON export for question/answer pairs that were marked as "relevant".
 | 
						|
 | 
						|
    The context_size param can be used to limit response size for large documents.
 | 
						|
    """
 | 
						|
    if only_positive_labels:
 | 
						|
        labels = document_store.get_all_labels(filters={"is_correct_answer": [True], "origin": ["user-feedback"]})
 | 
						|
    else:
 | 
						|
        labels = document_store.get_all_labels(filters={"origin": ["user-feedback"]})
 | 
						|
        # Filter out the labels where the passage is correct but answer is wrong (in SQuAD this matches
 | 
						|
        # neither a "positive example" nor a negative "is_impossible" one)
 | 
						|
        labels = [l for l in labels if not (l.is_correct_document is True and l.is_correct_answer is False)]
 | 
						|
 | 
						|
    export_data = []
 | 
						|
 | 
						|
    for label in labels:
 | 
						|
        document = document_store.get_document_by_id(label.document_id)
 | 
						|
        if document is None:
 | 
						|
            raise HTTPException(
 | 
						|
                status_code=500, detail="Could not find document with id {label.document_id} for label id {label.id}"
 | 
						|
            )
 | 
						|
 | 
						|
        if full_document_context:
 | 
						|
            context = document.text
 | 
						|
            answer_start = label.offset_start_in_doc
 | 
						|
        else:
 | 
						|
            text = document.text
 | 
						|
            # the final length of context(including the answer string) is 'context_size'.
 | 
						|
            # we try to add equal characters for context before and after the answer string.
 | 
						|
            # if either beginning or end of text is reached, we correspondingly
 | 
						|
            # append more context characters at the other end of answer string.
 | 
						|
            context_to_add = int((context_size - len(label.answer)) / 2)
 | 
						|
            start_pos = max(label.offset_start_in_doc - context_to_add, 0)
 | 
						|
            additional_context_at_end = max(context_to_add - label.offset_start_in_doc, 0)
 | 
						|
            end_pos = min(label.offset_start_in_doc + len(label.answer) + context_to_add, len(text) - 1)
 | 
						|
            additional_context_at_start = max(
 | 
						|
                label.offset_start_in_doc + len(label.answer) + context_to_add - len(text), 0
 | 
						|
            )
 | 
						|
            start_pos = max(0, start_pos - additional_context_at_start)
 | 
						|
            end_pos = min(len(text) - 1, end_pos + additional_context_at_end)
 | 
						|
            context = text[start_pos:end_pos]
 | 
						|
            answer_start = label.offset_start_in_doc - start_pos
 | 
						|
 | 
						|
        if label.is_correct_answer is False and label.is_correct_document is False:  # No answer
 | 
						|
            squad_label = {
 | 
						|
                "paragraphs": [
 | 
						|
                    {
 | 
						|
                        "context": context,
 | 
						|
                        "id": label.document_id,
 | 
						|
                        "qas": [{"question": label.question, "id": label.id, "is_impossible": True, "answers": []}],
 | 
						|
                    }
 | 
						|
                ]
 | 
						|
            }
 | 
						|
        else:
 | 
						|
            squad_label = {
 | 
						|
                "paragraphs": [
 | 
						|
                    {
 | 
						|
                        "context": context,
 | 
						|
                        "id": label.document_id,
 | 
						|
                        "qas": [
 | 
						|
                            {
 | 
						|
                                "question": label.question,
 | 
						|
                                "id": label.id,
 | 
						|
                                "is_impossible": False,
 | 
						|
                                "answers": [{"text": label.answer, "answer_start": answer_start}],
 | 
						|
                            }
 | 
						|
                        ],
 | 
						|
                    }
 | 
						|
                ]
 | 
						|
            }
 | 
						|
 | 
						|
            # quality check
 | 
						|
            start = squad_label["paragraphs"][0]["qas"][0]["answers"][0]["answer_start"]
 | 
						|
            answer = squad_label["paragraphs"][0]["qas"][0]["answers"][0]["text"]
 | 
						|
            context = squad_label["paragraphs"][0]["context"]
 | 
						|
            if not context[start: start + len(answer)] == answer:
 | 
						|
                logger.error(
 | 
						|
                    f"Skipping invalid squad label as string via offsets "
 | 
						|
                    f"('{context[start:start + len(answer)]}') does not match answer string ('{answer}') "
 | 
						|
                )
 | 
						|
        export_data.append(squad_label)
 | 
						|
 | 
						|
    export = {"data": export_data}
 | 
						|
 | 
						|
    with open("feedback_squad_direct.json", "w", encoding="utf8") as f:
 | 
						|
        json.dump(export_data, f, ensure_ascii=False, sort_keys=True, indent=4)
 | 
						|
    return export
 |