mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-05 19:47:45 +00:00
Align similarity scores across document stores (#1967)
* align document store similarity functions * remove unnecessary imports * undone accidental change * stopped weaviate from pretending to support dot product similarity * stopped weaviate from pretending to support dot product similarity * Add latest docstring and tutorial changes * fix fixture params for document stores * use cosine similarity for most tests * fix cosine similarity test * fix faiss test * fix weaviate test * fix accidental deletion * fix document_store fixture * test fix; shouldn't be merged * fix test_normalize_embeddings_diff_shapes * probably a better fix * fix for parameter combinations * revert new pytest_generate_tests functionality * simplify pytest_generate_tests * normalize embeddings for test_dpr_embedding * add to faiss doc that embeddings are normalized * Add latest docstring and tutorial changes * remove unnecessary parameters and add comments * simplify two lines of memory.py into one * test similarity scores with smaller language model * fix test_similarity_score Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
parent
965b9614db
commit
3e4dbbb32c
@ -1228,7 +1228,7 @@ the vector embeddings are indexed in a FAISS Index.
|
||||
Benchmarks: XXX
|
||||
- `faiss_index`: Pass an existing FAISS Index, i.e. an empty one that you configured manually
|
||||
or one with docs that you used in Haystack before and want to load again.
|
||||
- `return_embedding`: To return document embedding
|
||||
- `return_embedding`: To return document embedding. Unlike other document stores, FAISS will return normalized embeddings
|
||||
- `index`: Name of index in document store to use.
|
||||
- `similarity`: The similarity function used to compare document vectors. 'dot_product' is the default since it is
|
||||
more performant with DPR embeddings. 'cosine' is recommended if you are using a Sentence-Transformer model.
|
||||
@ -1322,7 +1322,7 @@ a large number of documents without having to load all documents in memory.
|
||||
DocumentStore's default index (self.index) will be used.
|
||||
- `filters`: Optional filters to narrow down the documents to return.
|
||||
Example: {"name": ["some", "more"], "category": ["only_one"]}
|
||||
- `return_embedding`: Whether to return the document embeddings.
|
||||
- `return_embedding`: Whether to return the document embeddings. Unlike other document stores, FAISS will return normalized embeddings
|
||||
- `batch_size`: When working with large number of documents, batching can help reduce memory footprint.
|
||||
|
||||
<a name="faiss.FAISSDocumentStore.get_embedding_count"></a>
|
||||
@ -1404,7 +1404,7 @@ Find the document that is most similar to the provided `query_emb` by using a ve
|
||||
Example: {"name": ["some", "more"], "category": ["only_one"]}
|
||||
- `top_k`: How many documents to return
|
||||
- `index`: Index name to query the document from.
|
||||
- `return_embedding`: To return document embedding
|
||||
- `return_embedding`: To return document embedding. Unlike other document stores, FAISS will return normalized embeddings
|
||||
|
||||
**Returns**:
|
||||
|
||||
@ -1761,6 +1761,7 @@ Some of the key differences in contrast to FAISS & Milvus:
|
||||
2. Allows combination of vector search and scalar filtering, i.e. you can filter for a certain tag and do dense retrieval on that subset
|
||||
3. Has less variety of ANN algorithms, as of now only HNSW.
|
||||
4. Requires document ids to be in uuid-format. If wrongly formatted ids are provided at indexing time they will be replaced with uuids automatically.
|
||||
5. Only support cosine similarity.
|
||||
|
||||
Weaviate python client is used to connect to the server, more details are here
|
||||
https://weaviate-python-client.readthedocs.io/en/docs/weaviate.html
|
||||
@ -1776,7 +1777,7 @@ The current implementation is not supporting the storage of labels, so you canno
|
||||
#### \_\_init\_\_
|
||||
|
||||
```python
|
||||
| __init__(host: Union[str, List[str]] = "http://localhost", port: Union[int, List[int]] = 8080, timeout_config: tuple = (5, 15), username: str = None, password: str = None, index: str = "Document", embedding_dim: int = 768, content_field: str = "content", name_field: str = "name", similarity: str = "dot_product", index_type: str = "hnsw", custom_schema: Optional[dict] = None, return_embedding: bool = False, embedding_field: str = "embedding", progress_bar: bool = True, duplicate_documents: str = 'overwrite', **kwargs, ,)
|
||||
| __init__(host: Union[str, List[str]] = "http://localhost", port: Union[int, List[int]] = 8080, timeout_config: tuple = (5, 15), username: str = None, password: str = None, index: str = "Document", embedding_dim: int = 768, content_field: str = "content", name_field: str = "name", similarity: str = "cosine", index_type: str = "hnsw", custom_schema: Optional[dict] = None, return_embedding: bool = False, embedding_field: str = "embedding", progress_bar: bool = True, duplicate_documents: str = 'overwrite', **kwargs, ,)
|
||||
```
|
||||
|
||||
**Arguments**:
|
||||
@ -1792,7 +1793,7 @@ The current implementation is not supporting the storage of labels, so you canno
|
||||
- `content_field`: Name of field that might contain the answer and will therefore be passed to the Reader Model (e.g. "full_text").
|
||||
If no Reader is used (e.g. in FAQ-Style QA) the plain content of this field will just be returned.
|
||||
- `name_field`: Name of field that contains the title of the the doc
|
||||
- `similarity`: The similarity function used to compare document vectors. 'dot_product' is the default.
|
||||
- `similarity`: The similarity function used to compare document vectors. 'cosine' is the only currently supported option and default.
|
||||
'cosine' is recommended for Sentence Transformers.
|
||||
- `index_type`: Index type of any vector object defined in weaviate schema. The vector index type is pluggable.
|
||||
Currently, HSNW is only supported.
|
||||
|
||||
@ -74,7 +74,7 @@ class FAISSDocumentStore(SQLDocumentStore):
|
||||
Benchmarks: XXX
|
||||
:param faiss_index: Pass an existing FAISS Index, i.e. an empty one that you configured manually
|
||||
or one with docs that you used in Haystack before and want to load again.
|
||||
:param return_embedding: To return document embedding
|
||||
:param return_embedding: To return document embedding. Unlike other document stores, FAISS will return normalized embeddings
|
||||
:param index: Name of index in document store to use.
|
||||
:param similarity: The similarity function used to compare document vectors. 'dot_product' is the default since it is
|
||||
more performant with DPR embeddings. 'cosine' is recommended if you are using a Sentence-Transformer model.
|
||||
@ -379,7 +379,7 @@ class FAISSDocumentStore(SQLDocumentStore):
|
||||
DocumentStore's default index (self.index) will be used.
|
||||
:param filters: Optional filters to narrow down the documents to return.
|
||||
Example: {"name": ["some", "more"], "category": ["only_one"]}
|
||||
:param return_embedding: Whether to return the document embeddings.
|
||||
:param return_embedding: Whether to return the document embeddings. Unlike other document stores, FAISS will return normalized embeddings
|
||||
:param batch_size: When working with large number of documents, batching can help reduce memory footprint.
|
||||
"""
|
||||
if headers:
|
||||
@ -510,7 +510,7 @@ class FAISSDocumentStore(SQLDocumentStore):
|
||||
Example: {"name": ["some", "more"], "category": ["only_one"]}
|
||||
:param top_k: How many documents to return
|
||||
:param index: Index name to query the document from.
|
||||
:param return_embedding: To return document embedding
|
||||
:param return_embedding: To return document embedding. Unlike other document stores, FAISS will return normalized embeddings
|
||||
:return:
|
||||
"""
|
||||
if headers:
|
||||
|
||||
@ -210,13 +210,11 @@ class InMemoryDocumentStore(BaseDocumentStore):
|
||||
new_document.embedding = doc.embedding if return_embedding is True else None
|
||||
|
||||
if self.similarity == "dot_product":
|
||||
score = np.dot(query_emb, doc.embedding) / (
|
||||
np.linalg.norm(query_emb) * np.linalg.norm(doc.embedding)
|
||||
)
|
||||
score = np.dot(query_emb, doc.embedding)
|
||||
elif self.similarity == "cosine":
|
||||
# cosine similarity score = 1 - cosine distance
|
||||
score = 1 - cosine(query_emb, doc.embedding)
|
||||
new_document.score = (score + 1) / 2
|
||||
new_document.score = self.finalize_raw_score(score, self.similarity)
|
||||
candidate_docs.append(new_document)
|
||||
|
||||
return sorted(candidate_docs, key=lambda x: x.score if x.score is not None else 0.0, reverse=True)[0:top_k]
|
||||
|
||||
@ -31,6 +31,7 @@ class WeaviateDocumentStore(BaseDocumentStore):
|
||||
2. Allows combination of vector search and scalar filtering, i.e. you can filter for a certain tag and do dense retrieval on that subset
|
||||
3. Has less variety of ANN algorithms, as of now only HNSW.
|
||||
4. Requires document ids to be in uuid-format. If wrongly formatted ids are provided at indexing time they will be replaced with uuids automatically.
|
||||
5. Only support cosine similarity.
|
||||
|
||||
Weaviate python client is used to connect to the server, more details are here
|
||||
https://weaviate-python-client.readthedocs.io/en/docs/weaviate.html
|
||||
@ -53,7 +54,7 @@ class WeaviateDocumentStore(BaseDocumentStore):
|
||||
embedding_dim: int = 768,
|
||||
content_field: str = "content",
|
||||
name_field: str = "name",
|
||||
similarity: str = "dot_product",
|
||||
similarity: str = "cosine",
|
||||
index_type: str = "hnsw",
|
||||
custom_schema: Optional[dict] = None,
|
||||
return_embedding: bool = False,
|
||||
@ -74,7 +75,7 @@ class WeaviateDocumentStore(BaseDocumentStore):
|
||||
:param content_field: Name of field that might contain the answer and will therefore be passed to the Reader Model (e.g. "full_text").
|
||||
If no Reader is used (e.g. in FAQ-Style QA) the plain content of this field will just be returned.
|
||||
:param name_field: Name of field that contains the title of the the doc
|
||||
:param similarity: The similarity function used to compare document vectors. 'dot_product' is the default.
|
||||
:param similarity: The similarity function used to compare document vectors. 'cosine' is the only currently supported option and default.
|
||||
'cosine' is recommended for Sentence Transformers.
|
||||
:param index_type: Index type of any vector object defined in weaviate schema. The vector index type is pluggable.
|
||||
Currently, HSNW is only supported.
|
||||
@ -93,6 +94,8 @@ class WeaviateDocumentStore(BaseDocumentStore):
|
||||
overwrite: Update any existing documents with the same ID when adding documents.
|
||||
fail: an error is raised if the document ID of the document being added already exists.
|
||||
"""
|
||||
if similarity != "cosine":
|
||||
raise ValueError(f"Weaviate only supports cosine similarity, but you provided {similarity}")
|
||||
# save init parameters to enable export of component config as YAML
|
||||
self.set_config(
|
||||
host=host, port=port, timeout_config=timeout_config, username=username, password=password,
|
||||
|
||||
@ -54,7 +54,7 @@ def pytest_generate_tests(metafunc):
|
||||
# @pytest.mark.parametrize("document_store", ["memory"], indirect=False)
|
||||
found_mark_parametrize_document_store = False
|
||||
for marker in metafunc.definition.iter_markers('parametrize'):
|
||||
if 'document_store' in marker.args[0] or 'document_store_with_docs' in marker.args[0] or 'document_store_type' in marker.args[0]:
|
||||
if 'document_store' in marker.args[0]:
|
||||
found_mark_parametrize_document_store = True
|
||||
break
|
||||
# for all others that don't have explicit parametrization, we add the ones from the CLI arg
|
||||
@ -479,34 +479,50 @@ def get_retriever(retriever_type, document_store):
|
||||
|
||||
@pytest.fixture(params=["elasticsearch", "faiss", "memory", "milvus", "weaviate"])
|
||||
def document_store_with_docs(request, test_docs_xs):
|
||||
document_store = get_document_store(request.param)
|
||||
embedding_dim = request.node.get_closest_marker("embedding_dim", pytest.mark.embedding_dim(768))
|
||||
document_store = get_document_store(request.param, embedding_dim.args[0])
|
||||
document_store.write_documents(test_docs_xs)
|
||||
yield document_store
|
||||
document_store.delete_documents()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def document_store(request, test_docs_xs):
|
||||
def document_store(request):
|
||||
embedding_dim = request.node.get_closest_marker("embedding_dim", pytest.mark.embedding_dim(768))
|
||||
document_store = get_document_store(request.param, embedding_dim.args[0])
|
||||
yield document_store
|
||||
document_store.delete_documents()
|
||||
|
||||
@pytest.fixture(params=["faiss", "milvus", "weaviate"])
|
||||
def document_store_cosine(request, test_docs_xs):
|
||||
@pytest.fixture(params=["memory", "faiss", "milvus", "elasticsearch"])
|
||||
def document_store_dot_product(request):
|
||||
embedding_dim = request.node.get_closest_marker("embedding_dim", pytest.mark.embedding_dim(768))
|
||||
document_store = get_document_store(request.param, embedding_dim.args[0], similarity="cosine")
|
||||
document_store = get_document_store(request.param, embedding_dim.args[0], similarity="dot_product")
|
||||
yield document_store
|
||||
document_store.delete_documents()
|
||||
|
||||
@pytest.fixture(params=["memory", "faiss", "milvus", "elasticsearch"])
|
||||
def document_store_dot_product_with_docs(request, test_docs_xs):
|
||||
embedding_dim = request.node.get_closest_marker("embedding_dim", pytest.mark.embedding_dim(768))
|
||||
document_store = get_document_store(request.param, embedding_dim.args[0], similarity="dot_product")
|
||||
document_store.write_documents(test_docs_xs)
|
||||
yield document_store
|
||||
document_store.delete_documents()
|
||||
|
||||
@pytest.fixture(params=["elasticsearch", "faiss", "memory", "milvus"])
|
||||
def document_store_dot_product_small(request):
|
||||
embedding_dim = request.node.get_closest_marker("embedding_dim", pytest.mark.embedding_dim(3))
|
||||
document_store = get_document_store(request.param, embedding_dim.args[0], similarity="dot_product")
|
||||
yield document_store
|
||||
document_store.delete_documents()
|
||||
|
||||
@pytest.fixture(params=["elasticsearch", "faiss", "memory", "milvus", "weaviate"])
|
||||
def document_store_cosine_small(request, test_docs_xs):
|
||||
def document_store_small(request):
|
||||
embedding_dim = request.node.get_closest_marker("embedding_dim", pytest.mark.embedding_dim(3))
|
||||
document_store = get_document_store(request.param, embedding_dim.args[0], similarity="cosine")
|
||||
yield document_store
|
||||
document_store.delete_documents()
|
||||
document_store.delete_documents()
|
||||
|
||||
def get_document_store(document_store_type, embedding_dim=768, embedding_field="embedding", index="haystack_test", similarity:str="dot_product"):
|
||||
def get_document_store(document_store_type, embedding_dim=768, embedding_field="embedding", index="haystack_test", similarity:str="cosine"): # cosine is default similarity as dot product is not supported by Weaviate
|
||||
if document_store_type == "sql":
|
||||
document_store = SQLDocumentStore(url="sqlite://", index=index)
|
||||
elif document_store_type == "memory":
|
||||
|
||||
@ -14,7 +14,8 @@ from haystack.errors import DuplicateDocumentError
|
||||
from haystack.schema import Document, Label, Answer, Span
|
||||
from haystack.document_stores.elasticsearch import ElasticsearchDocumentStore
|
||||
from haystack.document_stores.faiss import FAISSDocumentStore
|
||||
|
||||
from haystack.nodes import EmbeddingRetriever
|
||||
from haystack.pipelines import DocumentSearchPipeline
|
||||
|
||||
@pytest.mark.elasticsearch
|
||||
def test_init_elastic_client():
|
||||
@ -954,8 +955,27 @@ def test_elasticsearch_synonyms():
|
||||
indexed_settings = client.indices.get_settings(index="haystack_synonym_arg")
|
||||
|
||||
assert synonym_type == indexed_settings['haystack_synonym_arg']['settings']['index']['analysis']['filter']['synonym']['type']
|
||||
assert synonyms == indexed_settings['haystack_synonym_arg']['settings']['index']['analysis']['filter']['synonym']['synonyms']
|
||||
assert synonyms == indexed_settings['haystack_synonym_arg']['settings']['index']['analysis']['filter']['synonym']['synonyms']
|
||||
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["memory", "faiss", "milvus", "weaviate", "elasticsearch"], indirect=True)
|
||||
@pytest.mark.embedding_dim(384)
|
||||
def test_similarity_score(document_store_with_docs):
|
||||
retriever = EmbeddingRetriever(document_store=document_store_with_docs, embedding_model="sentence-transformers/paraphrase-MiniLM-L3-v2")
|
||||
document_store_with_docs.update_embeddings(retriever)
|
||||
pipeline = DocumentSearchPipeline(retriever)
|
||||
prediction = pipeline.run("Paul lives in New York")
|
||||
scores = [document.score for document in prediction["documents"]]
|
||||
assert scores == pytest.approx([0.9102500000000191, 0.6491700000000264, 0.6321699999999737], abs=1e-3)
|
||||
|
||||
@pytest.mark.parametrize("document_store_dot_product_with_docs", ["memory", "faiss", "milvus", "elasticsearch"], indirect=True)
|
||||
@pytest.mark.embedding_dim(384)
|
||||
def test_similarity_score_dot_product(document_store_dot_product_with_docs):
|
||||
retriever = EmbeddingRetriever(document_store=document_store_dot_product_with_docs, embedding_model="sentence-transformers/paraphrase-MiniLM-L3-v2")
|
||||
document_store_dot_product_with_docs.update_embeddings(retriever)
|
||||
pipeline = DocumentSearchPipeline(retriever)
|
||||
prediction = pipeline.run("Paul lives in New York")
|
||||
scores = [document.score for document in prediction["documents"]]
|
||||
assert scores == pytest.approx([0.5526493562767626, 0.5189836204008691, 0.5179697571274173], abs=1e-3)
|
||||
|
||||
def test_custom_headers(document_store_with_docs: BaseDocumentStore):
|
||||
mock_client = None
|
||||
|
||||
@ -157,7 +157,8 @@ def test_faiss_write_docs(document_store, index_buffer_size, batch_size):
|
||||
original_doc = [d for d in DOCUMENTS if d["content"] == doc.content][0]
|
||||
stored_emb = document_store.faiss_indexes[document_store.index].reconstruct(int(doc.meta["vector_id"]))
|
||||
# compare original input vec with stored one (ignore extra dim added by hnsw)
|
||||
assert np.allclose(original_doc["embedding"], stored_emb, rtol=0.01)
|
||||
# original input vec is normalized as faiss only stores normalized vectors
|
||||
assert np.allclose(original_doc["embedding"] / np.linalg.norm(original_doc["embedding"]), stored_emb, rtol=0.01)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@ -178,7 +179,8 @@ def test_update_docs(document_store, retriever, batch_size):
|
||||
updated_embedding = retriever.embed_documents([Document.from_dict(original_doc)])
|
||||
stored_doc = document_store.get_all_documents(filters={"name": [doc.meta["name"]]})[0]
|
||||
# compare original input vec with stored one (ignore extra dim added by hnsw)
|
||||
assert np.allclose(updated_embedding, stored_doc.embedding, rtol=0.01)
|
||||
# original input vec is normalized as faiss only stores normalized vectors
|
||||
assert np.allclose(updated_embedding / np.linalg.norm(updated_embedding), stored_doc.embedding, rtol=0.01)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@ -441,16 +443,16 @@ def ensure_ids_are_correct_uuids(docs:list,document_store:object)->None:
|
||||
for d in docs:
|
||||
d["id"] = str(uuid.uuid4())
|
||||
|
||||
def test_cosine_similarity(document_store_cosine):
|
||||
def test_cosine_similarity(document_store):
|
||||
# below we will write documents to the store and then query it to see if vectors were normalized
|
||||
|
||||
ensure_ids_are_correct_uuids(docs=DOCUMENTS,document_store=document_store_cosine)
|
||||
document_store_cosine.write_documents(documents=DOCUMENTS)
|
||||
ensure_ids_are_correct_uuids(docs=DOCUMENTS,document_store=document_store)
|
||||
document_store.write_documents(documents=DOCUMENTS)
|
||||
|
||||
# note that the same query will be used later when querying after updating the embeddings
|
||||
query = np.random.rand(768).astype(np.float32)
|
||||
|
||||
query_results = document_store_cosine.query_by_embedding(query_emb=query, top_k=len(DOCUMENTS), return_embedding=True)
|
||||
query_results = document_store.query_by_embedding(query_emb=query, top_k=len(DOCUMENTS), return_embedding=True)
|
||||
|
||||
# check if search with cosine similarity returns the correct number of results
|
||||
assert len(query_results) == len(DOCUMENTS)
|
||||
@ -461,7 +463,7 @@ def test_cosine_similarity(document_store_cosine):
|
||||
for doc in query_results:
|
||||
result_emb = doc.embedding
|
||||
original_emb = np.array([indexed_docs[doc.content]], dtype="float32")
|
||||
document_store_cosine.normalize_embedding(original_emb[0])
|
||||
document_store.normalize_embedding(original_emb[0])
|
||||
|
||||
# check if the stored embedding was normalized
|
||||
assert np.allclose(original_emb[0], result_emb, rtol=0.01)
|
||||
@ -475,27 +477,27 @@ def test_cosine_similarity(document_store_cosine):
|
||||
return [np.random.rand(768).astype(np.float32) for doc in docs]
|
||||
|
||||
retriever = MockRetriever()
|
||||
document_store_cosine.update_embeddings(retriever=retriever)
|
||||
query_results = document_store_cosine.query_by_embedding(query_emb=query, top_k=len(DOCUMENTS), return_embedding=True)
|
||||
document_store.update_embeddings(retriever=retriever)
|
||||
query_results = document_store.query_by_embedding(query_emb=query, top_k=len(DOCUMENTS), return_embedding=True)
|
||||
|
||||
for doc in query_results:
|
||||
original_emb = np.array([indexed_docs[doc.content]], dtype="float32")
|
||||
document_store_cosine.normalize_embedding(original_emb[0])
|
||||
document_store.normalize_embedding(original_emb[0])
|
||||
# check if the original embedding has changed after updating the embeddings
|
||||
assert not np.allclose(original_emb[0], doc.embedding, rtol=0.01)
|
||||
|
||||
|
||||
def test_normalize_embeddings_diff_shapes(document_store_cosine_small):
|
||||
def test_normalize_embeddings_diff_shapes(document_store_dot_product_small):
|
||||
VEC_1 = np.array([.1, .2, .3], dtype="float32")
|
||||
document_store_cosine_small.normalize_embedding(VEC_1)
|
||||
document_store_dot_product_small.normalize_embedding(VEC_1)
|
||||
assert np.linalg.norm(VEC_1) - 1 < 0.01
|
||||
|
||||
VEC_1 = np.array([.1, .2, .3], dtype="float32").reshape(1, -1)
|
||||
document_store_cosine_small.normalize_embedding(VEC_1)
|
||||
document_store_dot_product_small.normalize_embedding(VEC_1)
|
||||
assert np.linalg.norm(VEC_1) - 1 < 0.01
|
||||
|
||||
|
||||
def test_cosine_sanity_check(document_store_cosine_small):
|
||||
def test_cosine_sanity_check(document_store_small):
|
||||
VEC_1 = np.array([.1, .2, .3], dtype="float32")
|
||||
VEC_2 = np.array([.4, .5, .6], dtype="float32")
|
||||
|
||||
@ -504,10 +506,10 @@ def test_cosine_sanity_check(document_store_cosine_small):
|
||||
KNOWN_COSINE = (0.9746317 + 1) / 2
|
||||
|
||||
docs = [{"name": "vec_1", "text": "vec_1", "content": "vec_1", "embedding": VEC_1}]
|
||||
ensure_ids_are_correct_uuids(docs=docs,document_store=document_store_cosine_small)
|
||||
document_store_cosine_small.write_documents(documents=docs)
|
||||
ensure_ids_are_correct_uuids(docs=docs,document_store=document_store_small)
|
||||
document_store_small.write_documents(documents=docs)
|
||||
|
||||
query_results = document_store_cosine_small.query_by_embedding(query_emb=VEC_2, top_k=1, return_embedding=True)
|
||||
query_results = document_store_small.query_by_embedding(query_emb=VEC_2, top_k=1, return_embedding=True)
|
||||
|
||||
# check if faiss returns the same cosine similarity. Manual testing with faiss yielded 0.9746318
|
||||
assert math.isclose(query_results[0].score, KNOWN_COSINE, abs_tol=0.00002)
|
||||
|
||||
@ -154,17 +154,23 @@ def test_dpr_embedding(document_store, retriever, docs):
|
||||
document_store.update_embeddings(retriever=retriever)
|
||||
time.sleep(1)
|
||||
|
||||
doc_1 = document_store.get_document_by_id("1")
|
||||
assert len(doc_1.embedding) == 768
|
||||
assert abs(doc_1.embedding[0] - (-0.3063)) < 0.001
|
||||
doc_2 = document_store.get_document_by_id("2")
|
||||
assert abs(doc_2.embedding[0] - (-0.3914)) < 0.001
|
||||
doc_3 = document_store.get_document_by_id("3")
|
||||
assert abs(doc_3.embedding[0] - (-0.2470)) < 0.001
|
||||
doc_4 = document_store.get_document_by_id("4")
|
||||
assert abs(doc_4.embedding[0] - (-0.0802)) < 0.001
|
||||
doc_5 = document_store.get_document_by_id("5")
|
||||
assert abs(doc_5.embedding[0] - (-0.0551)) < 0.001
|
||||
# always normalize vector as faiss returns normalized vectors and other document stores do not
|
||||
doc_1 = document_store.get_document_by_id("1").embedding
|
||||
doc_1 /= np.linalg.norm(doc_1)
|
||||
assert len(doc_1) == 768
|
||||
assert abs(doc_1[0] - (-0.0250)) < 0.001
|
||||
doc_2 = document_store.get_document_by_id("2").embedding
|
||||
doc_2 /= np.linalg.norm(doc_2)
|
||||
assert abs(doc_2[0] - (-0.0314)) < 0.001
|
||||
doc_3 = document_store.get_document_by_id("3").embedding
|
||||
doc_3 /= np.linalg.norm(doc_3)
|
||||
assert abs(doc_3[0] - (-0.0200)) < 0.001
|
||||
doc_4 = document_store.get_document_by_id("4").embedding
|
||||
doc_4 /= np.linalg.norm(doc_4)
|
||||
assert abs(doc_4[0] - (-0.0070)) < 0.001
|
||||
doc_5 = document_store.get_document_by_id("5").embedding
|
||||
doc_5 /= np.linalg.norm(doc_5)
|
||||
assert abs(doc_5[0] - (-0.0049)) < 0.001
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
|
||||
@ -125,17 +125,17 @@ def test_most_similar_documents_pipeline(retriever, document_store):
|
||||
|
||||
|
||||
@pytest.mark.elasticsearch
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
|
||||
@pytest.mark.parametrize("document_store_dot_product_with_docs", ["elasticsearch"], indirect=True)
|
||||
@pytest.mark.parametrize("reader", ["farm"], indirect=True)
|
||||
def test_join_document_pipeline(document_store_with_docs, reader):
|
||||
es = ElasticsearchRetriever(document_store=document_store_with_docs)
|
||||
def test_join_document_pipeline(document_store_dot_product_with_docs, reader):
|
||||
es = ElasticsearchRetriever(document_store=document_store_dot_product_with_docs)
|
||||
dpr = DensePassageRetriever(
|
||||
document_store=document_store_with_docs,
|
||||
document_store=document_store_dot_product_with_docs,
|
||||
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
|
||||
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
|
||||
use_gpu=False,
|
||||
)
|
||||
document_store_with_docs.update_embeddings(dpr)
|
||||
document_store_dot_product_with_docs.update_embeddings(dpr)
|
||||
|
||||
query = "Where does Carla live?"
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user