diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 591c32534..b8952826a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -39,5 +39,9 @@ jobs: pip install -e . - name: Test with pytest + run: cd test && pytest + + - name: Test with mypy run: | - cd test && pytest + pip install mypy + mypy haystack --ignore-missing-imports diff --git a/haystack/api/controller/feedback.py b/haystack/api/controller/feedback.py index 1a13961e2..a53910bac 100644 --- a/haystack/api/controller/feedback.py +++ b/haystack/api/controller/feedback.py @@ -38,7 +38,7 @@ document_store = ElasticsearchDocumentStore( search_fields=SEARCH_FIELD_NAME, embedding_dim=EMBEDDING_DIM, embedding_field=EMBEDDING_FIELD_NAME, - excluded_meta_data=EXCLUDE_META_DATA_FIELDS, + excluded_meta_data=EXCLUDE_META_DATA_FIELDS, # type: ignore ) @@ -52,7 +52,7 @@ class Feedback(BaseModel): @router.post("/doc-qa-feedback") -def feedback(feedback: Feedback): +def doc_qa_feedback(feedback: Feedback): if feedback.answer and feedback.offset_start_in_doc: elasticsearch_client.index(index=DB_INDEX_FEEDBACK, body=feedback.dict()) else: @@ -63,7 +63,7 @@ def feedback(feedback: Feedback): @router.post("/faq-qa-feedback") -def feedback(feedback: Feedback): +def faq_qa_feedback(feedback: Feedback): elasticsearch_client.index(index=DB_INDEX_FEEDBACK, body=feedback.dict()) diff --git a/haystack/api/controller/search.py b/haystack/api/controller/search.py index 995cf9ad4..76785c79f 100644 --- a/haystack/api/controller/search.py +++ b/haystack/api/controller/search.py @@ -15,6 +15,7 @@ from haystack.api.config import DB_HOST, DB_PORT, DB_USER, DB_PW, DB_INDEX, ES_C from haystack.api.controller.utils import RequestLimiter from haystack.database.elasticsearch import ElasticsearchDocumentStore from haystack.reader.farm import FARMReader +from haystack.retriever.base import BaseRetriever from haystack.retriever.elasticsearch import ElasticsearchRetriever, EmbeddingRetriever logger = logging.getLogger(__name__) @@ -34,12 +35,12 @@ document_store = ElasticsearchDocumentStore( search_fields=SEARCH_FIELD_NAME, embedding_dim=EMBEDDING_DIM, embedding_field=EMBEDDING_FIELD_NAME, - excluded_meta_data=EXCLUDE_META_DATA_FIELDS, + excluded_meta_data=EXCLUDE_META_DATA_FIELDS, # type: ignore ) if EMBEDDING_MODEL_PATH: - retriever = EmbeddingRetriever(document_store=document_store, embedding_model=EMBEDDING_MODEL_PATH, gpu=USE_GPU) + retriever = EmbeddingRetriever(document_store=document_store, embedding_model=EMBEDDING_MODEL_PATH, gpu=USE_GPU) # type: BaseRetriever else: retriever = ElasticsearchRetriever(document_store=document_store) @@ -54,7 +55,7 @@ if READER_MODEL_PATH: # for extractive doc-qa num_processes=MAX_PROCESSES, max_seq_len=MAX_SEQ_LEN, doc_stride=DOC_STRIDE, - ) + ) # type: Optional[FARMReader] else: reader = None # don't need one for pure FAQ matching @@ -66,7 +67,7 @@ FINDERS = {1: Finder(reader=reader, retriever=retriever)} ############################################# class Question(BaseModel): questions: List[str] - filters: Dict[str, Optional[str]] = None + filters: Optional[Dict[str, Optional[str]]] = None top_k_reader: int = DEFAULT_TOP_K_READER top_k_retriever: int = DEFAULT_TOP_K_RETRIEVER @@ -74,8 +75,8 @@ class Question(BaseModel): class Answer(BaseModel): answer: Optional[str] question: Optional[str] - score: float = None - probability: float = None + score: Optional[float] = None + probability: Optional[float] = None context: Optional[str] offset_start: int offset_end: int @@ -112,14 +113,16 @@ def doc_qa(model_id: int, request: Question): for question in request.questions: if request.filters: # put filter values into a list and remove filters with null value - request.filters = {key: [value] for key, value in request.filters.items() if value is not None} + filters = {key: [value] for key, value in request.filters.items() if value is not None} logger.info(f" [{datetime.now()}] Request: {request}") + else: + filters = {} result = finder.get_answers( question=question, top_k_retriever=request.top_k_retriever, top_k_reader=request.top_k_reader, - filters=request.filters, + filters=filters, ) results.append(result) @@ -141,11 +144,13 @@ def faq_qa(model_id: int, request: Question): for question in request.questions: if request.filters: # put filter values into a list and remove filters with null value - request.filters = {key: [value] for key, value in request.filters.items() if value is not None} + filters = {key: [value] for key, value in request.filters.items() if value is not None} logger.info(f" [{datetime.now()}] Request: {request}") + else: + filters = {} result = finder.get_answers_via_similar_questions( - question=question, top_k_retriever=request.top_k_retriever, filters=request.filters, + question=question, top_k_retriever=request.top_k_retriever, filters=filters, ) results.append(result) diff --git a/haystack/database/base.py b/haystack/database/base.py index 15a28500c..f15c919ef 100644 --- a/haystack/database/base.py +++ b/haystack/database/base.py @@ -1,35 +1,9 @@ from abc import abstractmethod -from typing import Any, Optional, Dict +from typing import Any, Optional, Dict, List from pydantic import BaseModel, Field -class BaseDocumentStore: - """ - Base class for implementing Document Stores. - """ - - @abstractmethod - def write_documents(self, documents): - pass - - @abstractmethod - def get_document_by_id(self, id): - pass - - @abstractmethod - def get_document_ids_by_tags(self, tag): - pass - - @abstractmethod - def get_document_count(self): - pass - - @abstractmethod - def query_by_embedding(self, query_emb, top_k=10, candidate_doc_ids=None): - pass - - class Document(BaseModel): id: str = Field(..., description="_id field from Elasticsearch") text: str = Field(..., description="Text of the document") @@ -41,5 +15,35 @@ class Document(BaseModel): # name: Optional[str] = Field(None, description="Title of the document") question: Optional[str] = Field(None, description="Question text for FAQs.") query_score: Optional[float] = Field(None, description="Elasticsearch query score for a retrieved document") - meta: Optional[Dict[str, Any]] = Field(None, description="") + meta: Dict[str, Any] = Field({}, description="") tags: Optional[Dict[str, Any]] = Field(None, description="Tags that allow filtering of the data") + + +class BaseDocumentStore: + """ + Base class for implementing Document Stores. + """ + + @abstractmethod + def write_documents(self, documents: List[dict]): + pass + + @abstractmethod + def get_all_documents(self) -> List[Document]: + pass + + @abstractmethod + def get_document_by_id(self, id: str) -> Optional[Document]: + pass + + @abstractmethod + def get_document_ids_by_tags(self, tag) -> List[str]: + pass + + @abstractmethod + def get_document_count(self) -> int: + pass + + @abstractmethod + def query_by_embedding(self, query_emb: List[float], top_k: int = 10, candidate_doc_ids: Optional[List[str]] = None) -> List[Document]: + pass diff --git a/haystack/database/elasticsearch.py b/haystack/database/elasticsearch.py index 74800ccce..5680b57ff 100644 --- a/haystack/database/elasticsearch.py +++ b/haystack/database/elasticsearch.py @@ -1,7 +1,7 @@ import json import logging from string import Template -from typing import Union +from typing import List, Optional, Union, Dict, Any from elasticsearch import Elasticsearch from elasticsearch.helpers import bulk, scan @@ -18,14 +18,14 @@ class ElasticsearchDocumentStore(BaseDocumentStore): username: str = "", password: str = "", index: str = "document", - search_fields: Union[str,list] = "text", + search_fields: Union[str, list] = "text", text_field: str = "text", name_field: str = "name", external_source_id_field: str = "external_source_id", - embedding_field: str = None, - embedding_dim: str = None, - custom_mapping: dict = None, - excluded_meta_data: list = None, + embedding_field: Optional[str] = None, + embedding_dim: Optional[str] = None, + custom_mapping: Optional[dict] = None, + excluded_meta_data: Optional[list] = None, scheme: str = "http", ca_certs: bool = False, verify_certs: bool = True, @@ -93,14 +93,14 @@ class ElasticsearchDocumentStore(BaseDocumentStore): self.embedding_field = embedding_field self.excluded_meta_data = excluded_meta_data - def get_document_by_id(self, id: str) -> Document: + def get_document_by_id(self, id: str) -> Optional[Document]: query = {"query": {"ids": {"values": [id]}}} result = self.client.search(index=self.index, body=query)["hits"]["hits"] document = self._convert_es_hit_to_document(result[0]) if result else None return document - def get_document_ids_by_tags(self, tags: dict) -> [str]: + def get_document_ids_by_tags(self, tags: dict) -> List[str]: term_queries = [{"terms": {key: value}} for key, value in tags.items()] query = {"query": {"bool": {"must": term_queries}}} logger.debug(f"Tag filter query: {query}") @@ -110,19 +110,19 @@ class ElasticsearchDocumentStore(BaseDocumentStore): doc_ids.append(hit["_id"]) return doc_ids - def write_documents(self, documents): + def write_documents(self, documents: List[dict]): for doc in documents: doc["_op_type"] = "create" doc["_index"] = self.index bulk(self.client, documents, request_timeout=300) - def get_document_count(self): + def get_document_count(self) -> int: result = self.client.count() count = result["count"] return count - def get_all_documents(self): + def get_all_documents(self) -> List[Document]: result = scan(self.client, query={"query": {"match_all": {}}}, index=self.index) documents = [self._convert_es_hit_to_document(hit) for hit in result] @@ -131,11 +131,11 @@ class ElasticsearchDocumentStore(BaseDocumentStore): def query( self, query: str, - filters: dict = None, + filters: Optional[dict] = None, top_k: int = 10, - custom_query: str = None, - index: str = None, - ) -> [Document]: + custom_query: Optional[str] = None, + index: Optional[str] = None, + ) -> List[Document]: if index is None: index = self.index @@ -180,7 +180,7 @@ class ElasticsearchDocumentStore(BaseDocumentStore): documents = [self._convert_es_hit_to_document(hit) for hit in result] return documents - def query_by_embedding(self, query_emb, top_k=10, candidate_doc_ids=None) -> [Document]: + def query_by_embedding(self, query_emb: List[float], top_k: int = 10, candidate_doc_ids: Optional[List[str]] = None) -> List[Document]: if not self.embedding_field: raise RuntimeError("Please specify arg `embedding_field` in ElasticsearchDocumentStore()") else: @@ -198,7 +198,7 @@ class ElasticsearchDocumentStore(BaseDocumentStore): } } } - } + } # type: Dict[str,Any] if candidate_doc_ids: body["query"]["script_score"]["query"] = { @@ -216,7 +216,7 @@ class ElasticsearchDocumentStore(BaseDocumentStore): documents = [self._convert_es_hit_to_document(hit, score_adjustment=-1) for hit in result] return documents - def _convert_es_hit_to_document(self, hit, score_adjustment=0) -> Document: + def _convert_es_hit_to_document(self, hit: dict, score_adjustment: int = 0) -> Document: # We put all additional data of the doc into meta_data and return it in the API meta_data = {k:v for k,v in hit["_source"].items() if k not in (self.text_field, self.external_source_id_field)} meta_data["name"] = meta_data.pop(self.name_field, None) @@ -286,7 +286,7 @@ class ElasticsearchDocumentStore(BaseDocumentStore): bulk(self.client, eval_docs_to_index) bulk(self.client, questions_to_index) - def get_all_documents_in_index(self, index, filters=None): + def get_all_documents_in_index(self, index: str, filters: Optional[dict] = None): body = { "query": { "bool": { @@ -295,7 +295,7 @@ class ElasticsearchDocumentStore(BaseDocumentStore): } } } - } + } # type: Dict[str, Any] if filters: body["query"]["bool"]["filter"] = {"term": filters} diff --git a/haystack/database/memory.py b/haystack/database/memory.py index ed620e441..313f29cb2 100644 --- a/haystack/database/memory.py +++ b/haystack/database/memory.py @@ -1,3 +1,5 @@ +from typing import Any, Dict, List, Optional, Union, Tuple + from haystack.database.base import BaseDocumentStore, Document @@ -10,7 +12,7 @@ class InMemoryDocumentStore(BaseDocumentStore): self.docs = {} self.doc_tags = {} - def write_documents(self, documents): + def write_documents(self, documents: List[dict]): import hashlib if documents is None: @@ -33,7 +35,7 @@ class InMemoryDocumentStore(BaseDocumentStore): self._map_tags_to_ids(hash, tags) - def _map_tags_to_ids(self, hash, tags): + def _map_tags_to_ids(self, hash: str, tags: List[str]): if isinstance(tags, list): for tag in tags: if isinstance(tag, dict): @@ -48,10 +50,11 @@ class InMemoryDocumentStore(BaseDocumentStore): else: self.doc_tags[comp_key] = [hash] - def get_document_by_id(self, id): - return self.docs[id] + def get_document_by_id(self, id: str) -> Document: + document = self._convert_memory_hit_to_document(self.docs[id], doc_id=id) + return document - def _convert_memory_hit_to_document(self, hit, doc_id=None) -> Document: + def _convert_memory_hit_to_document(self, hit: Tuple[Any, Any], doc_id: Optional[str] = None) -> Document: document = Document( id=doc_id, text=hit[0].get('text', None), @@ -60,7 +63,7 @@ class InMemoryDocumentStore(BaseDocumentStore): ) return document - def query_by_embedding(self, query_emb, top_k=10, candidate_doc_ids=None) -> [Document]: + def query_by_embedding(self, query_emb: List[float], top_k: int = 10, candidate_doc_ids: Optional[List[str]] = None) -> List[Document]: from haystack.api import config from numpy import dot from numpy.linalg import norm @@ -78,7 +81,7 @@ class InMemoryDocumentStore(BaseDocumentStore): return sorted(candidate_docs, key=lambda x: x.query_score, reverse=True)[0:top_k] - def get_document_ids_by_tags(self, tags): + def get_document_ids_by_tags(self, tags: Union[List[Dict[str, Union[str, List[str]]]], Dict[str, Union[str, List[str]]]]) -> List[str]: """ The format for the dict is {"tag-1": "value-1", "tag-2": "value-2" ...} The format for the dict is {"tag-1": ["value-1","value-2"], "tag-2": ["value-3]" ...} @@ -88,7 +91,7 @@ class InMemoryDocumentStore(BaseDocumentStore): result = self._find_ids_by_tags(tags) return result - def _find_ids_by_tags(self, tags): + def _find_ids_by_tags(self, tags: List[Dict[str, Union[str, List[str]]]]): result = [] for tag in tags: tag_keys = tag.keys() @@ -102,8 +105,8 @@ class InMemoryDocumentStore(BaseDocumentStore): result.append(self.docs.get(doc_id)) return result - def get_document_count(self): + def get_document_count(self) -> int: return len(self.docs.items()) - def get_all_documents(self): + def get_all_documents(self) -> List[Document]: return [Document(id=item[0], text=item[1]['text'], name=item[1]['name'], meta=item[1].get('meta', {})) for item in self.docs.items()] diff --git a/haystack/database/sql.py b/haystack/database/sql.py index cb381b69f..0be9d1cba 100644 --- a/haystack/database/sql.py +++ b/haystack/database/sql.py @@ -1,11 +1,12 @@ -import json +from typing import Any, Dict, Union, List, Optional + from sqlalchemy import create_engine, Column, Integer, String, DateTime, func, ForeignKey, PickleType from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import relationship, sessionmaker from haystack.database.base import BaseDocumentStore, Document as DocumentSchema -Base = declarative_base() +Base = declarative_base() # type: Any class ORMBase(Base): @@ -43,19 +44,19 @@ class DocumentTag(ORMBase): class SQLDocumentStore(BaseDocumentStore): - def __init__(self, url="sqlite://"): + def __init__(self, url: str = "sqlite://"): engine = create_engine(url) ORMBase.metadata.create_all(engine) Session = sessionmaker(bind=engine) self.session = Session() - def get_document_by_id(self, id): + def get_document_by_id(self, id: str) -> Optional[DocumentSchema]: document_row = self.session.query(Document).get(id) document = self._convert_sql_row_to_document(document_row) return document - def get_all_documents(self): + def get_all_documents(self) -> List[DocumentSchema]: document_rows = self.session.query(Document).all() documents = [] for row in document_rows: @@ -63,7 +64,7 @@ class SQLDocumentStore(BaseDocumentStore): return documents - def get_document_ids_by_tags(self, tags): + def get_document_ids_by_tags(self, tags: Dict[str, Union[str, List]]) -> List[str]: """ Get list of document ids that have tags from the given list of tags. @@ -91,16 +92,16 @@ class SQLDocumentStore(BaseDocumentStore): doc_ids = [row[0] for row in query_results] return doc_ids - def write_documents(self, documents): + def write_documents(self, documents: List[dict]): for doc in documents: row = Document(name=doc["name"], text=doc["text"], meta_data=doc.get("meta", {})) self.session.add(row) self.session.commit() - def get_document_count(self): + def get_document_count(self) -> int: return self.session.query(Document).count() - def _convert_sql_row_to_document(self, row) -> Document: + def _convert_sql_row_to_document(self, row) -> DocumentSchema: document = DocumentSchema( id=row.id, text=row.text, diff --git a/haystack/finder.py b/haystack/finder.py index 435f590d8..f549e68ac 100644 --- a/haystack/finder.py +++ b/haystack/finder.py @@ -1,9 +1,13 @@ import logging +import time +from statistics import mean +from typing import Optional, Dict, Any import numpy as np from scipy.special import expit -import time -from statistics import mean + +from haystack.reader.base import BaseReader +from haystack.retriever.base import BaseRetriever logger = logging.getLogger(__name__) @@ -15,13 +19,13 @@ class Finder: It provides an interface to predict top n answers for a given question. """ - def __init__(self, reader, retriever): + def __init__(self, reader: Optional[BaseReader], retriever: Optional[BaseRetriever]): self.retriever = retriever self.reader = reader if self.reader is None and self.retriever is None: raise AttributeError("Finder: self.reader and self.retriever can not be both None") - def get_answers(self, question: str, top_k_reader: int = 1, top_k_retriever: int = 10, filters: dict = None): + def get_answers(self, question: str, top_k_reader: int = 1, top_k_retriever: int = 10, filters: Optional[dict] = None): """ Get top k answers for a given question. @@ -41,8 +45,8 @@ class Finder: if len(documents) == 0: logger.info("Retriever did not return any documents. Skipping reader ...") - results = {"question": question, "answers": []} - return results + empty_result = {"question": question, "answers": []} + return empty_result # 2) Apply reader to get granular answer(s) len_chars = sum([len(d.text) for d in documents]) @@ -50,7 +54,7 @@ class Finder: results = self.reader.predict(question=question, documents=documents, - top_k=top_k_reader) + top_k=top_k_reader) # type: Dict[str, Any] # Add corresponding document_name and more meta data, if an answer contains the document_id for ans in results["answers"]: @@ -61,7 +65,7 @@ class Finder: return results - def get_answers_via_similar_questions(self, question: str, top_k_retriever: int = 10, filters: dict = None): + def get_answers_via_similar_questions(self, question: str, top_k_retriever: int = 10, filters: Optional[dict] = None): """ Get top k answers for a given question using only a retriever. @@ -75,12 +79,12 @@ class Finder: if self.retriever is None: raise AttributeError("Finder.get_answers_via_similar_questions requires self.retriever") - results = {"question": question, "answers": []} + results = {"question": question, "answers": []} # type: Dict[str, Any] # 1) Optional: reduce the search space via document tags if filters: logging.info(f"Apply filters: {filters}") - candidate_doc_ids = self.retriever.document_store.get_document_ids_by_tags(filters) + candidate_doc_ids = self.retriever.document_store.get_document_ids_by_tags(filters) # type: ignore logger.info(f"Got candidate IDs due to filters: {candidate_doc_ids}") if len(candidate_doc_ids) == 0: @@ -88,28 +92,35 @@ class Finder: return results else: - candidate_doc_ids = None + candidate_doc_ids = None # type: ignore # 2) Apply retriever to match similar questions via cosine similarity of embeddings - documents = self.retriever.retrieve(question, top_k=top_k_retriever, candidate_doc_ids=candidate_doc_ids) + documents = self.retriever.retrieve(question, top_k=top_k_retriever, candidate_doc_ids=candidate_doc_ids) # type: ignore # 3) Format response for doc in documents: #TODO proper calibratation of pseudo probabilities - cur_answer = {"question": doc.meta["question"], "answer": doc.text, "context": doc.text, + cur_answer = {"question": doc.meta["question"], "answer": doc.text, "context": doc.text, # type: ignore "score": doc.query_score, "offset_start": 0, "offset_end": len(doc.text), "meta": doc.meta } - if self.retriever.embedding_model: - probability = (doc.query_score + 1) / 2 + if self.retriever.embedding_model: # type: ignore + probability = (doc.query_score + 1) / 2 # type: ignore else: - probability = float(expit(np.asarray(doc.query_score / 8))) + probability = float(expit(np.asarray(doc.query_score / 8))) # type: ignore + cur_answer["probability"] = probability results["answers"].append(cur_answer) return results - def eval(self, label_index: str = "feedback", doc_index: str = "eval_document", label_origin: str = "gold_label", - top_k_retriever: int = 10, top_k_reader: int = 10): + def eval( + self, + label_index: str = "feedback", + doc_index: str = "eval_document", + label_origin: str = "gold_label", + top_k_retriever: int = 10, + top_k_reader: int = 10, + ): """ Evaluation of the whole pipeline by first evaluating the Retriever and then evaluating the Reader on the result of the Retriever. @@ -155,10 +166,14 @@ class Finder: :param top_k_reader: How many answers to return per question :type top_k_reader: int """ + + if not self.reader or not self.retriever: + raise Exception("Finder needs to have a reader and retriever for the evaluation.") + finder_start_time = time.time() # extract all questions for evaluation filter = {"origin": label_origin} - questions = self.retriever.document_store.get_all_documents_in_index(index=label_index, filters=filter) + questions = self.retriever.document_store.get_all_documents_in_index(index=label_index, filters=filter) # type: ignore correct_retrievals = 0 summed_avg_precision_retriever = 0 @@ -192,7 +207,7 @@ class Finder: # check if correct doc among retrieved docs if doc.meta["doc_id"] == question["_source"]["doc_id"]: correct_retrievals += 1 - summed_avg_precision_retriever += 1 / (doc_idx + 1) + summed_avg_precision_retriever += 1 / (doc_idx + 1) # type: ignore questions_with_docs.append({ "question": question, "docs": retrieved_docs, @@ -202,8 +217,8 @@ class Finder: number_of_questions = q_idx + 1 number_of_no_answer = 0 - previous_return_no_answers = self.reader.return_no_answers - self.reader.return_no_answers = True + previous_return_no_answers = self.reader.return_no_answers # type: ignore + self.reader.return_no_answers = True # type: ignore # extract answers reader_start_time = time.time() for q_idx, question in enumerate(questions_with_docs): @@ -260,10 +275,10 @@ class Finder: current_f1 = (2 * precision * recall) / (precision + recall) # top-1 answer if answer_idx == 0: - summed_f1_top1 += current_f1 - summed_f1_top1_has_answer += current_f1 + summed_f1_top1 += current_f1 # type: ignore + summed_f1_top1_has_answer += current_f1 # type: ignore if current_f1 > best_f1: - best_f1 = current_f1 + best_f1 = current_f1 # type: ignore # top-k answers: use best f1-score summed_f1_topk += best_f1 summed_f1_topk_has_answer += best_f1 @@ -311,7 +326,7 @@ class Finder: reader_top1_no_answer_accuracy = correct_no_answers_top1 / number_of_no_answer reader_topk_no_answer_accuracy = correct_no_answers_topk / number_of_no_answer - self.reader.return_no_answers = previous_return_no_answers + self.reader.return_no_answers = previous_return_no_answers # type: ignore logger.info((f"{correct_readings_topk} out of {number_of_questions} questions were correctly answered " f"({(correct_readings_topk/number_of_questions):.2%}).")) diff --git a/haystack/indexing/cleaning.py b/haystack/indexing/cleaning.py index dfe1939d2..f84cf8f8c 100644 --- a/haystack/indexing/cleaning.py +++ b/haystack/indexing/cleaning.py @@ -1,15 +1,15 @@ import re -def clean_wiki_text(text): +def clean_wiki_text(text: str) -> str: # get rid of multiple new lines while "\n\n" in text: text = text.replace("\n\n", "\n") # remove extremely short lines - text = text.split("\n") + lines = text.split("\n") cleaned = [] - for l in text: + for l in lines: if len(l) > 30: cleaned.append(l) elif l[:2] == "==" and l[-2:] == "==": diff --git a/haystack/indexing/file_converters/base.py b/haystack/indexing/file_converters/base.py index c4121da47..89d8e6d16 100644 --- a/haystack/indexing/file_converters/base.py +++ b/haystack/indexing/file_converters/base.py @@ -1,5 +1,6 @@ from abc import abstractmethod from pathlib import Path +from typing import List, Optional class BaseConverter: @@ -9,11 +10,11 @@ class BaseConverter: def __init__( self, - remove_numeric_tables: bool = None, - remove_header_footer: bool = None, - remove_whitespace: bool = None, - remove_empty_lines: bool = None, - valid_languages: [str] = None, + remove_numeric_tables: Optional[bool] = None, + remove_header_footer: Optional[bool] = None, + remove_whitespace: Optional[bool] = None, + remove_empty_lines: Optional[bool] = None, + valid_languages: Optional[List[str]] = None, ): """ :param remove_numeric_tables: This option uses heuristics to remove numeric rows from the tables. @@ -40,5 +41,5 @@ class BaseConverter: self.valid_languages = valid_languages @abstractmethod - def extract_pages(self, file_path: Path) -> [str]: + def extract_pages(self, file_path: Path) -> List[str]: pass diff --git a/haystack/indexing/file_converters/pdftotext.py b/haystack/indexing/file_converters/pdftotext.py index 2dbc09b7b..5413bbd7a 100644 --- a/haystack/indexing/file_converters/pdftotext.py +++ b/haystack/indexing/file_converters/pdftotext.py @@ -4,6 +4,7 @@ import subprocess from functools import partial, reduce from itertools import chain from pathlib import Path +from typing import List, Optional, Tuple, Generator, Set import fitz import langdetect @@ -16,11 +17,11 @@ logger = logging.getLogger(__name__) class PDFToTextConverter(BaseConverter): def __init__( self, - remove_numeric_tables: bool = False, - remove_whitespace: bool = None, - remove_empty_lines: bool = None, - remove_header_footer: bool = None, - valid_languages: [str] = None, + remove_numeric_tables: Optional[bool] = False, + remove_whitespace: Optional[bool] = None, + remove_empty_lines: Optional[bool] = None, + remove_header_footer: Optional[bool] = None, + valid_languages: Optional[List[str]] = None, ): """ :param remove_numeric_tables: This option uses heuristics to remove numeric rows from the tables. @@ -64,7 +65,7 @@ class PDFToTextConverter(BaseConverter): valid_languages=valid_languages, ) - def extract_pages(self, file_path: Path) -> [str]: + def extract_pages(self, file_path: Path) -> List[str]: page_count = fitz.open(file_path).pageCount @@ -106,13 +107,12 @@ class PDFToTextConverter(BaseConverter): pages.append(page) page_number += 1 - if self.valid_languages: - document_text = "".join(pages) - if not self._validate_language(document_text): - logger.warning( - f"The language for {file_path} is not one of {self.valid_languages}. The file may not have " - f"been decoded in the correct text format." - ) + document_text = "".join(pages) + if not self._validate_language(document_text): + logger.warning( + f"The language for {file_path} is not one of {self.valid_languages}. The file may not have " + f"been decoded in the correct text format." + ) if self.remove_header_footer: pages, header, footer = self.find_and_remove_header_footer( @@ -122,7 +122,7 @@ class PDFToTextConverter(BaseConverter): return pages - def _extract_page(self, file_path: Path, page_number: int, layout: bool): + def _extract_page(self, file_path: Path, page_number: int, layout: bool) -> str: """ Extract a page from the pdf file at file_path. @@ -132,17 +132,20 @@ class PDFToTextConverter(BaseConverter): the content stream order. """ if layout: - command = ["pdftotext", "-layout", "-f", str(page_number), "-l", str(page_number), file_path, "-"] + command = ["pdftotext", "-layout", "-f", str(page_number), "-l", str(page_number), str(file_path), "-"] else: - command = ["pdftotext", "-f", str(page_number), "-l", str(page_number), file_path, "-"] + command = ["pdftotext", "-f", str(page_number), "-l", str(page_number), str(file_path), "-"] output_page = subprocess.run(command, capture_output=True, shell=False) page = output_page.stdout.decode(errors="ignore") return page - def _validate_language(self, text: str): + def _validate_language(self, text: str) -> bool: """ Validate if the language of the text is one of valid languages. """ + if not self.valid_languages: + return True + try: lang = langdetect.detect(text) except langdetect.lang_detect_exception.LangDetectException: @@ -153,7 +156,7 @@ class PDFToTextConverter(BaseConverter): else: return False - def _ngram(self, seq: str, n: int): + def _ngram(self, seq: str, n: int) -> Generator[str, None, None]: """ Return ngram (of tokens - currently splitted by whitespace) :param seq: str, string from which the ngram shall be created @@ -166,20 +169,20 @@ class PDFToTextConverter(BaseConverter): seq = seq.replace("\n", " \n") seq = seq.replace("\t", " \t") - seq = seq.split(" ") + words = seq.split(" ") ngrams = ( - " ".join(seq[i : i + n]).replace(" \n", "\n").replace(" \t", "\t") for i in range(0, len(seq) - n + 1) + " ".join(words[i : i + n]).replace(" \n", "\n").replace(" \t", "\t") for i in range(0, len(words) - n + 1) ) return ngrams - def _allngram(self, seq: str, min_ngram: int, max_ngram: int): + def _allngram(self, seq: str, min_ngram: int, max_ngram: int) -> Set[str]: lengths = range(min_ngram, max_ngram) if max_ngram else range(min_ngram, len(seq)) ngrams = map(partial(self._ngram, seq), lengths) res = set(chain.from_iterable(ngrams)) return res - def find_longest_common_ngram(self, sequences: [str], max_ngram: int = 30, min_ngram: int = 3): + def find_longest_common_ngram(self, sequences: List[str], max_ngram: int = 30, min_ngram: int = 3) -> Optional[str]: """ Find the longest common ngram across different text sequences (e.g. start of pages). Considering all ngrams between the specified range. Helpful for finding footers, headers etc. @@ -201,8 +204,8 @@ class PDFToTextConverter(BaseConverter): return longest if longest.strip() else None def find_and_remove_header_footer( - self, pages: [str], n_chars: int, n_first_pages_to_ignore: int, n_last_pages_to_ignore: int - ): + self, pages: List[str], n_chars: int, n_first_pages_to_ignore: int, n_last_pages_to_ignore: int + ) -> Tuple[List[str], Optional[str], Optional[str]]: """ Heuristic to find footers and headers across different pages by searching for the longest common string. For headers we only search in the first n_chars characters (for footer: last n_chars). diff --git a/haystack/indexing/utils.py b/haystack/indexing/utils.py index e8ec48485..517d6b159 100644 --- a/haystack/indexing/utils.py +++ b/haystack/indexing/utils.py @@ -1,16 +1,18 @@ -from pathlib import Path import logging -from farm.data_handler.utils import http_get -import tempfile import tarfile +import tempfile import zipfile -from typing import Callable +from pathlib import Path +from typing import Callable, List, Optional + +from farm.data_handler.utils import http_get + from haystack.indexing.file_converters.pdftotext import PDFToTextConverter logger = logging.getLogger(__name__) -def convert_files_to_dicts(dir_path: str, clean_func: Callable = None, split_paragraphs: bool = False) -> [dict]: +def convert_files_to_dicts(dir_path: str, clean_func: Optional[Callable] = None, split_paragraphs: bool = False) -> List[dict]: """ Convert all files(.txt, .pdf) in the sub-directories of the given path to Python dicts that can be written to a Document Store. @@ -24,7 +26,7 @@ def convert_files_to_dicts(dir_path: str, clean_func: Callable = None, split_par file_paths = [p for p in Path(dir_path).glob("**/*")] if ".pdf" in [p.suffix.lower() for p in file_paths]: - pdf_converter = PDFToTextConverter() + pdf_converter = PDFToTextConverter() # type: Optional[PDFToTextConverter] else: pdf_converter = None @@ -33,7 +35,7 @@ def convert_files_to_dicts(dir_path: str, clean_func: Callable = None, split_par if path.suffix.lower() == ".txt": with open(path) as doc: text = doc.read() - elif path.suffix.lower() == ".pdf": + elif path.suffix.lower() == ".pdf" and pdf_converter: pages = pdf_converter.extract_pages(path) text = "\n".join(pages) else: @@ -53,7 +55,7 @@ def convert_files_to_dicts(dir_path: str, clean_func: Callable = None, split_par return documents -def fetch_archive_from_http(url, output_dir, proxies=None): +def fetch_archive_from_http(url: str, output_dir: str, proxies: Optional[dict] = None): """ Fetch an archive (zip or tar.gz) from a url via http and extract content to an output directory. @@ -86,11 +88,11 @@ def fetch_archive_from_http(url, output_dir, proxies=None): temp_file.seek(0) # making tempfile accessible # extract if url[-4:] == ".zip": - archive = zipfile.ZipFile(temp_file.name) - archive.extractall(output_dir) + zip_archive = zipfile.ZipFile(temp_file.name) + zip_archive.extractall(output_dir) elif url[-7:] == ".tar.gz": - archive = tarfile.open(temp_file.name) - archive.extractall(output_dir) + tar_archive = tarfile.open(temp_file.name) + tar_archive.extractall(output_dir) # temp_file gets deleted here return True diff --git a/haystack/reader/base.py b/haystack/reader/base.py new file mode 100644 index 000000000..11b7946e4 --- /dev/null +++ b/haystack/reader/base.py @@ -0,0 +1,11 @@ +from abc import ABC, abstractmethod +from typing import List, Optional + +from haystack.database.base import Document + + +class BaseReader(ABC): + + @abstractmethod + def predict(self, question: str, documents: List[Document], top_k: Optional[int] = None): + pass diff --git a/haystack/reader/farm.py b/haystack/reader/farm.py index 75eeb0f1a..c6e16a87f 100644 --- a/haystack/reader/farm.py +++ b/haystack/reader/farm.py @@ -1,5 +1,6 @@ import logging from pathlib import Path +from typing import List, Optional, Union import numpy as np from farm.data_handler.data_silo import DataSilo @@ -14,11 +15,11 @@ from scipy.special import expit from haystack.database.base import Document from haystack.database.elasticsearch import ElasticsearchDocumentStore - +from haystack.reader.base import BaseReader logger = logging.getLogger(__name__) -class FARMReader: +class FARMReader(BaseReader): """ Transformer based model for extractive Question Answering using the FARM framework (https://github.com/deepset-ai/FARM). While the underlying model can vary (BERT, Roberta, DistilBERT ...) the interface remains the same. @@ -30,16 +31,16 @@ class FARMReader: def __init__( self, - model_name_or_path, - context_window_size=150, - batch_size=50, - use_gpu=True, - no_ans_boost=None, - top_k_per_candidate=3, - top_k_per_sample=1, - num_processes=None, - max_seq_len=256, - doc_stride=128 + model_name_or_path: Union[str, Path], + context_window_size: int = 150, + batch_size: int = 50, + use_gpu: bool = True, + no_ans_boost: Optional[int] = None, + top_k_per_candidate: int = 3, + top_k_per_sample: int = 1, + num_processes: Optional[int] = None, + max_seq_len: int = 256, + doc_stride: int = 128, ): """ @@ -97,9 +98,22 @@ class FARMReader: self.max_seq_len = max_seq_len self.use_gpu = use_gpu - def train(self, data_dir, train_filename, dev_filename=None, test_file_name=None, - use_gpu=None, batch_size=10, n_epochs=2, learning_rate=1e-5, - max_seq_len=None, warmup_proportion=0.2, dev_split=0.1, evaluate_every=300, save_dir=None): + def train( + self, + data_dir: str, + train_filename: str, + dev_filename: Optional[str] = None, + test_file_name: Optional[str] = None, + use_gpu: Optional[bool] = None, + batch_size: int = 10, + n_epochs: int = 2, + learning_rate: float = 1e-5, + max_seq_len: Optional[int] = None, + warmup_proportion: float = 0.2, + dev_split: Optional[float] = 0.1, + evaluate_every: int = 300, + save_dir: Optional[str] = None, + ): """ Fine-tune a model on a QA dataset. Options: - Take a plain language model (e.g. `bert-base-cased`) and train it for QA (e.g. on SQuAD data) @@ -141,7 +155,6 @@ class FARMReader: if not save_dir: save_dir = f"../../saved_models/{self.inferencer.model.language_model.name}" - save_dir = Path(save_dir) # 1. Create a DataProcessor that handles all the conversion from raw text into a pytorch Dataset label_list = ["start_token", "end_token"] @@ -184,14 +197,14 @@ class FARMReader: ) # 5. Let it grow! self.inferencer.model = trainer.train() - self.save(save_dir) + self.save(Path(save_dir)) - def save(self, directory): + def save(self, directory: Path): logger.info(f"Saving reader model to {directory}") self.inferencer.model.save(directory) self.inferencer.processor.save(directory) - def predict(self, question: str, documents: [Document], top_k: int = None): + def predict(self, question: str, documents: List[Document], top_k: Optional[int] = None): """ Use loaded QA model to find answers for a question in the supplied list of Document. @@ -245,7 +258,8 @@ class FARMReader: if a["answer"]: cur = {"answer": a["answer"], "score": a["score"], - "probability": float(expit(np.asarray([a["score"]]) / 8)), #just a pseudo prob for now + # just a pseudo prob for now + "probability": float(expit(np.asarray([a["score"]]) / 8)), # type: ignore "context": a["context"], "offset_start": a["offset_answer_start"] - a["offset_context_start"], "offset_end": a["offset_answer_end"] - a["offset_context_start"], @@ -316,8 +330,14 @@ class FARMReader: } return results - def eval(self, document_store: ElasticsearchDocumentStore, device: str, label_index: str = "feedback", - doc_index: str = "eval_document", label_origin: str = "gold_label"): + def eval( + self, + document_store: ElasticsearchDocumentStore, + device: str, + label_index: str = "feedback", + doc_index: str = "eval_document", + label_origin: str = "gold_label", + ): """ Performs evaluation on evaluation documents in Elasticsearch DocumentStore. @@ -386,7 +406,7 @@ class FARMReader: return results @staticmethod - def _calc_no_answer(no_ans_gaps,best_score_answer): + def _calc_no_answer(no_ans_gaps: List[float], best_score_answer: float): # "no answer" scores and positive answers scores are difficult to compare, because # + a positive answer score is related to one specific document # - a "no answer" score is related to all input documents @@ -396,7 +416,8 @@ class FARMReader: # No_ans_gap coming from FARM mean how much no_ans_boost should change to switch predictions no_ans_gaps = np.array(no_ans_gaps) max_no_ans_gap = np.max(no_ans_gaps) - if (np.sum(no_ans_gaps < 0) == len(no_ans_gaps)): # all passages "no answer" as top score + # all passages "no answer" as top score + if (np.sum(no_ans_gaps < 0) == len(no_ans_gaps)): # type: ignore no_ans_score = best_score_answer - max_no_ans_gap # max_no_ans_gap is negative, so it increases best pos score else: # case: at least one passage predicts an answer (positive no_ans_gap) no_ans_score = best_score_answer - max_no_ans_gap @@ -410,7 +431,7 @@ class FARMReader: "document_id": None} return no_ans_prediction, max_no_ans_gap - def predict_on_texts(self, question: str, texts: [str], top_k=None): + def predict_on_texts(self, question: str, texts: List[str], top_k: Optional[int] = None): documents = [] for i, text in enumerate(texts): documents.append( diff --git a/haystack/reader/transformers.py b/haystack/reader/transformers.py index e362637be..ec780e6c0 100644 --- a/haystack/reader/transformers.py +++ b/haystack/reader/transformers.py @@ -1,9 +1,12 @@ +from typing import List, Optional + from transformers import pipeline from haystack.database.base import Document +from haystack.reader.base import BaseReader -class TransformersReader: +class TransformersReader(BaseReader): """ Transformer based model for extractive Question Answering using the huggingface's transformers framework (https://github.com/huggingface/transformers). @@ -15,13 +18,11 @@ class TransformersReader: def __init__( self, - model="distilbert-base-uncased-distilled-squad", - tokenizer="distilbert-base-uncased", - context_window_size=30, - #no_answer_shift=-100, - #batch_size=16, - use_gpu=0, - n_best_per_passage=2 + model: str = "distilbert-base-uncased-distilled-squad", + tokenizer: str = "distilbert-base-uncased", + context_window_size: int = 30, + use_gpu: int = 0, + n_best_per_passage: int = 2, ): """ Load a QA model from Transformers. @@ -44,7 +45,7 @@ class TransformersReader: self.n_best_per_passage = n_best_per_passage #TODO param to modify bias for no_answer - def predict(self, question: str, documents: [Document], top_k: int = None): + def predict(self, question: str, documents: List[Document], top_k: Optional[int] = None): """ Use loaded QA model to find answers for a question in the supplied list of Document. diff --git a/haystack/retriever/base.py b/haystack/retriever/base.py index ce3c29cb9..ca2210c3d 100644 --- a/haystack/retriever/base.py +++ b/haystack/retriever/base.py @@ -1,7 +1,10 @@ from abc import ABC, abstractmethod +from typing import List + +from haystack.database.base import Document class BaseRetriever(ABC): @abstractmethod - def retrieve(self, query, candidate_doc_ids=None, top_k=1): + def retrieve(self, query: str, filters: dict = None, top_k: int = 10, index: str = None) -> List[Document]: pass diff --git a/haystack/retriever/elasticsearch.py b/haystack/retriever/elasticsearch.py index 2b175c642..a82e16495 100644 --- a/haystack/retriever/elasticsearch.py +++ b/haystack/retriever/elasticsearch.py @@ -1,15 +1,17 @@ import logging -from typing import Type +from typing import List, Union + from farm.infer import Inferencer -from haystack.database.base import Document, BaseDocumentStore +from haystack.database.base import Document +from haystack.database.elasticsearch import ElasticsearchDocumentStore from haystack.retriever.base import BaseRetriever logger = logging.getLogger(__name__) class ElasticsearchRetriever(BaseRetriever): - def __init__(self, document_store: Type[BaseDocumentStore], custom_query: str = None): + def __init__(self, document_store: ElasticsearchDocumentStore, custom_query: str = None): """ :param document_store: an instance of a DocumentStore to retrieve documents from. :param custom_query: query string as per Elasticsearch DSL with a mandatory question placeholder($question). @@ -38,10 +40,10 @@ class ElasticsearchRetriever(BaseRetriever): self.retrieve(query="Why did the revenue increase?", filters={"years": ["2019"], "quarters": ["Q1", "Q2"]}) """ - self.document_store = document_store + self.document_store = document_store # type: ignore self.custom_query = custom_query - def retrieve(self, query: str, filters: dict = None, top_k: int = 10, index: str = None) -> [Document]: + def retrieve(self, query: str, filters: dict = None, top_k: int = 10, index: str = None) -> List[Document]: if index is None: index = self.document_store.index @@ -50,8 +52,13 @@ class ElasticsearchRetriever(BaseRetriever): return documents - def eval(self, label_index: str = "feedback", doc_index: str = "eval_document", label_origin: str = "gold_label", - top_k: int = 10) -> dict: + def eval( + self, + label_index: str = "feedback", + doc_index: str = "eval_document", + label_origin: str = "gold_label", + top_k: int = 10, + ) -> dict: """ Performs evaluation on the Retriever. Retriever is evaluated based on whether it finds the correct document given the question string and at which @@ -81,7 +88,7 @@ class ElasticsearchRetriever(BaseRetriever): for doc_idx, doc in enumerate(retrieved_docs): if doc.meta["doc_id"] == question["_source"]["doc_id"]: correct_retrievals += 1 - summed_avg_precision += 1 / (doc_idx + 1) + summed_avg_precision += 1 / (doc_idx + 1) # type: ignore break number_of_questions = q_idx + 1 @@ -97,7 +104,7 @@ class ElasticsearchRetriever(BaseRetriever): class EmbeddingRetriever(BaseRetriever): def __init__( self, - document_store: Type[BaseDocumentStore], + document_store: ElasticsearchDocumentStore, embedding_model: str, gpu: bool = True, model_format: str = "farm", @@ -137,13 +144,13 @@ class EmbeddingRetriever(BaseRetriever): else: raise NotImplementedError - def retrieve(self, query: str, candidate_doc_ids: [str] = None, top_k: int = 10) -> [Document]: + def retrieve(self, query: str, candidate_doc_ids: List[str] = None, top_k: int = 10) -> List[Document]: # type: ignore query_emb = self.create_embedding(texts=[query]) documents = self.document_store.query_by_embedding(query_emb[0], top_k, candidate_doc_ids) return documents - def create_embedding(self, texts: [str]): + def create_embedding(self, texts: Union[List[str], str]) -> List[List[float]]: """ Create embeddings for each text in a list of texts using the retrievers model (`self.embedding_model`) :param texts: texts to embed @@ -152,14 +159,15 @@ class EmbeddingRetriever(BaseRetriever): # for backward compatibility: cast pure str input if type(texts) == str: - texts = [texts] + texts = [texts] # type: ignore assert type(texts) == list, "Expecting a list of texts, i.e. create_embeddings(texts=['text1',...])" if self.model_format == "farm": - res = self.embedding_model.inference_from_dicts(dicts=[{"text": t} for t in texts]) + res = self.embedding_model.inference_from_dicts(dicts=[{"text": t} for t in texts]) # type: ignore emb = [list(r["vec"]) for r in res] #cast from numpy elif self.model_format == "sentence_transformers": # text is single string, sentence-transformers needs a list of strings - res = self.embedding_model.encode(texts) # get back list of numpy embedding vectors + # get back list of numpy embedding vectors + res = self.embedding_model.encode(texts) # type: ignore emb = [list(r.astype('float64')) for r in res] #cast from numpy return emb diff --git a/haystack/retriever/tfidf.py b/haystack/retriever/tfidf.py index a94fdaa1e..41d3c9720 100644 --- a/haystack/retriever/tfidf.py +++ b/haystack/retriever/tfidf.py @@ -1,11 +1,12 @@ import logging from collections import OrderedDict, namedtuple +from typing import List import pandas as pd -from haystack.database.base import Document -from haystack.retriever.base import BaseRetriever from sklearn.feature_extraction.text import TfidfVectorizer +from haystack.database.base import BaseDocumentStore, Document +from haystack.retriever.base import BaseRetriever logger = logging.getLogger(__name__) @@ -23,7 +24,7 @@ class TfidfRetriever(BaseRetriever): It uses sklearn's TfidfVectorizer to compute a tf-idf matrix. """ - def __init__(self, document_store): + def __init__(self, document_store: BaseDocumentStore): self.vectorizer = TfidfVectorizer( lowercase=True, stop_words=None, @@ -36,7 +37,7 @@ class TfidfRetriever(BaseRetriever): self.df = None self.fit() - def _get_all_paragraphs(self): + def _get_all_paragraphs(self) -> List[Paragraph]: """ Split the list of documents in paragraphs """ @@ -55,7 +56,7 @@ class TfidfRetriever(BaseRetriever): logger.info(f"Found {len(paragraphs)} candidate paragraphs from {len(documents)} docs in DB") return paragraphs - def _calc_scores(self, query): + def _calc_scores(self, query: str) -> dict: question_vector = self.vectorizer.transform([query]) scores = self.tfidf_matrix.dot(question_vector.T).toarray() @@ -65,21 +66,22 @@ class TfidfRetriever(BaseRetriever): ) return indices_and_scores - def retrieve(self, query, filters=None, top_k=10, verbose=True): + def retrieve(self, query: str, filters: dict = None, top_k: int = 10, index: str = None) -> List[Document]: if filters: raise NotImplementedError("Filters are not implemented in TfidfRetriever.") + if index: + raise NotImplementedError("Switching index is not supported in TfidfRetriever.") # get scores indices_and_scores = self._calc_scores(query) # rank paragraphs - df_sliced = self.df.loc[indices_and_scores.keys()] + df_sliced = self.df.loc[indices_and_scores.keys()] # type: ignore df_sliced = df_sliced[:top_k] - if verbose: - logger.info( - f"Identified {df_sliced.shape[0]} candidates via retriever:\n {df_sliced.to_string(col_space=10, index=False)}" - ) + logger.debug( + f"Identified {df_sliced.shape[0]} candidates via retriever:\n {df_sliced.to_string(col_space=10, index=False)}" + ) # get actual content for the top candidates paragraphs = list(df_sliced.text.values) diff --git a/haystack/utils.py b/haystack/utils.py index 5a4ca2d7f..d45da0bb1 100644 --- a/haystack/utils.py +++ b/haystack/utils.py @@ -2,13 +2,13 @@ import json from collections import defaultdict import logging import pprint - +from typing import Dict, Any from haystack.database.sql import Document logger = logging.getLogger(__name__) -def print_answers(results, details="all"): +def print_answers(results: dict, details: str = "all"): answers = results["answers"] pp = pprint.PrettyPrinter(indent=4) if details != "all": @@ -28,7 +28,7 @@ def print_answers(results, details="all"): pp.pprint(results) -def convert_labels_to_squad(labels_file): +def convert_labels_to_squad(labels_file: str): """ Convert the export from the labeling UI to SQuAD format for training. @@ -42,7 +42,7 @@ def convert_labels_to_squad(labels_file): for label in labels: labels_grouped_by_documents[label["document_id"]].append(label) - labels_in_squad_format = {"data": []} + labels_in_squad_format = {"data": []} # type: Dict[str, Any] for document_id, labels in labels_grouped_by_documents.items(): qas = [] for label in labels: