Refactor pipeline for better generalizability & Add TransformersReader (#1)

* add flag to skip writing docs to non-empty db

* change finder pipeline structure for better generalizability

* add basic TransformersReader

* update tutorials and requirements
This commit is contained in:
Malte Pietsch 2020-01-13 18:56:22 +01:00 committed by GitHub
parent 0670f0a2fc
commit cab0932fab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 285 additions and 99 deletions

View File

@ -38,6 +38,7 @@ class Finder:
:return: :return:
""" """
# 1) Optional: reduce the search space via document tags
if filters: if filters:
query = """ query = """
SELECT id FROM document WHERE id in ( SELECT id FROM document WHERE id in (
@ -61,49 +62,15 @@ class Finder:
else: else:
candidate_doc_ids = None candidate_doc_ids = None
retrieved_scores = self.retriever.retrieve(question, top_k=top_k_retriever) # 2) Apply retriever to get fast candidate paragraphs
paragraphs, meta_data = self.retriever.retrieve(question, top_k=top_k_retriever, candidate_doc_ids=candidate_doc_ids)
inference_dicts = self._convert_retrieved_text_to_reader_format( # 3) Apply reader to get granular answer(s)
retrieved_scores, question, candidate_doc_ids=candidate_doc_ids logger.info(f"Applying the reader now to look for the answer in detail ...")
) results = self.reader.predict(question=question,
results = self.reader.predict(inference_dicts, top_k=top_k_reader) paragrahps=paragraphs,
return results["results"] meta_data_paragraphs=meta_data,
top_k=top_k_reader)
def _convert_retrieved_text_to_reader_format( return results
self, retrieved_scores, question, candidate_doc_ids=None, verbose=True
):
"""
The reader expect the input as:
{
"text": "FARM is a home for all species of pretrained language models (e.g. BERT) that can be adapted to
different domain languages or down-stream tasks. With FARM you can easily create SOTA NLP models for tasks
like document classification, NER or question answering.",
"document_id": 127,
"questions" : ["What can you do with FARM?"]
}
:param retrieved_scores: tfidf scores as returned by the retriever
:param question: question string
:param verbose: enable verbose logging
"""
df_sliced = self.retriever.df.loc[retrieved_scores.keys()]
if verbose:
logger.info(
f"Identified {df_sliced.shape[0]} candidates via retriever:\n {df_sliced.to_string(col_space=10, index=False)}"
)
logger.info(
f"Applying the reader now to look for the answer in detail ..."
)
inference_dicts = []
for idx, row in df_sliced.iterrows():
if candidate_doc_ids and row["document_id"] not in candidate_doc_ids:
continue
inference_dicts.append(
{
"text": row["text"],
"document_id": row["document_id"],
"questions": [question],
}
)
return inference_dicts

View File

@ -9,15 +9,28 @@ import zipfile
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def write_documents_to_db(document_dir, clean_func=None): def write_documents_to_db(document_dir, clean_func=None, only_empty_db=False):
""" """
Write all text files(.txt) in the sub-directories of the given path to the connected database. Write all text files(.txt) in the sub-directories of the given path to the connected database.
:param document_dir: path for the documents to be written to the database :param document_dir: path for the documents to be written to the database
:return: :param clean_func: a custom cleaning function that gets applied to each doc (input: str, output:str)
:param only_empty_db: If true, docs will only be written if db is completely empty.
Useful to avoid indexing the same initial docs again and again.
:return: None
""" """
file_paths = Path(document_dir).glob("**/*.txt") file_paths = Path(document_dir).glob("**/*.txt")
n_docs = 0 n_docs = 0
# check if db has already docs
if only_empty_db:
n_docs = db.session.query(Document).count()
if n_docs > 0:
logger.info(f"Skip writing documents since DB already contains {n_docs} docs ... "
"(Disable `only_empty_db`, if you want to add docs anyway.)")
return None
# read and add docs
for path in file_paths: for path in file_paths:
with open(path) as doc: with open(path) as doc:
text = doc.read() text = doc.read()

View File

@ -18,6 +18,7 @@ class FARMReader:
no_answer_shift=-100, no_answer_shift=-100,
batch_size=16, batch_size=16,
use_gpu=True, use_gpu=True,
n_best_per_passage=2
): ):
""" """
Load a saved FARM model in Inference mode. Load a saved FARM model in Inference mode.
@ -27,56 +28,77 @@ class FARMReader:
self.model = Inferencer.load(model_dir, batch_size=batch_size, gpu=use_gpu) self.model = Inferencer.load(model_dir, batch_size=batch_size, gpu=use_gpu)
self.model.model.prediction_heads[0].context_size = context_size self.model.model.prediction_heads[0].context_size = context_size
self.model.model.prediction_heads[0].no_answer_shift = no_answer_shift self.model.model.prediction_heads[0].no_answer_shift = no_answer_shift
self.model.model.prediction_heads[0].n_best = n_best_per_passage
def predict(self, input_dicts, top_k=None):
def predict(self, question, paragrahps, meta_data_paragraphs=None, top_k=None, max_processes=1):
""" """
Run inference on the loaded model for the given input dicts. Use loaded QA model to find answers for a question in the supplied paragraphs.
:param input_dicts: list of input dicts Returns dictionaries containing answers sorted by (desc.) probability
Example:
{'question': 'Who is the father of Arya Stark?',
'answers': [
{'answer': 'Eddard,',
'context': " She travels with her father, Eddard, to King's Landing when he is ",
'offset_answer_start': 147,
'offset_answer_end': 154,
'probability': 0.9787139466668613,
'score': None,
'document_id': None
},
...
]
}
:param question: question string
:param paragraphs: list of strings in which to search for the answer
:param meta_data_paragraphs: list of dicts containing meta data for the paragraphs.
len(paragraphs) == len(meta_data_paragraphs)
:param top_k: the maximum number of answers to return :param top_k: the maximum number of answers to return
:return: :param max_processes: max number of parallel processes
:return: dict containing question and answers
""" """
results = self.model.inference_from_dicts(
dicts=input_dicts, rest_api_schema=True, use_multiprocessing=False if meta_data_paragraphs is None:
meta_data_paragraphs = len(paragrahps) * [None]
assert len(paragrahps) == len(meta_data_paragraphs)
# convert input to FARM format
input_dicts = []
for paragraph, meta_data in zip(paragrahps, meta_data_paragraphs):
cur = {"text": paragraph,
"questions": [question],
"document_id": meta_data["document_id"]
}
input_dicts.append(cur)
# get answers from QA model (Top 5 per input paragraph)
predictions = self.model.inference_from_dicts(
dicts=input_dicts, rest_api_schema=True, max_processes=max_processes
) )
# The FARM Inferencer as of now do not support multi document QA. # assemble answers from all the different paragraphs & format them
# The QA inference is done for each text independently and the answers = []
# results are sorted descending by their `score`. for pred in predictions:
for a in pred["predictions"][0]["answers"]:
if a["answer"]: #skip "no answer"
cur = {"answer": a["answer"],
"score": a["score"],
"probability": expit(np.asarray([a["score"]]) / 8), #just a pseudo prob for now
"context": a["context"],
"offset_start": a["offset_answer_start"] - a["offset_context_start"],
"offset_end": a["offset_answer_start"] - a["offset_context_start"],
"document_id": a["document_id"]}
answers.append(cur)
all_predictions = [] # sort answers by their `probability` and select top-k
for res in results: answers = sorted(
all_predictions.extend(res["predictions"]) answers, key=lambda k: k["probability"], reverse=True
all_answers = []
for pred in all_predictions:
answers = pred["answers"]
for a in answers:
# Two sets of offset fields are returned by FARM -- context level and document level.
# For the API, only context level offsets are relevant.
a["offset_start"] = a["offset_answer_start"] - a["offset_context_start"]
a["offset_end"] = a["offset_context_end"] - a["offset_answer_end"]
all_answers.extend(answers)
# remove all null answers (where an answers in not found in the text)
all_answers = [ans for ans in all_answers if ans["answer"]]
scores = np.asarray([ans["score"] for ans in all_answers])
probabilities = expit(scores / 8)
for ans, prob in zip(all_answers, probabilities):
ans["probability"] = prob
# sort answers by their `probability`
sorted_answers = sorted(
all_answers, key=lambda k: k["probability"], reverse=True
) )
answers = answers[:top_k]
# all predictions here are for the same questions, so the the metadata from result = {"question": question,
# the first prediction in the list is taken. "answers": answers}
if all_predictions:
resp = all_predictions[0] # get the first prediction dict
resp["answers"] = sorted_answers[:top_k]
else:
resp = []
return {"results": [resp]} return result

View File

@ -0,0 +1,102 @@
from transformers import pipeline
class TransformersReader:
"""
A reader using the QA Pipeline class from huggingface's Transformers.
Easily load any of the pretrained community QA models from here: https://huggingface.co/models
"""
def __init__(
self,
model="distilbert-base-uncased-distilled-squad",
tokenizer="distilbert-base-uncased",
context_size=30,
#no_answer_shift=-100,
#batch_size=16,
use_gpu=0,
n_best_per_passage=2
):
"""
Load a QA model from Transformers.
Available models include:
- distilbert-base-uncased-distilled-squad
- bert-large-cased-whole-word-masking-finetuned-squad
- bert-large-uncased-whole-word-masking-finetuned-squad
See https://huggingface.co/models for full list of available QA models
:param model: name of the model
:param tokenizer: name of the tokenizer (usually the same as model)
:param context_size: num of chars (before and after the answer) to return as "context" for each answer.
The context usually helps users to understand if the answer really makes sense.
:param use_gpu: < 1 -> use cpu
>= 1 -> num of gpus to use
"""
self.model = pipeline("question-answering", model=model, tokenizer=tokenizer, device=use_gpu)
self.context_size = context_size
self.n_best_per_passage = n_best_per_passage
#TODO param to modify bias for no_answer
def predict(self, question, paragrahps, meta_data_paragraphs=None, top_k=None):
"""
Use loaded QA model to find answers for a question in the supplied paragraphs.
Returns dictionaries containing answers sorted by (desc.) probability
Example:
{'question': 'Who is the father of Arya Stark?',
'answers': [
{'answer': 'Eddard,',
'context': " She travels with her father, Eddard, to King's Landing when he is ",
'offset_answer_start': 147,
'offset_answer_end': 154,
'probability': 0.9787139466668613,
'score': None,
'document_id': None
},
...
]
}
:param question: question string
:param paragraphs: list of strings in which to search for the answer
:param meta_data_paragraphs: list of dicts containing meta data for the paragraphs.
len(paragraphs) == len(meta_data_paragraphs)
:param top_k: the maximum number of answers to return
:param max_processes: max number of parallel processes
:return: dict containing question and answers
"""
#TODO pass metadata
# get top-answers for each candidate passage
answers = []
for p in paragrahps:
query = {"context": p, "question": question}
predictions = self.model(query, topk=self.n_best_per_passage)
# assemble and format all answers
for pred in predictions:
if pred["answer"]:
context_start = max(0, pred["start"] - self.context_size)
context_end = min(len(p), pred["end"] + self.context_size)
answers.append({
"answer": pred["answer"],
"context": p[context_start:context_end],
"offset_answer_start": pred["start"],
"offset_answer_end": pred["end"],
"probability": pred["score"],
"score": None,
"document_id": None
})
# sort answers by their `probability` and select top-k
answers = sorted(
answers, key=lambda k: k["probability"], reverse=True
)
answers = answers[:top_k]
results = {"question": question,
"answers": answers}
return results

View File

@ -34,7 +34,7 @@ class TfidfRetriever(BaseRetriever):
Split documents into smaller units (eg, paragraphs or pages) to reduce the Split documents into smaller units (eg, paragraphs or pages) to reduce the
computations when text is passed on to a Reader for QA. computations when text is passed on to a Reader for QA.
It uses sklearn TfidfVectorizer to compute a tf-idf matrix. It uses sklearn's TfidfVectorizer to compute a tf-idf matrix.
""" """
def __init__(self): def __init__(self):
@ -69,15 +69,37 @@ class TfidfRetriever(BaseRetriever):
logger.info(f"Found {len(paragraphs)} candidate paragraphs from {len(documents)} docs in DB") logger.info(f"Found {len(paragraphs)} candidate paragraphs from {len(documents)} docs in DB")
return paragraphs return paragraphs
def retrieve(self, query, candidate_doc_ids=None, top_k=10): def _calc_scores(self, query):
question_vector = self.vectorizer.transform([query]) question_vector = self.vectorizer.transform([query])
scores = self.tfidf_matrix.dot(question_vector.T).toarray() scores = self.tfidf_matrix.dot(question_vector.T).toarray()
idx_scores = [(idx, score) for idx, score in enumerate(scores)] idx_scores = [(idx, score) for idx, score in enumerate(scores)]
top_k_scores = OrderedDict( indices_and_scores = OrderedDict(
sorted(idx_scores, key=(lambda tup: tup[1]), reverse=True)[:top_k] sorted(idx_scores, key=(lambda tup: tup[1]), reverse=True)
) )
return top_k_scores return indices_and_scores
def retrieve(self, query, candidate_doc_ids=None, top_k=10, verbose=True):
# get scores
indices_and_scores = self._calc_scores(query)
# rank & filter paragraphs
df_sliced = self.df.loc[indices_and_scores.keys()]
if candidate_doc_ids:
df_sliced = df_sliced[df_sliced.document_id.isin(candidate_doc_ids)]
df_sliced = df_sliced[:top_k]
if verbose:
logger.info(
f"Identified {df_sliced.shape[0]} candidates via retriever:\n {df_sliced.to_string(col_space=10, index=False)}"
)
# get actual content for the top candidates
paragraphs = list(df_sliced.text.values)
meta_data = [{"document_id": row["document_id"], "paragraph_id": row["paragraph_id"]}
for idx, row in df_sliced.iterrows()]
return paragraphs, meta_data
def fit(self): def fit(self):
self.df = pd.DataFrame.from_dict(self.paragraphs) self.df = pd.DataFrame.from_dict(self.paragraphs)

View File

@ -19,7 +19,7 @@ def create_db():
def print_answers(results, details="all"): def print_answers(results, details="all"):
answers = results[0]["answers"] answers = results["answers"]
pp = pprint.PrettyPrinter(indent=4) pp = pprint.PrettyPrinter(indent=4)
if details != "all": if details != "all":
if details == "minimal": if details == "minimal":

View File

@ -1,4 +1,6 @@
farm==0.3.2 # FARM (incl. transformers 2.3.0 with pipelines)
#farm -e git+https://github.com/deepset-ai/FARM.git@1d30237b037050ef0ac5516f427443cdd18a4d43
-e git://github.com/deepset-ai/FARM.git@1d30237b037050ef0ac5516f427443cdd18a4d43#egg=farm
flask flask
flask_cors flask_cors
flask_restplus flask_restplus

View File

@ -88,7 +88,7 @@
"# Now, let's write the docs to our DB. \n", "# Now, let's write the docs to our DB. \n",
"# You can supply a cleaning function that is applied to each doc (e.g. to remove footers)\n", "# You can supply a cleaning function that is applied to each doc (e.g. to remove footers)\n",
"# It must take a str as input, and return a str.\n", "# It must take a str as input, and return a str.\n",
"write_documents_to_db(document_dir=doc_dir, clean_func=clean_wiki_text)" "write_documents_to_db(document_dir=doc_dir, clean_func=clean_wiki_text, only_empty_db=True)"
] ]
}, },
{ {
@ -146,7 +146,10 @@
"# Reader use more powerful but slower deep learning models, here: a BERT QA model trained via FARM on Squad 2.0\n", "# Reader use more powerful but slower deep learning models, here: a BERT QA model trained via FARM on Squad 2.0\n",
"from haystack.indexing.io import fetch_archive_from_http\n", "from haystack.indexing.io import fetch_archive_from_http\n",
"fetch_archive_from_http(url=\"https://s3.eu-central-1.amazonaws.com/deepset.ai-farm-models/0.3.0/bert-english-qa-large.tar.gz\", output_dir=\"model\")\n", "fetch_archive_from_http(url=\"https://s3.eu-central-1.amazonaws.com/deepset.ai-farm-models/0.3.0/bert-english-qa-large.tar.gz\", output_dir=\"model\")\n",
"reader = FARMReader(model_dir=\"model/bert-english-qa-large\", use_gpu=False)" "reader = FARMReader(model_dir=\"model/bert-english-qa-large\", use_gpu=False)\n",
"\n",
"# OR: use alternatively a reader from huggingface's Transformers package\n",
"# reader = TransformersReader(use_gpu=-1)"
] ]
}, },
{ {
@ -276,10 +279,10 @@
"pycharm": { "pycharm": {
"stem_cell": { "stem_cell": {
"cell_type": "raw", "cell_type": "raw",
"source": [],
"metadata": { "metadata": {
"collapsed": false "collapsed": false
}, }
"source": []
} }
} }
}, },

