mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-10-29 16:59:47 +00:00
Add type hints and mypy checks (#138)
This commit is contained in:
parent
180dc8cbd6
commit
98f1a3f9a7
6
.github/workflows/ci.yml
vendored
6
.github/workflows/ci.yml
vendored
@ -39,5 +39,9 @@ jobs:
|
||||
pip install -e .
|
||||
|
||||
- name: Test with pytest
|
||||
run: cd test && pytest
|
||||
|
||||
- name: Test with mypy
|
||||
run: |
|
||||
cd test && pytest
|
||||
pip install mypy
|
||||
mypy haystack --ignore-missing-imports
|
||||
|
||||
@ -38,7 +38,7 @@ document_store = ElasticsearchDocumentStore(
|
||||
search_fields=SEARCH_FIELD_NAME,
|
||||
embedding_dim=EMBEDDING_DIM,
|
||||
embedding_field=EMBEDDING_FIELD_NAME,
|
||||
excluded_meta_data=EXCLUDE_META_DATA_FIELDS,
|
||||
excluded_meta_data=EXCLUDE_META_DATA_FIELDS, # type: ignore
|
||||
)
|
||||
|
||||
|
||||
@ -52,7 +52,7 @@ class Feedback(BaseModel):
|
||||
|
||||
|
||||
@router.post("/doc-qa-feedback")
|
||||
def feedback(feedback: Feedback):
|
||||
def doc_qa_feedback(feedback: Feedback):
|
||||
if feedback.answer and feedback.offset_start_in_doc:
|
||||
elasticsearch_client.index(index=DB_INDEX_FEEDBACK, body=feedback.dict())
|
||||
else:
|
||||
@ -63,7 +63,7 @@ def feedback(feedback: Feedback):
|
||||
|
||||
|
||||
@router.post("/faq-qa-feedback")
|
||||
def feedback(feedback: Feedback):
|
||||
def faq_qa_feedback(feedback: Feedback):
|
||||
elasticsearch_client.index(index=DB_INDEX_FEEDBACK, body=feedback.dict())
|
||||
|
||||
|
||||
|
||||
@ -15,6 +15,7 @@ from haystack.api.config import DB_HOST, DB_PORT, DB_USER, DB_PW, DB_INDEX, ES_C
|
||||
from haystack.api.controller.utils import RequestLimiter
|
||||
from haystack.database.elasticsearch import ElasticsearchDocumentStore
|
||||
from haystack.reader.farm import FARMReader
|
||||
from haystack.retriever.base import BaseRetriever
|
||||
from haystack.retriever.elasticsearch import ElasticsearchRetriever, EmbeddingRetriever
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -34,12 +35,12 @@ document_store = ElasticsearchDocumentStore(
|
||||
search_fields=SEARCH_FIELD_NAME,
|
||||
embedding_dim=EMBEDDING_DIM,
|
||||
embedding_field=EMBEDDING_FIELD_NAME,
|
||||
excluded_meta_data=EXCLUDE_META_DATA_FIELDS,
|
||||
excluded_meta_data=EXCLUDE_META_DATA_FIELDS, # type: ignore
|
||||
)
|
||||
|
||||
|
||||
if EMBEDDING_MODEL_PATH:
|
||||
retriever = EmbeddingRetriever(document_store=document_store, embedding_model=EMBEDDING_MODEL_PATH, gpu=USE_GPU)
|
||||
retriever = EmbeddingRetriever(document_store=document_store, embedding_model=EMBEDDING_MODEL_PATH, gpu=USE_GPU) # type: BaseRetriever
|
||||
else:
|
||||
retriever = ElasticsearchRetriever(document_store=document_store)
|
||||
|
||||
@ -54,7 +55,7 @@ if READER_MODEL_PATH: # for extractive doc-qa
|
||||
num_processes=MAX_PROCESSES,
|
||||
max_seq_len=MAX_SEQ_LEN,
|
||||
doc_stride=DOC_STRIDE,
|
||||
)
|
||||
) # type: Optional[FARMReader]
|
||||
else:
|
||||
reader = None # don't need one for pure FAQ matching
|
||||
|
||||
@ -66,7 +67,7 @@ FINDERS = {1: Finder(reader=reader, retriever=retriever)}
|
||||
#############################################
|
||||
class Question(BaseModel):
|
||||
questions: List[str]
|
||||
filters: Dict[str, Optional[str]] = None
|
||||
filters: Optional[Dict[str, Optional[str]]] = None
|
||||
top_k_reader: int = DEFAULT_TOP_K_READER
|
||||
top_k_retriever: int = DEFAULT_TOP_K_RETRIEVER
|
||||
|
||||
@ -74,8 +75,8 @@ class Question(BaseModel):
|
||||
class Answer(BaseModel):
|
||||
answer: Optional[str]
|
||||
question: Optional[str]
|
||||
score: float = None
|
||||
probability: float = None
|
||||
score: Optional[float] = None
|
||||
probability: Optional[float] = None
|
||||
context: Optional[str]
|
||||
offset_start: int
|
||||
offset_end: int
|
||||
@ -112,14 +113,16 @@ def doc_qa(model_id: int, request: Question):
|
||||
for question in request.questions:
|
||||
if request.filters:
|
||||
# put filter values into a list and remove filters with null value
|
||||
request.filters = {key: [value] for key, value in request.filters.items() if value is not None}
|
||||
filters = {key: [value] for key, value in request.filters.items() if value is not None}
|
||||
logger.info(f" [{datetime.now()}] Request: {request}")
|
||||
else:
|
||||
filters = {}
|
||||
|
||||
result = finder.get_answers(
|
||||
question=question,
|
||||
top_k_retriever=request.top_k_retriever,
|
||||
top_k_reader=request.top_k_reader,
|
||||
filters=request.filters,
|
||||
filters=filters,
|
||||
)
|
||||
results.append(result)
|
||||
|
||||
@ -141,11 +144,13 @@ def faq_qa(model_id: int, request: Question):
|
||||
for question in request.questions:
|
||||
if request.filters:
|
||||
# put filter values into a list and remove filters with null value
|
||||
request.filters = {key: [value] for key, value in request.filters.items() if value is not None}
|
||||
filters = {key: [value] for key, value in request.filters.items() if value is not None}
|
||||
logger.info(f" [{datetime.now()}] Request: {request}")
|
||||
else:
|
||||
filters = {}
|
||||
|
||||
result = finder.get_answers_via_similar_questions(
|
||||
question=question, top_k_retriever=request.top_k_retriever, filters=request.filters,
|
||||
question=question, top_k_retriever=request.top_k_retriever, filters=filters,
|
||||
)
|
||||
results.append(result)
|
||||
|
||||
|
||||
@ -1,35 +1,9 @@
|
||||
from abc import abstractmethod
|
||||
from typing import Any, Optional, Dict
|
||||
from typing import Any, Optional, Dict, List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class BaseDocumentStore:
|
||||
"""
|
||||
Base class for implementing Document Stores.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def write_documents(self, documents):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_document_by_id(self, id):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_document_ids_by_tags(self, tag):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_document_count(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def query_by_embedding(self, query_emb, top_k=10, candidate_doc_ids=None):
|
||||
pass
|
||||
|
||||
|
||||
class Document(BaseModel):
|
||||
id: str = Field(..., description="_id field from Elasticsearch")
|
||||
text: str = Field(..., description="Text of the document")
|
||||
@ -41,5 +15,35 @@ class Document(BaseModel):
|
||||
# name: Optional[str] = Field(None, description="Title of the document")
|
||||
question: Optional[str] = Field(None, description="Question text for FAQs.")
|
||||
query_score: Optional[float] = Field(None, description="Elasticsearch query score for a retrieved document")
|
||||
meta: Optional[Dict[str, Any]] = Field(None, description="")
|
||||
meta: Dict[str, Any] = Field({}, description="")
|
||||
tags: Optional[Dict[str, Any]] = Field(None, description="Tags that allow filtering of the data")
|
||||
|
||||
|
||||
class BaseDocumentStore:
|
||||
"""
|
||||
Base class for implementing Document Stores.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def write_documents(self, documents: List[dict]):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_all_documents(self) -> List[Document]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_document_by_id(self, id: str) -> Optional[Document]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_document_ids_by_tags(self, tag) -> List[str]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_document_count(self) -> int:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def query_by_embedding(self, query_emb: List[float], top_k: int = 10, candidate_doc_ids: Optional[List[str]] = None) -> List[Document]:
|
||||
pass
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
from string import Template
|
||||
from typing import Union
|
||||
from typing import List, Optional, Union, Dict, Any
|
||||
from elasticsearch import Elasticsearch
|
||||
from elasticsearch.helpers import bulk, scan
|
||||
|
||||
@ -18,14 +18,14 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
||||
username: str = "",
|
||||
password: str = "",
|
||||
index: str = "document",
|
||||
search_fields: Union[str,list] = "text",
|
||||
search_fields: Union[str, list] = "text",
|
||||
text_field: str = "text",
|
||||
name_field: str = "name",
|
||||
external_source_id_field: str = "external_source_id",
|
||||
embedding_field: str = None,
|
||||
embedding_dim: str = None,
|
||||
custom_mapping: dict = None,
|
||||
excluded_meta_data: list = None,
|
||||
embedding_field: Optional[str] = None,
|
||||
embedding_dim: Optional[str] = None,
|
||||
custom_mapping: Optional[dict] = None,
|
||||
excluded_meta_data: Optional[list] = None,
|
||||
scheme: str = "http",
|
||||
ca_certs: bool = False,
|
||||
verify_certs: bool = True,
|
||||
@ -93,14 +93,14 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
||||
self.embedding_field = embedding_field
|
||||
self.excluded_meta_data = excluded_meta_data
|
||||
|
||||
def get_document_by_id(self, id: str) -> Document:
|
||||
def get_document_by_id(self, id: str) -> Optional[Document]:
|
||||
query = {"query": {"ids": {"values": [id]}}}
|
||||
result = self.client.search(index=self.index, body=query)["hits"]["hits"]
|
||||
|
||||
document = self._convert_es_hit_to_document(result[0]) if result else None
|
||||
return document
|
||||
|
||||
def get_document_ids_by_tags(self, tags: dict) -> [str]:
|
||||
def get_document_ids_by_tags(self, tags: dict) -> List[str]:
|
||||
term_queries = [{"terms": {key: value}} for key, value in tags.items()]
|
||||
query = {"query": {"bool": {"must": term_queries}}}
|
||||
logger.debug(f"Tag filter query: {query}")
|
||||
@ -110,19 +110,19 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
||||
doc_ids.append(hit["_id"])
|
||||
return doc_ids
|
||||
|
||||
def write_documents(self, documents):
|
||||
def write_documents(self, documents: List[dict]):
|
||||
for doc in documents:
|
||||
doc["_op_type"] = "create"
|
||||
doc["_index"] = self.index
|
||||
|
||||
bulk(self.client, documents, request_timeout=300)
|
||||
|
||||
def get_document_count(self):
|
||||
def get_document_count(self) -> int:
|
||||
result = self.client.count()
|
||||
count = result["count"]
|
||||
return count
|
||||
|
||||
def get_all_documents(self):
|
||||
def get_all_documents(self) -> List[Document]:
|
||||
result = scan(self.client, query={"query": {"match_all": {}}}, index=self.index)
|
||||
documents = [self._convert_es_hit_to_document(hit) for hit in result]
|
||||
|
||||
@ -131,11 +131,11 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
||||
def query(
|
||||
self,
|
||||
query: str,
|
||||
filters: dict = None,
|
||||
filters: Optional[dict] = None,
|
||||
top_k: int = 10,
|
||||
custom_query: str = None,
|
||||
index: str = None,
|
||||
) -> [Document]:
|
||||
custom_query: Optional[str] = None,
|
||||
index: Optional[str] = None,
|
||||
) -> List[Document]:
|
||||
|
||||
if index is None:
|
||||
index = self.index
|
||||
@ -180,7 +180,7 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
||||
documents = [self._convert_es_hit_to_document(hit) for hit in result]
|
||||
return documents
|
||||
|
||||
def query_by_embedding(self, query_emb, top_k=10, candidate_doc_ids=None) -> [Document]:
|
||||
def query_by_embedding(self, query_emb: List[float], top_k: int = 10, candidate_doc_ids: Optional[List[str]] = None) -> List[Document]:
|
||||
if not self.embedding_field:
|
||||
raise RuntimeError("Please specify arg `embedding_field` in ElasticsearchDocumentStore()")
|
||||
else:
|
||||
@ -198,7 +198,7 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} # type: Dict[str,Any]
|
||||
|
||||
if candidate_doc_ids:
|
||||
body["query"]["script_score"]["query"] = {
|
||||
@ -216,7 +216,7 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
||||
documents = [self._convert_es_hit_to_document(hit, score_adjustment=-1) for hit in result]
|
||||
return documents
|
||||
|
||||
def _convert_es_hit_to_document(self, hit, score_adjustment=0) -> Document:
|
||||
def _convert_es_hit_to_document(self, hit: dict, score_adjustment: int = 0) -> Document:
|
||||
# We put all additional data of the doc into meta_data and return it in the API
|
||||
meta_data = {k:v for k,v in hit["_source"].items() if k not in (self.text_field, self.external_source_id_field)}
|
||||
meta_data["name"] = meta_data.pop(self.name_field, None)
|
||||
@ -286,7 +286,7 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
||||
bulk(self.client, eval_docs_to_index)
|
||||
bulk(self.client, questions_to_index)
|
||||
|
||||
def get_all_documents_in_index(self, index, filters=None):
|
||||
def get_all_documents_in_index(self, index: str, filters: Optional[dict] = None):
|
||||
body = {
|
||||
"query": {
|
||||
"bool": {
|
||||
@ -295,7 +295,7 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} # type: Dict[str, Any]
|
||||
|
||||
if filters:
|
||||
body["query"]["bool"]["filter"] = {"term": filters}
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import Any, Dict, List, Optional, Union, Tuple
|
||||
|
||||
from haystack.database.base import BaseDocumentStore, Document
|
||||
|
||||
|
||||
@ -10,7 +12,7 @@ class InMemoryDocumentStore(BaseDocumentStore):
|
||||
self.docs = {}
|
||||
self.doc_tags = {}
|
||||
|
||||
def write_documents(self, documents):
|
||||
def write_documents(self, documents: List[dict]):
|
||||
import hashlib
|
||||
|
||||
if documents is None:
|
||||
@ -33,7 +35,7 @@ class InMemoryDocumentStore(BaseDocumentStore):
|
||||
|
||||
self._map_tags_to_ids(hash, tags)
|
||||
|
||||
def _map_tags_to_ids(self, hash, tags):
|
||||
def _map_tags_to_ids(self, hash: str, tags: List[str]):
|
||||
if isinstance(tags, list):
|
||||
for tag in tags:
|
||||
if isinstance(tag, dict):
|
||||
@ -48,10 +50,11 @@ class InMemoryDocumentStore(BaseDocumentStore):
|
||||
else:
|
||||
self.doc_tags[comp_key] = [hash]
|
||||
|
||||
def get_document_by_id(self, id):
|
||||
return self.docs[id]
|
||||
def get_document_by_id(self, id: str) -> Document:
|
||||
document = self._convert_memory_hit_to_document(self.docs[id], doc_id=id)
|
||||
return document
|
||||
|
||||
def _convert_memory_hit_to_document(self, hit, doc_id=None) -> Document:
|
||||
def _convert_memory_hit_to_document(self, hit: Tuple[Any, Any], doc_id: Optional[str] = None) -> Document:
|
||||
document = Document(
|
||||
id=doc_id,
|
||||
text=hit[0].get('text', None),
|
||||
@ -60,7 +63,7 @@ class InMemoryDocumentStore(BaseDocumentStore):
|
||||
)
|
||||
return document
|
||||
|
||||
def query_by_embedding(self, query_emb, top_k=10, candidate_doc_ids=None) -> [Document]:
|
||||
def query_by_embedding(self, query_emb: List[float], top_k: int = 10, candidate_doc_ids: Optional[List[str]] = None) -> List[Document]:
|
||||
from haystack.api import config
|
||||
from numpy import dot
|
||||
from numpy.linalg import norm
|
||||
@ -78,7 +81,7 @@ class InMemoryDocumentStore(BaseDocumentStore):
|
||||
|
||||
return sorted(candidate_docs, key=lambda x: x.query_score, reverse=True)[0:top_k]
|
||||
|
||||
def get_document_ids_by_tags(self, tags):
|
||||
def get_document_ids_by_tags(self, tags: Union[List[Dict[str, Union[str, List[str]]]], Dict[str, Union[str, List[str]]]]) -> List[str]:
|
||||
"""
|
||||
The format for the dict is {"tag-1": "value-1", "tag-2": "value-2" ...}
|
||||
The format for the dict is {"tag-1": ["value-1","value-2"], "tag-2": ["value-3]" ...}
|
||||
@ -88,7 +91,7 @@ class InMemoryDocumentStore(BaseDocumentStore):
|
||||
result = self._find_ids_by_tags(tags)
|
||||
return result
|
||||
|
||||
def _find_ids_by_tags(self, tags):
|
||||
def _find_ids_by_tags(self, tags: List[Dict[str, Union[str, List[str]]]]):
|
||||
result = []
|
||||
for tag in tags:
|
||||
tag_keys = tag.keys()
|
||||
@ -102,8 +105,8 @@ class InMemoryDocumentStore(BaseDocumentStore):
|
||||
result.append(self.docs.get(doc_id))
|
||||
return result
|
||||
|
||||
def get_document_count(self):
|
||||
def get_document_count(self) -> int:
|
||||
return len(self.docs.items())
|
||||
|
||||
def get_all_documents(self):
|
||||
def get_all_documents(self) -> List[Document]:
|
||||
return [Document(id=item[0], text=item[1]['text'], name=item[1]['name'], meta=item[1].get('meta', {})) for item in self.docs.items()]
|
||||
|
||||
@ -1,11 +1,12 @@
|
||||
import json
|
||||
from typing import Any, Dict, Union, List, Optional
|
||||
|
||||
from sqlalchemy import create_engine, Column, Integer, String, DateTime, func, ForeignKey, PickleType
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import relationship, sessionmaker
|
||||
|
||||
from haystack.database.base import BaseDocumentStore, Document as DocumentSchema
|
||||
|
||||
Base = declarative_base()
|
||||
Base = declarative_base() # type: Any
|
||||
|
||||
|
||||
class ORMBase(Base):
|
||||
@ -43,19 +44,19 @@ class DocumentTag(ORMBase):
|
||||
|
||||
|
||||
class SQLDocumentStore(BaseDocumentStore):
|
||||
def __init__(self, url="sqlite://"):
|
||||
def __init__(self, url: str = "sqlite://"):
|
||||
engine = create_engine(url)
|
||||
ORMBase.metadata.create_all(engine)
|
||||
Session = sessionmaker(bind=engine)
|
||||
self.session = Session()
|
||||
|
||||
def get_document_by_id(self, id):
|
||||
def get_document_by_id(self, id: str) -> Optional[DocumentSchema]:
|
||||
document_row = self.session.query(Document).get(id)
|
||||
document = self._convert_sql_row_to_document(document_row)
|
||||
|
||||
return document
|
||||
|
||||
def get_all_documents(self):
|
||||
def get_all_documents(self) -> List[DocumentSchema]:
|
||||
document_rows = self.session.query(Document).all()
|
||||
documents = []
|
||||
for row in document_rows:
|
||||
@ -63,7 +64,7 @@ class SQLDocumentStore(BaseDocumentStore):
|
||||
|
||||
return documents
|
||||
|
||||
def get_document_ids_by_tags(self, tags):
|
||||
def get_document_ids_by_tags(self, tags: Dict[str, Union[str, List]]) -> List[str]:
|
||||
"""
|
||||
Get list of document ids that have tags from the given list of tags.
|
||||
|
||||
@ -91,16 +92,16 @@ class SQLDocumentStore(BaseDocumentStore):
|
||||
doc_ids = [row[0] for row in query_results]
|
||||
return doc_ids
|
||||
|
||||
def write_documents(self, documents):
|
||||
def write_documents(self, documents: List[dict]):
|
||||
for doc in documents:
|
||||
row = Document(name=doc["name"], text=doc["text"], meta_data=doc.get("meta", {}))
|
||||
self.session.add(row)
|
||||
self.session.commit()
|
||||
|
||||
def get_document_count(self):
|
||||
def get_document_count(self) -> int:
|
||||
return self.session.query(Document).count()
|
||||
|
||||
def _convert_sql_row_to_document(self, row) -> Document:
|
||||
def _convert_sql_row_to_document(self, row) -> DocumentSchema:
|
||||
document = DocumentSchema(
|
||||
id=row.id,
|
||||
text=row.text,
|
||||
|
||||
@ -1,9 +1,13 @@
|
||||
import logging
|
||||
import time
|
||||
from statistics import mean
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
import numpy as np
|
||||
from scipy.special import expit
|
||||
import time
|
||||
from statistics import mean
|
||||
|
||||
from haystack.reader.base import BaseReader
|
||||
from haystack.retriever.base import BaseRetriever
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -15,13 +19,13 @@ class Finder:
|
||||
It provides an interface to predict top n answers for a given question.
|
||||
"""
|
||||
|
||||
def __init__(self, reader, retriever):
|
||||
def __init__(self, reader: Optional[BaseReader], retriever: Optional[BaseRetriever]):
|
||||
self.retriever = retriever
|
||||
self.reader = reader
|
||||
if self.reader is None and self.retriever is None:
|
||||
raise AttributeError("Finder: self.reader and self.retriever can not be both None")
|
||||
|
||||
def get_answers(self, question: str, top_k_reader: int = 1, top_k_retriever: int = 10, filters: dict = None):
|
||||
def get_answers(self, question: str, top_k_reader: int = 1, top_k_retriever: int = 10, filters: Optional[dict] = None):
|
||||
"""
|
||||
Get top k answers for a given question.
|
||||
|
||||
@ -41,8 +45,8 @@ class Finder:
|
||||
|
||||
if len(documents) == 0:
|
||||
logger.info("Retriever did not return any documents. Skipping reader ...")
|
||||
results = {"question": question, "answers": []}
|
||||
return results
|
||||
empty_result = {"question": question, "answers": []}
|
||||
return empty_result
|
||||
|
||||
# 2) Apply reader to get granular answer(s)
|
||||
len_chars = sum([len(d.text) for d in documents])
|
||||
@ -50,7 +54,7 @@ class Finder:
|
||||
|
||||
results = self.reader.predict(question=question,
|
||||
documents=documents,
|
||||
top_k=top_k_reader)
|
||||
top_k=top_k_reader) # type: Dict[str, Any]
|
||||
|
||||
# Add corresponding document_name and more meta data, if an answer contains the document_id
|
||||
for ans in results["answers"]:
|
||||
@ -61,7 +65,7 @@ class Finder:
|
||||
|
||||
return results
|
||||
|
||||
def get_answers_via_similar_questions(self, question: str, top_k_retriever: int = 10, filters: dict = None):
|
||||
def get_answers_via_similar_questions(self, question: str, top_k_retriever: int = 10, filters: Optional[dict] = None):
|
||||
"""
|
||||
Get top k answers for a given question using only a retriever.
|
||||
|
||||
@ -75,12 +79,12 @@ class Finder:
|
||||
if self.retriever is None:
|
||||
raise AttributeError("Finder.get_answers_via_similar_questions requires self.retriever")
|
||||
|
||||
results = {"question": question, "answers": []}
|
||||
results = {"question": question, "answers": []} # type: Dict[str, Any]
|
||||
|
||||
# 1) Optional: reduce the search space via document tags
|
||||
if filters:
|
||||
logging.info(f"Apply filters: {filters}")
|
||||
candidate_doc_ids = self.retriever.document_store.get_document_ids_by_tags(filters)
|
||||
candidate_doc_ids = self.retriever.document_store.get_document_ids_by_tags(filters) # type: ignore
|
||||
logger.info(f"Got candidate IDs due to filters: {candidate_doc_ids}")
|
||||
|
||||
if len(candidate_doc_ids) == 0:
|
||||
@ -88,28 +92,35 @@ class Finder:
|
||||
return results
|
||||
|
||||
else:
|
||||
candidate_doc_ids = None
|
||||
candidate_doc_ids = None # type: ignore
|
||||
|
||||
# 2) Apply retriever to match similar questions via cosine similarity of embeddings
|
||||
documents = self.retriever.retrieve(question, top_k=top_k_retriever, candidate_doc_ids=candidate_doc_ids)
|
||||
documents = self.retriever.retrieve(question, top_k=top_k_retriever, candidate_doc_ids=candidate_doc_ids) # type: ignore
|
||||
|
||||
# 3) Format response
|
||||
for doc in documents:
|
||||
#TODO proper calibratation of pseudo probabilities
|
||||
cur_answer = {"question": doc.meta["question"], "answer": doc.text, "context": doc.text,
|
||||
cur_answer = {"question": doc.meta["question"], "answer": doc.text, "context": doc.text, # type: ignore
|
||||
"score": doc.query_score, "offset_start": 0, "offset_end": len(doc.text), "meta": doc.meta
|
||||
}
|
||||
if self.retriever.embedding_model:
|
||||
probability = (doc.query_score + 1) / 2
|
||||
if self.retriever.embedding_model: # type: ignore
|
||||
probability = (doc.query_score + 1) / 2 # type: ignore
|
||||
else:
|
||||
probability = float(expit(np.asarray(doc.query_score / 8)))
|
||||
probability = float(expit(np.asarray(doc.query_score / 8))) # type: ignore
|
||||
|
||||
cur_answer["probability"] = probability
|
||||
results["answers"].append(cur_answer)
|
||||
|
||||
return results
|
||||
|
||||
def eval(self, label_index: str = "feedback", doc_index: str = "eval_document", label_origin: str = "gold_label",
|
||||
top_k_retriever: int = 10, top_k_reader: int = 10):
|
||||
def eval(
|
||||
self,
|
||||
label_index: str = "feedback",
|
||||
doc_index: str = "eval_document",
|
||||
label_origin: str = "gold_label",
|
||||
top_k_retriever: int = 10,
|
||||
top_k_reader: int = 10,
|
||||
):
|
||||
"""
|
||||
Evaluation of the whole pipeline by first evaluating the Retriever and then evaluating the Reader on the result
|
||||
of the Retriever.
|
||||
@ -155,10 +166,14 @@ class Finder:
|
||||
:param top_k_reader: How many answers to return per question
|
||||
:type top_k_reader: int
|
||||
"""
|
||||
|
||||
if not self.reader or not self.retriever:
|
||||
raise Exception("Finder needs to have a reader and retriever for the evaluation.")
|
||||
|
||||
finder_start_time = time.time()
|
||||
# extract all questions for evaluation
|
||||
filter = {"origin": label_origin}
|
||||
questions = self.retriever.document_store.get_all_documents_in_index(index=label_index, filters=filter)
|
||||
questions = self.retriever.document_store.get_all_documents_in_index(index=label_index, filters=filter) # type: ignore
|
||||
|
||||
correct_retrievals = 0
|
||||
summed_avg_precision_retriever = 0
|
||||
@ -192,7 +207,7 @@ class Finder:
|
||||
# check if correct doc among retrieved docs
|
||||
if doc.meta["doc_id"] == question["_source"]["doc_id"]:
|
||||
correct_retrievals += 1
|
||||
summed_avg_precision_retriever += 1 / (doc_idx + 1)
|
||||
summed_avg_precision_retriever += 1 / (doc_idx + 1) # type: ignore
|
||||
questions_with_docs.append({
|
||||
"question": question,
|
||||
"docs": retrieved_docs,
|
||||
@ -202,8 +217,8 @@ class Finder:
|
||||
number_of_questions = q_idx + 1
|
||||
|
||||
number_of_no_answer = 0
|
||||
previous_return_no_answers = self.reader.return_no_answers
|
||||
self.reader.return_no_answers = True
|
||||
previous_return_no_answers = self.reader.return_no_answers # type: ignore
|
||||
self.reader.return_no_answers = True # type: ignore
|
||||
# extract answers
|
||||
reader_start_time = time.time()
|
||||
for q_idx, question in enumerate(questions_with_docs):
|
||||
@ -260,10 +275,10 @@ class Finder:
|
||||
current_f1 = (2 * precision * recall) / (precision + recall)
|
||||
# top-1 answer
|
||||
if answer_idx == 0:
|
||||
summed_f1_top1 += current_f1
|
||||
summed_f1_top1_has_answer += current_f1
|
||||
summed_f1_top1 += current_f1 # type: ignore
|
||||
summed_f1_top1_has_answer += current_f1 # type: ignore
|
||||
if current_f1 > best_f1:
|
||||
best_f1 = current_f1
|
||||
best_f1 = current_f1 # type: ignore
|
||||
# top-k answers: use best f1-score
|
||||
summed_f1_topk += best_f1
|
||||
summed_f1_topk_has_answer += best_f1
|
||||
@ -311,7 +326,7 @@ class Finder:
|
||||
reader_top1_no_answer_accuracy = correct_no_answers_top1 / number_of_no_answer
|
||||
reader_topk_no_answer_accuracy = correct_no_answers_topk / number_of_no_answer
|
||||
|
||||
self.reader.return_no_answers = previous_return_no_answers
|
||||
self.reader.return_no_answers = previous_return_no_answers # type: ignore
|
||||
|
||||
logger.info((f"{correct_readings_topk} out of {number_of_questions} questions were correctly answered "
|
||||
f"({(correct_readings_topk/number_of_questions):.2%})."))
|
||||
|
||||
@ -1,15 +1,15 @@
|
||||
import re
|
||||
|
||||
|
||||
def clean_wiki_text(text):
|
||||
def clean_wiki_text(text: str) -> str:
|
||||
# get rid of multiple new lines
|
||||
while "\n\n" in text:
|
||||
text = text.replace("\n\n", "\n")
|
||||
|
||||
# remove extremely short lines
|
||||
text = text.split("\n")
|
||||
lines = text.split("\n")
|
||||
cleaned = []
|
||||
for l in text:
|
||||
for l in lines:
|
||||
if len(l) > 30:
|
||||
cleaned.append(l)
|
||||
elif l[:2] == "==" and l[-2:] == "==":
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from abc import abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
class BaseConverter:
|
||||
@ -9,11 +10,11 @@ class BaseConverter:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
remove_numeric_tables: bool = None,
|
||||
remove_header_footer: bool = None,
|
||||
remove_whitespace: bool = None,
|
||||
remove_empty_lines: bool = None,
|
||||
valid_languages: [str] = None,
|
||||
remove_numeric_tables: Optional[bool] = None,
|
||||
remove_header_footer: Optional[bool] = None,
|
||||
remove_whitespace: Optional[bool] = None,
|
||||
remove_empty_lines: Optional[bool] = None,
|
||||
valid_languages: Optional[List[str]] = None,
|
||||
):
|
||||
"""
|
||||
:param remove_numeric_tables: This option uses heuristics to remove numeric rows from the tables.
|
||||
@ -40,5 +41,5 @@ class BaseConverter:
|
||||
self.valid_languages = valid_languages
|
||||
|
||||
@abstractmethod
|
||||
def extract_pages(self, file_path: Path) -> [str]:
|
||||
def extract_pages(self, file_path: Path) -> List[str]:
|
||||
pass
|
||||
|
||||
@ -4,6 +4,7 @@ import subprocess
|
||||
from functools import partial, reduce
|
||||
from itertools import chain
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple, Generator, Set
|
||||
|
||||
import fitz
|
||||
import langdetect
|
||||
@ -16,11 +17,11 @@ logger = logging.getLogger(__name__)
|
||||
class PDFToTextConverter(BaseConverter):
|
||||
def __init__(
|
||||
self,
|
||||
remove_numeric_tables: bool = False,
|
||||
remove_whitespace: bool = None,
|
||||
remove_empty_lines: bool = None,
|
||||
remove_header_footer: bool = None,
|
||||
valid_languages: [str] = None,
|
||||
remove_numeric_tables: Optional[bool] = False,
|
||||
remove_whitespace: Optional[bool] = None,
|
||||
remove_empty_lines: Optional[bool] = None,
|
||||
remove_header_footer: Optional[bool] = None,
|
||||
valid_languages: Optional[List[str]] = None,
|
||||
):
|
||||
"""
|
||||
:param remove_numeric_tables: This option uses heuristics to remove numeric rows from the tables.
|
||||
@ -64,7 +65,7 @@ class PDFToTextConverter(BaseConverter):
|
||||
valid_languages=valid_languages,
|
||||
)
|
||||
|
||||
def extract_pages(self, file_path: Path) -> [str]:
|
||||
def extract_pages(self, file_path: Path) -> List[str]:
|
||||
|
||||
page_count = fitz.open(file_path).pageCount
|
||||
|
||||
@ -106,13 +107,12 @@ class PDFToTextConverter(BaseConverter):
|
||||
pages.append(page)
|
||||
page_number += 1
|
||||
|
||||
if self.valid_languages:
|
||||
document_text = "".join(pages)
|
||||
if not self._validate_language(document_text):
|
||||
logger.warning(
|
||||
f"The language for {file_path} is not one of {self.valid_languages}. The file may not have "
|
||||
f"been decoded in the correct text format."
|
||||
)
|
||||
document_text = "".join(pages)
|
||||
if not self._validate_language(document_text):
|
||||
logger.warning(
|
||||
f"The language for {file_path} is not one of {self.valid_languages}. The file may not have "
|
||||
f"been decoded in the correct text format."
|
||||
)
|
||||
|
||||
if self.remove_header_footer:
|
||||
pages, header, footer = self.find_and_remove_header_footer(
|
||||
@ -122,7 +122,7 @@ class PDFToTextConverter(BaseConverter):
|
||||
|
||||
return pages
|
||||
|
||||
def _extract_page(self, file_path: Path, page_number: int, layout: bool):
|
||||
def _extract_page(self, file_path: Path, page_number: int, layout: bool) -> str:
|
||||
"""
|
||||
Extract a page from the pdf file at file_path.
|
||||
|
||||
@ -132,17 +132,20 @@ class PDFToTextConverter(BaseConverter):
|
||||
the content stream order.
|
||||
"""
|
||||
if layout:
|
||||
command = ["pdftotext", "-layout", "-f", str(page_number), "-l", str(page_number), file_path, "-"]
|
||||
command = ["pdftotext", "-layout", "-f", str(page_number), "-l", str(page_number), str(file_path), "-"]
|
||||
else:
|
||||
command = ["pdftotext", "-f", str(page_number), "-l", str(page_number), file_path, "-"]
|
||||
command = ["pdftotext", "-f", str(page_number), "-l", str(page_number), str(file_path), "-"]
|
||||
output_page = subprocess.run(command, capture_output=True, shell=False)
|
||||
page = output_page.stdout.decode(errors="ignore")
|
||||
return page
|
||||
|
||||
def _validate_language(self, text: str):
|
||||
def _validate_language(self, text: str) -> bool:
|
||||
"""
|
||||
Validate if the language of the text is one of valid languages.
|
||||
"""
|
||||
if not self.valid_languages:
|
||||
return True
|
||||
|
||||
try:
|
||||
lang = langdetect.detect(text)
|
||||
except langdetect.lang_detect_exception.LangDetectException:
|
||||
@ -153,7 +156,7 @@ class PDFToTextConverter(BaseConverter):
|
||||
else:
|
||||
return False
|
||||
|
||||
def _ngram(self, seq: str, n: int):
|
||||
def _ngram(self, seq: str, n: int) -> Generator[str, None, None]:
|
||||
"""
|
||||
Return ngram (of tokens - currently splitted by whitespace)
|
||||
:param seq: str, string from which the ngram shall be created
|
||||
@ -166,20 +169,20 @@ class PDFToTextConverter(BaseConverter):
|
||||
seq = seq.replace("\n", " \n")
|
||||
seq = seq.replace("\t", " \t")
|
||||
|
||||
seq = seq.split(" ")
|
||||
words = seq.split(" ")
|
||||
ngrams = (
|
||||
" ".join(seq[i : i + n]).replace(" \n", "\n").replace(" \t", "\t") for i in range(0, len(seq) - n + 1)
|
||||
" ".join(words[i : i + n]).replace(" \n", "\n").replace(" \t", "\t") for i in range(0, len(words) - n + 1)
|
||||
)
|
||||
|
||||
return ngrams
|
||||
|
||||
def _allngram(self, seq: str, min_ngram: int, max_ngram: int):
|
||||
def _allngram(self, seq: str, min_ngram: int, max_ngram: int) -> Set[str]:
|
||||
lengths = range(min_ngram, max_ngram) if max_ngram else range(min_ngram, len(seq))
|
||||
ngrams = map(partial(self._ngram, seq), lengths)
|
||||
res = set(chain.from_iterable(ngrams))
|
||||
return res
|
||||
|
||||
def find_longest_common_ngram(self, sequences: [str], max_ngram: int = 30, min_ngram: int = 3):
|
||||
def find_longest_common_ngram(self, sequences: List[str], max_ngram: int = 30, min_ngram: int = 3) -> Optional[str]:
|
||||
"""
|
||||
Find the longest common ngram across different text sequences (e.g. start of pages).
|
||||
Considering all ngrams between the specified range. Helpful for finding footers, headers etc.
|
||||
@ -201,8 +204,8 @@ class PDFToTextConverter(BaseConverter):
|
||||
return longest if longest.strip() else None
|
||||
|
||||
def find_and_remove_header_footer(
|
||||
self, pages: [str], n_chars: int, n_first_pages_to_ignore: int, n_last_pages_to_ignore: int
|
||||
):
|
||||
self, pages: List[str], n_chars: int, n_first_pages_to_ignore: int, n_last_pages_to_ignore: int
|
||||
) -> Tuple[List[str], Optional[str], Optional[str]]:
|
||||
"""
|
||||
Heuristic to find footers and headers across different pages by searching for the longest common string.
|
||||
For headers we only search in the first n_chars characters (for footer: last n_chars).
|
||||
|
||||
@ -1,16 +1,18 @@
|
||||
from pathlib import Path
|
||||
import logging
|
||||
from farm.data_handler.utils import http_get
|
||||
import tempfile
|
||||
import tarfile
|
||||
import tempfile
|
||||
import zipfile
|
||||
from typing import Callable
|
||||
from pathlib import Path
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
from farm.data_handler.utils import http_get
|
||||
|
||||
from haystack.indexing.file_converters.pdftotext import PDFToTextConverter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def convert_files_to_dicts(dir_path: str, clean_func: Callable = None, split_paragraphs: bool = False) -> [dict]:
|
||||
def convert_files_to_dicts(dir_path: str, clean_func: Optional[Callable] = None, split_paragraphs: bool = False) -> List[dict]:
|
||||
"""
|
||||
Convert all files(.txt, .pdf) in the sub-directories of the given path to Python dicts that can be written to a
|
||||
Document Store.
|
||||
@ -24,7 +26,7 @@ def convert_files_to_dicts(dir_path: str, clean_func: Callable = None, split_par
|
||||
|
||||
file_paths = [p for p in Path(dir_path).glob("**/*")]
|
||||
if ".pdf" in [p.suffix.lower() for p in file_paths]:
|
||||
pdf_converter = PDFToTextConverter()
|
||||
pdf_converter = PDFToTextConverter() # type: Optional[PDFToTextConverter]
|
||||
else:
|
||||
pdf_converter = None
|
||||
|
||||
@ -33,7 +35,7 @@ def convert_files_to_dicts(dir_path: str, clean_func: Callable = None, split_par
|
||||
if path.suffix.lower() == ".txt":
|
||||
with open(path) as doc:
|
||||
text = doc.read()
|
||||
elif path.suffix.lower() == ".pdf":
|
||||
elif path.suffix.lower() == ".pdf" and pdf_converter:
|
||||
pages = pdf_converter.extract_pages(path)
|
||||
text = "\n".join(pages)
|
||||
else:
|
||||
@ -53,7 +55,7 @@ def convert_files_to_dicts(dir_path: str, clean_func: Callable = None, split_par
|
||||
return documents
|
||||
|
||||
|
||||
def fetch_archive_from_http(url, output_dir, proxies=None):
|
||||
def fetch_archive_from_http(url: str, output_dir: str, proxies: Optional[dict] = None):
|
||||
"""
|
||||
Fetch an archive (zip or tar.gz) from a url via http and extract content to an output directory.
|
||||
|
||||
@ -86,11 +88,11 @@ def fetch_archive_from_http(url, output_dir, proxies=None):
|
||||
temp_file.seek(0) # making tempfile accessible
|
||||
# extract
|
||||
if url[-4:] == ".zip":
|
||||
archive = zipfile.ZipFile(temp_file.name)
|
||||
archive.extractall(output_dir)
|
||||
zip_archive = zipfile.ZipFile(temp_file.name)
|
||||
zip_archive.extractall(output_dir)
|
||||
elif url[-7:] == ".tar.gz":
|
||||
archive = tarfile.open(temp_file.name)
|
||||
archive.extractall(output_dir)
|
||||
tar_archive = tarfile.open(temp_file.name)
|
||||
tar_archive.extractall(output_dir)
|
||||
# temp_file gets deleted here
|
||||
return True
|
||||
|
||||
|
||||
11
haystack/reader/base.py
Normal file
11
haystack/reader/base.py
Normal file
@ -0,0 +1,11 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional
|
||||
|
||||
from haystack.database.base import Document
|
||||
|
||||
|
||||
class BaseReader(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def predict(self, question: str, documents: List[Document], top_k: Optional[int] = None):
|
||||
pass
|
||||
@ -1,5 +1,6 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from farm.data_handler.data_silo import DataSilo
|
||||
@ -14,11 +15,11 @@ from scipy.special import expit
|
||||
|
||||
from haystack.database.base import Document
|
||||
from haystack.database.elasticsearch import ElasticsearchDocumentStore
|
||||
|
||||
from haystack.reader.base import BaseReader
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FARMReader:
|
||||
class FARMReader(BaseReader):
|
||||
"""
|
||||
Transformer based model for extractive Question Answering using the FARM framework (https://github.com/deepset-ai/FARM).
|
||||
While the underlying model can vary (BERT, Roberta, DistilBERT ...) the interface remains the same.
|
||||
@ -30,16 +31,16 @@ class FARMReader:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name_or_path,
|
||||
context_window_size=150,
|
||||
batch_size=50,
|
||||
use_gpu=True,
|
||||
no_ans_boost=None,
|
||||
top_k_per_candidate=3,
|
||||
top_k_per_sample=1,
|
||||
num_processes=None,
|
||||
max_seq_len=256,
|
||||
doc_stride=128
|
||||
model_name_or_path: Union[str, Path],
|
||||
context_window_size: int = 150,
|
||||
batch_size: int = 50,
|
||||
use_gpu: bool = True,
|
||||
no_ans_boost: Optional[int] = None,
|
||||
top_k_per_candidate: int = 3,
|
||||
top_k_per_sample: int = 1,
|
||||
num_processes: Optional[int] = None,
|
||||
max_seq_len: int = 256,
|
||||
doc_stride: int = 128,
|
||||
):
|
||||
|
||||
"""
|
||||
@ -97,9 +98,22 @@ class FARMReader:
|
||||
self.max_seq_len = max_seq_len
|
||||
self.use_gpu = use_gpu
|
||||
|
||||
def train(self, data_dir, train_filename, dev_filename=None, test_file_name=None,
|
||||
use_gpu=None, batch_size=10, n_epochs=2, learning_rate=1e-5,
|
||||
max_seq_len=None, warmup_proportion=0.2, dev_split=0.1, evaluate_every=300, save_dir=None):
|
||||
def train(
|
||||
self,
|
||||
data_dir: str,
|
||||
train_filename: str,
|
||||
dev_filename: Optional[str] = None,
|
||||
test_file_name: Optional[str] = None,
|
||||
use_gpu: Optional[bool] = None,
|
||||
batch_size: int = 10,
|
||||
n_epochs: int = 2,
|
||||
learning_rate: float = 1e-5,
|
||||
max_seq_len: Optional[int] = None,
|
||||
warmup_proportion: float = 0.2,
|
||||
dev_split: Optional[float] = 0.1,
|
||||
evaluate_every: int = 300,
|
||||
save_dir: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Fine-tune a model on a QA dataset. Options:
|
||||
- Take a plain language model (e.g. `bert-base-cased`) and train it for QA (e.g. on SQuAD data)
|
||||
@ -141,7 +155,6 @@ class FARMReader:
|
||||
|
||||
if not save_dir:
|
||||
save_dir = f"../../saved_models/{self.inferencer.model.language_model.name}"
|
||||
save_dir = Path(save_dir)
|
||||
|
||||
# 1. Create a DataProcessor that handles all the conversion from raw text into a pytorch Dataset
|
||||
label_list = ["start_token", "end_token"]
|
||||
@ -184,14 +197,14 @@ class FARMReader:
|
||||
)
|
||||
# 5. Let it grow!
|
||||
self.inferencer.model = trainer.train()
|
||||
self.save(save_dir)
|
||||
self.save(Path(save_dir))
|
||||
|
||||
def save(self, directory):
|
||||
def save(self, directory: Path):
|
||||
logger.info(f"Saving reader model to {directory}")
|
||||
self.inferencer.model.save(directory)
|
||||
self.inferencer.processor.save(directory)
|
||||
|
||||
def predict(self, question: str, documents: [Document], top_k: int = None):
|
||||
def predict(self, question: str, documents: List[Document], top_k: Optional[int] = None):
|
||||
"""
|
||||
Use loaded QA model to find answers for a question in the supplied list of Document.
|
||||
|
||||
@ -245,7 +258,8 @@ class FARMReader:
|
||||
if a["answer"]:
|
||||
cur = {"answer": a["answer"],
|
||||
"score": a["score"],
|
||||
"probability": float(expit(np.asarray([a["score"]]) / 8)), #just a pseudo prob for now
|
||||
# just a pseudo prob for now
|
||||
"probability": float(expit(np.asarray([a["score"]]) / 8)), # type: ignore
|
||||
"context": a["context"],
|
||||
"offset_start": a["offset_answer_start"] - a["offset_context_start"],
|
||||
"offset_end": a["offset_answer_end"] - a["offset_context_start"],
|
||||
@ -316,8 +330,14 @@ class FARMReader:
|
||||
}
|
||||
return results
|
||||
|
||||
def eval(self, document_store: ElasticsearchDocumentStore, device: str, label_index: str = "feedback",
|
||||
doc_index: str = "eval_document", label_origin: str = "gold_label"):
|
||||
def eval(
|
||||
self,
|
||||
document_store: ElasticsearchDocumentStore,
|
||||
device: str,
|
||||
label_index: str = "feedback",
|
||||
doc_index: str = "eval_document",
|
||||
label_origin: str = "gold_label",
|
||||
):
|
||||
"""
|
||||
Performs evaluation on evaluation documents in Elasticsearch DocumentStore.
|
||||
|
||||
@ -386,7 +406,7 @@ class FARMReader:
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def _calc_no_answer(no_ans_gaps,best_score_answer):
|
||||
def _calc_no_answer(no_ans_gaps: List[float], best_score_answer: float):
|
||||
# "no answer" scores and positive answers scores are difficult to compare, because
|
||||
# + a positive answer score is related to one specific document
|
||||
# - a "no answer" score is related to all input documents
|
||||
@ -396,7 +416,8 @@ class FARMReader:
|
||||
# No_ans_gap coming from FARM mean how much no_ans_boost should change to switch predictions
|
||||
no_ans_gaps = np.array(no_ans_gaps)
|
||||
max_no_ans_gap = np.max(no_ans_gaps)
|
||||
if (np.sum(no_ans_gaps < 0) == len(no_ans_gaps)): # all passages "no answer" as top score
|
||||
# all passages "no answer" as top score
|
||||
if (np.sum(no_ans_gaps < 0) == len(no_ans_gaps)): # type: ignore
|
||||
no_ans_score = best_score_answer - max_no_ans_gap # max_no_ans_gap is negative, so it increases best pos score
|
||||
else: # case: at least one passage predicts an answer (positive no_ans_gap)
|
||||
no_ans_score = best_score_answer - max_no_ans_gap
|
||||
@ -410,7 +431,7 @@ class FARMReader:
|
||||
"document_id": None}
|
||||
return no_ans_prediction, max_no_ans_gap
|
||||
|
||||
def predict_on_texts(self, question: str, texts: [str], top_k=None):
|
||||
def predict_on_texts(self, question: str, texts: List[str], top_k: Optional[int] = None):
|
||||
documents = []
|
||||
for i, text in enumerate(texts):
|
||||
documents.append(
|
||||
|
||||
@ -1,9 +1,12 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from transformers import pipeline
|
||||
|
||||
from haystack.database.base import Document
|
||||
from haystack.reader.base import BaseReader
|
||||
|
||||
|
||||
class TransformersReader:
|
||||
class TransformersReader(BaseReader):
|
||||
"""
|
||||
Transformer based model for extractive Question Answering using the huggingface's transformers framework
|
||||
(https://github.com/huggingface/transformers).
|
||||
@ -15,13 +18,11 @@ class TransformersReader:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model="distilbert-base-uncased-distilled-squad",
|
||||
tokenizer="distilbert-base-uncased",
|
||||
context_window_size=30,
|
||||
#no_answer_shift=-100,
|
||||
#batch_size=16,
|
||||
use_gpu=0,
|
||||
n_best_per_passage=2
|
||||
model: str = "distilbert-base-uncased-distilled-squad",
|
||||
tokenizer: str = "distilbert-base-uncased",
|
||||
context_window_size: int = 30,
|
||||
use_gpu: int = 0,
|
||||
n_best_per_passage: int = 2,
|
||||
):
|
||||
"""
|
||||
Load a QA model from Transformers.
|
||||
@ -44,7 +45,7 @@ class TransformersReader:
|
||||
self.n_best_per_passage = n_best_per_passage
|
||||
#TODO param to modify bias for no_answer
|
||||
|
||||
def predict(self, question: str, documents: [Document], top_k: int = None):
|
||||
def predict(self, question: str, documents: List[Document], top_k: Optional[int] = None):
|
||||
"""
|
||||
Use loaded QA model to find answers for a question in the supplied list of Document.
|
||||
|
||||
|
||||
@ -1,7 +1,10 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
|
||||
from haystack.database.base import Document
|
||||
|
||||
|
||||
class BaseRetriever(ABC):
|
||||
@abstractmethod
|
||||
def retrieve(self, query, candidate_doc_ids=None, top_k=1):
|
||||
def retrieve(self, query: str, filters: dict = None, top_k: int = 10, index: str = None) -> List[Document]:
|
||||
pass
|
||||
|
||||
@ -1,15 +1,17 @@
|
||||
import logging
|
||||
from typing import Type
|
||||
from typing import List, Union
|
||||
|
||||
from farm.infer import Inferencer
|
||||
|
||||
from haystack.database.base import Document, BaseDocumentStore
|
||||
from haystack.database.base import Document
|
||||
from haystack.database.elasticsearch import ElasticsearchDocumentStore
|
||||
from haystack.retriever.base import BaseRetriever
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ElasticsearchRetriever(BaseRetriever):
|
||||
def __init__(self, document_store: Type[BaseDocumentStore], custom_query: str = None):
|
||||
def __init__(self, document_store: ElasticsearchDocumentStore, custom_query: str = None):
|
||||
"""
|
||||
:param document_store: an instance of a DocumentStore to retrieve documents from.
|
||||
:param custom_query: query string as per Elasticsearch DSL with a mandatory question placeholder($question).
|
||||
@ -38,10 +40,10 @@ class ElasticsearchRetriever(BaseRetriever):
|
||||
self.retrieve(query="Why did the revenue increase?",
|
||||
filters={"years": ["2019"], "quarters": ["Q1", "Q2"]})
|
||||
"""
|
||||
self.document_store = document_store
|
||||
self.document_store = document_store # type: ignore
|
||||
self.custom_query = custom_query
|
||||
|
||||
def retrieve(self, query: str, filters: dict = None, top_k: int = 10, index: str = None) -> [Document]:
|
||||
def retrieve(self, query: str, filters: dict = None, top_k: int = 10, index: str = None) -> List[Document]:
|
||||
if index is None:
|
||||
index = self.document_store.index
|
||||
|
||||
@ -50,8 +52,13 @@ class ElasticsearchRetriever(BaseRetriever):
|
||||
|
||||
return documents
|
||||
|
||||
def eval(self, label_index: str = "feedback", doc_index: str = "eval_document", label_origin: str = "gold_label",
|
||||
top_k: int = 10) -> dict:
|
||||
def eval(
|
||||
self,
|
||||
label_index: str = "feedback",
|
||||
doc_index: str = "eval_document",
|
||||
label_origin: str = "gold_label",
|
||||
top_k: int = 10,
|
||||
) -> dict:
|
||||
"""
|
||||
Performs evaluation on the Retriever.
|
||||
Retriever is evaluated based on whether it finds the correct document given the question string and at which
|
||||
@ -81,7 +88,7 @@ class ElasticsearchRetriever(BaseRetriever):
|
||||
for doc_idx, doc in enumerate(retrieved_docs):
|
||||
if doc.meta["doc_id"] == question["_source"]["doc_id"]:
|
||||
correct_retrievals += 1
|
||||
summed_avg_precision += 1 / (doc_idx + 1)
|
||||
summed_avg_precision += 1 / (doc_idx + 1) # type: ignore
|
||||
break
|
||||
|
||||
number_of_questions = q_idx + 1
|
||||
@ -97,7 +104,7 @@ class ElasticsearchRetriever(BaseRetriever):
|
||||
class EmbeddingRetriever(BaseRetriever):
|
||||
def __init__(
|
||||
self,
|
||||
document_store: Type[BaseDocumentStore],
|
||||
document_store: ElasticsearchDocumentStore,
|
||||
embedding_model: str,
|
||||
gpu: bool = True,
|
||||
model_format: str = "farm",
|
||||
@ -137,13 +144,13 @@ class EmbeddingRetriever(BaseRetriever):
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def retrieve(self, query: str, candidate_doc_ids: [str] = None, top_k: int = 10) -> [Document]:
|
||||
def retrieve(self, query: str, candidate_doc_ids: List[str] = None, top_k: int = 10) -> List[Document]: # type: ignore
|
||||
query_emb = self.create_embedding(texts=[query])
|
||||
documents = self.document_store.query_by_embedding(query_emb[0], top_k, candidate_doc_ids)
|
||||
|
||||
return documents
|
||||
|
||||
def create_embedding(self, texts: [str]):
|
||||
def create_embedding(self, texts: Union[List[str], str]) -> List[List[float]]:
|
||||
"""
|
||||
Create embeddings for each text in a list of texts using the retrievers model (`self.embedding_model`)
|
||||
:param texts: texts to embed
|
||||
@ -152,14 +159,15 @@ class EmbeddingRetriever(BaseRetriever):
|
||||
|
||||
# for backward compatibility: cast pure str input
|
||||
if type(texts) == str:
|
||||
texts = [texts]
|
||||
texts = [texts] # type: ignore
|
||||
assert type(texts) == list, "Expecting a list of texts, i.e. create_embeddings(texts=['text1',...])"
|
||||
|
||||
if self.model_format == "farm":
|
||||
res = self.embedding_model.inference_from_dicts(dicts=[{"text": t} for t in texts])
|
||||
res = self.embedding_model.inference_from_dicts(dicts=[{"text": t} for t in texts]) # type: ignore
|
||||
emb = [list(r["vec"]) for r in res] #cast from numpy
|
||||
elif self.model_format == "sentence_transformers":
|
||||
# text is single string, sentence-transformers needs a list of strings
|
||||
res = self.embedding_model.encode(texts) # get back list of numpy embedding vectors
|
||||
# get back list of numpy embedding vectors
|
||||
res = self.embedding_model.encode(texts) # type: ignore
|
||||
emb = [list(r.astype('float64')) for r in res] #cast from numpy
|
||||
return emb
|
||||
|
||||
@ -1,11 +1,12 @@
|
||||
import logging
|
||||
from collections import OrderedDict, namedtuple
|
||||
from typing import List
|
||||
|
||||
import pandas as pd
|
||||
from haystack.database.base import Document
|
||||
from haystack.retriever.base import BaseRetriever
|
||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||
|
||||
from haystack.database.base import BaseDocumentStore, Document
|
||||
from haystack.retriever.base import BaseRetriever
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -23,7 +24,7 @@ class TfidfRetriever(BaseRetriever):
|
||||
It uses sklearn's TfidfVectorizer to compute a tf-idf matrix.
|
||||
"""
|
||||
|
||||
def __init__(self, document_store):
|
||||
def __init__(self, document_store: BaseDocumentStore):
|
||||
self.vectorizer = TfidfVectorizer(
|
||||
lowercase=True,
|
||||
stop_words=None,
|
||||
@ -36,7 +37,7 @@ class TfidfRetriever(BaseRetriever):
|
||||
self.df = None
|
||||
self.fit()
|
||||
|
||||
def _get_all_paragraphs(self):
|
||||
def _get_all_paragraphs(self) -> List[Paragraph]:
|
||||
"""
|
||||
Split the list of documents in paragraphs
|
||||
"""
|
||||
@ -55,7 +56,7 @@ class TfidfRetriever(BaseRetriever):
|
||||
logger.info(f"Found {len(paragraphs)} candidate paragraphs from {len(documents)} docs in DB")
|
||||
return paragraphs
|
||||
|
||||
def _calc_scores(self, query):
|
||||
def _calc_scores(self, query: str) -> dict:
|
||||
question_vector = self.vectorizer.transform([query])
|
||||
|
||||
scores = self.tfidf_matrix.dot(question_vector.T).toarray()
|
||||
@ -65,21 +66,22 @@ class TfidfRetriever(BaseRetriever):
|
||||
)
|
||||
return indices_and_scores
|
||||
|
||||
def retrieve(self, query, filters=None, top_k=10, verbose=True):
|
||||
def retrieve(self, query: str, filters: dict = None, top_k: int = 10, index: str = None) -> List[Document]:
|
||||
if filters:
|
||||
raise NotImplementedError("Filters are not implemented in TfidfRetriever.")
|
||||
if index:
|
||||
raise NotImplementedError("Switching index is not supported in TfidfRetriever.")
|
||||
|
||||
# get scores
|
||||
indices_and_scores = self._calc_scores(query)
|
||||
|
||||
# rank paragraphs
|
||||
df_sliced = self.df.loc[indices_and_scores.keys()]
|
||||
df_sliced = self.df.loc[indices_and_scores.keys()] # type: ignore
|
||||
df_sliced = df_sliced[:top_k]
|
||||
|
||||
if verbose:
|
||||
logger.info(
|
||||
f"Identified {df_sliced.shape[0]} candidates via retriever:\n {df_sliced.to_string(col_space=10, index=False)}"
|
||||
)
|
||||
logger.debug(
|
||||
f"Identified {df_sliced.shape[0]} candidates via retriever:\n {df_sliced.to_string(col_space=10, index=False)}"
|
||||
)
|
||||
|
||||
# get actual content for the top candidates
|
||||
paragraphs = list(df_sliced.text.values)
|
||||
|
||||
@ -2,13 +2,13 @@ import json
|
||||
from collections import defaultdict
|
||||
import logging
|
||||
import pprint
|
||||
|
||||
from typing import Dict, Any
|
||||
from haystack.database.sql import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def print_answers(results, details="all"):
|
||||
def print_answers(results: dict, details: str = "all"):
|
||||
answers = results["answers"]
|
||||
pp = pprint.PrettyPrinter(indent=4)
|
||||
if details != "all":
|
||||
@ -28,7 +28,7 @@ def print_answers(results, details="all"):
|
||||
pp.pprint(results)
|
||||
|
||||
|
||||
def convert_labels_to_squad(labels_file):
|
||||
def convert_labels_to_squad(labels_file: str):
|
||||
"""
|
||||
Convert the export from the labeling UI to SQuAD format for training.
|
||||
|
||||
@ -42,7 +42,7 @@ def convert_labels_to_squad(labels_file):
|
||||
for label in labels:
|
||||
labels_grouped_by_documents[label["document_id"]].append(label)
|
||||
|
||||
labels_in_squad_format = {"data": []}
|
||||
labels_in_squad_format = {"data": []} # type: Dict[str, Any]
|
||||
for document_id, labels in labels_grouped_by_documents.items():
|
||||
qas = []
|
||||
for label in labels:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user