haystack/haystack/utils.py
Tanay Soni 337376c81d Add batch_size and generators to document stores. (#733)
* Add batch update of embeddings in document stores

* Resolve merge conflict

* Remove document ordering dependency in tests

* Adjust index buffer size for tests

* Adjust ES Scroll Slice

* Use generator for document store pagination

* Add pagination for InMemoryDocumentStore

* Fix missing index parameter in FAISS update_embeddings()

* Fix FAISS update_embeddings()

* Update FAISS tests

* Update eval tests

* Revert code formatting change

* Fix document count in FAISS update embeddings

* Fix vector_ids reset in SQLDocumentStore

* Update doctrings

* Update docstring
2021-01-21 16:00:08 +01:00

127 lines
3.9 KiB
Python

import json
from collections import defaultdict
from itertools import islice
import logging
import pprint
import pandas as pd
from typing import Dict, Any, List
from haystack.document_store.sql import DocumentORM
logger = logging.getLogger(__name__)
def print_answers(results: dict, details: str = "all"):
answers = results["answers"]
pp = pprint.PrettyPrinter(indent=4)
if details != "all":
if details == "minimal":
keys_to_keep = set(["answer", "context"])
elif details == "medium":
keys_to_keep = set(["answer", "context", "score"])
else:
keys_to_keep = answers.keys()
# filter the results
filtered_answers = []
for ans in answers:
filtered_answers.append({k: ans[k] for k in keys_to_keep})
pp.pprint(filtered_answers)
else:
pp.pprint(results)
def export_answers_to_csv(agg_results: list, output_file):
"""
Exports answers coming from finder.get_answers() to a CSV file
:param agg_results: list of predictions coming from finder.get_answers()
:param output_file: filename of output file
:return: None
"""
if isinstance(agg_results, dict):
agg_results = [agg_results]
assert "query" in agg_results[0], f"Wrong format used for {agg_results[0]}"
assert "answers" in agg_results[0], f"Wrong format used for {agg_results[0]}"
data = {} # type: Dict[str, List[Any]]
data["query"] = []
data["prediction"] = []
data["prediction_rank"] = []
data["prediction_context"] = []
for res in agg_results:
for i in range(len(res["answers"])):
temp = res["answers"][i]
data["query"].append(res["query"])
data["prediction"].append(temp["answer"])
data["prediction_rank"].append(i + 1)
data["prediction_context"].append(temp["context"])
df = pd.DataFrame(data)
df.to_csv(output_file, index=False)
def convert_labels_to_squad(labels_file: str):
"""
Convert the export from the labeling UI to SQuAD format for training.
:param labels_file: path for export file from the labeling tool
:return:
"""
with open(labels_file) as label_file:
labels = json.load(label_file)
labels_grouped_by_documents = defaultdict(list)
for label in labels:
labels_grouped_by_documents[label["document_id"]].append(label)
labels_in_squad_format = {"data": []} # type: Dict[str, Any]
for document_id, labels in labels_grouped_by_documents.items():
qas = []
for label in labels:
doc = DocumentORM.query.get(label["document_id"])
assert (
doc.text[label["start_offset"] : label["end_offset"]]
== label["selected_text"]
)
qas.append(
{
"question": label["question"],
"id": label["id"],
"question_id": label["question_id"],
"answers": [
{
"text": label["selected_text"],
"answer_start": label["start_offset"],
"labeller_id": label["labeler_id"],
}
],
"is_impossible": False,
}
)
squad_format_label = {
"paragraphs": [
{"qas": qas, "context": doc.text, "document_id": document_id}
]
}
labels_in_squad_format["data"].append(squad_format_label)
with open("labels_in_squad_format.json", "w+") as outfile:
json.dump(labels_in_squad_format, outfile)
def get_batches_from_generator(iterable, n):
"""
Batch elements of an iterable into fixed-length chunks or blocks.
"""
it = iter(iterable)
x = tuple(islice(it, n))
while x:
yield x
x = tuple(islice(it, n))