mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-07-27 02:40:41 +00:00
72 lines
2.5 KiB
Python
72 lines
2.5 KiB
Python
![]() |
from haystack.document_store import MilvusDocumentStore
|
||
|
from haystack.retriever import DensePassageRetriever
|
||
|
from retriever import prepare_data
|
||
|
import datetime
|
||
|
from pprint import pprint
|
||
|
from milvus import IndexType
|
||
|
|
||
|
def main(index_type, n_docs=100_000, similarity="dot_product"):
|
||
|
|
||
|
doc_index = "document"
|
||
|
label_index = "label"
|
||
|
|
||
|
docs, labels = prepare_data(
|
||
|
data_dir="data/",
|
||
|
filename_gold="nq2squad-dev.json",
|
||
|
filename_negative="psgs_w100_minus_gold_100k.tsv",
|
||
|
remote_url="https://ext-haystack-retriever-eval.s3-eu-west-1.amazonaws.com/",
|
||
|
embeddings_filenames=["wikipedia_passages_100k.pkl"],
|
||
|
embeddings_dir="embeddings/",
|
||
|
n_docs=n_docs,
|
||
|
add_precomputed=True
|
||
|
)
|
||
|
|
||
|
if index_type == "flat":
|
||
|
doc_store = MilvusDocumentStore(index=doc_index, similarity=similarity)
|
||
|
elif index_type == "hnsw":
|
||
|
index_param = {"M": 64, "efConstruction": 80}
|
||
|
search_param = {"ef": 20}
|
||
|
doc_store = MilvusDocumentStore(
|
||
|
index=doc_index,
|
||
|
index_type=IndexType.HNSW,
|
||
|
index_param=index_param,
|
||
|
search_param=search_param,
|
||
|
similarity=similarity
|
||
|
)
|
||
|
|
||
|
doc_store.write_documents(documents=docs, index=doc_index)
|
||
|
doc_store.write_labels(labels=labels, index=label_index)
|
||
|
|
||
|
retriever = DensePassageRetriever(
|
||
|
document_store=doc_store,
|
||
|
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
|
||
|
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
|
||
|
use_gpu=True,
|
||
|
use_fast_tokenizers=True
|
||
|
)
|
||
|
|
||
|
raw_results = retriever.eval(label_index=label_index, doc_index=doc_index)
|
||
|
results = {
|
||
|
"n_queries": raw_results["n_questions"],
|
||
|
"retrieve_time": raw_results["retrieve_time"],
|
||
|
"queries_per_second": raw_results["n_questions"] / raw_results["retrieve_time"],
|
||
|
"seconds_per_query": raw_results["retrieve_time"] / raw_results["n_questions"],
|
||
|
"recall": raw_results["recall"] * 100,
|
||
|
"map": raw_results["map"] * 100,
|
||
|
"top_k": raw_results["top_k"],
|
||
|
"date_time": datetime.datetime.now(),
|
||
|
"error": None
|
||
|
}
|
||
|
|
||
|
pprint(results)
|
||
|
|
||
|
doc_store.delete_all_documents(index=doc_index)
|
||
|
doc_store.delete_all_documents(index=label_index)
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
similarity = "l2"
|
||
|
n_docs = 100_000
|
||
|
|
||
|
main(index_type="flat", similarity=similarity, n_docs=n_docs)
|
||
|
main(index_type="hnsw", similarity=similarity, n_docs=n_docs)
|