Add type hints and mypy checks (#138)

This commit is contained in:
Tanay Soni 2020-06-10 17:22:37 +02:00 committed by GitHub
parent 180dc8cbd6
commit 98f1a3f9a7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 300 additions and 216 deletions

View File

@ -39,5 +39,9 @@ jobs:
pip install -e . pip install -e .
- name: Test with pytest - name: Test with pytest
run: cd test && pytest
- name: Test with mypy
run: | run: |
cd test && pytest pip install mypy
mypy haystack --ignore-missing-imports

View File

@ -38,7 +38,7 @@ document_store = ElasticsearchDocumentStore(
search_fields=SEARCH_FIELD_NAME, search_fields=SEARCH_FIELD_NAME,
embedding_dim=EMBEDDING_DIM, embedding_dim=EMBEDDING_DIM,
embedding_field=EMBEDDING_FIELD_NAME, embedding_field=EMBEDDING_FIELD_NAME,
excluded_meta_data=EXCLUDE_META_DATA_FIELDS, excluded_meta_data=EXCLUDE_META_DATA_FIELDS, # type: ignore
) )
@ -52,7 +52,7 @@ class Feedback(BaseModel):
@router.post("/doc-qa-feedback") @router.post("/doc-qa-feedback")
def feedback(feedback: Feedback): def doc_qa_feedback(feedback: Feedback):
if feedback.answer and feedback.offset_start_in_doc: if feedback.answer and feedback.offset_start_in_doc:
elasticsearch_client.index(index=DB_INDEX_FEEDBACK, body=feedback.dict()) elasticsearch_client.index(index=DB_INDEX_FEEDBACK, body=feedback.dict())
else: else:
@ -63,7 +63,7 @@ def feedback(feedback: Feedback):
@router.post("/faq-qa-feedback") @router.post("/faq-qa-feedback")
def feedback(feedback: Feedback): def faq_qa_feedback(feedback: Feedback):
elasticsearch_client.index(index=DB_INDEX_FEEDBACK, body=feedback.dict()) elasticsearch_client.index(index=DB_INDEX_FEEDBACK, body=feedback.dict())

View File

@ -15,6 +15,7 @@ from haystack.api.config import DB_HOST, DB_PORT, DB_USER, DB_PW, DB_INDEX, ES_C
from haystack.api.controller.utils import RequestLimiter from haystack.api.controller.utils import RequestLimiter
from haystack.database.elasticsearch import ElasticsearchDocumentStore from haystack.database.elasticsearch import ElasticsearchDocumentStore
from haystack.reader.farm import FARMReader from haystack.reader.farm import FARMReader
from haystack.retriever.base import BaseRetriever
from haystack.retriever.elasticsearch import ElasticsearchRetriever, EmbeddingRetriever from haystack.retriever.elasticsearch import ElasticsearchRetriever, EmbeddingRetriever
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -34,12 +35,12 @@ document_store = ElasticsearchDocumentStore(
search_fields=SEARCH_FIELD_NAME, search_fields=SEARCH_FIELD_NAME,
embedding_dim=EMBEDDING_DIM, embedding_dim=EMBEDDING_DIM,
embedding_field=EMBEDDING_FIELD_NAME, embedding_field=EMBEDDING_FIELD_NAME,
excluded_meta_data=EXCLUDE_META_DATA_FIELDS, excluded_meta_data=EXCLUDE_META_DATA_FIELDS, # type: ignore
) )
if EMBEDDING_MODEL_PATH: if EMBEDDING_MODEL_PATH:
retriever = EmbeddingRetriever(document_store=document_store, embedding_model=EMBEDDING_MODEL_PATH, gpu=USE_GPU) retriever = EmbeddingRetriever(document_store=document_store, embedding_model=EMBEDDING_MODEL_PATH, gpu=USE_GPU) # type: BaseRetriever
else: else:
retriever = ElasticsearchRetriever(document_store=document_store) retriever = ElasticsearchRetriever(document_store=document_store)
@ -54,7 +55,7 @@ if READER_MODEL_PATH: # for extractive doc-qa
num_processes=MAX_PROCESSES, num_processes=MAX_PROCESSES,
max_seq_len=MAX_SEQ_LEN, max_seq_len=MAX_SEQ_LEN,
doc_stride=DOC_STRIDE, doc_stride=DOC_STRIDE,
) ) # type: Optional[FARMReader]
else: else:
reader = None # don't need one for pure FAQ matching reader = None # don't need one for pure FAQ matching
@ -66,7 +67,7 @@ FINDERS = {1: Finder(reader=reader, retriever=retriever)}
############################################# #############################################
class Question(BaseModel): class Question(BaseModel):
questions: List[str] questions: List[str]
filters: Dict[str, Optional[str]] = None filters: Optional[Dict[str, Optional[str]]] = None
top_k_reader: int = DEFAULT_TOP_K_READER top_k_reader: int = DEFAULT_TOP_K_READER
top_k_retriever: int = DEFAULT_TOP_K_RETRIEVER top_k_retriever: int = DEFAULT_TOP_K_RETRIEVER
@ -74,8 +75,8 @@ class Question(BaseModel):
class Answer(BaseModel): class Answer(BaseModel):
answer: Optional[str] answer: Optional[str]
question: Optional[str] question: Optional[str]
score: float = None score: Optional[float] = None
probability: float = None probability: Optional[float] = None
context: Optional[str] context: Optional[str]
offset_start: int offset_start: int
offset_end: int offset_end: int
@ -112,14 +113,16 @@ def doc_qa(model_id: int, request: Question):
for question in request.questions: for question in request.questions:
if request.filters: if request.filters:
# put filter values into a list and remove filters with null value # put filter values into a list and remove filters with null value
request.filters = {key: [value] for key, value in request.filters.items() if value is not None} filters = {key: [value] for key, value in request.filters.items() if value is not None}
logger.info(f" [{datetime.now()}] Request: {request}") logger.info(f" [{datetime.now()}] Request: {request}")
else:
filters = {}
result = finder.get_answers( result = finder.get_answers(
question=question, question=question,
top_k_retriever=request.top_k_retriever, top_k_retriever=request.top_k_retriever,
top_k_reader=request.top_k_reader, top_k_reader=request.top_k_reader,
filters=request.filters, filters=filters,
) )
results.append(result) results.append(result)
@ -141,11 +144,13 @@ def faq_qa(model_id: int, request: Question):
for question in request.questions: for question in request.questions:
if request.filters: if request.filters:
# put filter values into a list and remove filters with null value # put filter values into a list and remove filters with null value
request.filters = {key: [value] for key, value in request.filters.items() if value is not None} filters = {key: [value] for key, value in request.filters.items() if value is not None}
logger.info(f" [{datetime.now()}] Request: {request}") logger.info(f" [{datetime.now()}] Request: {request}")
else:
filters = {}
result = finder.get_answers_via_similar_questions( result = finder.get_answers_via_similar_questions(
question=question, top_k_retriever=request.top_k_retriever, filters=request.filters, question=question, top_k_retriever=request.top_k_retriever, filters=filters,
) )
results.append(result) results.append(result)

View File

