2020-10-12 13:34:42 +02:00
|
|
|
import os
|
2023-05-25 11:19:46 +02:00
|
|
|
import tarfile
|
|
|
|
import tempfile
|
|
|
|
|
|
|
|
import pandas as pd
|
|
|
|
|
|
|
|
from haystack import Label, Document, Answer
|
|
|
|
from haystack.document_stores import eval_data_from_json
|
|
|
|
from haystack.utils import launch_es, launch_opensearch, launch_weaviate
|
2021-09-28 16:34:24 +02:00
|
|
|
from haystack.modeling.data_handler.processor import http_get
|
2021-04-09 17:24:16 +02:00
|
|
|
|
2020-10-12 13:34:42 +02:00
|
|
|
import logging
|
2023-05-25 11:19:46 +02:00
|
|
|
from typing import Dict, Union
|
2020-10-12 13:34:42 +02:00
|
|
|
from pathlib import Path
|
2022-02-03 13:43:18 +01:00
|
|
|
|
2020-10-12 13:34:42 +02:00
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
2023-05-25 11:19:46 +02:00
|
|
|
def prepare_environment(pipeline_config: Dict, benchmark_config: Dict):
|
|
|
|
"""
|
|
|
|
Prepare the environment for running a benchmark.
|
|
|
|
"""
|
|
|
|
# Download data if specified in benchmark config
|
|
|
|
if "data_url" in benchmark_config:
|
|
|
|
download_from_url(url=benchmark_config["data_url"], target_dir="data/")
|
|
|
|
|
|
|
|
n_docs = 0
|
|
|
|
if "documents_directory" in benchmark_config:
|
|
|
|
documents_dir = Path(benchmark_config["documents_directory"])
|
|
|
|
n_docs = len(
|
|
|
|
[
|
|
|
|
file_path
|
|
|
|
for file_path in documents_dir.iterdir()
|
|
|
|
if file_path.is_file() and not file_path.name.startswith(".")
|
|
|
|
]
|
2022-02-03 13:43:18 +01:00
|
|
|
)
|
|
|
|
|
2023-05-25 11:19:46 +02:00
|
|
|
# Launch DocumentStore Docker container if needed
|
|
|
|
for comp in pipeline_config["components"]:
|
|
|
|
if comp["type"].endswith("DocumentStore"):
|
|
|
|
launch_document_store(comp["type"], n_docs=n_docs)
|
|
|
|
break
|
|
|
|
|
2020-10-12 13:34:42 +02:00
|
|
|
|
2023-05-25 11:19:46 +02:00
|
|
|
def launch_document_store(document_store: str, n_docs: int = 0):
|
|
|
|
"""
|
|
|
|
Launch a DocumentStore Docker container.
|
|
|
|
"""
|
|
|
|
java_opts = None if n_docs < 500000 else "-Xms4096m -Xmx4096m"
|
|
|
|
if document_store == "ElasticsearchDocumentStore":
|
|
|
|
launch_es(sleep=30, delete_existing=True, java_opts=java_opts)
|
|
|
|
elif document_store == "OpenSearchDocumentStore":
|
|
|
|
launch_opensearch(sleep=30, delete_existing=True, java_opts=java_opts)
|
|
|
|
elif document_store == "WeaviateDocumentStore":
|
|
|
|
launch_weaviate(sleep=30, delete_existing=True)
|
2020-10-12 13:34:42 +02:00
|
|
|
|
2022-02-03 13:43:18 +01:00
|
|
|
|
2023-05-25 11:19:46 +02:00
|
|
|
def download_from_url(url: str, target_dir: Union[str, Path]):
|
|
|
|
"""
|
|
|
|
Download from a URL to a local file.
|
2020-10-12 13:34:42 +02:00
|
|
|
|
2023-05-25 11:19:46 +02:00
|
|
|
:param url: URL
|
|
|
|
:param target_dir: Local directory where the URL content will be saved.
|
|
|
|
"""
|
2022-02-03 13:43:18 +01:00
|
|
|
|
2023-05-25 11:19:46 +02:00
|
|
|
if not os.path.exists(target_dir):
|
|
|
|
os.makedirs(target_dir)
|
2020-10-15 18:12:17 +02:00
|
|
|
|
2023-05-25 11:19:46 +02:00
|
|
|
url_path = Path(url)
|
|
|
|
logger.info("Downloading %s to %s", url_path.name, target_dir)
|
|
|
|
with tempfile.NamedTemporaryFile() as temp_file:
|
|
|
|
http_get(url=url, temp_file=temp_file)
|
|
|
|
temp_file.flush()
|
|
|
|
temp_file.seek(0)
|
|
|
|
if tarfile.is_tarfile(temp_file.name):
|
|
|
|
with tarfile.open(temp_file.name) as tar:
|
|
|
|
tar.extractall(target_dir)
|
|
|
|
else:
|
|
|
|
with open(Path(target_dir) / url_path.name, "wb") as file:
|
|
|
|
file.write(temp_file.read())
|
2020-10-15 18:12:17 +02:00
|
|
|
|
2021-04-09 17:24:16 +02:00
|
|
|
|
2023-05-25 11:19:46 +02:00
|
|
|
def load_eval_data(eval_set_file: Path):
|
2021-04-09 17:24:16 +02:00
|
|
|
"""
|
2023-05-25 11:19:46 +02:00
|
|
|
Load evaluation data from a file.
|
|
|
|
:param eval_set_file: Path to the evaluation data file.
|
|
|
|
"""
|
2023-05-25 15:39:02 +02:00
|
|
|
if not os.path.exists(eval_set_file):
|
|
|
|
raise FileNotFoundError(f"The file {eval_set_file} does not exist.")
|
|
|
|
elif os.path.isdir(eval_set_file):
|
|
|
|
raise IsADirectoryError(f"The path {eval_set_file} is a directory, not a file.")
|
|
|
|
|
2023-05-25 11:19:46 +02:00
|
|
|
if eval_set_file.suffix == ".json":
|
|
|
|
_, labels = eval_data_from_json(str(eval_set_file))
|
|
|
|
queries = [label.query for label in labels]
|
|
|
|
elif eval_set_file.suffix == ".csv":
|
|
|
|
eval_data = pd.read_csv(eval_set_file)
|
|
|
|
|
|
|
|
labels = []
|
|
|
|
queries = []
|
|
|
|
for idx, row in eval_data.iterrows():
|
|
|
|
query = row["question"]
|
|
|
|
context = row["context"]
|
|
|
|
answer = Answer(answer=row["text"]) if "text" in row else None
|
|
|
|
label = Label(
|
|
|
|
query=query,
|
|
|
|
document=Document(context),
|
|
|
|
answer=answer,
|
|
|
|
is_correct_answer=True,
|
|
|
|
is_correct_document=True,
|
|
|
|
origin="gold-label",
|
|
|
|
)
|
|
|
|
labels.append(label)
|
|
|
|
queries.append(query)
|
2021-04-09 17:24:16 +02:00
|
|
|
else:
|
2023-05-25 11:19:46 +02:00
|
|
|
raise ValueError(
|
|
|
|
f"Unsupported file format: {eval_set_file.suffix}. Provide a SQuAD-style .json or a .csv file containing "
|
|
|
|
f"the columns 'question' and 'context' for Retriever evaluations and additionally 'text' (containing the "
|
|
|
|
f"answer string) for Reader evaluations."
|
|
|
|
)
|
|
|
|
|
|
|
|
return labels, queries
|