Refactor pipeline for better generalizability & Add TransformersReader (#1)

* add flag to skip writing docs to non-empty db

* change finder pipeline structure for better generalizability

* add basic TransformersReader

* update tutorials and requirements
This commit is contained in:
Malte Pietsch 2020-01-13 18:56:22 +01:00 committed by GitHub
parent 0670f0a2fc
commit cab0932fab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 285 additions and 99 deletions

View File

@ -38,6 +38,7 @@ class Finder:
:return:
"""
# 1) Optional: reduce the search space via document tags
if filters:
query = """
SELECT id FROM document WHERE id in (
@ -61,49 +62,15 @@ class Finder:
else:
candidate_doc_ids = None
retrieved_scores = self.retriever.retrieve(question, top_k=top_k_retriever)
# 2) Apply retriever to get fast candidate paragraphs
paragraphs, meta_data = self.retriever.retrieve(question, top_k=top_k_retriever, candidate_doc_ids=candidate_doc_ids)
inference_dicts = self._convert_retrieved_text_to_reader_format(
retrieved_scores, question, candidate_doc_ids=candidate_doc_ids
)
results = self.reader.predict(inference_dicts, top_k=top_k_reader)
return results["results"]
# 3) Apply reader to get granular answer(s)
logger.info(f"Applying the reader now to look for the answer in detail ...")
results = self.reader.predict(question=question,
paragrahps=paragraphs,
meta_data_paragraphs=meta_data,
top_k=top_k_reader)
def _convert_retrieved_text_to_reader_format(
self, retrieved_scores, question, candidate_doc_ids=None, verbose=True
):
"""
The reader expect the input as:
{
"text": "FARM is a home for all species of pretrained language models (e.g. BERT) that can be adapted to
different domain languages or down-stream tasks. With FARM you can easily create SOTA NLP models for tasks
like document classification, NER or question answering.",
"document_id": 127,
"questions" : ["What can you do with FARM?"]
}
return results
:param retrieved_scores: tfidf scores as returned by the retriever
:param question: question string
:param verbose: enable verbose logging
"""
df_sliced = self.retriever.df.loc[retrieved_scores.keys()]
if verbose:
logger.info(
f"Identified {df_sliced.shape[0]} candidates via retriever:\n {df_sliced.to_string(col_space=10, index=False)}"
)
logger.info(
f"Applying the reader now to look for the answer in detail ..."
)
inference_dicts = []
for idx, row in df_sliced.iterrows():
if candidate_doc_ids and row["document_id"] not in candidate_doc_ids:
continue
inference_dicts.append(
{
"text": row["text"],
"document_id": row["document_id"],
"questions": [question],
}
)
return inference_dicts

View File

@ -9,15 +9,28 @@ import zipfile
logger = logging.getLogger(__name__)
def write_documents_to_db(document_dir, clean_func=None):
def write_documents_to_db(document_dir, clean_func=None, only_empty_db=False):
"""
Write all text files(.txt) in the sub-directories of the given path to the connected database.
:param document_dir: path for the documents to be written to the database
:return:
:param clean_func: a custom cleaning function that gets applied to each doc (input: str, output:str)
:param only_empty_db: If true, docs will only be written if db is completely empty.
Useful to avoid indexing the same initial docs again and again.
:return: None
"""
file_paths = Path(document_dir).glob("**/*.txt")
n_docs = 0
# check if db has already docs
if only_empty_db:
n_docs = db.session.query(Document).count()
if n_docs > 0:
logger.info(f"Skip writing documents since DB already contains {n_docs} docs ... "
"(Disable `only_empty_db`, if you want to add docs anyway.)")
return None
# read and add docs
for path in file_paths:
with open(path) as doc:
text = doc.read()

View File

@ -18,6 +18,7 @@ class FARMReader:
no_answer_shift=-100,
batch_size=16,
use_gpu=True,
n_best_per_passage=2
):
"""
Load a saved FARM model in Inference mode.
@ -27,56 +28,77 @@ class FARMReader:
self.model = Inferencer.load(model_dir, batch_size=batch_size, gpu=use_gpu)
self.model.model.prediction_heads[0].context_size = context_size
self.model.model.prediction_heads[0].no_answer_shift = no_answer_shift
self.model.model.prediction_heads[0].n_best = n_best_per_passage
def predict(self, input_dicts, top_k=None):
def predict(self, question, paragrahps, meta_data_paragraphs=None, top_k=None, max_processes=1):
"""
Run inference on the loaded model for the given input dicts.
Use loaded QA model to find answers for a question in the supplied paragraphs.
:param input_dicts: list of input dicts
Returns dictionaries containing answers sorted by (desc.) probability
Example:
{'question': 'Who is the father of Arya Stark?',
'answers': [
{'answer': 'Eddard,',
'context': " She travels with her father, Eddard, to King's Landing when he is ",
'offset_answer_start': 147,
'offset_answer_end': 154,
'probability': 0.9787139466668613,
'score': None,
'document_id': None
},
...
]
}
:param question: question string
:param paragraphs: list of strings in which to search for the answer
:param meta_data_paragraphs: list of dicts containing meta data for the paragraphs.
len(paragraphs) == len(meta_data_paragraphs)
:param top_k: the maximum number of answers to return
:return:
:param max_processes: max number of parallel processes
:return: dict containing question and answers
"""
results = self.model.inference_from_dicts(
dicts=input_dicts, rest_api_schema=True, use_multiprocessing=False
if meta_data_paragraphs is None:
meta_data_paragraphs = len(paragrahps) * [None]
assert len(paragrahps) == len(meta_data_paragraphs)
# convert input to FARM format
input_dicts = []
for paragraph, meta_data in zip(paragrahps, meta_data_paragraphs):
cur = {"text": paragraph,
"questions": [question],
"document_id": meta_data["document_id"]
}
input_dicts.append(cur)
# get answers from QA model (Top 5 per input paragraph)
predictions = self.model.inference_from_dicts(
dicts=input_dicts, rest_api_schema=True, max_processes=max_processes
)
# The FARM Inferencer as of now do not support multi document QA.
# The QA inference is done for each text independently and the
# results are sorted descending by their `score`.
# assemble answers from all the different paragraphs & format them
answers = []
for pred in predictions:
for a in pred["predictions"][0]["answers"]:
if a["answer"]: #skip "no answer"
cur = {"answer": a["answer"],
"score": a["score"],
"probability": expit(np.asarray([a["score"]]) / 8), #just a pseudo prob for now
"context": a["context"],
"offset_start": a["offset_answer_start"] - a["offset_context_start"],
"offset_end": a["offset_answer_start"] - a["offset_context_start"],
"document_id": a["document_id"]}
answers.append(cur)
all_predictions = []
for res in results:
all_predictions.extend(res["predictions"])
all_answers = []
for pred in all_predictions:
answers = pred["answers"]
for a in answers:
# Two sets of offset fields are returned by FARM -- context level and document level.
# For the API, only context level offsets are relevant.
a["offset_start"] = a["offset_answer_start"] - a["offset_context_start"]
a["offset_end"] = a["offset_context_end"] - a["offset_answer_end"]
all_answers.extend(answers)
# remove all null answers (where an answers in not found in the text)
all_answers = [ans for ans in all_answers if ans["answer"]]
scores = np.asarray([ans["score"] for ans in all_answers])
probabilities = expit(scores / 8)
for ans, prob in zip(all_answers, probabilities):
ans["probability"] = prob
# sort answers by their `probability`
sorted_answers = sorted(
all_answers, key=lambda k: k["probability"], reverse=True
# sort answers by their `probability` and select top-k
answers = sorted(
answers, key=lambda k: k["probability"], reverse=True
)
answers = answers[:top_k]
# all predictions here are for the same questions, so the the metadata from
# the first prediction in the list is taken.
if all_predictions:
resp = all_predictions[0] # get the first prediction dict
resp["answers"] = sorted_answers[:top_k]
else:
resp = []
result = {"question": question,
"answers": answers}
return {"results": [resp]}
return result

View File

@ -0,0 +1,102 @@
from transformers import pipeline
class TransformersReader:
"""
A reader using the QA Pipeline class from huggingface's Transformers.
Easily load any of the pretrained community QA models from here: https://huggingface.co/models
"""
def __init__(
self,
model="distilbert-base-uncased-distilled-squad",
tokenizer="distilbert-base-uncased",
context_size=30,
#no_answer_shift=-100,
#batch_size=16,
use_gpu=0,
n_best_per_passage=2
):
"""
Load a QA model from Transformers.
Available models include:
- distilbert-base-uncased-distilled-squad
- bert-large-cased-whole-word-masking-finetuned-squad
- bert-large-uncased-whole-word-masking-finetuned-squad
See https://huggingface.co/models for full list of available QA models
:param model: name of the model
:param tokenizer: name of the tokenizer (usually the same as model)
:param context_size: num of chars (before and after the answer) to return as "context" for each answer.
The context usually helps users to understand if the answer really makes sense.
:param use_gpu: < 1 -> use cpu
>= 1 -> num of gpus to use
"""
self.model = pipeline("question-answering", model=model, tokenizer=tokenizer, device=use_gpu)
self.context_size = context_size
self.n_best_per_passage = n_best_per_passage
#TODO param to modify bias for no_answer
def predict(self, question, paragrahps, meta_data_paragraphs=None, top_k=None):
"""
Use loaded QA model to find answers for a question in the supplied paragraphs.
Returns dictionaries containing answers sorted by (desc.) probability
Example:
{'question': 'Who is the father of Arya Stark?',
'answers': [
{'answer': 'Eddard,',
'context': " She travels with her father, Eddard, to King's Landing when he is ",
'offset_answer_start': 147,
'offset_answer_end': 154,
'probability': 0.9787139466668613,
'score': None,
'document_id': None
},
...
]
}
:param question: question string
:param paragraphs: list of strings in which to search for the answer
:param meta_data_paragraphs: list of dicts containing meta data for the paragraphs.
len(paragraphs) == len(meta_data_paragraphs)
:param top_k: the maximum number of answers to return
:param max_processes: max number of parallel processes
:return: dict containing question and answers
"""
#TODO pass metadata
# get top-answers for each candidate passage
answers = []
for p in paragrahps:
query = {"context": p, "question": question}
predictions = self.model(query, topk=self.n_best_per_passage)
# assemble and format all answers
for pred in predictions:
if pred["answer"]:
context_start = max(0, pred["start"] - self.context_size)
context_end = min(len(p), pred["end"] + self.context_size)
answers.append({
"answer": pred["answer"],
"context": p[context_start:context_end],
"offset_answer_start": pred["start"],
"offset_answer_end": pred["end"],
"probability": pred["score"],
"score": None,
"document_id": None
})
# sort answers by their `probability` and select top-k
answers = sorted(
answers, key=lambda k: k["probability"], reverse=True
)
answers = answers[:top_k]
results = {"question": question,
"answers": answers}
return results

View File

@ -34,7 +34,7 @@ class TfidfRetriever(BaseRetriever):
Split documents into smaller units (eg, paragraphs or pages) to reduce the
computations when text is passed on to a Reader for QA.
It uses sklearn TfidfVectorizer to compute a tf-idf matrix.
It uses sklearn's TfidfVectorizer to compute a tf-idf matrix.
"""
def __init__(self):
@ -69,15 +69,37 @@ class TfidfRetriever(BaseRetriever):
logger.info(f"Found {len(paragraphs)} candidate paragraphs from {len(documents)} docs in DB")
return paragraphs
def retrieve(self, query, candidate_doc_ids=None, top_k=10):
def _calc_scores(self, query):
question_vector = self.vectorizer.transform([query])
scores = self.tfidf_matrix.dot(question_vector.T).toarray()
idx_scores = [(idx, score) for idx, score in enumerate(scores)]
top_k_scores = OrderedDict(
sorted(idx_scores, key=(lambda tup: tup[1]), reverse=True)[:top_k]
indices_and_scores = OrderedDict(
sorted(idx_scores, key=(lambda tup: tup[1]), reverse=True)
)
return top_k_scores
return indices_and_scores
def retrieve(self, query, candidate_doc_ids=None, top_k=10, verbose=True):
# get scores
indices_and_scores = self._calc_scores(query)
# rank & filter paragraphs
df_sliced = self.df.loc[indices_and_scores.keys()]
if candidate_doc_ids:
df_sliced = df_sliced[df_sliced.document_id.isin(candidate_doc_ids)]
df_sliced = df_sliced[:top_k]
if verbose:
logger.info(
f"Identified {df_sliced.shape[0]} candidates via retriever:\n {df_sliced.to_string(col_space=10, index=False)}"
)
# get actual content for the top candidates
paragraphs = list(df_sliced.text.values)
meta_data = [{"document_id": row["document_id"], "paragraph_id": row["paragraph_id"]}
for idx, row in df_sliced.iterrows()]
return paragraphs, meta_data
def fit(self):
self.df = pd.DataFrame.from_dict(self.paragraphs)

View File

@ -19,7 +19,7 @@ def create_db():
def print_answers(results, details="all"):
answers = results[0]["answers"]
answers = results["answers"]
pp = pprint.PrettyPrinter(indent=4)
if details != "all":
if details == "minimal":

View File

@ -1,4 +1,6 @@
farm==0.3.2
# FARM (incl. transformers 2.3.0 with pipelines)
#farm -e git+https://github.com/deepset-ai/FARM.git@1d30237b037050ef0ac5516f427443cdd18a4d43
-e git://github.com/deepset-ai/FARM.git@1d30237b037050ef0ac5516f427443cdd18a4d43#egg=farm
flask
flask_cors
flask_restplus

View File

@ -88,7 +88,7 @@
"# Now, let's write the docs to our DB. \n",
"# You can supply a cleaning function that is applied to each doc (e.g. to remove footers)\n",
"# It must take a str as input, and return a str.\n",
"write_documents_to_db(document_dir=doc_dir, clean_func=clean_wiki_text)"
"write_documents_to_db(document_dir=doc_dir, clean_func=clean_wiki_text, only_empty_db=True)"
]
},
{
@ -146,7 +146,10 @@
"# Reader use more powerful but slower deep learning models, here: a BERT QA model trained via FARM on Squad 2.0\n",
"from haystack.indexing.io import fetch_archive_from_http\n",
"fetch_archive_from_http(url=\"https://s3.eu-central-1.amazonaws.com/deepset.ai-farm-models/0.3.0/bert-english-qa-large.tar.gz\", output_dir=\"model\")\n",
"reader = FARMReader(model_dir=\"model/bert-english-qa-large\", use_gpu=False)"
"reader = FARMReader(model_dir=\"model/bert-english-qa-large\", use_gpu=False)\n",
"\n",
"# OR: use alternatively a reader from huggingface's Transformers package\n",
"# reader = TransformersReader(use_gpu=-1)"
]
},
{
@ -276,13 +279,13 @@
"pycharm": {
"stem_cell": {
"cell_type": "raw",
"source": [],
"metadata": {
"collapsed": false
},
"source": []
}
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}
}