@ -1,35 +1,9 @@
from abc import abstractmethod from abc import abstractmethod
from typing import Any, Optional, Dict from typing import Any, Optional, Dict, List
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
class BaseDocumentStore:
"""
Base class for implementing Document Stores.
"""
@abstractmethod
def write_documents(self, documents):
pass
@abstractmethod
def get_document_by_id(self, id):
pass
@abstractmethod
def get_document_ids_by_tags(self, tag):
pass
@abstractmethod
def get_document_count(self):
pass
@abstractmethod
def query_by_embedding(self, query_emb, top_k=10, candidate_doc_ids=None):
pass
class Document(BaseModel): class Document(BaseModel):
id: str = Field(..., description="_id field from Elasticsearch") id: str = Field(..., description="_id field from Elasticsearch")
text: str = Field(..., description="Text of the document") text: str = Field(..., description="Text of the document")
@ -41,5 +15,35 @@ class Document(BaseModel):
# name: Optional[str] = Field(None, description="Title of the document") # name: Optional[str] = Field(None, description="Title of the document")
question: Optional[str] = Field(None, description="Question text for FAQs.") question: Optional[str] = Field(None, description="Question text for FAQs.")
query_score: Optional[float] = Field(None, description="Elasticsearch query score for a retrieved document") query_score: Optional[float] = Field(None, description="Elasticsearch query score for a retrieved document")
meta: Optional[Dict[str, Any]] = Field(None, description="") meta: Dict[str, Any] = Field({}, description="")
tags: Optional[Dict[str, Any]] = Field(None, description="Tags that allow filtering of the data") tags: Optional[Dict[str, Any]] = Field(None, description="Tags that allow filtering of the data")
class BaseDocumentStore:
"""
Base class for implementing Document Stores.
"""
@abstractmethod
def write_documents(self, documents: List[dict]):
pass
@abstractmethod
def get_all_documents(self) -> List[Document]:
pass
@abstractmethod
def get_document_by_id(self, id: str) -> Optional[Document]:
pass
@abstractmethod
def get_document_ids_by_tags(self, tag) -> List[str]:
pass
@abstractmethod
def get_document_count(self) -> int:
pass
@abstractmethod
def query_by_embedding(self, query_emb: List[float], top_k: int = 10, candidate_doc_ids: Optional[List[str]] = None) -> List[Document]:
pass

View File

@ -1,7 +1,7 @@
import json import json
import logging import logging
from string import Template from string import Template
from typing import Union from typing import List, Optional, Union, Dict, Any
from elasticsearch import Elasticsearch from elasticsearch import Elasticsearch
from elasticsearch.helpers import bulk, scan from elasticsearch.helpers import bulk, scan
@ -18,14 +18,14 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
username: str = "", username: str = "",
password: str = "", password: str = "",
index: str = "document", index: str = "document",
search_fields: Union[str,list] = "text", search_fields: Union[str, list] = "text",
text_field: str = "text", text_field: str = "text",
name_field: str = "name", name_field: str = "name",
external_source_id_field: str = "external_source_id", external_source_id_field: str = "external_source_id",
embedding_field: str = None, embedding_field: Optional[str] = None,
embedding_dim: str = None, embedding_dim: Optional[str] = None,
custom_mapping: dict = None, custom_mapping: Optional[dict] = None,
excluded_meta_data: list = None, excluded_meta_data: Optional[list] = None,
scheme: str = "http", scheme: str = "http",
ca_certs: bool = False, ca_certs: bool = False,
verify_certs: bool = True, verify_certs: bool = True,
@ -93,14 +93,14 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
self.embedding_field = embedding_field self.embedding_field = embedding_field
self.excluded_meta_data = excluded_meta_data self.excluded_meta_data = excluded_meta_data
def get_document_by_id(self, id: str) -> Document: def get_document_by_id(self, id: str) -> Optional[Document]:
query = {"query": {"ids": {"values": [id]}}} query = {"query": {"ids": {"values": [id]}}}
result = self.client.search(index=self.index, body=query)["hits"]["hits"] result = self.client.search(index=self.index, body=query)["hits"]["hits"]
document = self._convert_es_hit_to_document(result[0]) if result else None document = self._convert_es_hit_to_document(result[0]) if result else None
return document return document
def get_document_ids_by_tags(self, tags: dict) -> [str]: def get_document_ids_by_tags(self, tags: dict) -> List[str]:
term_queries = [{"terms": {key: value}} for key, value in tags.items()] term_queries = [{"terms": {key: value}} for key, value in tags.items()]
query = {"query": {"bool": {"must": term_queries}}} query = {"query": {"bool": {"must": term_queries}}}
logger.debug(f"Tag filter query: {query}") logger.debug(f"Tag filter query: {query}")
@ -110,19 +110,19 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
doc_ids.append(hit["_id"]) doc_ids.append(hit["_id"])
return doc_ids return doc_ids
def write_documents(self, documents): def write_documents(self, documents: List[dict]):
for doc in documents: for doc in documents:
doc["_op_type"] = "create" doc["_op_type"] = "create"
doc["_index"] = self.index doc["_index"] = self.index
bulk(self.client, documents, request_timeout=300) bulk(self.client, documents, request_timeout=300)
def get_document_count(self): def get_document_count(self) -> int:
result = self.client.count() result = self.client.count()
count = result["count"] count = result["count"]
return count return count
def get_all_documents(self): def get_all_documents(self) -> List[Document]:
result = scan(self.client, query={"query": {"match_all": {}}}, index=self.index) result = scan(self.client, query={"query": {"match_all": {}}}, index=self.index)
documents = [self._convert_es_hit_to_document(hit) for hit in result] documents = [self._convert_es_hit_to_document(hit) for hit in result]
@ -131,11 +131,11 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
def query( def query(
self, self,
query: str, query: str,
filters: dict = None, filters: Optional[dict] = None,
top_k: int = 10, top_k: int = 10,
custom_query: str = None, custom_query: Optional[str] = None,
index: str = None, index: Optional[str] = None,
) -> [Document]: ) -> List[Document]:
if index is None: if index is None:
index = self.index index = self.index
@ -180,7 +180,7 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
documents = [self._convert_es_hit_to_document(hit) for hit in result] documents = [self._convert_es_hit_to_document(hit) for hit in result]
return documents return documents
def query_by_embedding(self, query_emb, top_k=10, candidate_doc_ids=None) -> [Document]: def query_by_embedding(self, query_emb: List[float], top_k: int = 10, candidate_doc_ids: Optional[List[str]] = None) -> List[Document]:
if not self.embedding_field: if not self.embedding_field:
raise RuntimeError("Please specify arg `embedding_field` in ElasticsearchDocumentStore()") raise RuntimeError("Please specify arg `embedding_field` in ElasticsearchDocumentStore()")
else: else:
@ -198,7 +198,7 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
} }
} }
} }
} } # type: Dict[str,Any]
if candidate_doc_ids: if candidate_doc_ids:
body["query"]["script_score"]["query"] = { body["query"]["script_score"]["query"] = {
@ -216,7 +216,7 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
documents = [self._convert_es_hit_to_document(hit, score_adjustment=-1) for hit in result] documents = [self._convert_es_hit_to_document(hit, score_adjustment=-1) for hit in result]
return documents return documents
def _convert_es_hit_to_document(self, hit, score_adjustment=0) -> Document: def _convert_es_hit_to_document(self, hit: dict, score_adjustment: int = 0) -> Document:
# We put all additional data of the doc into meta_data and return it in the API # We put all additional data of the doc into meta_data and return it in the API
meta_data = {k:v for k,v in hit["_source"].items() if k not in (self.text_field, self.external_source_id_field)} meta_data = {k:v for k,v in hit["_source"].items() if k not in (self.text_field, self.external_source_id_field)}
meta_data["name"] = meta_data.pop(self.name_field, None) meta_data["name"] = meta_data.pop(self.name_field, None)
@ -286,7 +286,7 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
bulk(self.client, eval_docs_to_index) bulk(self.client, eval_docs_to_index)
bulk(self.client, questions_to_index) bulk(self.client, questions_to_index)
def get_all_documents_in_index(self, index, filters=None): def get_all_documents_in_index(self, index: str, filters: Optional[dict] = None):
body = { body = {
"query": { "query": {
"bool": { "bool": {
@ -295,7 +295,7 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
} }
} }
} }
} } # type: Dict[str, Any]
if filters: if filters:
body["query"]["bool"]["filter"] = {"term": filters} body["query"]["bool"]["filter"] = {"term": filters}

View File

