mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-28 07:29:06 +00:00
Refactor REST APIs to use Pipelines (#922)
This commit is contained in:
parent
64ad953c6a
commit
8c68699e1c
@ -7,21 +7,9 @@ services:
|
||||
image: "deepset/haystack-cpu:latest"
|
||||
ports:
|
||||
- 8000:8000
|
||||
volumes:
|
||||
# Optional: mount your own models from disk into the container
|
||||
- "./models:/home/user/models"
|
||||
environment:
|
||||
# See rest_api/config.py for more variables that you can configure here.
|
||||
- DB_HOST=elasticsearch
|
||||
- USE_GPU=False
|
||||
- TOP_K_PER_SAMPLE=3 # how many answers can come from the same small passage (reduce value for more variety of answers)
|
||||
# Load a model from transformers' model hub or a local path into the FARMReader.
|
||||
- READER_MODEL_PATH=deepset/roberta-base-squad2
|
||||
# - READER_MODEL_PATH=home/user/models/roberta-base-squad2
|
||||
# Alternative: If you want to use the TransformersReader (e.g. for loading a local model in transformers format):
|
||||
# - READER_TYPE=TransformersReader
|
||||
# - READER_MODEL_PATH=/home/user/models/roberta-base-squad2
|
||||
# - READER_TOKENIZER=/home/user/models/roberta-base-squad2
|
||||
# See rest_api/pipelines.yaml for configurations of Search & Indexing Pipeline.
|
||||
- ELASTICSEARCHDOCUMENTSTORE_PARAMS_HOST=elasticsearch
|
||||
restart: always
|
||||
depends_on:
|
||||
- elasticsearch
|
||||
@ -36,7 +24,9 @@ services:
|
||||
environment:
|
||||
- discovery.type=single-node
|
||||
ui:
|
||||
image: "deepset/haystack-streamlit-ui"
|
||||
build:
|
||||
context: ui
|
||||
dockerfile: Dockerfile
|
||||
ports:
|
||||
- 8501:8501
|
||||
environment:
|
||||
|
||||
@ -68,6 +68,15 @@ supplied meta data like author, url, external IDs can be supplied as a dictionar
|
||||
|
||||
Validate if the language of the text is one of valid languages.
|
||||
|
||||
<a name="base.FileTypeClassifier"></a>
|
||||
## FileTypeClassifier Objects
|
||||
|
||||
```python
|
||||
class FileTypeClassifier(BaseComponent)
|
||||
```
|
||||
|
||||
Route files in an Indexing Pipeline to corresponding file converters.
|
||||
|
||||
<a name="txt"></a>
|
||||
# Module txt
|
||||
|
||||
|
||||
@ -40,7 +40,7 @@ Add a new node to the pipeline.
|
||||
#### get\_node
|
||||
|
||||
```python
|
||||
| get_node(name: str)
|
||||
| get_node(name: str) -> Optional[BaseComponent]
|
||||
```
|
||||
|
||||
Get a node from the Pipeline.
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from haystack.file_converter.base import FileTypeClassifier
|
||||
from haystack.file_converter.docx import DocxToTextConverter
|
||||
from haystack.file_converter.markdown import MarkdownConverter
|
||||
from haystack.file_converter.pdf import PDFToTextConverter
|
||||
from haystack.file_converter.tika import TikaConverter
|
||||
from haystack.file_converter.txt import TextConverter
|
||||
from haystack.file_converter.markdown import MarkdownConverter
|
||||
|
||||
@ -89,3 +89,19 @@ class BaseConverter(BaseComponent):
|
||||
|
||||
result = {"document": document, **kwargs}
|
||||
return result, "output_1"
|
||||
|
||||
|
||||
class FileTypeClassifier(BaseComponent):
|
||||
"""
|
||||
Route files in an Indexing Pipeline to corresponding file converters.
|
||||
"""
|
||||
outgoing_edges = 5
|
||||
|
||||
def run(self, file_path: Path, **kwargs): # type: ignore
|
||||
output = {"file_path": file_path, **kwargs}
|
||||
ext = file_path.name.split(".")[-1].lower()
|
||||
try:
|
||||
index = ["txt", "pdf", "md", "docx", "html"].index(ext) + 1
|
||||
return output, f"output_{index}"
|
||||
except ValueError:
|
||||
raise Exception(f"Files with an extension '{ext}' are not supported.")
|
||||
|
||||
@ -85,13 +85,14 @@ class Pipeline(ABC):
|
||||
input_edge_name = "output_1"
|
||||
self.graph.add_edge(input_node_name, name, label=input_edge_name)
|
||||
|
||||
def get_node(self, name: str):
|
||||
def get_node(self, name: str) -> Optional[BaseComponent]:
|
||||
"""
|
||||
Get a node from the Pipeline.
|
||||
|
||||
:param name: The name of the node.
|
||||
"""
|
||||
component = self.graph.nodes[name]["component"]
|
||||
graph_node = self.graph.nodes.get(name)
|
||||
component = graph_node["component"] if graph_node else None
|
||||
return component
|
||||
|
||||
def set_node(self, name: str, component):
|
||||
@ -219,7 +220,7 @@ class Pipeline(ABC):
|
||||
else:
|
||||
pipelines_in_yaml = list(filter(lambda p: p["name"] == pipeline_name, data["pipelines"]))
|
||||
if not pipelines_in_yaml:
|
||||
raise Exception(f"Cannot find any pipeline with name '{pipeline_name}' declared in the YAML file.")
|
||||
raise KeyError(f"Cannot find any pipeline with name '{pipeline_name}' declared in the YAML file.")
|
||||
pipeline_config = pipelines_in_yaml[0]
|
||||
|
||||
definitions = {} # definitions of each component from the YAML.
|
||||
@ -252,7 +253,7 @@ class Pipeline(ABC):
|
||||
if name in components.keys(): # check if component is already loaded.
|
||||
return components[name]
|
||||
|
||||
component_params = definitions[name]["params"]
|
||||
component_params = definitions[name].get("params", {})
|
||||
component_type = definitions[name]["type"]
|
||||
logger.debug(f"Loading component `{name}` of type `{definitions[name]['type']}`")
|
||||
|
||||
|
||||
@ -37,8 +37,27 @@ class BasePreProcessor(BaseComponent):
|
||||
) -> List[Dict[str, Any]]:
|
||||
raise NotImplementedError
|
||||
|
||||
def run(self, document: dict, **kwargs): # type: ignore
|
||||
documents = self.process(document)
|
||||
|
||||
def run( # type: ignore
|
||||
self,
|
||||
document: dict,
|
||||
clean_whitespace: Optional[bool] = None,
|
||||
clean_header_footer: Optional[bool] = None,
|
||||
clean_empty_lines: Optional[bool] = None,
|
||||
split_by: Optional[str] = None,
|
||||
split_length: Optional[int] = None,
|
||||
split_overlap: Optional[int] = None,
|
||||
split_respect_sentence_boundary: Optional[bool] = None,
|
||||
**kwargs,
|
||||
):
|
||||
documents = self.process(
|
||||
document=document,
|
||||
clean_whitespace=clean_whitespace,
|
||||
clean_header_footer=clean_header_footer,
|
||||
clean_empty_lines=clean_empty_lines,
|
||||
split_by=split_by,
|
||||
split_length=split_length,
|
||||
split_overlap=split_overlap,
|
||||
split_respect_sentence_boundary=split_respect_sentence_boundary,
|
||||
)
|
||||
result = {"documents": documents, **kwargs}
|
||||
return result, "output_1"
|
||||
|
||||
@ -1,31 +1,27 @@
|
||||
import logging
|
||||
|
||||
import uvicorn
|
||||
from elasticapm.contrib.starlette import make_apm_client, ElasticAPM
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
|
||||
from rest_api.config import APM_SERVER, APM_SERVICE_NAME
|
||||
from rest_api.controller.errors.http_error import http_error_handler
|
||||
from rest_api.controller.router import router as api_router
|
||||
|
||||
logging.basicConfig(format="%(asctime)s %(message)s", datefmt="%m/%d/%Y %I:%M:%S %p")
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.getLogger("elasticsearch").setLevel(logging.WARNING)
|
||||
logging.getLogger("haystack").setLevel(logging.INFO)
|
||||
|
||||
|
||||
def get_application() -> FastAPI:
|
||||
application = FastAPI(title="Haystack-API", debug=True, version="0.1")
|
||||
|
||||
# This middleware enables allow all cross-domain requests to the API from a browser. For production
|
||||
# deployments, it could be made more restrictive.
|
||||
application.add_middleware(
|
||||
CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"],
|
||||
)
|
||||
|
||||
if APM_SERVER:
|
||||
apm_config = {"SERVICE_NAME": APM_SERVICE_NAME, "SERVER_URL": APM_SERVER, "CAPTURE_BODY": "all"}
|
||||
elasticapm = make_apm_client(apm_config)
|
||||
application.add_middleware(ElasticAPM, client=elasticapm)
|
||||
|
||||
application.add_exception_handler(HTTPException, http_error_handler)
|
||||
|
||||
application.include_router(api_router)
|
||||
@ -38,7 +34,7 @@ app = get_application()
|
||||
logger.info("Open http://127.0.0.1:8000/docs to see Swagger API Documentation.")
|
||||
logger.info(
|
||||
"""
|
||||
Or just try it out directly: curl --request POST --url 'http://127.0.0.1:8000/models/1/doc-qa' --data '{"questions": ["What is the capital of Germany?"]}'
|
||||
Or just try it out directly: curl --request POST --url 'http://127.0.0.1:8000/query' --data '{"query": "Did Albus Dumbledore die?"}'
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
@ -1,74 +1,9 @@
|
||||
import ast
|
||||
import os
|
||||
|
||||
# FastAPI
|
||||
PROJECT_NAME = os.getenv("PROJECT_NAME", "FastAPI")
|
||||
PIPELINE_YAML_PATH = os.getenv("PIPELINE_YAML_PATH", "rest_api/pipelines.yaml")
|
||||
QUERY_PIPELINE_NAME = os.getenv("QUERY_PIPELINE_NAME", "query")
|
||||
INDEXING_PIPELINE_NAME = os.getenv("INDEXING_PIPELINE_NAME", "indexing")
|
||||
|
||||
# Resources / Computation
|
||||
USE_GPU = os.getenv("USE_GPU", "True").lower() == "true"
|
||||
GPU_NUMBER = int(os.getenv("GPU_NUMBER", 1))
|
||||
MAX_PROCESSES = int(os.getenv("MAX_PROCESSES", 0))
|
||||
BATCHSIZE = int(os.getenv("BATCHSIZE", 50))
|
||||
CONCURRENT_REQUEST_PER_WORKER = int(os.getenv("CONCURRENT_REQUEST_PER_WORKER", 4))
|
||||
FILE_UPLOAD_PATH = os.getenv("FILE_UPLOAD_PATH", "./file-upload")
|
||||
|
||||
# DB
|
||||
DB_HOST = os.getenv("DB_HOST", "localhost")
|
||||
DB_PORT = int(os.getenv("DB_PORT", 9200))
|
||||
DB_USER = os.getenv("DB_USER", "")
|
||||
DB_PW = os.getenv("DB_PW", "")
|
||||
DB_INDEX = os.getenv("DB_INDEX", "document")
|
||||
DB_INDEX_FEEDBACK = os.getenv("DB_INDEX_FEEDBACK", "label")
|
||||
ES_CONN_SCHEME = os.getenv("ES_CONN_SCHEME", "http")
|
||||
TEXT_FIELD_NAME = os.getenv("TEXT_FIELD_NAME", "text")
|
||||
NAME_FIELD_NAME = os.getenv("NAME_FIELD_NAME", "name")
|
||||
SEARCH_FIELD_NAME = os.getenv("SEARCH_FIELD_NAME", "text")
|
||||
FAQ_QUESTION_FIELD_NAME = os.getenv("FAQ_QUESTION_FIELD_NAME", "question")
|
||||
EMBEDDING_FIELD_NAME = os.getenv("EMBEDDING_FIELD_NAME", "embedding")
|
||||
EMBEDDING_DIM = int(os.getenv("EMBEDDING_DIM", 768))
|
||||
VECTOR_SIMILARITY_METRIC = os.getenv("VECTOR_SIMILARITY_METRIC", "dot_product")
|
||||
CREATE_INDEX = os.getenv("CREATE_INDEX", "True").lower() == "true"
|
||||
UPDATE_EXISTING_DOCUMENTS = os.getenv("UPDATE_EXISTING_DOCUMENTS", "False").lower() == "true"
|
||||
|
||||
# Reader
|
||||
READER_MODEL_PATH = os.getenv("READER_MODEL_PATH", "deepset/roberta-base-squad2")
|
||||
READER_TYPE = os.getenv("READER_TYPE", "FARMReader") # alternative: 'TransformersReader'
|
||||
READER_TOKENIZER = os.getenv("READER_TOKENIZER", None)
|
||||
CONTEXT_WINDOW_SIZE = int(os.getenv("CONTEXT_WINDOW_SIZE", 500))
|
||||
DEFAULT_TOP_K_READER = int(os.getenv("DEFAULT_TOP_K_READER", 5)) # How many answers to return in total
|
||||
TOP_K_PER_CANDIDATE = int(os.getenv("TOP_K_PER_CANDIDATE", 3)) # How many answers can come from one indexed doc
|
||||
TOP_K_PER_SAMPLE = int(os.getenv("TOP_K_PER_SAMPLE", 1)) # How many answers can come from one passage that the reader processes at once (i.e. text of max_seq_len from the doc)
|
||||
NO_ANS_BOOST = int(os.getenv("NO_ANS_BOOST", -10))
|
||||
READER_CAN_HAVE_NO_ANSWER = os.getenv("READER_CAN_HAVE_NO_ANSWER", "True").lower() == "true"
|
||||
DOC_STRIDE = int(os.getenv("DOC_STRIDE", 128))
|
||||
MAX_SEQ_LEN = int(os.getenv("MAX_SEQ_LEN", 256))
|
||||
|
||||
# Retriever
|
||||
RETRIEVER_TYPE = os.getenv("RETRIEVER_TYPE", "ElasticsearchRetriever") # alternatives: 'EmbeddingRetriever', 'ElasticsearchRetriever', 'ElasticsearchFilterOnlyRetriever', None
|
||||
DEFAULT_TOP_K_RETRIEVER = int(os.getenv("DEFAULT_TOP_K_RETRIEVER", 5))
|
||||
EXCLUDE_META_DATA_FIELDS = os.getenv("EXCLUDE_META_DATA_FIELDS", f"['question_emb','embedding']")
|
||||
if EXCLUDE_META_DATA_FIELDS:
|
||||
EXCLUDE_META_DATA_FIELDS = ast.literal_eval(EXCLUDE_META_DATA_FIELDS)
|
||||
EMBEDDING_MODEL_PATH = os.getenv("EMBEDDING_MODEL_PATH", "deepset/sentence_bert")
|
||||
EMBEDDING_MODEL_FORMAT = os.getenv("EMBEDDING_MODEL_FORMAT", "farm")
|
||||
|
||||
# File uploads
|
||||
FILE_UPLOAD_PATH = os.getenv("FILE_UPLOAD_PATH", "file-uploads")
|
||||
REMOVE_NUMERIC_TABLES = os.getenv("REMOVE_NUMERIC_TABLES", "True").lower() == "true"
|
||||
VALID_LANGUAGES = os.getenv("VALID_LANGUAGES", None)
|
||||
if VALID_LANGUAGES:
|
||||
VALID_LANGUAGES = ast.literal_eval(VALID_LANGUAGES)
|
||||
|
||||
# Preprocessing
|
||||
REMOVE_WHITESPACE = os.getenv("REMOVE_WHITESPACE", "True").lower() == "true"
|
||||
REMOVE_EMPTY_LINES = os.getenv("REMOVE_EMPTY_LINES", "True").lower() == "true"
|
||||
REMOVE_HEADER_FOOTER = os.getenv("REMOVE_HEADER_FOOTER", "True").lower() == "true"
|
||||
SPLIT_BY = os.getenv("SPLIT_BY", "word")
|
||||
SPLIT_LENGTH = os.getenv("SPLIT_LENGTH", 1_000)
|
||||
SPLIT_OVERLAP = os.getenv("SPLIT_OVERLAP", None)
|
||||
SPLIT_RESPECT_SENTENCE_BOUNDARY = os.getenv("SPLIT_RESPECT_SENTENCE_BOUNDARY", True)
|
||||
|
||||
|
||||
# Monitoring
|
||||
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO")
|
||||
APM_SERVER = os.getenv("APM_SERVER", None)
|
||||
APM_SERVICE_NAME = os.getenv("APM_SERVICE_NAME", "haystack-backend")
|
||||
|
||||
@ -1,62 +1,25 @@
|
||||
from typing import Optional
|
||||
import time
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, Union, List, Optional
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Dict, Union, List
|
||||
|
||||
from haystack.document_store.elasticsearch import ElasticsearchDocumentStore
|
||||
from rest_api.config import (
|
||||
DB_HOST,
|
||||
DB_PORT,
|
||||
DB_USER,
|
||||
DB_PW,
|
||||
DB_INDEX,
|
||||
DB_INDEX_FEEDBACK,
|
||||
ES_CONN_SCHEME,
|
||||
TEXT_FIELD_NAME,
|
||||
SEARCH_FIELD_NAME,
|
||||
EMBEDDING_DIM,
|
||||
EMBEDDING_FIELD_NAME,
|
||||
EXCLUDE_META_DATA_FIELDS,
|
||||
FAQ_QUESTION_FIELD_NAME,
|
||||
CREATE_INDEX,
|
||||
VECTOR_SIMILARITY_METRIC,
|
||||
UPDATE_EXISTING_DOCUMENTS
|
||||
)
|
||||
from rest_api.controller.search import PIPELINE
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
document_store = ElasticsearchDocumentStore(
|
||||
host=DB_HOST,
|
||||
port=DB_PORT,
|
||||
username=DB_USER,
|
||||
password=DB_PW,
|
||||
index=DB_INDEX,
|
||||
label_index=DB_INDEX_FEEDBACK,
|
||||
scheme=ES_CONN_SCHEME,
|
||||
ca_certs=False,
|
||||
verify_certs=False,
|
||||
text_field=TEXT_FIELD_NAME,
|
||||
search_fields=SEARCH_FIELD_NAME,
|
||||
faq_question_field=FAQ_QUESTION_FIELD_NAME,
|
||||
embedding_dim=EMBEDDING_DIM,
|
||||
embedding_field=EMBEDDING_FIELD_NAME,
|
||||
excluded_meta_data=EXCLUDE_META_DATA_FIELDS, # type: ignore
|
||||
create_index=CREATE_INDEX,
|
||||
update_existing_documents=UPDATE_EXISTING_DOCUMENTS,
|
||||
similarity=VECTOR_SIMILARITY_METRIC
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# TODO make this generic for other pipelines with different naming
|
||||
retriever = PIPELINE.get_node(name="ESRetriever")
|
||||
document_store = retriever.document_store if retriever else None
|
||||
|
||||
|
||||
class FAQQAFeedback(BaseModel):
|
||||
class ExtractiveQAFeedback(BaseModel):
|
||||
question: str = Field(..., description="The question input by the user, i.e., the query.")
|
||||
is_correct_answer: bool = Field(..., description="Whether the answer is correct or not.")
|
||||
document_id: str = Field(..., description="The document in the query result for which feedback is given.")
|
||||
model_id: Optional[int] = Field(None, description="The model used for the query.")
|
||||
|
||||
|
||||
class DocQAFeedback(FAQQAFeedback):
|
||||
is_correct_document: bool = Field(
|
||||
...,
|
||||
description="In case of negative feedback, there could be two cases; incorrect answer but correct "
|
||||
@ -67,22 +30,18 @@ class DocQAFeedback(FAQQAFeedback):
|
||||
..., description="The answer start offset in the original doc. Only required for doc-qa feedback."
|
||||
)
|
||||
|
||||
|
||||
class FilterRequest(BaseModel):
|
||||
filters: Optional[Dict[str, Optional[Union[str, List[str]]]]] = None
|
||||
|
||||
@router.post("/doc-qa-feedback")
|
||||
def doc_qa_feedback(feedback: DocQAFeedback):
|
||||
|
||||
@router.post("/feedback")
|
||||
def user_feedback(feedback: ExtractiveQAFeedback):
|
||||
document_store.write_labels([{"origin": "user-feedback", **feedback.dict()}])
|
||||
|
||||
|
||||
@router.post("/faq-qa-feedback")
|
||||
def faq_qa_feedback(feedback: FAQQAFeedback):
|
||||
feedback_payload = {"is_correct_document": feedback.is_correct_answer, "answer": None, **feedback.dict()}
|
||||
document_store.write_labels([{"origin": "user-feedback-faq", **feedback_payload}])
|
||||
|
||||
|
||||
@router.post("/eval-doc-qa-feedback")
|
||||
def eval_doc_qa_feedback(filters: FilterRequest = None):
|
||||
@router.post("/eval-feedback")
|
||||
def eval_extractive_qa_feedback(filters: FilterRequest = None):
|
||||
"""
|
||||
Return basic accuracy metrics based on the user feedback.
|
||||
Which ratio of answers was correct? Which ratio of documents was correct?
|
||||
@ -90,10 +49,11 @@ def eval_doc_qa_feedback(filters: FilterRequest = None):
|
||||
|
||||
**Example:**
|
||||
|
||||
```
|
||||
| curl --location --request POST 'http://127.0.0.1:8000/eval-doc-qa-feedback' \
|
||||
| --header 'Content-Type: application/json' \
|
||||
| --data-raw '{ "filters": {"document_id": ["XRR3xnEBCYVTkbTystOB"]} }'
|
||||
```
|
||||
| curl --location --request POST 'http://127.0.0.1:8000/eval-doc-qa-feedback' \
|
||||
| --header 'Content-Type: application/json' \
|
||||
| --data-raw '{ "filters": {"document_id": ["XRR3xnEBCYVTkbTystOB"]} }'
|
||||
|
||||
"""
|
||||
|
||||
if filters:
|
||||
@ -102,86 +62,109 @@ def eval_doc_qa_feedback(filters: FilterRequest = None):
|
||||
else:
|
||||
filters = {"origin": ["user-feedback"]}
|
||||
|
||||
labels = document_store.get_all_labels(
|
||||
index=DB_INDEX_FEEDBACK,
|
||||
filters=filters
|
||||
)
|
||||
labels = document_store.get_all_labels(filters=filters)
|
||||
|
||||
if len(labels) > 0:
|
||||
answer_feedback = [1 if l.is_correct_answer else 0 for l in labels]
|
||||
doc_feedback = [1 if l.is_correct_document else 0 for l in labels]
|
||||
|
||||
answer_accuracy = sum(answer_feedback)/len(answer_feedback)
|
||||
doc_accuracy = sum(doc_feedback)/len(doc_feedback)
|
||||
answer_accuracy = sum(answer_feedback) / len(answer_feedback)
|
||||
doc_accuracy = sum(doc_feedback) / len(doc_feedback)
|
||||
|
||||
res = {"answer_accuracy": answer_accuracy,
|
||||
"document_accuracy": doc_accuracy,
|
||||
"n_feedback": len(labels)}
|
||||
res = {"answer_accuracy": answer_accuracy, "document_accuracy": doc_accuracy, "n_feedback": len(labels)}
|
||||
else:
|
||||
res = {"answer_accuracy": None,
|
||||
"document_accuracy": None,
|
||||
"n_feedback": 0}
|
||||
res = {"answer_accuracy": None, "document_accuracy": None, "n_feedback": 0}
|
||||
return res
|
||||
|
||||
@router.get("/export-doc-qa-feedback")
|
||||
def export_doc_qa_feedback(context_size: int = 2_000):
|
||||
|
||||
@router.get("/export-feedback")
|
||||
def export_extractive_qa_feedback(
|
||||
context_size: int = 100_000, full_document_context: bool = True, only_positive_labels: bool = False
|
||||
):
|
||||
"""
|
||||
SQuAD format JSON export for question/answer pairs that were marked as "relevant".
|
||||
|
||||
The context_size param can be used to limit response size for large documents.
|
||||
"""
|
||||
labels = document_store.get_all_labels(
|
||||
index=DB_INDEX_FEEDBACK, filters={"is_correct_answer": [True], "origin": ["user-feedback"]}
|
||||
)
|
||||
if only_positive_labels:
|
||||
labels = document_store.get_all_labels(filters={"is_correct_answer": [True], "origin": ["user-feedback"]})
|
||||
else:
|
||||
labels = document_store.get_all_labels(filters={"origin": ["user-feedback"]})
|
||||
# Filter out the labels where the passage is correct but answer is wrong (in SQuAD this matches
|
||||
# neither a "positive example" nor a negative "is_impossible" one)
|
||||
labels = [l for l in labels if not (l.is_correct_document is True and l.is_correct_answer is False)]
|
||||
|
||||
export_data = []
|
||||
|
||||
for label in labels:
|
||||
document = document_store.get_document_by_id(label.document_id)
|
||||
text = document.text
|
||||
if document is None:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Could not find document with id {label.document_id} for label id {label.id}"
|
||||
)
|
||||
|
||||
# the final length of context(including the answer string) is 'context_size'.
|
||||
# we try to add equal characters for context before and after the answer string.
|
||||
# if either beginning or end of text is reached, we correspondingly
|
||||
# append more context characters at the other end of answer string.
|
||||
context_to_add = int((context_size - len(label.answer)) / 2)
|
||||
if full_document_context:
|
||||
context = document.text
|
||||
answer_start = label.offset_start_in_doc
|
||||
else:
|
||||
text = document.text
|
||||
# the final length of context(including the answer string) is 'context_size'.
|
||||
# we try to add equal characters for context before and after the answer string.
|
||||
# if either beginning or end of text is reached, we correspondingly
|
||||
# append more context characters at the other end of answer string.
|
||||
context_to_add = int((context_size - len(label.answer)) / 2)
|
||||
start_pos = max(label.offset_start_in_doc - context_to_add, 0)
|
||||
additional_context_at_end = max(context_to_add - label.offset_start_in_doc, 0)
|
||||
end_pos = min(label.offset_start_in_doc + len(label.answer) + context_to_add, len(text) - 1)
|
||||
additional_context_at_start = max(
|
||||
label.offset_start_in_doc + len(label.answer) + context_to_add - len(text), 0
|
||||
)
|
||||
start_pos = max(0, start_pos - additional_context_at_start)
|
||||
end_pos = min(len(text) - 1, end_pos + additional_context_at_end)
|
||||
context = text[start_pos:end_pos]
|
||||
answer_start = label.offset_start_in_doc - start_pos
|
||||
|
||||
start_pos = max(label.offset_start_in_doc - context_to_add, 0)
|
||||
additional_context_at_end = max(context_to_add - label.offset_start_in_doc, 0)
|
||||
if label.is_correct_answer is False and label.is_correct_document is False: # No answer
|
||||
squad_label = {
|
||||
"paragraphs": [
|
||||
{
|
||||
"context": context,
|
||||
"id": label.document_id,
|
||||
"qas": [{"question": label.question, "id": label.id, "is_impossible": True, "answers": []}],
|
||||
}
|
||||
]
|
||||
}
|
||||
else:
|
||||
squad_label = {
|
||||
"paragraphs": [
|
||||
{
|
||||
"context": context,
|
||||
"id": label.document_id,
|
||||
"qas": [
|
||||
{
|
||||
"question": label.question,
|
||||
"id": label.id,
|
||||
"is_impossible": False,
|
||||
"answers": [{"text": label.answer, "answer_start": answer_start}],
|
||||
}
|
||||
],
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
end_pos = min(label.offset_start_in_doc + len(label.answer) + context_to_add, len(text) - 1)
|
||||
additional_context_at_start = max(label.offset_start_in_doc + len(label.answer) + context_to_add - len(text), 0)
|
||||
|
||||
start_pos = max(0, start_pos - additional_context_at_start)
|
||||
end_pos = min(len(text) - 1, end_pos + additional_context_at_end)
|
||||
|
||||
context_to_export = text[start_pos:end_pos]
|
||||
|
||||
export_data.append({"paragraphs": [{"qas": label, "context": context_to_export}]})
|
||||
|
||||
export = {"data": export_data}
|
||||
|
||||
return export
|
||||
|
||||
|
||||
@router.get("/export-faq-qa-feedback")
|
||||
def export_faq_feedback():
|
||||
"""
|
||||
Export feedback for faq-qa in JSON format.
|
||||
"""
|
||||
|
||||
labels = document_store.get_all_labels(index=DB_INDEX_FEEDBACK, filters={"origin": ["user-feedback-faq"]})
|
||||
|
||||
export_data = []
|
||||
for label in labels:
|
||||
document = document_store.get_document_by_id(label.document_id)
|
||||
feedback = {
|
||||
"question": document.question,
|
||||
"query": label.question,
|
||||
"is_correct_answer": label.is_correct_answer,
|
||||
"is_correct_document": label.is_correct_answer,
|
||||
}
|
||||
export_data.append(feedback)
|
||||
# quality check
|
||||
start = squad_label["paragraphs"][0]["qas"][0]["answers"][0]["answer_start"]
|
||||
answer = squad_label["paragraphs"][0]["qas"][0]["answers"][0]["text"]
|
||||
context = squad_label["paragraphs"][0]["context"]
|
||||
if not context[start: start + len(answer)] == answer:
|
||||
logger.error(
|
||||
f"Skipping invalid squad label as string via offsets "
|
||||
f"('{context[start:start + len(answer)]}') does not match answer string ('{answer}') "
|
||||
)
|
||||
export_data.append(squad_label)
|
||||
|
||||
export = {"data": export_data}
|
||||
|
||||
with open("feedback_squad_direct.json", "w", encoding="utf8") as f:
|
||||
json.dump(export_data, f, ensure_ascii=False, sort_keys=True, indent=4)
|
||||
return export
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
@ -5,94 +6,56 @@ import uuid
|
||||
from pathlib import Path
|
||||
from typing import Optional, List
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import HTTPException
|
||||
from fastapi import UploadFile, File, Form
|
||||
|
||||
from rest_api.config import DB_HOST, DB_PORT, DB_USER, DB_PW, DB_INDEX, DB_INDEX_FEEDBACK, ES_CONN_SCHEME, TEXT_FIELD_NAME, \
|
||||
SEARCH_FIELD_NAME, FILE_UPLOAD_PATH, EMBEDDING_DIM, EMBEDDING_FIELD_NAME, EXCLUDE_META_DATA_FIELDS, VALID_LANGUAGES, \
|
||||
FAQ_QUESTION_FIELD_NAME, REMOVE_NUMERIC_TABLES, REMOVE_WHITESPACE, REMOVE_EMPTY_LINES, REMOVE_HEADER_FOOTER, \
|
||||
CREATE_INDEX, UPDATE_EXISTING_DOCUMENTS, VECTOR_SIMILARITY_METRIC, SPLIT_BY, SPLIT_LENGTH, SPLIT_OVERLAP, \
|
||||
SPLIT_RESPECT_SENTENCE_BOUNDARY
|
||||
from haystack.document_store.elasticsearch import ElasticsearchDocumentStore
|
||||
from haystack.file_converter.pdf import PDFToTextConverter
|
||||
from haystack.file_converter.txt import TextConverter
|
||||
from haystack.preprocessor.preprocessor import PreProcessor
|
||||
from fastapi import APIRouter, UploadFile, File, Form, HTTPException
|
||||
|
||||
from haystack.pipeline import Pipeline
|
||||
from rest_api.config import PIPELINE_YAML_PATH, FILE_UPLOAD_PATH, INDEXING_PIPELINE_NAME
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
try:
|
||||
INDEXING_PIPELINE = Pipeline.load_from_yaml(Path(PIPELINE_YAML_PATH), pipeline_name=INDEXING_PIPELINE_NAME)
|
||||
except KeyError:
|
||||
INDEXING_PIPELINE = None
|
||||
logger.info("Indexing Pipeline not found in the YAML configuration. File Upload API will not be available.")
|
||||
|
||||
document_store = ElasticsearchDocumentStore(
|
||||
host=DB_HOST,
|
||||
port=DB_PORT,
|
||||
username=DB_USER,
|
||||
password=DB_PW,
|
||||
index=DB_INDEX,
|
||||
label_index=DB_INDEX_FEEDBACK,
|
||||
scheme=ES_CONN_SCHEME,
|
||||
ca_certs=None,
|
||||
verify_certs=False,
|
||||
text_field=TEXT_FIELD_NAME,
|
||||
search_fields=SEARCH_FIELD_NAME,
|
||||
embedding_dim=EMBEDDING_DIM,
|
||||
embedding_field=EMBEDDING_FIELD_NAME,
|
||||
excluded_meta_data=EXCLUDE_META_DATA_FIELDS, # type: ignore
|
||||
faq_question_field=FAQ_QUESTION_FIELD_NAME,
|
||||
create_index=CREATE_INDEX,
|
||||
update_existing_documents=UPDATE_EXISTING_DOCUMENTS,
|
||||
similarity=VECTOR_SIMILARITY_METRIC
|
||||
)
|
||||
|
||||
os.makedirs(FILE_UPLOAD_PATH, exist_ok=True) # create directory for uploading files
|
||||
|
||||
|
||||
@router.post("/file-upload")
|
||||
def upload_file_to_document_store(
|
||||
def file_upload(
|
||||
file: UploadFile = File(...),
|
||||
remove_numeric_tables: Optional[bool] = Form(REMOVE_NUMERIC_TABLES),
|
||||
remove_whitespace: Optional[bool] = Form(REMOVE_WHITESPACE),
|
||||
remove_empty_lines: Optional[bool] = Form(REMOVE_EMPTY_LINES),
|
||||
remove_header_footer: Optional[bool] = Form(REMOVE_HEADER_FOOTER),
|
||||
valid_languages: Optional[List[str]] = Form(VALID_LANGUAGES),
|
||||
split_by: Optional[str] = Form(SPLIT_BY),
|
||||
split_length: Optional[int] = Form(SPLIT_LENGTH),
|
||||
split_overlap: Optional[int] = Form(SPLIT_OVERLAP),
|
||||
split_respect_sentence_boundary: Optional[bool] = Form(SPLIT_RESPECT_SENTENCE_BOUNDARY),
|
||||
meta: Optional[str] = Form("null"), # JSON serialized string
|
||||
remove_numeric_tables: Optional[bool] = Form(None),
|
||||
remove_whitespace: Optional[bool] = Form(None),
|
||||
remove_empty_lines: Optional[bool] = Form(None),
|
||||
remove_header_footer: Optional[bool] = Form(None),
|
||||
valid_languages: Optional[List[str]] = Form(None),
|
||||
split_by: Optional[str] = Form(None),
|
||||
split_length: Optional[int] = Form(None),
|
||||
split_overlap: Optional[int] = Form(None),
|
||||
split_respect_sentence_boundary: Optional[bool] = Form(None),
|
||||
):
|
||||
if not INDEXING_PIPELINE:
|
||||
raise HTTPException(status_code=501, detail="Indexing Pipeline is not configured.")
|
||||
try:
|
||||
file_path = Path(FILE_UPLOAD_PATH) / f"{uuid.uuid4().hex}_{file.filename}"
|
||||
with file_path.open("wb") as buffer:
|
||||
shutil.copyfileobj(file.file, buffer)
|
||||
|
||||
if file.filename.split(".")[-1].lower() == "pdf":
|
||||
pdf_converter = PDFToTextConverter(
|
||||
remove_numeric_tables=remove_numeric_tables, valid_languages=valid_languages
|
||||
)
|
||||
document = pdf_converter.convert(file_path)
|
||||
elif file.filename.split(".")[-1].lower() == "txt":
|
||||
txt_converter = TextConverter(
|
||||
remove_numeric_tables=remove_numeric_tables, valid_languages=valid_languages,
|
||||
)
|
||||
document = txt_converter.convert(file_path)
|
||||
else:
|
||||
raise HTTPException(status_code=415, detail=f"Only .pdf and .txt file formats are supported.")
|
||||
|
||||
document = {TEXT_FIELD_NAME: document["text"], "name": file.filename}
|
||||
|
||||
preprocessor = PreProcessor(
|
||||
clean_whitespace=remove_whitespace,
|
||||
clean_header_footer=remove_header_footer,
|
||||
clean_empty_lines=remove_empty_lines,
|
||||
INDEXING_PIPELINE.run(
|
||||
file_path=file_path,
|
||||
remove_numeric_tables=remove_numeric_tables,
|
||||
remove_whitespace=remove_whitespace,
|
||||
remove_empty_lines=remove_empty_lines,
|
||||
remove_header_footer=remove_header_footer,
|
||||
valid_languages=valid_languages,
|
||||
split_by=split_by,
|
||||
split_length=split_length,
|
||||
split_overlap=split_overlap,
|
||||
split_respect_sentence_boundary=split_respect_sentence_boundary,
|
||||
meta=json.loads(meta) or {},
|
||||
)
|
||||
|
||||
documents = preprocessor.process(document)
|
||||
document_store.write_documents(documents)
|
||||
return "File upload was successful."
|
||||
finally:
|
||||
file.file.close()
|
||||
|
||||
@ -1,73 +0,0 @@
|
||||
import sys
|
||||
from typing import Any, Collection, Dict, List, Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from rest_api.config import DEFAULT_TOP_K_READER, DEFAULT_TOP_K_RETRIEVER
|
||||
|
||||
MAX_RECURSION_DEPTH = sys.getrecursionlimit() - 1
|
||||
|
||||
|
||||
class Question(BaseModel):
|
||||
questions: List[str]
|
||||
filters: Optional[Dict[str, Optional[Union[str, List[str]]]]] = None
|
||||
top_k_reader: int = DEFAULT_TOP_K_READER
|
||||
top_k_retriever: int = DEFAULT_TOP_K_RETRIEVER
|
||||
|
||||
@classmethod
|
||||
def from_elastic_query_dsl(cls, query_request: Dict[str, Any], top_k_reader: int = DEFAULT_TOP_K_READER):
|
||||
|
||||
# Refer Query DSL
|
||||
# https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-bool-query.html
|
||||
# Currently do not support query matching with field parameter
|
||||
query_strings: List[str] = []
|
||||
filters: Dict[str, str] = {}
|
||||
top_k_retriever: int = DEFAULT_TOP_K_RETRIEVER if "size" not in query_request else query_request["size"]
|
||||
|
||||
cls._iterate_dsl_request(query_request, query_strings, filters)
|
||||
|
||||
if len(query_strings) != 1:
|
||||
raise SyntaxError('Only one valid `query` field required expected, '
|
||||
'refer https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-match-query.html')
|
||||
|
||||
return cls(questions=query_strings, filters=filters if len(filters) else None, top_k_retriever=top_k_retriever,
|
||||
top_k_reader=top_k_reader)
|
||||
|
||||
@classmethod
|
||||
def _iterate_dsl_request(cls, query_dsl: Any, query_strings: List[str], filters: Dict[str, str], depth: int = 0):
|
||||
if depth == MAX_RECURSION_DEPTH:
|
||||
raise RecursionError('Parsing incoming DSL reaching current value of the recursion limit')
|
||||
|
||||
# For question: Only consider values of "query" key for "match" and "multi_match" request.
|
||||
# For filter: Only consider Dict[str, str] value of "term" or "terms" key
|
||||
if isinstance(query_dsl, List):
|
||||
for item in query_dsl:
|
||||
cls._iterate_dsl_request(item, query_strings, filters, depth + 1)
|
||||
elif isinstance(query_dsl, Dict):
|
||||
for key, value in query_dsl.items():
|
||||
# "query" value should be "str" type
|
||||
if key == 'query' and isinstance(value, str):
|
||||
query_strings.append(value)
|
||||
elif key in ["filter", "filters"]:
|
||||
cls._iterate_filters(value, filters, depth + 1)
|
||||
elif isinstance(value, Collection):
|
||||
cls._iterate_dsl_request(value, query_strings, filters, depth + 1)
|
||||
|
||||
@classmethod
|
||||
def _iterate_filters(cls, filter_dsl: Any, filters: Dict[str, str], depth: int = 0):
|
||||
if depth == MAX_RECURSION_DEPTH:
|
||||
raise RecursionError('Parsing incoming DSL reaching current value of the recursion limit')
|
||||
|
||||
if isinstance(filter_dsl, List):
|
||||
for item in filter_dsl:
|
||||
cls._iterate_filters(item, filters, depth + 1)
|
||||
elif isinstance(filter_dsl, Dict):
|
||||
for key, value in filter_dsl.items():
|
||||
if key in ["term", "terms"]:
|
||||
if isinstance(value, Dict):
|
||||
for filter_key, filter_value in value.items():
|
||||
# Currently only accepting Dict[str, str]
|
||||
if isinstance(filter_value, str):
|
||||
filters[filter_key] = filter_value
|
||||
elif isinstance(value, Collection):
|
||||
cls._iterate_filters(value, filters, depth + 1)
|
||||
@ -1,40 +0,0 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class Answer(BaseModel):
|
||||
answer: Optional[str]
|
||||
question: Optional[str]
|
||||
score: Optional[float] = None
|
||||
probability: Optional[float] = None
|
||||
context: Optional[str]
|
||||
offset_start: int
|
||||
offset_end: int
|
||||
offset_start_in_doc: Optional[int]
|
||||
offset_end_in_doc: Optional[int]
|
||||
document_id: Optional[str] = None
|
||||
meta: Optional[Dict[str, str]]
|
||||
|
||||
|
||||
class AnswersToIndividualQuestion(BaseModel):
|
||||
question: str
|
||||
answers: List[Optional[Answer]]
|
||||
|
||||
@staticmethod
|
||||
def to_elastic_response_dsl(data: Dict[str, Any]):
|
||||
result_dsl = {'hits': {'hits': [], 'total': {'value': len(data["answers"])}}}
|
||||
for answer in data["answers"]:
|
||||
|
||||
record = {"_source": {k: v for k, v in dict(answer).items()}}
|
||||
record["_id"] = record["_source"].pop("document_id", None)
|
||||
record["_score"] = record["_source"].pop("score", None)
|
||||
|
||||
result_dsl['hits']['hits'].append(record)
|
||||
|
||||
return result_dsl
|
||||
|
||||
|
||||
class Answers(BaseModel):
|
||||
results: List[AnswersToIndividualQuestion]
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
from rest_api.controller import file_upload
|
||||
from rest_api.controller import search, feedback
|
||||
from rest_api.controller import file_upload, search, feedback
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@ -1,216 +1,74 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import elasticapm
|
||||
from fastapi import APIRouter
|
||||
from fastapi import HTTPException
|
||||
|
||||
from haystack import Finder
|
||||
from rest_api.config import DB_HOST, DB_PORT, DB_USER, DB_PW, DB_INDEX, DB_INDEX_FEEDBACK, DEFAULT_TOP_K_READER, \
|
||||
ES_CONN_SCHEME, TEXT_FIELD_NAME, SEARCH_FIELD_NAME, EMBEDDING_DIM, EMBEDDING_FIELD_NAME, EXCLUDE_META_DATA_FIELDS, \
|
||||
RETRIEVER_TYPE, EMBEDDING_MODEL_PATH, USE_GPU, READER_MODEL_PATH, BATCHSIZE, CONTEXT_WINDOW_SIZE, \
|
||||
TOP_K_PER_CANDIDATE, NO_ANS_BOOST, READER_CAN_HAVE_NO_ANSWER, MAX_PROCESSES, MAX_SEQ_LEN, DOC_STRIDE, \
|
||||
CONCURRENT_REQUEST_PER_WORKER, FAQ_QUESTION_FIELD_NAME, EMBEDDING_MODEL_FORMAT, READER_TYPE, READER_TOKENIZER, \
|
||||
GPU_NUMBER, NAME_FIELD_NAME, VECTOR_SIMILARITY_METRIC, CREATE_INDEX, LOG_LEVEL, UPDATE_EXISTING_DOCUMENTS, \
|
||||
TOP_K_PER_SAMPLE
|
||||
|
||||
from rest_api.controller.request import Question
|
||||
from rest_api.controller.response import Answers, AnswersToIndividualQuestion
|
||||
from pydantic import BaseModel
|
||||
|
||||
from haystack import Pipeline
|
||||
from rest_api.config import PIPELINE_YAML_PATH, LOG_LEVEL, QUERY_PIPELINE_NAME
|
||||
from rest_api.controller.utils import RequestLimiter
|
||||
from haystack.document_store.elasticsearch import ElasticsearchDocumentStore
|
||||
from haystack.reader.base import BaseReader
|
||||
from haystack.reader.farm import FARMReader
|
||||
from haystack.reader.transformers import TransformersReader
|
||||
from haystack.retriever.base import BaseRetriever
|
||||
from haystack.retriever.sparse import ElasticsearchRetriever, ElasticsearchFilterOnlyRetriever
|
||||
from haystack.retriever.dense import EmbeddingRetriever
|
||||
|
||||
logger = logging.getLogger('haystack')
|
||||
logger.setLevel(LOG_LEVEL)
|
||||
logging.getLogger("haystack").setLevel(LOG_LEVEL)
|
||||
logger = logging.getLogger("haystack")
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# Init global components: DocumentStore, Retriever, Reader, Finder
|
||||
document_store = ElasticsearchDocumentStore(
|
||||
host=DB_HOST,
|
||||
port=DB_PORT,
|
||||
username=DB_USER,
|
||||
password=DB_PW,
|
||||
index=DB_INDEX,
|
||||
label_index=DB_INDEX_FEEDBACK,
|
||||
scheme=ES_CONN_SCHEME,
|
||||
ca_certs=False,
|
||||
verify_certs=False,
|
||||
text_field=TEXT_FIELD_NAME,
|
||||
name_field=NAME_FIELD_NAME,
|
||||
search_fields=SEARCH_FIELD_NAME,
|
||||
embedding_dim=EMBEDDING_DIM,
|
||||
embedding_field=EMBEDDING_FIELD_NAME,
|
||||
excluded_meta_data=EXCLUDE_META_DATA_FIELDS, # type: ignore
|
||||
faq_question_field=FAQ_QUESTION_FIELD_NAME,
|
||||
create_index=CREATE_INDEX,
|
||||
update_existing_documents=UPDATE_EXISTING_DOCUMENTS,
|
||||
similarity=VECTOR_SIMILARITY_METRIC
|
||||
)
|
||||
|
||||
if RETRIEVER_TYPE == "EmbeddingRetriever":
|
||||
retriever = EmbeddingRetriever(
|
||||
document_store=document_store,
|
||||
embedding_model=EMBEDDING_MODEL_PATH,
|
||||
model_format=EMBEDDING_MODEL_FORMAT,
|
||||
use_gpu=USE_GPU
|
||||
) # type: BaseRetriever
|
||||
elif RETRIEVER_TYPE == "ElasticsearchRetriever":
|
||||
retriever = ElasticsearchRetriever(document_store=document_store)
|
||||
elif RETRIEVER_TYPE is None or RETRIEVER_TYPE == "ElasticsearchFilterOnlyRetriever":
|
||||
retriever = ElasticsearchFilterOnlyRetriever(document_store=document_store)
|
||||
else:
|
||||
raise ValueError(f"Could not load Retriever of type '{RETRIEVER_TYPE}'. "
|
||||
f"Please adjust RETRIEVER_TYPE to one of: "
|
||||
f"'EmbeddingRetriever', 'ElasticsearchRetriever', 'ElasticsearchFilterOnlyRetriever', None"
|
||||
f"OR modify rest_api/search.py to support your retriever"
|
||||
)
|
||||
|
||||
if READER_MODEL_PATH: # for extractive doc-qa
|
||||
logger.info(f"Loading reader model `{READER_MODEL_PATH}` ...")
|
||||
if READER_TYPE == "TransformersReader":
|
||||
use_gpu = -1 if not USE_GPU else GPU_NUMBER
|
||||
reader = TransformersReader(
|
||||
model_name_or_path=READER_MODEL_PATH,
|
||||
use_gpu=use_gpu,
|
||||
context_window_size=CONTEXT_WINDOW_SIZE,
|
||||
return_no_answers=READER_CAN_HAVE_NO_ANSWER,
|
||||
tokenizer=READER_TOKENIZER
|
||||
) # type: Optional[BaseReader]
|
||||
elif READER_TYPE == "FARMReader":
|
||||
reader = FARMReader(
|
||||
model_name_or_path=READER_MODEL_PATH,
|
||||
batch_size=BATCHSIZE,
|
||||
use_gpu=USE_GPU,
|
||||
context_window_size=CONTEXT_WINDOW_SIZE,
|
||||
top_k_per_candidate=TOP_K_PER_CANDIDATE,
|
||||
top_k_per_sample=TOP_K_PER_SAMPLE,
|
||||
no_ans_boost=NO_ANS_BOOST,
|
||||
num_processes=MAX_PROCESSES,
|
||||
max_seq_len=MAX_SEQ_LEN,
|
||||
doc_stride=DOC_STRIDE,
|
||||
) # type: Optional[BaseReader]
|
||||
else:
|
||||
raise ValueError(f"Could not load Reader of type '{READER_TYPE}'. "
|
||||
f"Please adjust READER_TYPE to one of: "
|
||||
f"'FARMReader', 'TransformersReader', None"
|
||||
)
|
||||
else:
|
||||
reader = None # don't need one for pure FAQ matching
|
||||
|
||||
FINDERS = {1: Finder(reader=reader, retriever=retriever)}
|
||||
class Request(BaseModel):
|
||||
query: str
|
||||
filters: Optional[Dict[str, Optional[Union[str, List[str]]]]] = None
|
||||
|
||||
|
||||
#############################################
|
||||
# Endpoints
|
||||
#############################################
|
||||
doc_qa_limiter = RequestLimiter(CONCURRENT_REQUEST_PER_WORKER)
|
||||
class Answer(BaseModel):
|
||||
answer: Optional[str]
|
||||
question: Optional[str]
|
||||
score: Optional[float] = None
|
||||
probability: Optional[float] = None
|
||||
context: Optional[str]
|
||||
offset_start: int
|
||||
offset_end: int
|
||||
offset_start_in_doc: Optional[int]
|
||||
offset_end_in_doc: Optional[int]
|
||||
document_id: Optional[str] = None
|
||||
meta: Optional[Dict[str, str]]
|
||||
|
||||
|
||||
@router.post("/models/{model_id}/doc-qa", response_model=Answers, response_model_exclude_unset=True)
|
||||
def doc_qa(model_id: int, question_request: Question):
|
||||
with doc_qa_limiter.run():
|
||||
start_time = time.time()
|
||||
finder = FINDERS.get(model_id, None)
|
||||
if not finder:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Could not get Finder with ID {model_id}. Available IDs: {list(FINDERS.keys())}"
|
||||
)
|
||||
|
||||
results = search_documents(finder, question_request, start_time)
|
||||
|
||||
return {"results": results}
|
||||
class Response(BaseModel):
|
||||
query: str
|
||||
answers: List[Answer]
|
||||
|
||||
|
||||
@router.post("/models/{model_id}/faq-qa", response_model=Answers, response_model_exclude_unset=True)
|
||||
def faq_qa(model_id: int, request: Question):
|
||||
finder = FINDERS.get(model_id, None)
|
||||
if not finder:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Could not get Finder with ID {model_id}. Available IDs: {list(FINDERS.keys())}"
|
||||
)
|
||||
|
||||
results = []
|
||||
for question in request.questions:
|
||||
if request.filters:
|
||||
# put filter values into a list and remove filters with null value
|
||||
filters = {}
|
||||
for key, values in request.filters.items():
|
||||
if values is None:
|
||||
continue
|
||||
if not isinstance(values, list):
|
||||
values = [values]
|
||||
filters[key] = values
|
||||
logger.info(f" [{datetime.now()}] Request: {request}")
|
||||
else:
|
||||
filters = {}
|
||||
|
||||
result = finder.get_answers_via_similar_questions(
|
||||
question=question, top_k_retriever=request.top_k_retriever, filters=filters,
|
||||
)
|
||||
results.append(result)
|
||||
|
||||
elasticapm.set_custom_context({"results": results})
|
||||
logger.info(json.dumps({"request": request.dict(), "results": results}))
|
||||
|
||||
return {"results": results}
|
||||
PIPELINE = Pipeline.load_from_yaml(Path(PIPELINE_YAML_PATH), pipeline_name=QUERY_PIPELINE_NAME)
|
||||
logger.info(f"Loaded pipeline nodes: {PIPELINE.graph.nodes.keys()}")
|
||||
concurrency_limiter = RequestLimiter(4)
|
||||
|
||||
|
||||
@router.post("/models/{model_id}/query", response_model=Dict[str, Any], response_model_exclude_unset=True)
|
||||
def query(model_id: int, query_request: Dict[str, Any], top_k_reader: int = DEFAULT_TOP_K_READER):
|
||||
with doc_qa_limiter.run():
|
||||
start_time = time.time()
|
||||
finder = FINDERS.get(model_id, None)
|
||||
if not finder:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Could not get Finder with ID {model_id}. Available IDs: {list(FINDERS.keys())}"
|
||||
)
|
||||
|
||||
question_request = Question.from_elastic_query_dsl(query_request, top_k_reader)
|
||||
|
||||
answers = search_documents(finder, question_request, start_time)
|
||||
response: Dict[str, Any] = {}
|
||||
if answers and len(answers) > 0:
|
||||
response = AnswersToIndividualQuestion.to_elastic_response_dsl(dict(answers[0]))
|
||||
|
||||
return response
|
||||
@router.post("/query", response_model=Response)
|
||||
def query(request: Request):
|
||||
with concurrency_limiter.run():
|
||||
result = _process_request(PIPELINE, request)
|
||||
return result
|
||||
|
||||
|
||||
def search_documents(finder, question_request, start_time) -> List[AnswersToIndividualQuestion]:
|
||||
results = []
|
||||
for question in question_request.questions:
|
||||
if question_request.filters:
|
||||
# put filter values into a list and remove filters with null value
|
||||
filters = {}
|
||||
for key, values in question_request.filters.items():
|
||||
if values is None:
|
||||
continue
|
||||
if not isinstance(values, list):
|
||||
values = [values]
|
||||
filters[key] = values
|
||||
logger.info(f" [{datetime.now()}] Request: {question_request}")
|
||||
else:
|
||||
filters = {}
|
||||
def _process_request(pipeline, request) -> Response:
|
||||
start_time = time.time()
|
||||
|
||||
filters = {}
|
||||
if request.filters:
|
||||
# put filter values into a list and remove filters with null value
|
||||
for key, values in request.filters.items():
|
||||
if values is None:
|
||||
continue
|
||||
if not isinstance(values, list):
|
||||
values = [values]
|
||||
filters[key] = values
|
||||
|
||||
result = pipeline.run(query=request.query, filters=filters)
|
||||
|
||||
result = finder.get_answers(
|
||||
question=question,
|
||||
top_k_retriever=question_request.top_k_retriever,
|
||||
top_k_reader=question_request.top_k_reader,
|
||||
filters=filters,
|
||||
)
|
||||
results.append(result)
|
||||
elasticapm.set_custom_context({"results": results})
|
||||
end_time = time.time()
|
||||
logger.info(
|
||||
json.dumps({"request": question_request.dict(), "results": results,
|
||||
"time": f"{(end_time - start_time):.2f}"}))
|
||||
return results
|
||||
logger.info(json.dumps({"request": request.dict(), "response": result, "time": f"{(end_time - start_time):.2f}"}))
|
||||
|
||||
return result
|
||||
|
||||
47
rest_api/pipelines.yaml
Normal file
47
rest_api/pipelines.yaml
Normal file
@ -0,0 +1,47 @@
|
||||
version: '0.7'
|
||||
|
||||
components: # define all the building-blocks for Pipeline
|
||||
- name: ElasticsearchDocumentStore
|
||||
type: ElasticsearchDocumentStore
|
||||
params:
|
||||
host: localhost
|
||||
- name: ESRetriever
|
||||
type: ElasticsearchRetriever
|
||||
params:
|
||||
document_store: ElasticsearchDocumentStore # params can reference other components defined in the YAML
|
||||
top_k: 5
|
||||
- name: Reader # custom-name for the component; helpful for visualization & debugging
|
||||
type: FARMReader # Haystack Class name for the component
|
||||
params:
|
||||
model_name_or_path: deepset/roberta-base-squad2
|
||||
- name: TextFileConverter
|
||||
type: TextConverter
|
||||
- name: PDFFileConverter
|
||||
type: PDFToTextConverter
|
||||
- name: Preprocessor
|
||||
type: PreProcessor
|
||||
- name: FileTypeClassifier
|
||||
type: FileTypeClassifier
|
||||
|
||||
pipelines:
|
||||
- name: query # a sample extractive-qa Pipeline
|
||||
type: Query
|
||||
nodes:
|
||||
- name: ESRetriever
|
||||
inputs: [Query]
|
||||
- name: Reader
|
||||
inputs: [ESRetriever]
|
||||
|
||||
- name: indexing
|
||||
type: Indexing
|
||||
nodes:
|
||||
- name: FileTypeClassifier
|
||||
inputs: [File]
|
||||
- name: TextFileConverter
|
||||
inputs: [FileTypeClassifier.output_1]
|
||||
- name: PDFFileConverter
|
||||
inputs: [FileTypeClassifier.output_2]
|
||||
- name: Preprocessor
|
||||
inputs: [PDFFileConverter, TextFileConverter]
|
||||
- name: ElasticsearchDocumentStore
|
||||
inputs: [Preprocessor]
|
||||
@ -1,47 +1,48 @@
|
||||
version: '0.7'
|
||||
|
||||
components:
|
||||
- name: TestReader
|
||||
- name: Reader
|
||||
type: FARMReader
|
||||
params:
|
||||
no_ans_boost: -10
|
||||
model_name_or_path: deepset/roberta-base-squad2
|
||||
- name: TestESRetriever
|
||||
- name: ESRetriever
|
||||
type: ElasticsearchRetriever
|
||||
params:
|
||||
document_store: TestDocumentStore
|
||||
document_store: DocumentStore
|
||||
custom_query: null
|
||||
- name: TestDocumentStore
|
||||
- name: DocumentStore
|
||||
type: ElasticsearchDocumentStore
|
||||
params:
|
||||
index: haystack_test_pipeline
|
||||
- name: TestPDFConverter
|
||||
index: haystack_test_document
|
||||
label_index: haystack_test_label
|
||||
- name: PDFConverter
|
||||
type: PDFToTextConverter
|
||||
params:
|
||||
remove_numeric_tables: false
|
||||
- name: TestPreprocessor
|
||||
- name: Preprocessor
|
||||
type: PreProcessor
|
||||
params:
|
||||
clean_whitespace: true
|
||||
|
||||
|
||||
pipelines:
|
||||
- name: test_query_pipeline
|
||||
- name: query_pipeline
|
||||
type: Query
|
||||
nodes:
|
||||
- name: TestESRetriever
|
||||
- name: ESRetriever
|
||||
inputs: [Query]
|
||||
- name: TestReader
|
||||
inputs: [TestESRetriever]
|
||||
- name: Reader
|
||||
inputs: [ESRetriever]
|
||||
|
||||
- name: test_indexing_pipeline
|
||||
- name: indexing_pipeline
|
||||
type: Indexing
|
||||
nodes:
|
||||
- name: TestPDFConverter
|
||||
- name: PDFConverter
|
||||
inputs: [File]
|
||||
- name: TestPreprocessor
|
||||
inputs: [TestPDFConverter]
|
||||
- name: TestESRetriever
|
||||
inputs: [TestPreprocessor]
|
||||
- name: TestDocumentStore
|
||||
inputs: [TestESRetriever]
|
||||
- name: Preprocessor
|
||||
inputs: [PDFConverter]
|
||||
- name: ESRetriever
|
||||
inputs: [Preprocessor]
|
||||
- name: DocumentStore
|
||||
inputs: [ESRetriever]
|
||||
|
||||
@ -1,201 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from rest_api.controller.request import Question
|
||||
from rest_api.controller.response import Answer, AnswersToIndividualQuestion
|
||||
|
||||
|
||||
def test_query_dsl_with_without_valid_query_field():
|
||||
query = {
|
||||
"query": {
|
||||
"bool": {
|
||||
"must": [
|
||||
{"match": {"title": "Search"}},
|
||||
{"match": {"content": "Elasticsearch"}}
|
||||
],
|
||||
"filter": [
|
||||
{"term": {"status": "published"}},
|
||||
{"range": {"publish_date": {"gte": "2015-01-01"}}}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
with pytest.raises(Exception):
|
||||
Question.from_elastic_query_dsl(query)
|
||||
|
||||
|
||||
def test_query_dsl_with_without_multiple_query_field():
|
||||
query = {
|
||||
"query": {
|
||||
"bool": {
|
||||
"should": [
|
||||
{"match": {"name.first": {"query": "shay", "_name": "first"}}},
|
||||
{"match": {"name.last": {"query": "banon", "_name": "last"}}}
|
||||
],
|
||||
"filter": {
|
||||
"terms": {
|
||||
"name.last": ["banon", "kimchy"],
|
||||
"_name": "test"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
with pytest.raises(Exception):
|
||||
Question.from_elastic_query_dsl(query)
|
||||
|
||||
|
||||
def test_query_dsl_with_single_query():
|
||||
query = {
|
||||
"query": {
|
||||
"match": {
|
||||
"message": {
|
||||
"query": "this is a test"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
question = Question.from_elastic_query_dsl(query)
|
||||
assert 1 == len(question.questions)
|
||||
assert question.questions.__contains__("this is a test")
|
||||
assert question.filters is None
|
||||
|
||||
|
||||
def test_query_dsl_with_filter():
|
||||
query = {
|
||||
"query": {
|
||||
"bool": {
|
||||
"should": [
|
||||
{"match": {"name.first": {"query": "shay", "_name": "first"}}}
|
||||
],
|
||||
"filter": {
|
||||
"terms": {
|
||||
"name.last": ["banon", "kimchy"],
|
||||
"_name": "test"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
question = Question.from_elastic_query_dsl(query)
|
||||
assert 1 == len(question.questions)
|
||||
assert question.questions.__contains__("shay")
|
||||
assert len(question.filters) == 1
|
||||
assert question.filters["_name"] == "test"
|
||||
|
||||
|
||||
def test_query_dsl_with_complex_query():
|
||||
query = {
|
||||
"size": 17,
|
||||
"query": {
|
||||
"bool": {
|
||||
"should": [
|
||||
{
|
||||
"multi_match": {
|
||||
"query": "I am test1",
|
||||
"type": "most_fields",
|
||||
"fields": ["text", "title"]
|
||||
}
|
||||
}
|
||||
],
|
||||
"filter": [
|
||||
{
|
||||
"terms": {
|
||||
"year": "2020"
|
||||
}
|
||||
},
|
||||
{
|
||||
"terms": {
|
||||
"quarter": "1"
|
||||
}
|
||||
},
|
||||
{
|
||||
"range": {
|
||||
"date": {
|
||||
"gte": "12-12-12"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
top_k_reader = 7
|
||||
question = Question.from_elastic_query_dsl(query, top_k_reader)
|
||||
assert 1 == len(question.questions)
|
||||
assert question.questions.__contains__("I am test1")
|
||||
assert 2 == len(question.filters)
|
||||
assert question.filters["year"] == "2020"
|
||||
assert question.filters["quarter"] == "1"
|
||||
assert 17 == question.top_k_retriever
|
||||
assert 7 == question.top_k_reader
|
||||
|
||||
|
||||
def test_response_dsl_with_empty_answers():
|
||||
sample_answer = AnswersToIndividualQuestion(question="test question", answers=[])
|
||||
response = AnswersToIndividualQuestion.to_elastic_response_dsl(sample_answer.__dict__)
|
||||
assert 0 == response['hits']['total']['value']
|
||||
assert 0 == len(response['hits']['hits'])
|
||||
|
||||
|
||||
def test_response_dsl_with_answers():
|
||||
full_answer = Answer(
|
||||
answer="answer",
|
||||
question="question",
|
||||
score=0.1234,
|
||||
probability=0.5678,
|
||||
context="context",
|
||||
offset_start=200,
|
||||
offset_end=300,
|
||||
offset_start_in_doc=2000,
|
||||
offset_end_in_doc=2100,
|
||||
document_id="id_1",
|
||||
meta={
|
||||
"meta1": "meta_value"
|
||||
}
|
||||
)
|
||||
empty_answer = Answer(
|
||||
answer=None,
|
||||
question=None,
|
||||
score=None,
|
||||
probability=None,
|
||||
context=None,
|
||||
offset_start=250,
|
||||
offset_end=350,
|
||||
offset_start_in_doc=None,
|
||||
offset_end_in_doc=None,
|
||||
document_id=None,
|
||||
meta=None
|
||||
)
|
||||
sample_answer = AnswersToIndividualQuestion(question="test question", answers=[full_answer, empty_answer])
|
||||
response = AnswersToIndividualQuestion.to_elastic_response_dsl(sample_answer.__dict__)
|
||||
|
||||
# Test number of returned answers
|
||||
assert response['hits']['total']['value'] == 2
|
||||
|
||||
# Test converted answers
|
||||
hits = response['hits']['hits']
|
||||
assert len(hits) == 2
|
||||
# Test full answer record
|
||||
assert hits[0]["_score"] == 0.1234
|
||||
assert hits[0]["_id"] == "id_1"
|
||||
assert hits[0]["_source"]["answer"] == "answer"
|
||||
assert hits[0]["_source"]["question"] == "question"
|
||||
assert hits[0]["_source"]["context"] == "context"
|
||||
assert hits[0]["_source"]["probability"] == 0.5678
|
||||
assert hits[0]["_source"]["offset_start"] == 200
|
||||
assert hits[0]["_source"]["offset_end"] == 300
|
||||
assert hits[0]["_source"]["offset_start_in_doc"] == 2000
|
||||
assert hits[0]["_source"]["offset_end_in_doc"] == 2100
|
||||
assert hits[0]["_source"]["meta"] == {"meta1": "meta_value"}
|
||||
# Test empty answer record
|
||||
assert hits[1]["_score"] is None
|
||||
assert hits[1]["_id"] is None
|
||||
assert hits[1]["_source"]["answer"] is None
|
||||
assert hits[1]["_source"]["question"] is None
|
||||
assert hits[1]["_source"]["context"] is None
|
||||
assert hits[1]["_source"]["probability"] is None
|
||||
assert hits[1]["_source"]["offset_start"] == 250
|
||||
assert hits[1]["_source"]["offset_end"] == 350
|
||||
assert hits[1]["_source"]["offset_start_in_doc"] is None
|
||||
assert hits[1]["_source"]["offset_end_in_doc"] is None
|
||||
assert hits[1]["_source"]["meta"] is None
|
||||
@ -13,11 +13,11 @@ from haystack.retriever.sparse import ElasticsearchRetriever
|
||||
def test_load_yaml(document_store_with_docs):
|
||||
# test correct load of indexing pipeline from yaml
|
||||
pipeline = Pipeline.load_from_yaml(Path("samples/pipeline/test_pipeline.yaml"),
|
||||
pipeline_name="test_indexing_pipeline")
|
||||
pipeline_name="indexing_pipeline")
|
||||
pipeline.run(file_path=Path("samples/pdf/sample_pdf_1.pdf"), top_k_retriever=10, top_k_reader=3)
|
||||
|
||||
# test correct load of query pipeline from yaml
|
||||
pipeline = Pipeline.load_from_yaml(Path("samples/pipeline/test_pipeline.yaml"), pipeline_name="test_query_pipeline")
|
||||
pipeline = Pipeline.load_from_yaml(Path("samples/pipeline/test_pipeline.yaml"), pipeline_name="query_pipeline")
|
||||
prediction = pipeline.run(query="Who made the PDF specification?", top_k_retriever=10, top_k_reader=3)
|
||||
assert prediction["query"] == "Who made the PDF specification?"
|
||||
assert prediction["answers"][0]["answer"] == "Adobe Systems"
|
||||
|
||||
@ -1,110 +1,80 @@
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from pathlib import Path
|
||||
from haystack import Finder
|
||||
from haystack.retriever.sparse import ElasticsearchRetriever
|
||||
|
||||
# TODO: Add integration tests for other APIs
|
||||
|
||||
|
||||
def get_test_client_and_override_dependencies(reader, document_store):
|
||||
def get_test_client_and_override_dependencies():
|
||||
import os
|
||||
os.environ["PIPELINE_YAML_PATH"] = "samples/pipeline/test_pipeline.yaml"
|
||||
os.environ["QUERY_PIPELINE_NAME"] = "query_pipeline"
|
||||
os.environ["INDEXING_PIPELINE_NAME"] = "indexing_pipeline"
|
||||
|
||||
from rest_api.application import app
|
||||
from rest_api.controller import search, file_upload
|
||||
|
||||
search.document_store = document_store
|
||||
search.retriever = ElasticsearchRetriever(document_store=document_store)
|
||||
search.FINDERS = {1: Finder(reader=reader, retriever=search.retriever)}
|
||||
file_upload.document_store = document_store
|
||||
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.elasticsearch
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
|
||||
@pytest.mark.parametrize("reader", ["farm"], indirect=True)
|
||||
def test_qa_api_filters(reader, document_store_with_docs):
|
||||
client = get_test_client_and_override_dependencies(reader, document_store_with_docs)
|
||||
|
||||
query_with_no_filter_value = {"questions": ["Where does Carla lives?"]}
|
||||
response = client.post(url="/models/1/doc-qa", json=query_with_no_filter_value)
|
||||
assert 200 == response.status_code
|
||||
response_json = response.json()
|
||||
assert response_json["results"][0]["answers"][0]["answer"] == "Berlin"
|
||||
|
||||
query_with_single_filter_value = {"questions": ["Where does Carla lives?"], "filters": {"name": "filename1"}}
|
||||
response = client.post(url="/models/1/doc-qa", json=query_with_single_filter_value)
|
||||
assert 200 == response.status_code
|
||||
response_json = response.json()
|
||||
assert response_json["results"][0]["answers"][0]["answer"] == "Berlin"
|
||||
|
||||
query_with_a_list_of_filter_values = {
|
||||
"questions": ["Where does Carla lives?"],
|
||||
"filters": {"name": ["filename1", "filename2"]},
|
||||
}
|
||||
response = client.post(url="/models/1/doc-qa", json=query_with_a_list_of_filter_values)
|
||||
assert 200 == response.status_code
|
||||
response_json = response.json()
|
||||
assert response_json["results"][0]["answers"][0]["answer"] == "Berlin"
|
||||
|
||||
query_with_non_existing_filter_value = {
|
||||
"questions": ["Where does Carla lives?"],
|
||||
"filters": {"name": ["invalid-name"]},
|
||||
}
|
||||
response = client.post(url="/models/1/doc-qa", json=query_with_non_existing_filter_value)
|
||||
assert 200 == response.status_code
|
||||
response_json = response.json()
|
||||
assert len(response_json["results"][0]["answers"]) == 0
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.elasticsearch
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
|
||||
@pytest.mark.parametrize("reader", ["farm"], indirect=True)
|
||||
def test_query_api_filters(reader, document_store_with_docs):
|
||||
client = get_test_client_and_override_dependencies(reader, document_store_with_docs)
|
||||
|
||||
query = {
|
||||
"size": 1,
|
||||
"query": {
|
||||
"bool": {
|
||||
"should": [
|
||||
{
|
||||
"multi_match": {
|
||||
"query": "Where Paul lives?"
|
||||
}
|
||||
}
|
||||
],
|
||||
"filter": [
|
||||
{
|
||||
"terms": {
|
||||
"name": "filename2"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
response = client.post(url="/models/1/query?top_k_reader=1", json=query)
|
||||
assert 200 == response.status_code
|
||||
response_json = response.json()
|
||||
assert 1 == response_json['hits']['total']['value']
|
||||
assert 1 == len(response_json['hits']['hits'])
|
||||
assert response_json['hits']['hits'][0]["_score"] is not None
|
||||
assert response_json['hits']['hits'][0]["_source"]["meta"] is not None
|
||||
assert response_json['hits']['hits'][0]["_id"] is not None
|
||||
assert "New York" == response_json['hits']['hits'][0]["_source"]["answer"]
|
||||
assert "My name is Paul and I live in New York" == response_json['hits']['hits'][0]["_source"]["context"]
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.elasticsearch
|
||||
@pytest.mark.parametrize("document_store", ["elasticsearch"], indirect=True)
|
||||
def test_file_upload(document_store):
|
||||
assert document_store.get_document_count() == 0
|
||||
client = get_test_client_and_override_dependencies(reader=None, document_store=document_store)
|
||||
def test_api(reader, document_store):
|
||||
client = get_test_client_and_override_dependencies()
|
||||
|
||||
# test file upload API
|
||||
file_to_upload = {'file': Path("samples/pdf/sample_pdf_1.pdf").open('rb')}
|
||||
response = client.post(url="/file-upload", files=file_to_upload)
|
||||
response = client.post(url="/file-upload", files=file_to_upload, data={"meta": '{"meta_key": "meta_value"}'})
|
||||
assert 200 == response.status_code
|
||||
assert document_store.get_document_count() > 0
|
||||
|
||||
# test query API
|
||||
query_with_no_filter_value = {"query": "Who made the PDF specification?"}
|
||||
response = client.post(url="/query", json=query_with_no_filter_value)
|
||||
assert 200 == response.status_code
|
||||
response_json = response.json()
|
||||
assert response_json["answers"][0]["answer"] == "Adobe Systems"
|
||||
document_id = response_json["answers"][0]["document_id"]
|
||||
|
||||
query_with_filter = {"query": "Who made the PDF specification?", "filters": {"meta_key": "meta_value"}}
|
||||
response = client.post(url="/query", json=query_with_filter)
|
||||
assert 200 == response.status_code
|
||||
response_json = response.json()
|
||||
assert response_json["answers"][0]["answer"] == "Adobe Systems"
|
||||
|
||||
query_with_filter_list = {"query": "Who made the PDF specification?", "filters": {"meta_key": ["meta_value", "another_value"]}}
|
||||
response = client.post(url="/query", json=query_with_filter_list)
|
||||
assert 200 == response.status_code
|
||||
response_json = response.json()
|
||||
assert response_json["answers"][0]["answer"] == "Adobe Systems"
|
||||
|
||||
query_with_invalid_filter = {"query": "Who made the PDF specification?", "filters": {"meta_key": "invalid_value"}}
|
||||
response = client.post(url="/query", json=query_with_invalid_filter)
|
||||
assert 200 == response.status_code
|
||||
response_json = response.json()
|
||||
assert len(response_json["answers"]) == 0
|
||||
|
||||
# test write feedback
|
||||
feedback = {
|
||||
"question": "Who made the PDF specification?",
|
||||
"is_correct_answer": True,
|
||||
"document_id": document_id,
|
||||
"is_correct_document": True,
|
||||
"answer": "Adobe Systems",
|
||||
"offset_start_in_doc": 60
|
||||
}
|
||||
response = client.post(url="/feedback", json=feedback)
|
||||
assert 200 == response.status_code
|
||||
|
||||
# test export feedback
|
||||
feedback_urls = [
|
||||
"/export-feedback?full_document_context=true",
|
||||
"/export-feedback?full_document_context=false&context_size=50",
|
||||
"/export-feedback?full_document_context=false&context_size=50000",
|
||||
]
|
||||
for url in feedback_urls:
|
||||
response = client.get(url=url, json=feedback)
|
||||
response_json = response.json()
|
||||
context = response_json["data"][0]["paragraphs"][0]["context"]
|
||||
answer_start = response_json["data"][0]["paragraphs"][0]["qas"][0]["answers"][0]["answer_start"]
|
||||
answer = response_json["data"][0]["paragraphs"][0]["qas"][0]["answers"][0]["text"]
|
||||
assert context[answer_start:answer_start+len(answer)] == answer
|
||||
|
||||
|
||||
55
ui/utils.py
55
ui/utils.py
@ -1,42 +1,25 @@
|
||||
import requests
|
||||
import streamlit as st
|
||||
import os
|
||||
|
||||
import requests
|
||||
import streamlit as st
|
||||
|
||||
API_ENDPOINT = os.getenv("API_ENDPOINT", "http://localhost:8000")
|
||||
MODEL_ID = "1"
|
||||
DOC_REQUEST = "doc-qa"
|
||||
DOC_REQUEST = "query"
|
||||
|
||||
|
||||
def format_request(question,filters=None,top_k_reader=5,top_k_retriever=5):
|
||||
if filters == None:
|
||||
return {
|
||||
"questions": [question],
|
||||
"top_k_retriever": top_k_retriever,
|
||||
"top_k_reader": top_k_reader
|
||||
}
|
||||
return {
|
||||
"questions": [question],
|
||||
"filters": {
|
||||
"option1":[filters]
|
||||
},
|
||||
"top_k_retriever": top_k_retriever,
|
||||
"top_k_reader": top_k_reader
|
||||
}
|
||||
|
||||
@st.cache(show_spinner=False)
|
||||
def retrieve_doc(question,filters=None,top_k_reader=5,top_k_retriever=5):
|
||||
# Query Haystack API
|
||||
url = API_ENDPOINT +'/models/' + MODEL_ID + "/" + DOC_REQUEST
|
||||
req = format_request(question,filters,top_k_reader=top_k_reader,top_k_retriever=top_k_retriever)
|
||||
response_raw = requests.post(url,json=req).json()
|
||||
def haystack_query(query, filters=None, top_k_reader=5, top_k_retriever=5):
|
||||
url = f"{API_ENDPOINT}/{DOC_REQUEST}"
|
||||
req = {"query": query, "filters": filters, "top_k_retriever": top_k_retriever, "top_k_reader": top_k_reader}
|
||||
response_raw = requests.post(url, json=req).json()
|
||||
|
||||
# Format response
|
||||
result = []
|
||||
answers = response_raw['results'][0]['answers']
|
||||
for i in range(len(answers)):
|
||||
answer = answers[i]['answer']
|
||||
if answer:
|
||||
context = '...' + answers[i]['context'] + '...'
|
||||
meta_name = answers[i]['meta']['name']
|
||||
relevance = round(answers[i]['probability']*100,2)
|
||||
result.append({'context':context,'answer':answer,'source':meta_name,'relevance':relevance})
|
||||
return result, response_raw
|
||||
result = []
|
||||
answers = response_raw["answers"]
|
||||
for i in range(len(answers)):
|
||||
answer = answers[i]["answer"]
|
||||
if answer:
|
||||
context = "..." + answers[i]["context"] + "..."
|
||||
meta_name = answers[i]["meta"].get("name")
|
||||
relevance = round(answers[i]["probability"] * 100, 2)
|
||||
result.append({"context": context, "answer": answer, "source": meta_name, "relevance": relevance})
|
||||
return result, response_raw
|
||||
|
||||
37
ui/webapp.py
37
ui/webapp.py
@ -1,28 +1,35 @@
|
||||
import streamlit as st
|
||||
from utils import retrieve_doc
|
||||
from annotated_text import annotated_text
|
||||
|
||||
def annotate_answer(answer,context):
|
||||
from utils import haystack_query
|
||||
|
||||
|
||||
def annotate_answer(answer, context):
|
||||
start_idx = context.find(answer)
|
||||
end_idx = start_idx+len(answer)
|
||||
annotated_text(context[:start_idx],(answer,"ANSWER","#8ef"),context[end_idx:])
|
||||
|
||||
end_idx = start_idx + len(answer)
|
||||
annotated_text(context[:start_idx], (answer, "ANSWER", "#8ef"), context[end_idx:])
|
||||
|
||||
|
||||
st.write("# Haystack Demo")
|
||||
st.sidebar.header("Options")
|
||||
top_k_reader = st.sidebar.slider("Max. number of answers",min_value=1,max_value=10,value=3,step=1)
|
||||
top_k_retriever = st.sidebar.slider("Max. number of documents from retriever",min_value=1,max_value=10,value=3,step=1)
|
||||
question = st.text_input("Please provide your query:",value="Who is the father of Arya Starck?")
|
||||
top_k_reader = st.sidebar.slider("Max. number of answers", min_value=1, max_value=10, value=3, step=1)
|
||||
top_k_retriever = st.sidebar.slider(
|
||||
"Max. number of documents from retriever", min_value=1, max_value=10, value=3, step=1
|
||||
)
|
||||
question = st.text_input("Please provide your query:", value="Who is the father of Arya Starck?")
|
||||
run_query = st.button("Run")
|
||||
debug = st.sidebar.checkbox("Show debug info")
|
||||
if run_query:
|
||||
with st.spinner("Performing neural search on documents... 🧠 \n "
|
||||
"Do you want to optimize speed or accuracy? \n"
|
||||
"Check out the docs: https://haystack.deepset.ai/docs/latest/optimizationmd "):
|
||||
results,raw_json = retrieve_doc(question,top_k_reader=top_k_reader,top_k_retriever=top_k_retriever)
|
||||
with st.spinner(
|
||||
"Performing neural search on documents... 🧠 \n "
|
||||
"Do you want to optimize speed or accuracy? \n"
|
||||
"Check out the docs: https://haystack.deepset.ai/docs/latest/optimizationmd "
|
||||
):
|
||||
results, raw_json = haystack_query(question, top_k_reader=top_k_reader, top_k_retriever=top_k_retriever)
|
||||
st.write("## Retrieved answers:")
|
||||
for result in results:
|
||||
annotate_answer(result['answer'],result['context'])
|
||||
'**Relevance:** ', result['relevance'] , '**Source:** ' , result['source']
|
||||
annotate_answer(result["answer"], result["context"])
|
||||
"**Relevance:** ", result["relevance"], "**Source:** ", result["source"]
|
||||
if debug:
|
||||
st.subheader('REST API JSON response')
|
||||
st.subheader("REST API JSON response")
|
||||
st.write(raw_json)
|
||||
Loading…
x
Reference in New Issue
Block a user