View File

@ -0,0 +1,55 @@
from haystack.reader.farm import FARMReader
from haystack.reader.transformers import TransformersReader
from haystack.retriever.tfidf import TfidfRetriever
from haystack import Finder
from haystack.indexing.io import write_documents_to_db, fetch_archive_from_http
from haystack.indexing.cleaning import clean_wiki_text
from haystack.utils import print_answers
## Indexing & cleaning documents
# Init a database (default: sqllite)
from haystack.database import db
db.create_all()
# Let's first get some documents that we want to query
# Here: 517 Wikipedia articles for Game of Thrones
doc_dir = "data/article_txt_got"
s3_url = "https://s3.eu-central-1.amazonaws.com/deepset.ai-farm-qa/datasets/documents/wiki_gameofthrones_txt.zip"
fetch_archive_from_http(url=s3_url, output_dir=doc_dir)
# Now, let's write the docs to our DB.
# You can supply a cleaning function that is applied to each doc (e.g. to remove footers)
# It must take a str as input, and return a str.
write_documents_to_db(document_dir=doc_dir, clean_func=clean_wiki_text, only_empty_db=True)
## Initalize Reader, Retriever & Finder
# A retriever identifies the k most promising chunks of text that might contain the answer for our question
# Retrievers use some simple but fast algorithm, here: TF-IDF
retriever = TfidfRetriever()
# A reader scans the text chunks in detail and extracts the k best answers
# Reader use more powerful but slower deep learning models, here: a BERT QA model trained via FARM on Squad 2.0
fetch_archive_from_http(url="https://s3.eu-central-1.amazonaws.com/deepset.ai-farm-models/0.3.0/bert-english-qa-large.tar.gz", output_dir="model")
reader = FARMReader(model_dir="model/bert-english-qa-large", use_gpu=False)
# OR: use alternatively a reader from huggingface's Transformers package
# reader = TransformersReader(use_gpu=-1)
# The Finder sticks together retriever and retriever in a pipeline to answer our actual questions
finder = Finder(reader, retriever)
## Voilá! Ask a question!
# You can configure how many candidates the reader and retriever shall return
# The higher top_k_retriever, the better (but also the slower) your answers.
prediction = finder.get_answers(question="Who is the father of Arya Stark?", top_k_retriever=10, top_k_reader=5)
#prediction = finder.get_answers(question="Who created the Dothraki vocabulary?", top_k_reader=5)
#prediction = finder.get_answers(question="Who is the sister of Sansa?", top_k_reader=5)
print_answers(prediction, details="minimal")