View File

@ -0,0 +1,55 @@
from haystack.reader.farm import FARMReader
from haystack.reader.transformers import TransformersReader
from haystack.retriever.tfidf import TfidfRetriever
from haystack import Finder
from haystack.indexing.io import write_documents_to_db, fetch_archive_from_http
from haystack.indexing.cleaning import clean_wiki_text
from haystack.utils import print_answers
## Indexing & cleaning documents
# Init a database (default: sqllite)
from haystack.database import db
db.create_all()
# Let's first get some documents that we want to query
# Here: 517 Wikipedia articles for Game of Thrones
doc_dir = "data/article_txt_got"
s3_url = "https://s3.eu-central-1.amazonaws.com/deepset.ai-farm-qa/datasets/documents/wiki_gameofthrones_txt.zip"
fetch_archive_from_http(url=s3_url, output_dir=doc_dir)
# Now, let's write the docs to our DB.
# You can supply a cleaning function that is applied to each doc (e.g. to remove footers)
# It must take a str as input, and return a str.
write_documents_to_db(document_dir=doc_dir, clean_func=clean_wiki_text, only_empty_db=True)
## Initalize Reader, Retriever & Finder
# A retriever identifies the k most promising chunks of text that might contain the answer for our question
# Retrievers use some simple but fast algorithm, here: TF-IDF
retriever = TfidfRetriever()
# A reader scans the text chunks in detail and extracts the k best answers
# Reader use more powerful but slower deep learning models, here: a BERT QA model trained via FARM on Squad 2.0
fetch_archive_from_http(url="https://s3.eu-central-1.amazonaws.com/deepset.ai-farm-models/0.3.0/bert-english-qa-large.tar.gz", output_dir="model")
reader = FARMReader(model_dir="model/bert-english-qa-large", use_gpu=False)
# OR: use alternatively a reader from huggingface's Transformers package
# reader = TransformersReader(use_gpu=-1)
# The Finder sticks together retriever and retriever in a pipeline to answer our actual questions
finder = Finder(reader, retriever)
## Voilá! Ask a question!
# You can configure how many candidates the reader and retriever shall return
# The higher top_k_retriever, the better (but also the slower) your answers.
prediction = finder.get_answers(question="Who is the father of Arya Stark?", top_k_retriever=10, top_k_reader=5)
#prediction = finder.get_answers(question="Who created the Dothraki vocabulary?", top_k_reader=5)
#prediction = finder.get_answers(question="Who is the sister of Sansa?", top_k_reader=5)
print_answers(prediction, details="minimal")