diff --git a/haystack/__init__.py b/haystack/__init__.py index 156a1264b..6f0c875fc 100644 --- a/haystack/__init__.py +++ b/haystack/__init__.py @@ -38,6 +38,7 @@ class Finder: :return: """ + # 1) Optional: reduce the search space via document tags if filters: query = """ SELECT id FROM document WHERE id in ( @@ -61,49 +62,15 @@ class Finder: else: candidate_doc_ids = None - retrieved_scores = self.retriever.retrieve(question, top_k=top_k_retriever) + # 2) Apply retriever to get fast candidate paragraphs + paragraphs, meta_data = self.retriever.retrieve(question, top_k=top_k_retriever, candidate_doc_ids=candidate_doc_ids) - inference_dicts = self._convert_retrieved_text_to_reader_format( - retrieved_scores, question, candidate_doc_ids=candidate_doc_ids - ) - results = self.reader.predict(inference_dicts, top_k=top_k_reader) - return results["results"] + # 3) Apply reader to get granular answer(s) + logger.info(f"Applying the reader now to look for the answer in detail ...") + results = self.reader.predict(question=question, + paragrahps=paragraphs, + meta_data_paragraphs=meta_data, + top_k=top_k_reader) - def _convert_retrieved_text_to_reader_format( - self, retrieved_scores, question, candidate_doc_ids=None, verbose=True - ): - """ - The reader expect the input as: - { - "text": "FARM is a home for all species of pretrained language models (e.g. BERT) that can be adapted to - different domain languages or down-stream tasks. With FARM you can easily create SOTA NLP models for tasks - like document classification, NER or question answering.", - "document_id": 127, - "questions" : ["What can you do with FARM?"] - } + return results - :param retrieved_scores: tfidf scores as returned by the retriever - :param question: question string - :param verbose: enable verbose logging - """ - df_sliced = self.retriever.df.loc[retrieved_scores.keys()] - if verbose: - logger.info( - f"Identified {df_sliced.shape[0]} candidates via retriever:\n {df_sliced.to_string(col_space=10, index=False)}" - ) - logger.info( - f"Applying the reader now to look for the answer in detail ..." - ) - inference_dicts = [] - for idx, row in df_sliced.iterrows(): - if candidate_doc_ids and row["document_id"] not in candidate_doc_ids: - continue - inference_dicts.append( - { - "text": row["text"], - "document_id": row["document_id"], - "questions": [question], - } - ) - - return inference_dicts diff --git a/haystack/indexing/io.py b/haystack/indexing/io.py index 2080c0c11..68cbbb0d5 100644 --- a/haystack/indexing/io.py +++ b/haystack/indexing/io.py @@ -9,15 +9,28 @@ import zipfile logger = logging.getLogger(__name__) -def write_documents_to_db(document_dir, clean_func=None): +def write_documents_to_db(document_dir, clean_func=None, only_empty_db=False): """ Write all text files(.txt) in the sub-directories of the given path to the connected database. :param document_dir: path for the documents to be written to the database - :return: + :param clean_func: a custom cleaning function that gets applied to each doc (input: str, output:str) + :param only_empty_db: If true, docs will only be written if db is completely empty. + Useful to avoid indexing the same initial docs again and again. + :return: None """ file_paths = Path(document_dir).glob("**/*.txt") n_docs = 0 + + # check if db has already docs + if only_empty_db: + n_docs = db.session.query(Document).count() + if n_docs > 0: + logger.info(f"Skip writing documents since DB already contains {n_docs} docs ... " + "(Disable `only_empty_db`, if you want to add docs anyway.)") + return None + + # read and add docs for path in file_paths: with open(path) as doc: text = doc.read() diff --git a/haystack/reader/farm.py b/haystack/reader/farm.py index 5e25156d2..5c7995ba7 100644 --- a/haystack/reader/farm.py +++ b/haystack/reader/farm.py @@ -18,6 +18,7 @@ class FARMReader: no_answer_shift=-100, batch_size=16, use_gpu=True, + n_best_per_passage=2 ): """ Load a saved FARM model in Inference mode. @@ -27,56 +28,77 @@ class FARMReader: self.model = Inferencer.load(model_dir, batch_size=batch_size, gpu=use_gpu) self.model.model.prediction_heads[0].context_size = context_size self.model.model.prediction_heads[0].no_answer_shift = no_answer_shift + self.model.model.prediction_heads[0].n_best = n_best_per_passage - def predict(self, input_dicts, top_k=None): + + def predict(self, question, paragrahps, meta_data_paragraphs=None, top_k=None, max_processes=1): """ - Run inference on the loaded model for the given input dicts. + Use loaded QA model to find answers for a question in the supplied paragraphs. - :param input_dicts: list of input dicts + Returns dictionaries containing answers sorted by (desc.) probability + Example: + {'question': 'Who is the father of Arya Stark?', + 'answers': [ + {'answer': 'Eddard,', + 'context': " She travels with her father, Eddard, to King's Landing when he is ", + 'offset_answer_start': 147, + 'offset_answer_end': 154, + 'probability': 0.9787139466668613, + 'score': None, + 'document_id': None + }, + ... + ] + } + + :param question: question string + :param paragraphs: list of strings in which to search for the answer + :param meta_data_paragraphs: list of dicts containing meta data for the paragraphs. + len(paragraphs) == len(meta_data_paragraphs) :param top_k: the maximum number of answers to return - :return: + :param max_processes: max number of parallel processes + :return: dict containing question and answers """ - results = self.model.inference_from_dicts( - dicts=input_dicts, rest_api_schema=True, use_multiprocessing=False + + if meta_data_paragraphs is None: + meta_data_paragraphs = len(paragrahps) * [None] + assert len(paragrahps) == len(meta_data_paragraphs) + + # convert input to FARM format + input_dicts = [] + for paragraph, meta_data in zip(paragrahps, meta_data_paragraphs): + cur = {"text": paragraph, + "questions": [question], + "document_id": meta_data["document_id"] + } + input_dicts.append(cur) + + # get answers from QA model (Top 5 per input paragraph) + predictions = self.model.inference_from_dicts( + dicts=input_dicts, rest_api_schema=True, max_processes=max_processes ) - # The FARM Inferencer as of now do not support multi document QA. - # The QA inference is done for each text independently and the - # results are sorted descending by their `score`. + # assemble answers from all the different paragraphs & format them + answers = [] + for pred in predictions: + for a in pred["predictions"][0]["answers"]: + if a["answer"]: #skip "no answer" + cur = {"answer": a["answer"], + "score": a["score"], + "probability": expit(np.asarray([a["score"]]) / 8), #just a pseudo prob for now + "context": a["context"], + "offset_start": a["offset_answer_start"] - a["offset_context_start"], + "offset_end": a["offset_answer_start"] - a["offset_context_start"], + "document_id": a["document_id"]} + answers.append(cur) - all_predictions = [] - for res in results: - all_predictions.extend(res["predictions"]) - - all_answers = [] - for pred in all_predictions: - answers = pred["answers"] - for a in answers: - # Two sets of offset fields are returned by FARM -- context level and document level. - # For the API, only context level offsets are relevant. - a["offset_start"] = a["offset_answer_start"] - a["offset_context_start"] - a["offset_end"] = a["offset_context_end"] - a["offset_answer_end"] - all_answers.extend(answers) - - # remove all null answers (where an answers in not found in the text) - all_answers = [ans for ans in all_answers if ans["answer"]] - - scores = np.asarray([ans["score"] for ans in all_answers]) - probabilities = expit(scores / 8) - for ans, prob in zip(all_answers, probabilities): - ans["probability"] = prob - - # sort answers by their `probability` - sorted_answers = sorted( - all_answers, key=lambda k: k["probability"], reverse=True + # sort answers by their `probability` and select top-k + answers = sorted( + answers, key=lambda k: k["probability"], reverse=True ) + answers = answers[:top_k] - # all predictions here are for the same questions, so the the metadata from - # the first prediction in the list is taken. - if all_predictions: - resp = all_predictions[0] # get the first prediction dict - resp["answers"] = sorted_answers[:top_k] - else: - resp = [] + result = {"question": question, + "answers": answers} - return {"results": [resp]} + return result diff --git a/haystack/reader/transformers.py b/haystack/reader/transformers.py new file mode 100644 index 000000000..0c8520fb7 --- /dev/null +++ b/haystack/reader/transformers.py @@ -0,0 +1,102 @@ +from transformers import pipeline + + +class TransformersReader: + """ + A reader using the QA Pipeline class from huggingface's Transformers. + Easily load any of the pretrained community QA models from here: https://huggingface.co/models + """ + + def __init__( + self, + model="distilbert-base-uncased-distilled-squad", + tokenizer="distilbert-base-uncased", + context_size=30, + #no_answer_shift=-100, + #batch_size=16, + use_gpu=0, + n_best_per_passage=2 + ): + """ + Load a QA model from Transformers. + Available models include: + - distilbert-base-uncased-distilled-squad + - bert-large-cased-whole-word-masking-finetuned-squad + - bert-large-uncased-whole-word-masking-finetuned-squad + + See https://huggingface.co/models for full list of available QA models + + :param model: name of the model + :param tokenizer: name of the tokenizer (usually the same as model) + :param context_size: num of chars (before and after the answer) to return as "context" for each answer. + The context usually helps users to understand if the answer really makes sense. + :param use_gpu: < 1 -> use cpu + >= 1 -> num of gpus to use + """ + self.model = pipeline("question-answering", model=model, tokenizer=tokenizer, device=use_gpu) + self.context_size = context_size + self.n_best_per_passage = n_best_per_passage + #TODO param to modify bias for no_answer + + + def predict(self, question, paragrahps, meta_data_paragraphs=None, top_k=None): + """ + Use loaded QA model to find answers for a question in the supplied paragraphs. + + Returns dictionaries containing answers sorted by (desc.) probability + Example: + {'question': 'Who is the father of Arya Stark?', + 'answers': [ + {'answer': 'Eddard,', + 'context': " She travels with her father, Eddard, to King's Landing when he is ", + 'offset_answer_start': 147, + 'offset_answer_end': 154, + 'probability': 0.9787139466668613, + 'score': None, + 'document_id': None + }, + ... + ] + } + + :param question: question string + :param paragraphs: list of strings in which to search for the answer + :param meta_data_paragraphs: list of dicts containing meta data for the paragraphs. + len(paragraphs) == len(meta_data_paragraphs) + :param top_k: the maximum number of answers to return + :param max_processes: max number of parallel processes + :return: dict containing question and answers + + """ + #TODO pass metadata + + # get top-answers for each candidate passage + answers = [] + for p in paragrahps: + query = {"context": p, "question": question} + predictions = self.model(query, topk=self.n_best_per_passage) + # assemble and format all answers + for pred in predictions: + if pred["answer"]: + context_start = max(0, pred["start"] - self.context_size) + context_end = min(len(p), pred["end"] + self.context_size) + answers.append({ + "answer": pred["answer"], + "context": p[context_start:context_end], + "offset_answer_start": pred["start"], + "offset_answer_end": pred["end"], + "probability": pred["score"], + "score": None, + "document_id": None + }) + + # sort answers by their `probability` and select top-k + answers = sorted( + answers, key=lambda k: k["probability"], reverse=True + ) + answers = answers[:top_k] + + results = {"question": question, + "answers": answers} + + return results diff --git a/haystack/retriever/tfidf.py b/haystack/retriever/tfidf.py index 9c6c4a920..603af0967 100644 --- a/haystack/retriever/tfidf.py +++ b/haystack/retriever/tfidf.py @@ -34,7 +34,7 @@ class TfidfRetriever(BaseRetriever): Split documents into smaller units (eg, paragraphs or pages) to reduce the computations when text is passed on to a Reader for QA. - It uses sklearn TfidfVectorizer to compute a tf-idf matrix. + It uses sklearn's TfidfVectorizer to compute a tf-idf matrix. """ def __init__(self): @@ -69,15 +69,37 @@ class TfidfRetriever(BaseRetriever): logger.info(f"Found {len(paragraphs)} candidate paragraphs from {len(documents)} docs in DB") return paragraphs - def retrieve(self, query, candidate_doc_ids=None, top_k=10): + def _calc_scores(self, query): question_vector = self.vectorizer.transform([query]) scores = self.tfidf_matrix.dot(question_vector.T).toarray() idx_scores = [(idx, score) for idx, score in enumerate(scores)] - top_k_scores = OrderedDict( - sorted(idx_scores, key=(lambda tup: tup[1]), reverse=True)[:top_k] + indices_and_scores = OrderedDict( + sorted(idx_scores, key=(lambda tup: tup[1]), reverse=True) ) - return top_k_scores + return indices_and_scores + + def retrieve(self, query, candidate_doc_ids=None, top_k=10, verbose=True): + # get scores + indices_and_scores = self._calc_scores(query) + + # rank & filter paragraphs + df_sliced = self.df.loc[indices_and_scores.keys()] + if candidate_doc_ids: + df_sliced = df_sliced[df_sliced.document_id.isin(candidate_doc_ids)] + df_sliced = df_sliced[:top_k] + + if verbose: + logger.info( + f"Identified {df_sliced.shape[0]} candidates via retriever:\n {df_sliced.to_string(col_space=10, index=False)}" + ) + + # get actual content for the top candidates + paragraphs = list(df_sliced.text.values) + meta_data = [{"document_id": row["document_id"], "paragraph_id": row["paragraph_id"]} + for idx, row in df_sliced.iterrows()] + + return paragraphs, meta_data def fit(self): self.df = pd.DataFrame.from_dict(self.paragraphs) diff --git a/haystack/utils.py b/haystack/utils.py index 0bb191b3a..e4aeb90d2 100644 --- a/haystack/utils.py +++ b/haystack/utils.py @@ -19,7 +19,7 @@ def create_db(): def print_answers(results, details="all"): - answers = results[0]["answers"] + answers = results["answers"] pp = pprint.PrettyPrinter(indent=4) if details != "all": if details == "minimal": diff --git a/requirements.txt b/requirements.txt index c9adb9c71..2cdb0954a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,6 @@ -farm==0.3.2 +# FARM (incl. transformers 2.3.0 with pipelines) +#farm -e git+https://github.com/deepset-ai/FARM.git@1d30237b037050ef0ac5516f427443cdd18a4d43 +-e git://github.com/deepset-ai/FARM.git@1d30237b037050ef0ac5516f427443cdd18a4d43#egg=farm flask flask_cors flask_restplus diff --git a/tutorials/Tutorial1_Basic_QA_Pipeline.ipynb b/tutorials/Tutorial1_Basic_QA_Pipeline.ipynb index 82fb41734..b989a8208 100644 --- a/tutorials/Tutorial1_Basic_QA_Pipeline.ipynb +++ b/tutorials/Tutorial1_Basic_QA_Pipeline.ipynb @@ -88,7 +88,7 @@ "# Now, let's write the docs to our DB. \n", "# You can supply a cleaning function that is applied to each doc (e.g. to remove footers)\n", "# It must take a str as input, and return a str.\n", - "write_documents_to_db(document_dir=doc_dir, clean_func=clean_wiki_text)" + "write_documents_to_db(document_dir=doc_dir, clean_func=clean_wiki_text, only_empty_db=True)" ] }, { @@ -146,7 +146,10 @@ "# Reader use more powerful but slower deep learning models, here: a BERT QA model trained via FARM on Squad 2.0\n", "from haystack.indexing.io import fetch_archive_from_http\n", "fetch_archive_from_http(url=\"https://s3.eu-central-1.amazonaws.com/deepset.ai-farm-models/0.3.0/bert-english-qa-large.tar.gz\", output_dir=\"model\")\n", - "reader = FARMReader(model_dir=\"model/bert-english-qa-large\", use_gpu=False)" + "reader = FARMReader(model_dir=\"model/bert-english-qa-large\", use_gpu=False)\n", + "\n", + "# OR: use alternatively a reader from huggingface's Transformers package\n", + "# reader = TransformersReader(use_gpu=-1)" ] }, { @@ -276,13 +279,13 @@ "pycharm": { "stem_cell": { "cell_type": "raw", + "source": [], "metadata": { "collapsed": false - }, - "source": [] + } } } }, "nbformat": 4, "nbformat_minor": 2 -} +} \ No newline at end of file diff --git a/tutorials/Tutorial1_Basic_QA_Pipeline.py b/tutorials/Tutorial1_Basic_QA_Pipeline.py new file mode 100755 index 000000000..6648535a5 --- /dev/null +++ b/tutorials/Tutorial1_Basic_QA_Pipeline.py @@ -0,0 +1,55 @@ +from haystack.reader.farm import FARMReader +from haystack.reader.transformers import TransformersReader +from haystack.retriever.tfidf import TfidfRetriever +from haystack import Finder +from haystack.indexing.io import write_documents_to_db, fetch_archive_from_http +from haystack.indexing.cleaning import clean_wiki_text +from haystack.utils import print_answers + + +## Indexing & cleaning documents +# Init a database (default: sqllite) +from haystack.database import db +db.create_all() + +# Let's first get some documents that we want to query +# Here: 517 Wikipedia articles for Game of Thrones +doc_dir = "data/article_txt_got" +s3_url = "https://s3.eu-central-1.amazonaws.com/deepset.ai-farm-qa/datasets/documents/wiki_gameofthrones_txt.zip" +fetch_archive_from_http(url=s3_url, output_dir=doc_dir) + +# Now, let's write the docs to our DB. +# You can supply a cleaning function that is applied to each doc (e.g. to remove footers) +# It must take a str as input, and return a str. +write_documents_to_db(document_dir=doc_dir, clean_func=clean_wiki_text, only_empty_db=True) + + +## Initalize Reader, Retriever & Finder + +# A retriever identifies the k most promising chunks of text that might contain the answer for our question +# Retrievers use some simple but fast algorithm, here: TF-IDF +retriever = TfidfRetriever() + +# A reader scans the text chunks in detail and extracts the k best answers +# Reader use more powerful but slower deep learning models, here: a BERT QA model trained via FARM on Squad 2.0 +fetch_archive_from_http(url="https://s3.eu-central-1.amazonaws.com/deepset.ai-farm-models/0.3.0/bert-english-qa-large.tar.gz", output_dir="model") +reader = FARMReader(model_dir="model/bert-english-qa-large", use_gpu=False) + +# OR: use alternatively a reader from huggingface's Transformers package +# reader = TransformersReader(use_gpu=-1) + +# The Finder sticks together retriever and retriever in a pipeline to answer our actual questions +finder = Finder(reader, retriever) + +## Voilá! Ask a question! +# You can configure how many candidates the reader and retriever shall return +# The higher top_k_retriever, the better (but also the slower) your answers. +prediction = finder.get_answers(question="Who is the father of Arya Stark?", top_k_retriever=10, top_k_reader=5) + +#prediction = finder.get_answers(question="Who created the Dothraki vocabulary?", top_k_reader=5) +#prediction = finder.get_answers(question="Who is the sister of Sansa?", top_k_reader=5) + +print_answers(prediction, details="minimal") + + +