mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-09-01 20:33:48 +00:00

* Add batch update of embeddings in document stores * Resolve merge conflict * Remove document ordering dependency in tests * Adjust index buffer size for tests * Adjust ES Scroll Slice * Use generator for document store pagination * Add pagination for InMemoryDocumentStore * Fix missing index parameter in FAISS update_embeddings() * Fix FAISS update_embeddings() * Update FAISS tests * Update eval tests * Revert code formatting change * Fix document count in FAISS update embeddings * Fix vector_ids reset in SQLDocumentStore * Update doctrings * Update docstring
127 lines
3.9 KiB
Python
127 lines
3.9 KiB
Python
import json
|
|
from collections import defaultdict
|
|
from itertools import islice
|
|
import logging
|
|
import pprint
|
|
import pandas as pd
|
|
from typing import Dict, Any, List
|
|
from haystack.document_store.sql import DocumentORM
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def print_answers(results: dict, details: str = "all"):
|
|
answers = results["answers"]
|
|
pp = pprint.PrettyPrinter(indent=4)
|
|
if details != "all":
|
|
if details == "minimal":
|
|
keys_to_keep = set(["answer", "context"])
|
|
elif details == "medium":
|
|
keys_to_keep = set(["answer", "context", "score"])
|
|
else:
|
|
keys_to_keep = answers.keys()
|
|
|
|
# filter the results
|
|
filtered_answers = []
|
|
for ans in answers:
|
|
filtered_answers.append({k: ans[k] for k in keys_to_keep})
|
|
pp.pprint(filtered_answers)
|
|
else:
|
|
pp.pprint(results)
|
|
|
|
|
|
def export_answers_to_csv(agg_results: list, output_file):
|
|
"""
|
|
Exports answers coming from finder.get_answers() to a CSV file
|
|
:param agg_results: list of predictions coming from finder.get_answers()
|
|
:param output_file: filename of output file
|
|
:return: None
|
|
"""
|
|
if isinstance(agg_results, dict):
|
|
agg_results = [agg_results]
|
|
|
|
assert "query" in agg_results[0], f"Wrong format used for {agg_results[0]}"
|
|
assert "answers" in agg_results[0], f"Wrong format used for {agg_results[0]}"
|
|
|
|
data = {} # type: Dict[str, List[Any]]
|
|
data["query"] = []
|
|
data["prediction"] = []
|
|
data["prediction_rank"] = []
|
|
data["prediction_context"] = []
|
|
|
|
for res in agg_results:
|
|
for i in range(len(res["answers"])):
|
|
temp = res["answers"][i]
|
|
data["query"].append(res["query"])
|
|
data["prediction"].append(temp["answer"])
|
|
data["prediction_rank"].append(i + 1)
|
|
data["prediction_context"].append(temp["context"])
|
|
|
|
df = pd.DataFrame(data)
|
|
df.to_csv(output_file, index=False)
|
|
|
|
|
|
|
|
def convert_labels_to_squad(labels_file: str):
|
|
"""
|
|
Convert the export from the labeling UI to SQuAD format for training.
|
|
|
|
:param labels_file: path for export file from the labeling tool
|
|
:return:
|
|
"""
|
|
with open(labels_file) as label_file:
|
|
labels = json.load(label_file)
|
|
|
|
labels_grouped_by_documents = defaultdict(list)
|
|
for label in labels:
|
|
labels_grouped_by_documents[label["document_id"]].append(label)
|
|
|
|
labels_in_squad_format = {"data": []} # type: Dict[str, Any]
|
|
for document_id, labels in labels_grouped_by_documents.items():
|
|
qas = []
|
|
for label in labels:
|
|
doc = DocumentORM.query.get(label["document_id"])
|
|
|
|
assert (
|
|
doc.text[label["start_offset"] : label["end_offset"]]
|
|
== label["selected_text"]
|
|
)
|
|
|
|
qas.append(
|
|
{
|
|
"question": label["question"],
|
|
"id": label["id"],
|
|
"question_id": label["question_id"],
|
|
"answers": [
|
|
{
|
|
"text": label["selected_text"],
|
|
"answer_start": label["start_offset"],
|
|
"labeller_id": label["labeler_id"],
|
|
}
|
|
],
|
|
"is_impossible": False,
|
|
}
|
|
)
|
|
|
|
squad_format_label = {
|
|
"paragraphs": [
|
|
{"qas": qas, "context": doc.text, "document_id": document_id}
|
|
]
|
|
}
|
|
|
|
labels_in_squad_format["data"].append(squad_format_label)
|
|
|
|
with open("labels_in_squad_format.json", "w+") as outfile:
|
|
json.dump(labels_in_squad_format, outfile)
|
|
|
|
|
|
def get_batches_from_generator(iterable, n):
|
|
"""
|
|
Batch elements of an iterable into fixed-length chunks or blocks.
|
|
"""
|
|
it = iter(iterable)
|
|
x = tuple(islice(it, n))
|
|
while x:
|
|
yield x
|
|
x = tuple(islice(it, n)) |