@ -1,3 +1,5 @@
from typing import Any, Dict, List, Optional, Union, Tuple
from haystack.database.base import BaseDocumentStore, Document from haystack.database.base import BaseDocumentStore, Document
@ -10,7 +12,7 @@ class InMemoryDocumentStore(BaseDocumentStore):
self.docs = {} self.docs = {}
self.doc_tags = {} self.doc_tags = {}
def write_documents(self, documents): def write_documents(self, documents: List[dict]):
import hashlib import hashlib
if documents is None: if documents is None:
@ -33,7 +35,7 @@ class InMemoryDocumentStore(BaseDocumentStore):
self._map_tags_to_ids(hash, tags) self._map_tags_to_ids(hash, tags)
def _map_tags_to_ids(self, hash, tags): def _map_tags_to_ids(self, hash: str, tags: List[str]):
if isinstance(tags, list): if isinstance(tags, list):
for tag in tags: for tag in tags:
if isinstance(tag, dict): if isinstance(tag, dict):
@ -48,10 +50,11 @@ class InMemoryDocumentStore(BaseDocumentStore):
else: else:
self.doc_tags[comp_key] = [hash] self.doc_tags[comp_key] = [hash]
def get_document_by_id(self, id): def get_document_by_id(self, id: str) -> Document:
return self.docs[id] document = self._convert_memory_hit_to_document(self.docs[id], doc_id=id)
return document
def _convert_memory_hit_to_document(self, hit, doc_id=None) -> Document: def _convert_memory_hit_to_document(self, hit: Tuple[Any, Any], doc_id: Optional[str] = None) -> Document:
document = Document( document = Document(
id=doc_id, id=doc_id,
text=hit[0].get('text', None), text=hit[0].get('text', None),
@ -60,7 +63,7 @@ class InMemoryDocumentStore(BaseDocumentStore):
) )
return document return document
def query_by_embedding(self, query_emb, top_k=10, candidate_doc_ids=None) -> [Document]: def query_by_embedding(self, query_emb: List[float], top_k: int = 10, candidate_doc_ids: Optional[List[str]] = None) -> List[Document]:
from haystack.api import config from haystack.api import config
from numpy import dot from numpy import dot
from numpy.linalg import norm from numpy.linalg import norm
@ -78,7 +81,7 @@ class InMemoryDocumentStore(BaseDocumentStore):
return sorted(candidate_docs, key=lambda x: x.query_score, reverse=True)[0:top_k] return sorted(candidate_docs, key=lambda x: x.query_score, reverse=True)[0:top_k]
def get_document_ids_by_tags(self, tags): def get_document_ids_by_tags(self, tags: Union[List[Dict[str, Union[str, List[str]]]], Dict[str, Union[str, List[str]]]]) -> List[str]:
""" """
The format for the dict is {"tag-1": "value-1", "tag-2": "value-2" ...} The format for the dict is {"tag-1": "value-1", "tag-2": "value-2" ...}
The format for the dict is {"tag-1": ["value-1","value-2"], "tag-2": ["value-3]" ...} The format for the dict is {"tag-1": ["value-1","value-2"], "tag-2": ["value-3]" ...}
@ -88,7 +91,7 @@ class InMemoryDocumentStore(BaseDocumentStore):
result = self._find_ids_by_tags(tags) result = self._find_ids_by_tags(tags)
return result return result
def _find_ids_by_tags(self, tags): def _find_ids_by_tags(self, tags: List[Dict[str, Union[str, List[str]]]]):
result = [] result = []
for tag in tags: for tag in tags:
tag_keys = tag.keys() tag_keys = tag.keys()
@ -102,8 +105,8 @@ class InMemoryDocumentStore(BaseDocumentStore):
result.append(self.docs.get(doc_id)) result.append(self.docs.get(doc_id))
return result return result
def get_document_count(self): def get_document_count(self) -> int:
return len(self.docs.items()) return len(self.docs.items())
def get_all_documents(self): def get_all_documents(self) -> List[Document]:
return [Document(id=item[0], text=item[1]['text'], name=item[1]['name'], meta=item[1].get('meta', {})) for item in self.docs.items()] return [Document(id=item[0], text=item[1]['text'], name=item[1]['name'], meta=item[1].get('meta', {})) for item in self.docs.items()]

View File

@ -1,11 +1,12 @@
import json from typing import Any, Dict, Union, List, Optional
from sqlalchemy import create_engine, Column, Integer, String, DateTime, func, ForeignKey, PickleType from sqlalchemy import create_engine, Column, Integer, String, DateTime, func, ForeignKey, PickleType
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import relationship, sessionmaker from sqlalchemy.orm import relationship, sessionmaker
from haystack.database.base import BaseDocumentStore, Document as DocumentSchema from haystack.database.base import BaseDocumentStore, Document as DocumentSchema
Base = declarative_base() Base = declarative_base() # type: Any
class ORMBase(Base): class ORMBase(Base):
@ -43,19 +44,19 @@ class DocumentTag(ORMBase):
class SQLDocumentStore(BaseDocumentStore): class SQLDocumentStore(BaseDocumentStore):
def __init__(self, url="sqlite://"): def __init__(self, url: str = "sqlite://"):
engine = create_engine(url) engine = create_engine(url)
ORMBase.metadata.create_all(engine) ORMBase.metadata.create_all(engine)
Session = sessionmaker(bind=engine) Session = sessionmaker(bind=engine)
self.session = Session() self.session = Session()
def get_document_by_id(self, id): def get_document_by_id(self, id: str) -> Optional[DocumentSchema]:
document_row = self.session.query(Document).get(id) document_row = self.session.query(Document).get(id)
document = self._convert_sql_row_to_document(document_row) document = self._convert_sql_row_to_document(document_row)
return document return document
def get_all_documents(self): def get_all_documents(self) -> List[DocumentSchema]:
document_rows = self.session.query(Document).all() document_rows = self.session.query(Document).all()
documents = [] documents = []
for row in document_rows: for row in document_rows:
@ -63,7 +64,7 @@ class SQLDocumentStore(BaseDocumentStore):
return documents return documents
def get_document_ids_by_tags(self, tags): def get_document_ids_by_tags(self, tags: Dict[str, Union[str, List]]) -> List[str]:
""" """
Get list of document ids that have tags from the given list of tags. Get list of document ids that have tags from the given list of tags.
@ -91,16 +92,16 @@ class SQLDocumentStore(BaseDocumentStore):
doc_ids = [row[0] for row in query_results] doc_ids = [row[0] for row in query_results]
return doc_ids return doc_ids
def write_documents(self, documents): def write_documents(self, documents: List[dict]):
for doc in documents: for doc in documents:
row = Document(name=doc["name"], text=doc["text"], meta_data=doc.get("meta", {})) row = Document(name=doc["name"], text=doc["text"], meta_data=doc.get("meta", {}))
self.session.add(row) self.session.add(row)
self.session.commit() self.session.commit()
def get_document_count(self): def get_document_count(self) -> int:
return self.session.query(Document).count() return self.session.query(Document).count()
def _convert_sql_row_to_document(self, row) -> Document: def _convert_sql_row_to_document(self, row) -> DocumentSchema:
document = DocumentSchema( document = DocumentSchema(
id=row.id, id=row.id,
text=row.text, text=row.text,

View File

@ -1,9 +1,13 @@
import logging import logging
import time
from statistics import mean
from typing import Optional, Dict, Any
import numpy as np import numpy as np
from scipy.special import expit from scipy.special import expit
import time
from statistics import mean from haystack.reader.base import BaseReader
from haystack.retriever.base import BaseRetriever
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -15,13 +19,13 @@ class Finder:
It provides an interface to predict top n answers for a given question. It provides an interface to predict top n answers for a given question.
""" """
def __init__(self, reader, retriever): def __init__(self, reader: Optional[BaseReader], retriever: Optional[BaseRetriever]):
self.retriever = retriever self.retriever = retriever
self.reader = reader self.reader = reader
if self.reader is None and self.retriever is None: if self.reader is None and self.retriever is None:
raise AttributeError("Finder: self.reader and self.retriever can not be both None") raise AttributeError("Finder: self.reader and self.retriever can not be both None")
def get_answers(self, question: str, top_k_reader: int = 1, top_k_retriever: int = 10, filters: dict = None): def get_answers(self, question: str, top_k_reader: int = 1, top_k_retriever: int = 10, filters: Optional[dict] = None):
""" """
Get top k answers for a given question. Get top k answers for a given question.
@ -41,8 +45,8 @@ class Finder:
if len(documents) == 0: if len(documents) == 0:
logger.info("Retriever did not return any documents. Skipping reader ...") logger.info("Retriever did not return any documents. Skipping reader ...")
results = {"question": question, "answers": []} empty_result = {"question": question, "answers": []}
return results return empty_result
# 2) Apply reader to get granular answer(s) # 2) Apply reader to get granular answer(s)
len_chars = sum([len(d.text) for d in documents]) len_chars = sum([len(d.text) for d in documents])
@ -50,7 +54,7 @@ class Finder:
results = self.reader.predict(question=question, results = self.reader.predict(question=question,
documents=documents, documents=documents,
top_k=top_k_reader) top_k=top_k_reader) # type: Dict[str, Any]
# Add corresponding document_name and more meta data, if an answer contains the document_id # Add corresponding document_name and more meta data, if an answer contains the document_id
for ans in results["answers"]: for ans in results["answers"]:
@ -61,7 +65,7 @@ class Finder:
return results return results
def get_answers_via_similar_questions(self, question: str, top_k_retriever: int = 10, filters: dict = None): def get_answers_via_similar_questions(self, question: str, top_k_retriever: int = 10, filters: Optional[dict] = None):
""" """
Get top k answers for a given question using only a retriever. Get top k answers for a given question using only a retriever.
@ -75,12 +79,12 @@ class Finder:
if self.retriever is None: if self.retriever is None:
raise AttributeError("Finder.get_answers_via_similar_questions requires self.retriever") raise AttributeError("Finder.get_answers_via_similar_questions requires self.retriever")
results = {"question": question, "answers": []} results = {"question": question, "answers": []} # type: Dict[str, Any]
# 1) Optional: reduce the search space via document tags # 1) Optional: reduce the search space via document tags
if filters: if filters:
logging.info(f"Apply filters: {filters}") logging.info(f"Apply filters: {filters}")
candidate_doc_ids = self.retriever.document_store.get_document_ids_by_tags(filters) candidate_doc_ids = self.retriever.document_store.get_document_ids_by_tags(filters) # type: ignore
logger.info(f"Got candidate IDs due to filters: {candidate_doc_ids}") logger.info(f"Got candidate IDs due to filters: {candidate_doc_ids}")
if len(candidate_doc_ids) == 0: if len(candidate_doc_ids) == 0:
@ -88,28 +92,35 @@ class Finder:
return results return results
else: else:
candidate_doc_ids = None candidate_doc_ids = None # type: ignore
# 2) Apply retriever to match similar questions via cosine similarity of embeddings # 2) Apply retriever to match similar questions via cosine similarity of embeddings
documents = self.retriever.retrieve(question, top_k=top_k_retriever, candidate_doc_ids=candidate_doc_ids) documents = self.retriever.retrieve(question, top_k=top_k_retriever, candidate_doc_ids=candidate_doc_ids) # type: ignore
# 3) Format response # 3) Format response
for doc in documents: for doc in documents:
#TODO proper calibratation of pseudo probabilities #TODO proper calibratation of pseudo probabilities
cur_answer = {"question": doc.meta["question"], "answer": doc.text, "context": doc.text, cur_answer = {"question": doc.meta["question"], "answer": doc.text, "context": doc.text, # type: ignore
"score": doc.query_score, "offset_start": 0, "offset_end": len(doc.text), "meta": doc.meta "score": doc.query_score, "offset_start": 0, "offset_end": len(doc.text), "meta": doc.meta
} }
if self.retriever.embedding_model: if self.retriever.embedding_model: # type: ignore
probability = (doc.query_score + 1) / 2 probability = (doc.query_score + 1) / 2 # type: ignore
else: else:
probability = float(expit(np.asarray(doc.query_score / 8))) probability = float(expit(np.asarray(doc.query_score / 8))) # type: ignore
cur_answer["probability"] = probability cur_answer["probability"] = probability
results["answers"].append(cur_answer) results["answers"].append(cur_answer)
return results return results
def eval(self, label_index: str = "feedback", doc_index: str = "eval_document", label_origin: str = "gold_label", def eval(
top_k_retriever: int = 10, top_k_reader: int = 10): self,
label_index: str = "feedback",
doc_index: str = "eval_document",
label_origin: str = "gold_label",
top_k_retriever: int = 10,
top_k_reader: int = 10,
):
""" """
Evaluation of the whole pipeline by first evaluating the Retriever and then evaluating the Reader on the result Evaluation of the whole pipeline by first evaluating the Retriever and then evaluating the Reader on the result
of the Retriever. of the Retriever.
@ -155,10 +166,14 @@ class Finder:
:param top_k_reader: How many answers to return per question :param top_k_reader: How many answers to return per question
:type top_k_reader: int :type top_k_reader: int
""" """
if not self.reader or not self.retriever:
raise Exception("Finder needs to have a reader and retriever for the evaluation.")
finder_start_time = time.time() finder_start_time = time.time()
# extract all questions for evaluation # extract all questions for evaluation
filter = {"origin": label_origin} filter = {"origin": label_origin}
questions = self.retriever.document_store.get_all_documents_in_index(index=label_index, filters=filter) questions = self.retriever.document_store.get_all_documents_in_index(index=label_index, filters=filter) # type: ignore
correct_retrievals = 0 correct_retrievals = 0
summed_avg_precision_retriever = 0 summed_avg_precision_retriever = 0
@ -192,7 +207,7 @@ class Finder:
# check if correct doc among retrieved docs # check if correct doc among retrieved docs
if doc.meta["doc_id"] == question["_source"]["doc_id"]: if doc.meta["doc_id"] == question["_source"]["doc_id"]:
correct_retrievals += 1 correct_retrievals += 1
summed_avg_precision_retriever += 1 / (doc_idx + 1) summed_avg_precision_retriever += 1 / (doc_idx + 1) # type: ignore
questions_with_docs.append({ questions_with_docs.append({
"question": question, "question": question,
"docs": retrieved_docs, "docs": retrieved_docs,
@ -202,8 +217,8 @@ class Finder:
number_of_questions = q_idx + 1 number_of_questions = q_idx + 1
number_of_no_answer = 0 number_of_no_answer = 0
previous_return_no_answers = self.reader.return_no_answers previous_return_no_answers = self.reader.return_no_answers # type: ignore
self.reader.return_no_answers = True self.reader.return_no_answers = True # type: ignore
# extract answers # extract answers
reader_start_time = time.time() reader_start_time = time.time()
for q_idx, question in enumerate(questions_with_docs): for q_idx, question in enumerate(questions_with_docs):
@ -260,10 +275,10 @@ class Finder:
current_f1 = (2 * precision * recall) / (precision + recall) current_f1 = (2 * precision * recall) / (precision + recall)
# top-1 answer # top-1 answer
if answer_idx == 0: if answer_idx == 0:
summed_f1_top1 += current_f1 summed_f1_top1 += current_f1 # type: ignore
summed_f1_top1_has_answer += current_f1 summed_f1_top1_has_answer += current_f1 # type: ignore
if current_f1 > best_f1: if current_f1 > best_f1:
best_f1 = current_f1 best_f1 = current_f1 # type: ignore
# top-k answers: use best f1-score # top-k answers: use best f1-score
summed_f1_topk += best_f1 summed_f1_topk += best_f1
summed_f1_topk_has_answer += best_f1 summed_f1_topk_has_answer += best_f1
@ -311,7 +326,7 @@ class Finder:
reader_top1_no_answer_accuracy = correct_no_answers_top1 / number_of_no_answer reader_top1_no_answer_accuracy = correct_no_answers_top1 / number_of_no_answer
reader_topk_no_answer_accuracy = correct_no_answers_topk / number_of_no_answer reader_topk_no_answer_accuracy = correct_no_answers_topk / number_of_no_answer
self.reader.return_no_answers = previous_return_no_answers self.reader.return_no_answers = previous_return_no_answers # type: ignore
logger.info((f"{correct_readings_topk} out of {number_of_questions} questions were correctly answered " logger.info((f"{correct_readings_topk} out of {number_of_questions} questions were correctly answered "
f"({(correct_readings_topk/number_of_questions):.2%}).")) f"({(correct_readings_topk/number_of_questions):.2%})."))

View File

@ -1,15 +1,15 @@
import re import re
def clean_wiki_text(text): def clean_wiki_text(text: str) -> str:
# get rid of multiple new lines # get rid of multiple new lines
while "\n\n" in text: while "\n\n" in text:
text = text.replace("\n\n", "\n") text = text.replace("\n\n", "\n")
# remove extremely short lines # remove extremely short lines
text = text.split("\n") lines = text.split("\n")
cleaned = [] cleaned = []
for l in text: for l in lines:
if len(l) > 30: if len(l) > 30:
cleaned.append(l) cleaned.append(l)
elif l[:2] == "==" and l[-2:] == "==": elif l[:2] == "==" and l[-2:] == "==":

View File

@ -1,5 +1,6 @@
from abc import abstractmethod from abc import abstractmethod
from pathlib import Path from pathlib import Path
from typing import List, Optional
class BaseConverter: class BaseConverter:
@ -9,11 +10,11 @@ class BaseConverter:
def __init__( def __init__(
self, self,
remove_numeric_tables: bool = None, remove_numeric_tables: Optional[bool] = None,
remove_header_footer: bool = None, remove_header_footer: Optional[bool] = None,
remove_whitespace: bool = None, remove_whitespace: Optional[bool] = None,
remove_empty_lines: bool = None, remove_empty_lines: Optional[bool] = None,
valid_languages: [str] = None, valid_languages: Optional[List[str]] = None,
): ):
""" """
:param remove_numeric_tables: This option uses heuristics to remove numeric rows from the tables. :param remove_numeric_tables: This option uses heuristics to remove numeric rows from the tables.
@ -40,5 +41,5 @@ class BaseConverter:
self.valid_languages = valid_languages self.valid_languages = valid_languages
@abstractmethod @abstractmethod
def extract_pages(self, file_path: Path) -> [str]: def extract_pages(self, file_path: Path) -> List[str]:
pass pass

View File

@ -4,6 +4,7 @@ import subprocess
from functools import partial, reduce from functools import partial, reduce
from itertools import chain from itertools import chain
from pathlib import Path from pathlib import Path
from typing import List, Optional, Tuple, Generator, Set
import fitz import fitz
import langdetect import langdetect
@ -16,11 +17,11 @@ logger = logging.getLogger(__name__)
class PDFToTextConverter(BaseConverter): class PDFToTextConverter(BaseConverter):
def __init__( def __init__(
self, self,
remove_numeric_tables: bool = False, remove_numeric_tables: Optional[bool] = False,
remove_whitespace: bool = None, remove_whitespace: Optional[bool] = None,
remove_empty_lines: bool = None, remove_empty_lines: Optional[bool] = None,
remove_header_footer: bool = None, remove_header_footer: Optional[bool] = None,
valid_languages: [str] = None, valid_languages: Optional[List[str]] = None,
): ):
""" """
:param remove_numeric_tables: This option uses heuristics to remove numeric rows from the tables. :param remove_numeric_tables: This option uses heuristics to remove numeric rows from the tables.
@ -64,7 +65,7 @@ class PDFToTextConverter(BaseConverter):
valid_languages=valid_languages, valid_languages=valid_languages,
) )
def extract_pages(self, file_path: Path) -> [str]: def extract_pages(self, file_path: Path) -> List[str]:
page_count = fitz.open(file_path).pageCount page_count = fitz.open(file_path).pageCount
@ -106,7 +107,6 @@ class PDFToTextConverter(BaseConverter):
pages.append(page) pages.append(page)
page_number += 1 page_number += 1
if self.valid_languages:
document_text = "".join(pages) document_text = "".join(pages)
if not self._validate_language(document_text): if not self._validate_language(document_text):
logger.warning( logger.warning(
@ -122,7 +122,7 @@ class PDFToTextConverter(BaseConverter):
return pages return pages
def _extract_page(self, file_path: Path, page_number: int, layout: bool): def _extract_page(self, file_path: Path, page_number: int, layout: bool) -> str:
""" """
Extract a page from the pdf file at file_path. Extract a page from the pdf file at file_path.
@ -132,17 +132,20 @@ class PDFToTextConverter(BaseConverter):
the content stream order. the content stream order.
""" """
if layout: if layout:
command = ["pdftotext", "-layout", "-f", str(page_number), "-l", str(page_number), file_path, "-"] command = ["pdftotext", "-layout", "-f", str(page_number), "-l", str(page_number), str(file_path), "-"]
else: else:
command = ["pdftotext", "-f", str(page_number), "-l", str(page_number), file_path, "-"] command = ["pdftotext", "-f", str(page_number), "-l", str(page_number), str(file_path), "-"]
output_page = subprocess.run(command, capture_output=True, shell=False) output_page = subprocess.run(command, capture_output=True, shell=False)
page = output_page.stdout.decode(errors="ignore") page = output_page.stdout.decode(errors="ignore")
return page return page
def _validate_language(self, text: str): def _validate_language(self, text: str) -> bool:
""" """
Validate if the language of the text is one of valid languages. Validate if the language of the text is one of valid languages.
""" """
if not self.valid_languages:
return True
try: try:
lang = langdetect.detect(text) lang = langdetect.detect(text)
except langdetect.lang_detect_exception.LangDetectException: except langdetect.lang_detect_exception.LangDetectException:
@ -153,7 +156,7 @@ class PDFToTextConverter(BaseConverter):
else: else:
return False return False
def _ngram(self, seq: str, n: int): def _ngram(self, seq: str, n: int) -> Generator[str, None, None]:
""" """
Return ngram (of tokens - currently splitted by whitespace) Return ngram (of tokens - currently splitted by whitespace)
:param seq: str, string from which the ngram shall be created :param seq: str, string from which the ngram shall be created
@ -166,20 +169,20 @@ class PDFToTextConverter(BaseConverter):
seq = seq.replace("\n", " \n") seq = seq.replace("\n", " \n")
seq = seq.replace("\t", " \t") seq = seq.replace("\t", " \t")
seq = seq.split(" ") words = seq.split(" ")
ngrams = ( ngrams = (
" ".join(seq[i : i + n]).replace(" \n", "\n").replace(" \t", "\t") for i in range(0, len(seq) - n + 1) " ".join(words[i : i + n]).replace(" \n", "\n").replace(" \t", "\t") for i in range(0, len(words) - n + 1)
) )
return ngrams return ngrams
def _allngram(self, seq: str, min_ngram: int, max_ngram: int): def _allngram(self, seq: str, min_ngram: int, max_ngram: int) -> Set[str]:
lengths = range(min_ngram, max_ngram) if max_ngram else range(min_ngram, len(seq)) lengths = range(min_ngram, max_ngram) if max_ngram else range(min_ngram, len(seq))
ngrams = map(partial(self._ngram, seq), lengths) ngrams = map(partial(self._ngram, seq), lengths)
res = set(chain.from_iterable(ngrams)) res = set(chain.from_iterable(ngrams))
return res return res
def find_longest_common_ngram(self, sequences: [str], max_ngram: int = 30, min_ngram: int = 3): def find_longest_common_ngram(self, sequences: List[str], max_ngram: int = 30, min_ngram: int = 3) -> Optional[str]:
""" """
Find the longest common ngram across different text sequences (e.g. start of pages). Find the longest common ngram across different text sequences (e.g. start of pages).
Considering all ngrams between the specified range. Helpful for finding footers, headers etc. Considering all ngrams between the specified range. Helpful for finding footers, headers etc.
@ -201,8 +204,8 @@ class PDFToTextConverter(BaseConverter):
return longest if longest.strip() else None return longest if longest.strip() else None
def find_and_remove_header_footer( def find_and_remove_header_footer(
self, pages: [str], n_chars: int, n_first_pages_to_ignore: int, n_last_pages_to_ignore: int self, pages: List[str], n_chars: int, n_first_pages_to_ignore: int, n_last_pages_to_ignore: int
): ) -> Tuple[List[str], Optional[str], Optional[str]]:
""" """
Heuristic to find footers and headers across different pages by searching for the longest common string. Heuristic to find footers and headers across different pages by searching for the longest common string.
For headers we only search in the first n_chars characters (for footer: last n_chars). For headers we only search in the first n_chars characters (for footer: last n_chars).

View File

@ -1,16 +1,18 @@
from pathlib import Path
import logging import logging
from farm.data_handler.utils import http_get
import tempfile
import tarfile import tarfile
import tempfile
import zipfile import zipfile
from typing import Callable from pathlib import Path
from typing import Callable, List, Optional
from farm.data_handler.utils import http_get
from haystack.indexing.file_converters.pdftotext import PDFToTextConverter from haystack.indexing.file_converters.pdftotext import PDFToTextConverter
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def convert_files_to_dicts(dir_path: str, clean_func: Callable = None, split_paragraphs: bool = False) -> [dict]: def convert_files_to_dicts(dir_path: str, clean_func: Optional[Callable] = None, split_paragraphs: bool = False) -> List[dict]:
""" """
Convert all files(.txt, .pdf) in the sub-directories of the given path to Python dicts that can be written to a Convert all files(.txt, .pdf) in the sub-directories of the given path to Python dicts that can be written to a
Document Store. Document Store.
@ -24,7 +26,7 @@ def convert_files_to_dicts(dir_path: str, clean_func: Callable = None, split_par
file_paths = [p for p in Path(dir_path).glob("**/*")] file_paths = [p for p in Path(dir_path).glob("**/*")]
if ".pdf" in [p.suffix.lower() for p in file_paths]: if ".pdf" in [p.suffix.lower() for p in file_paths]:
pdf_converter = PDFToTextConverter() pdf_converter = PDFToTextConverter() # type: Optional[PDFToTextConverter]
else: else:
pdf_converter = None pdf_converter = None
@ -33,7 +35,7 @@ def convert_files_to_dicts(dir_path: str, clean_func: Callable = None, split_par
if path.suffix.lower() == ".txt": if path.suffix.lower() == ".txt":
with open(path) as doc: with open(path) as doc:
text = doc.read() text = doc.read()
elif path.suffix.lower() == ".pdf": elif path.suffix.lower() == ".pdf" and pdf_converter:
pages = pdf_converter.extract_pages(path) pages = pdf_converter.extract_pages(path)
text = "\n".join(pages) text = "\n".join(pages)
else: else:
@ -53,7 +55,7 @@ def convert_files_to_dicts(dir_path: str, clean_func: Callable = None, split_par
return documents return documents
def fetch_archive_from_http(url, output_dir, proxies=None): def fetch_archive_from_http(url: str, output_dir: str, proxies: Optional[dict] = None):
""" """
Fetch an archive (zip or tar.gz) from a url via http and extract content to an output directory. Fetch an archive (zip or tar.gz) from a url via http and extract content to an output directory.
@ -86,11 +88,11 @@ def fetch_archive_from_http(url, output_dir, proxies=None):
temp_file.seek(0) # making tempfile accessible temp_file.seek(0) # making tempfile accessible
# extract # extract
if url[-4:] == ".zip": if url[-4:] == ".zip":
archive = zipfile.ZipFile(temp_file.name) zip_archive = zipfile.ZipFile(temp_file.name)
archive.extractall(output_dir) zip_archive.extractall(output_dir)
elif url[-7:] == ".tar.gz": elif url[-7:] == ".tar.gz":
archive = tarfile.open(temp_file.name) tar_archive = tarfile.open(temp_file.name)
archive.extractall(output_dir) tar_archive.extractall(output_dir)
# temp_file gets deleted here # temp_file gets deleted here
return True return True

11
haystack/reader/base.py Normal file
View File

@ -0,0 +1,11 @@
from abc import ABC, abstractmethod
from typing import List, Optional
from haystack.database.base import Document
class BaseReader(ABC):
@abstractmethod
def predict(self, question: str, documents: List[Document], top_k: Optional[int] = None):
pass

View File

@ -1,5 +1,6 @@
import logging import logging
from pathlib import Path from pathlib import Path
from typing import List, Optional, Union
import numpy as np import numpy as np
from farm.data_handler.data_silo import DataSilo from farm.data_handler.data_silo import DataSilo
@ -14,11 +15,11 @@ from scipy.special import expit
from haystack.database.base import Document from haystack.database.base import Document
from haystack.database.elasticsearch import ElasticsearchDocumentStore from haystack.database.elasticsearch import ElasticsearchDocumentStore
from haystack.reader.base import BaseReader
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class FARMReader: class FARMReader(BaseReader):
""" """
Transformer based model for extractive Question Answering using the FARM framework (https://github.com/deepset-ai/FARM). Transformer based model for extractive Question Answering using the FARM framework (https://github.com/deepset-ai/FARM).
While the underlying model can vary (BERT, Roberta, DistilBERT ...) the interface remains the same. While the underlying model can vary (BERT, Roberta, DistilBERT ...) the interface remains the same.
@ -30,16 +31,16 @@ class FARMReader:
def __init__( def __init__(
self, self,
model_name_or_path, model_name_or_path: Union[str, Path],
context_window_size=150, context_window_size: int = 150,
batch_size=50, batch_size: int = 50,
use_gpu=True, use_gpu: bool = True,
no_ans_boost=None, no_ans_boost: Optional[int] = None,
top_k_per_candidate=3, top_k_per_candidate: int = 3,
top_k_per_sample=1, top_k_per_sample: int = 1,
num_processes=None, num_processes: Optional[int] = None,
max_seq_len=256, max_seq_len: int = 256,
doc_stride=128 doc_stride: int = 128,
): ):
""" """
@ -97,9 +98,22 @@ class FARMReader:
self.max_seq_len = max_seq_len self.max_seq_len = max_seq_len
self.use_gpu = use_gpu self.use_gpu = use_gpu
def train(self, data_dir, train_filename, dev_filename=None, test_file_name=None, def train(
use_gpu=None, batch_size=10, n_epochs=2, learning_rate=1e-5, self,
max_seq_len=None, warmup_proportion=0.2, dev_split=0.1, evaluate_every=300, save_dir=None): data_dir: str,
train_filename: str,
dev_filename: Optional[str] = None,
test_file_name: Optional[str] = None,
use_gpu: Optional[bool] = None,
batch_size: int = 10,
n_epochs: int = 2,
learning_rate: float = 1e-5,
max_seq_len: Optional[int] = None,
warmup_proportion: float = 0.2,
dev_split: Optional[float] = 0.1,
evaluate_every: int = 300,
save_dir: Optional[str] = None,
):
""" """
Fine-tune a model on a QA dataset. Options: Fine-tune a model on a QA dataset. Options:
- Take a plain language model (e.g. `bert-base-cased`) and train it for QA (e.g. on SQuAD data) - Take a plain language model (e.g. `bert-base-cased`) and train it for QA (e.g. on SQuAD data)
@ -141,7 +155,6 @@ class FARMReader:
if not save_dir: if not save_dir:
save_dir = f"../../saved_models/{self.inferencer.model.language_model.name}" save_dir = f"../../saved_models/{self.inferencer.model.language_model.name}"
save_dir = Path(save_dir)
# 1. Create a DataProcessor that handles all the conversion from raw text into a pytorch Dataset # 1. Create a DataProcessor that handles all the conversion from raw text into a pytorch Dataset
label_list = ["start_token", "end_token"] label_list = ["start_token", "end_token"]
@ -184,14 +197,14 @@ class FARMReader:
) )
# 5. Let it grow! # 5. Let it grow!
self.inferencer.model = trainer.train() self.inferencer.model = trainer.train()
self.save(save_dir) self.save(Path(save_dir))
def save(self, directory): def save(self, directory: Path):
logger.info(f"Saving reader model to {directory}") logger.info(f"Saving reader model to {directory}")
self.inferencer.model.save(directory) self.inferencer.model.save(directory)
self.inferencer.processor.save(directory) self.inferencer.processor.save(directory)
def predict(self, question: str, documents: [Document], top_k: int = None): def predict(self, question: str, documents: List[Document], top_k: Optional[int] = None):
""" """
Use loaded QA model to find answers for a question in the supplied list of Document. Use loaded QA model to find answers for a question in the supplied list of Document.
@ -245,7 +258,8 @@ class FARMReader:
if a["answer"]: if a["answer"]:
cur = {"answer": a["answer"], cur = {"answer": a["answer"],
"score": a["score"], "score": a["score"],
"probability": float(expit(np.asarray([a["score"]]) / 8)), #just a pseudo prob for now # just a pseudo prob for now
"probability": float(expit(np.asarray([a["score"]]) / 8)), # type: ignore
"context": a["context"], "context": a["context"],
"offset_start": a["offset_answer_start"] - a["offset_context_start"], "offset_start": a["offset_answer_start"] - a["offset_context_start"],
"offset_end": a["offset_answer_end"] - a["offset_context_start"], "offset_end": a["offset_answer_end"] - a["offset_context_start"],
@ -316,8 +330,14 @@ class FARMReader:
} }
return results return results
def eval(self, document_store: ElasticsearchDocumentStore, device: str, label_index: str = "feedback", def eval(
doc_index: str = "eval_document", label_origin: str = "gold_label"): self,
document_store: ElasticsearchDocumentStore,
device: str,
label_index: str = "feedback",
doc_index: str = "eval_document",
label_origin: str = "gold_label",
):
""" """
Performs evaluation on evaluation documents in Elasticsearch DocumentStore. Performs evaluation on evaluation documents in Elasticsearch DocumentStore.
@ -386,7 +406,7 @@ class FARMReader:
return results return results
@staticmethod @staticmethod
def _calc_no_answer(no_ans_gaps,best_score_answer): def _calc_no_answer(no_ans_gaps: List[float], best_score_answer: float):
# "no answer" scores and positive answers scores are difficult to compare, because # "no answer" scores and positive answers scores are difficult to compare, because
# + a positive answer score is related to one specific document # + a positive answer score is related to one specific document
# - a "no answer" score is related to all input documents # - a "no answer" score is related to all input documents
@ -396,7 +416,8 @@ class FARMReader:
# No_ans_gap coming from FARM mean how much no_ans_boost should change to switch predictions # No_ans_gap coming from FARM mean how much no_ans_boost should change to switch predictions
no_ans_gaps = np.array(no_ans_gaps) no_ans_gaps = np.array(no_ans_gaps)
max_no_ans_gap = np.max(no_ans_gaps) max_no_ans_gap = np.max(no_ans_gaps)
if (np.sum(no_ans_gaps < 0) == len(no_ans_gaps)): # all passages "no answer" as top score # all passages "no answer" as top score
if (np.sum(no_ans_gaps < 0) == len(no_ans_gaps)): # type: ignore
no_ans_score = best_score_answer - max_no_ans_gap # max_no_ans_gap is negative, so it increases best pos score no_ans_score = best_score_answer - max_no_ans_gap # max_no_ans_gap is negative, so it increases best pos score
else: # case: at least one passage predicts an answer (positive no_ans_gap) else: # case: at least one passage predicts an answer (positive no_ans_gap)
no_ans_score = best_score_answer - max_no_ans_gap no_ans_score = best_score_answer - max_no_ans_gap
@ -410,7 +431,7 @@ class FARMReader:
"document_id": None} "document_id": None}
return no_ans_prediction, max_no_ans_gap return no_ans_prediction, max_no_ans_gap
def predict_on_texts(self, question: str, texts: [str], top_k=None): def predict_on_texts(self, question: str, texts: List[str], top_k: Optional[int] = None):
documents = [] documents = []
for i, text in enumerate(texts): for i, text in enumerate(texts):
documents.append( documents.append(

View File

@ -1,9 +1,12 @@
from typing import List, Optional
from transformers import pipeline from transformers import pipeline
from haystack.database.base import Document from haystack.database.base import Document
from haystack.reader.base import BaseReader
class TransformersReader: class TransformersReader(BaseReader):
""" """
Transformer based model for extractive Question Answering using the huggingface's transformers framework Transformer based model for extractive Question Answering using the huggingface's transformers framework
(https://github.com/huggingface/transformers). (https://github.com/huggingface/transformers).
@ -15,13 +18,11 @@ class TransformersReader:
def __init__( def __init__(
self, self,
model="distilbert-base-uncased-distilled-squad", model: str = "distilbert-base-uncased-distilled-squad",
tokenizer="distilbert-base-uncased", tokenizer: str = "distilbert-base-uncased",
context_window_size=30, context_window_size: int = 30,
#no_answer_shift=-100, use_gpu: int = 0,
#batch_size=16, n_best_per_passage: int = 2,
use_gpu=0,
n_best_per_passage=2
): ):
""" """
Load a QA model from Transformers. Load a QA model from Transformers.
@ -44,7 +45,7 @@ class TransformersReader:
self.n_best_per_passage = n_best_per_passage self.n_best_per_passage = n_best_per_passage
#TODO param to modify bias for no_answer #TODO param to modify bias for no_answer
def predict(self, question: str, documents: [Document], top_k: int = None): def predict(self, question: str, documents: List[Document], top_k: Optional[int] = None):
""" """
Use loaded QA model to find answers for a question in the supplied list of Document. Use loaded QA model to find answers for a question in the supplied list of Document.

View File

@ -1,7 +1,10 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List
from haystack.database.base import Document
class BaseRetriever(ABC): class BaseRetriever(ABC):
@abstractmethod @abstractmethod
def retrieve(self, query, candidate_doc_ids=None, top_k=1): def retrieve(self, query: str, filters: dict = None, top_k: int = 10, index: str = None) -> List[Document]:
pass pass

View File

@ -1,15 +1,17 @@
import logging import logging
from typing import Type from typing import List, Union
from farm.infer import Inferencer from farm.infer import Inferencer
from haystack.database.base import Document, BaseDocumentStore from haystack.database.base import Document
from haystack.database.elasticsearch import ElasticsearchDocumentStore
from haystack.retriever.base import BaseRetriever from haystack.retriever.base import BaseRetriever
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ElasticsearchRetriever(BaseRetriever): class ElasticsearchRetriever(BaseRetriever):
def __init__(self, document_store: Type[BaseDocumentStore], custom_query: str = None): def __init__(self, document_store: ElasticsearchDocumentStore, custom_query: str = None):
""" """
:param document_store: an instance of a DocumentStore to retrieve documents from. :param document_store: an instance of a DocumentStore to retrieve documents from.
:param custom_query: query string as per Elasticsearch DSL with a mandatory question placeholder($question). :param custom_query: query string as per Elasticsearch DSL with a mandatory question placeholder($question).
@ -38,10 +40,10 @@ class ElasticsearchRetriever(BaseRetriever):
self.retrieve(query="Why did the revenue increase?", self.retrieve(query="Why did the revenue increase?",
filters={"years": ["2019"], "quarters": ["Q1", "Q2"]}) filters={"years": ["2019"], "quarters": ["Q1", "Q2"]})
""" """
self.document_store = document_store self.document_store = document_store # type: ignore
self.custom_query = custom_query self.custom_query = custom_query
def retrieve(self, query: str, filters: dict = None, top_k: int = 10, index: str = None) -> [Document]: def retrieve(self, query: str, filters: dict = None, top_k: int = 10, index: str = None) -> List[Document]:
if index is None: if index is None:
index = self.document_store.index index = self.document_store.index
@ -50,8 +52,13 @@ class ElasticsearchRetriever(BaseRetriever):
return documents return documents
def eval(self, label_index: str = "feedback", doc_index: str = "eval_document", label_origin: str = "gold_label", def eval(
top_k: int = 10) -> dict: self,
label_index: str = "feedback",
doc_index: str = "eval_document",
label_origin: str = "gold_label",
top_k: int = 10,
) -> dict:
""" """
Performs evaluation on the Retriever. Performs evaluation on the Retriever.
Retriever is evaluated based on whether it finds the correct document given the question string and at which Retriever is evaluated based on whether it finds the correct document given the question string and at which
@ -81,7 +88,7 @@ class ElasticsearchRetriever(BaseRetriever):
for doc_idx, doc in enumerate(retrieved_docs): for doc_idx, doc in enumerate(retrieved_docs):
if doc.meta["doc_id"] == question["_source"]["doc_id"]: if doc.meta["doc_id"] == question["_source"]["doc_id"]:
correct_retrievals += 1 correct_retrievals += 1
summed_avg_precision += 1 / (doc_idx + 1) summed_avg_precision += 1 / (doc_idx + 1) # type: ignore
break break
number_of_questions = q_idx + 1 number_of_questions = q_idx + 1
@ -97,7 +104,7 @@ class ElasticsearchRetriever(BaseRetriever):
class EmbeddingRetriever(BaseRetriever): class EmbeddingRetriever(BaseRetriever):
def __init__( def __init__(
self, self,
document_store: Type[BaseDocumentStore], document_store: ElasticsearchDocumentStore,
embedding_model: str, embedding_model: str,
gpu: bool = True, gpu: bool = True,
model_format: str = "farm", model_format: str = "farm",
@ -137,13 +144,13 @@ class EmbeddingRetriever(BaseRetriever):
else: else:
raise NotImplementedError raise NotImplementedError
def retrieve(self, query: str, candidate_doc_ids: [str] = None, top_k: int = 10) -> [Document]: def retrieve(self, query: str, candidate_doc_ids: List[str] = None, top_k: int = 10) -> List[Document]: # type: ignore
query_emb = self.create_embedding(texts=[query]) query_emb = self.create_embedding(texts=[query])
documents = self.document_store.query_by_embedding(query_emb[0], top_k, candidate_doc_ids) documents = self.document_store.query_by_embedding(query_emb[0], top_k, candidate_doc_ids)
return documents return documents
def create_embedding(self, texts: [str]): def create_embedding(self, texts: Union[List[str], str]) -> List[List[float]]:
""" """
Create embeddings for each text in a list of texts using the retrievers model (`self.embedding_model`) Create embeddings for each text in a list of texts using the retrievers model (`self.embedding_model`)
:param texts: texts to embed :param texts: texts to embed
@ -152,14 +159,15 @@ class EmbeddingRetriever(BaseRetriever):
# for backward compatibility: cast pure str input # for backward compatibility: cast pure str input
if type(texts) == str: if type(texts) == str:
texts = [texts] texts = [texts] # type: ignore
assert type(texts) == list, "Expecting a list of texts, i.e. create_embeddings(texts=['text1',...])" assert type(texts) == list, "Expecting a list of texts, i.e. create_embeddings(texts=['text1',...])"
if self.model_format == "farm": if self.model_format == "farm":
res = self.embedding_model.inference_from_dicts(dicts=[{"text": t} for t in texts]) res = self.embedding_model.inference_from_dicts(dicts=[{"text": t} for t in texts]) # type: ignore
emb = [list(r["vec"]) for r in res] #cast from numpy emb = [list(r["vec"]) for r in res] #cast from numpy
elif self.model_format == "sentence_transformers": elif self.model_format == "sentence_transformers":
# text is single string, sentence-transformers needs a list of strings # text is single string, sentence-transformers needs a list of strings
res = self.embedding_model.encode(texts) # get back list of numpy embedding vectors # get back list of numpy embedding vectors
res = self.embedding_model.encode(texts) # type: ignore
emb = [list(r.astype('float64')) for r in res] #cast from numpy emb = [list(r.astype('float64')) for r in res] #cast from numpy
return emb return emb

View File

@ -1,11 +1,12 @@
import logging import logging
from collections import OrderedDict, namedtuple from collections import OrderedDict, namedtuple
from typing import List
import pandas as pd import pandas as pd
from haystack.database.base import Document
from haystack.retriever.base import BaseRetriever
from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.feature_extraction.text import TfidfVectorizer
from haystack.database.base import BaseDocumentStore, Document
from haystack.retriever.base import BaseRetriever
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -23,7 +24,7 @@ class TfidfRetriever(BaseRetriever):
It uses sklearn's TfidfVectorizer to compute a tf-idf matrix. It uses sklearn's TfidfVectorizer to compute a tf-idf matrix.
""" """
def __init__(self, document_store): def __init__(self, document_store: BaseDocumentStore):
self.vectorizer = TfidfVectorizer( self.vectorizer = TfidfVectorizer(
lowercase=True, lowercase=True,
stop_words=None, stop_words=None,
@ -36,7 +37,7 @@ class TfidfRetriever(BaseRetriever):
self.df = None self.df = None
self.fit() self.fit()
def _get_all_paragraphs(self): def _get_all_paragraphs(self) -> List[Paragraph]:
""" """
Split the list of documents in paragraphs Split the list of documents in paragraphs
""" """
@ -55,7 +56,7 @@ class TfidfRetriever(BaseRetriever):
logger.info(f"Found {len(paragraphs)} candidate paragraphs from {len(documents)} docs in DB") logger.info(f"Found {len(paragraphs)} candidate paragraphs from {len(documents)} docs in DB")
return paragraphs return paragraphs
def _calc_scores(self, query): def _calc_scores(self, query: str) -> dict:
question_vector = self.vectorizer.transform([query]) question_vector = self.vectorizer.transform([query])
scores = self.tfidf_matrix.dot(question_vector.T).toarray() scores = self.tfidf_matrix.dot(question_vector.T).toarray()
@ -65,19 +66,20 @@ class TfidfRetriever(BaseRetriever):
) )
return indices_and_scores return indices_and_scores
def retrieve(self, query, filters=None, top_k=10, verbose=True): def retrieve(self, query: str, filters: dict = None, top_k: int = 10, index: str = None) -> List[Document]:
if filters: if filters:
raise NotImplementedError("Filters are not implemented in TfidfRetriever.") raise NotImplementedError("Filters are not implemented in TfidfRetriever.")
if index:
raise NotImplementedError("Switching index is not supported in TfidfRetriever.")
# get scores # get scores
indices_and_scores = self._calc_scores(query) indices_and_scores = self._calc_scores(query)
# rank paragraphs # rank paragraphs
df_sliced = self.df.loc[indices_and_scores.keys()] df_sliced = self.df.loc[indices_and_scores.keys()] # type: ignore
df_sliced = df_sliced[:top_k] df_sliced = df_sliced[:top_k]
if verbose: logger.debug(
logger.info(
f"Identified {df_sliced.shape[0]} candidates via retriever:\n {df_sliced.to_string(col_space=10, index=False)}" f"Identified {df_sliced.shape[0]} candidates via retriever:\n {df_sliced.to_string(col_space=10, index=False)}"
) )

View File

@ -2,13 +2,13 @@ import json
from collections import defaultdict from collections import defaultdict
import logging import logging
import pprint import pprint
from typing import Dict, Any
from haystack.database.sql import Document from haystack.database.sql import Document
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def print_answers(results, details="all"): def print_answers(results: dict, details: str = "all"):
answers = results["answers"] answers = results["answers"]
pp = pprint.PrettyPrinter(indent=4) pp = pprint.PrettyPrinter(indent=4)
if details != "all": if details != "all":
@ -28,7 +28,7 @@ def print_answers(results, details="all"):
pp.pprint(results) pp.pprint(results)
def convert_labels_to_squad(labels_file): def convert_labels_to_squad(labels_file: str):
""" """
Convert the export from the labeling UI to SQuAD format for training. Convert the export from the labeling UI to SQuAD format for training.
@ -42,7 +42,7 @@ def convert_labels_to_squad(labels_file):
for label in labels: for label in labels:
labels_grouped_by_documents[label["document_id"]].append(label) labels_grouped_by_documents[label["document_id"]].append(label)
labels_in_squad_format = {"data": []} labels_in_squad_format = {"data": []} # type: Dict[str, Any]
for document_id, labels in labels_grouped_by_documents.items(): for document_id, labels in labels_grouped_by_documents.items():
qas = [] qas = []
for label in labels: for label in labels: