mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-24 13:38:53 +00:00
Unify vector_dim and embedding_dim parameter in Document Store (#1922)
* Refactored code to unify vector_dim and embedding_dim parameter in DocumentStores * Unit test cases updated to use `embedding_dim` instead of `vector_dim` * Unit test case update to use embedding_dim instead of vector_dim * Add latest docstring and tutorial changes * Put usage of `vector_dim` param in same if-block as corresponding warning Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: bogdankostic <bogdankostic@web.de>
This commit is contained in:
parent
00dc30ae54
commit
a44b6c18c0
@ -1202,14 +1202,15 @@ the vector embeddings are indexed in a FAISS Index.
|
||||
#### \_\_init\_\_
|
||||
|
||||
```python
|
||||
| __init__(sql_url: str = "sqlite:///faiss_document_store.db", vector_dim: int = 768, faiss_index_factory_str: str = "Flat", faiss_index: Optional["faiss.swigfaiss.Index"] = None, return_embedding: bool = False, index: str = "document", similarity: str = "dot_product", embedding_field: str = "embedding", progress_bar: bool = True, duplicate_documents: str = 'overwrite', faiss_index_path: Union[str, Path] = None, faiss_config_path: Union[str, Path] = None, **kwargs, ,)
|
||||
| __init__(sql_url: str = "sqlite:///faiss_document_store.db", vector_dim: int = None, embedding_dim: int = 768, faiss_index_factory_str: str = "Flat", faiss_index: Optional["faiss.swigfaiss.Index"] = None, return_embedding: bool = False, index: str = "document", similarity: str = "dot_product", embedding_field: str = "embedding", progress_bar: bool = True, duplicate_documents: str = 'overwrite', faiss_index_path: Union[str, Path] = None, faiss_config_path: Union[str, Path] = None, **kwargs, ,)
|
||||
```
|
||||
|
||||
**Arguments**:
|
||||
|
||||
- `sql_url`: SQL connection URL for database. It defaults to local file based SQLite DB. For large scale
|
||||
deployment, Postgres is recommended.
|
||||
- `vector_dim`: the embedding vector size.
|
||||
- `vector_dim`: Deprecated. Use embedding_dim instead.
|
||||
- `embedding_dim`: The embedding vector size. Default: 768.
|
||||
- `faiss_index_factory_str`: Create a new FAISS index of the specified type.
|
||||
The type is determined from the given string following the conventions
|
||||
of the original FAISS index factory.
|
||||
@ -1231,7 +1232,7 @@ the vector embeddings are indexed in a FAISS Index.
|
||||
- `index`: Name of index in document store to use.
|
||||
- `similarity`: The similarity function used to compare document vectors. 'dot_product' is the default since it is
|
||||
more performant with DPR embeddings. 'cosine' is recommended if you are using a Sentence-Transformer model.
|
||||
In both cases, the returned values in Document.score are normalized to be in range [0,1]:
|
||||
In both cases, the returned values in Document.score are normalized to be in range [0,1]:
|
||||
For `dot_product`: expit(np.asarray(raw_score / 100))
|
||||
FOr `cosine`: (raw_score + 1) / 2
|
||||
- `embedding_field`: Name of field containing an embedding vector.
|
||||
@ -1424,7 +1425,7 @@ Save FAISS Index to the specified file.
|
||||
- `config_path`: Path to save the initial configuration parameters to.
|
||||
Defaults to the same as the file path, save the extension (.json).
|
||||
This file contains all the parameters passed to FAISSDocumentStore()
|
||||
at creation time (for example the SQL path, vector_dim, etc), and will be
|
||||
at creation time (for example the SQL path, embedding_dim, etc), and will be
|
||||
used by the `load` method to restore the index with the appropriate configuration.
|
||||
|
||||
**Returns**:
|
||||
@ -1478,7 +1479,7 @@ Usage:
|
||||
#### \_\_init\_\_
|
||||
|
||||
```python
|
||||
| __init__(sql_url: str = "sqlite:///", milvus_url: str = "tcp://localhost:19530", connection_pool: str = "SingletonThread", index: str = "document", vector_dim: int = 768, index_file_size: int = 1024, similarity: str = "dot_product", index_type: IndexType = IndexType.FLAT, index_param: Optional[Dict[str, Any]] = None, search_param: Optional[Dict[str, Any]] = None, return_embedding: bool = False, embedding_field: str = "embedding", progress_bar: bool = True, duplicate_documents: str = 'overwrite', **kwargs, ,)
|
||||
| __init__(sql_url: str = "sqlite:///", milvus_url: str = "tcp://localhost:19530", connection_pool: str = "SingletonThread", index: str = "document", vector_dim: int = None, embedding_dim: int = 768, index_file_size: int = 1024, similarity: str = "dot_product", index_type: IndexType = IndexType.FLAT, index_param: Optional[Dict[str, Any]] = None, search_param: Optional[Dict[str, Any]] = None, return_embedding: bool = False, embedding_field: str = "embedding", progress_bar: bool = True, duplicate_documents: str = 'overwrite', **kwargs, ,)
|
||||
```
|
||||
|
||||
**Arguments**:
|
||||
@ -1491,7 +1492,8 @@ Usage:
|
||||
See https://milvus.io/docs/v1.0.0/install_milvus.md for instructions to start a Milvus instance.
|
||||
- `connection_pool`: Connection pool type to connect with Milvus server. Default: "SingletonThread".
|
||||
- `index`: Index name for text, embedding and metadata (in Milvus terms, this is the "collection name").
|
||||
- `vector_dim`: The embedding vector size. Default: 768.
|
||||
- `vector_dim`: Deprecated. Use embedding_dim instead.
|
||||
- `embedding_dim`: The embedding vector size. Default: 768.
|
||||
- `index_file_size`: Specifies the size of each segment file that is stored by Milvus and its default value is 1024 MB.
|
||||
When the size of newly inserted vectors reaches the specified volume, Milvus packs these vectors into a new segment.
|
||||
Milvus creates one index file for each segment. When conducting a vector search, Milvus searches all index files one by one.
|
||||
|
||||
@ -54,7 +54,7 @@ For more info on which suits your use case: https://github.com/facebookresearch/
|
||||
```python
|
||||
from haystack.document_stores import FAISSDocumentStore
|
||||
|
||||
document_store = FAISSDocumentStore(vector_dim=128, faiss_index_factory_str="Flat")
|
||||
document_store = FAISSDocumentStore(embedding_dim=128, faiss_index_factory_str="Flat")
|
||||
```
|
||||
|
||||
### Cleaning & indexing documents
|
||||
|
||||
@ -8,6 +8,7 @@ import logging
|
||||
from pathlib import Path
|
||||
from typing import Union, List, Optional, Dict, Generator
|
||||
from tqdm.auto import tqdm
|
||||
import warnings
|
||||
|
||||
try:
|
||||
import faiss
|
||||
@ -37,7 +38,8 @@ class FAISSDocumentStore(SQLDocumentStore):
|
||||
def __init__(
|
||||
self,
|
||||
sql_url: str = "sqlite:///faiss_document_store.db",
|
||||
vector_dim: int = 768,
|
||||
vector_dim: int = None,
|
||||
embedding_dim: int = 768,
|
||||
faiss_index_factory_str: str = "Flat",
|
||||
faiss_index: Optional["faiss.swigfaiss.Index"] = None,
|
||||
return_embedding: bool = False,
|
||||
@ -53,7 +55,8 @@ class FAISSDocumentStore(SQLDocumentStore):
|
||||
"""
|
||||
:param sql_url: SQL connection URL for database. It defaults to local file based SQLite DB. For large scale
|
||||
deployment, Postgres is recommended.
|
||||
:param vector_dim: the embedding vector size.
|
||||
:param vector_dim: Deprecated. Use embedding_dim instead.
|
||||
:param embedding_dim: The embedding vector size. Default: 768.
|
||||
:param faiss_index_factory_str: Create a new FAISS index of the specified type.
|
||||
The type is determined from the given string following the conventions
|
||||
of the original FAISS index factory.
|
||||
@ -75,7 +78,7 @@ class FAISSDocumentStore(SQLDocumentStore):
|
||||
:param index: Name of index in document store to use.
|
||||
:param similarity: The similarity function used to compare document vectors. 'dot_product' is the default since it is
|
||||
more performant with DPR embeddings. 'cosine' is recommended if you are using a Sentence-Transformer model.
|
||||
In both cases, the returned values in Document.score are normalized to be in range [0,1]:
|
||||
In both cases, the returned values in Document.score are normalized to be in range [0,1]:
|
||||
For `dot_product`: expit(np.asarray(raw_score / 100))
|
||||
FOr `cosine`: (raw_score + 1) / 2
|
||||
:param embedding_field: Name of field containing an embedding vector.
|
||||
@ -89,7 +92,7 @@ class FAISSDocumentStore(SQLDocumentStore):
|
||||
exists.
|
||||
:param faiss_index_path: Stored FAISS index file. Can be created via calling `save()`.
|
||||
If specified no other params besides faiss_config_path must be specified.
|
||||
:param faiss_config_path: Stored FAISS initial configuration parameters.
|
||||
:param faiss_config_path: Stored FAISS initial configuration parameters.
|
||||
Can be created via calling `save()`
|
||||
"""
|
||||
# special case if we want to load an existing index from disk
|
||||
@ -103,14 +106,15 @@ class FAISSDocumentStore(SQLDocumentStore):
|
||||
|
||||
# save init parameters to enable export of component config as YAML
|
||||
self.set_config(
|
||||
sql_url=sql_url,
|
||||
vector_dim=vector_dim,
|
||||
sql_url=sql_url,
|
||||
vector_dim=vector_dim,
|
||||
embedding_dim=embedding_dim,
|
||||
faiss_index_factory_str=faiss_index_factory_str,
|
||||
return_embedding=return_embedding,
|
||||
duplicate_documents=duplicate_documents,
|
||||
index=index,
|
||||
duplicate_documents=duplicate_documents,
|
||||
index=index,
|
||||
similarity=similarity,
|
||||
embedding_field=embedding_field,
|
||||
embedding_field=embedding_field,
|
||||
progress_bar=progress_bar
|
||||
)
|
||||
|
||||
@ -124,14 +128,20 @@ class FAISSDocumentStore(SQLDocumentStore):
|
||||
raise ValueError("The FAISS document store can currently only support dot_product, cosine and l2 similarity. "
|
||||
"Please set similarity to one of the above.")
|
||||
|
||||
self.vector_dim = vector_dim
|
||||
if vector_dim is not None:
|
||||
warnings.warn("The 'vector_dim' parameter is deprecated, "
|
||||
"use 'embedding_dim' instead.", DeprecationWarning, 2)
|
||||
self.embedding_dim = vector_dim
|
||||
else:
|
||||
self.embedding_dim = embedding_dim
|
||||
|
||||
self.faiss_index_factory_str = faiss_index_factory_str
|
||||
self.faiss_indexes: Dict[str, faiss.swigfaiss.Index] = {}
|
||||
if faiss_index:
|
||||
self.faiss_indexes[index] = faiss_index
|
||||
else:
|
||||
self.faiss_indexes[index] = self._create_new_index(
|
||||
vector_dim=self.vector_dim,
|
||||
embedding_dim=self.embedding_dim,
|
||||
index_factory=faiss_index_factory_str,
|
||||
metric_type=self.metric_type,
|
||||
**kwargs
|
||||
@ -158,7 +168,7 @@ class FAISSDocumentStore(SQLDocumentStore):
|
||||
if param.name not in allowed_params and param.default != locals[param.name]:
|
||||
invalid_param_set = True
|
||||
break
|
||||
|
||||
|
||||
if invalid_param_set or len(kwargs) > 0:
|
||||
raise ValueError("if faiss_index_path is passed no other params besides faiss_config_path are allowed.")
|
||||
|
||||
@ -172,12 +182,12 @@ class FAISSDocumentStore(SQLDocumentStore):
|
||||
"configuration file correctly points to the same database that "
|
||||
"was used when creating the original index.")
|
||||
|
||||
def _create_new_index(self, vector_dim: int, metric_type, index_factory: str = "Flat", **kwargs):
|
||||
def _create_new_index(self, embedding_dim: int, metric_type, index_factory: str = "Flat", **kwargs):
|
||||
if index_factory == "HNSW":
|
||||
# faiss index factory doesn't give the same results for HNSW IP, therefore direct init.
|
||||
# defaults here are similar to DPR codebase (good accuracy, but very high RAM consumption)
|
||||
n_links = kwargs.get("n_links", 64)
|
||||
index = faiss.IndexHNSWFlat(vector_dim, n_links, metric_type)
|
||||
index = faiss.IndexHNSWFlat(embedding_dim, n_links, metric_type)
|
||||
index.hnsw.efSearch = kwargs.get("efSearch", 20)#20
|
||||
index.hnsw.efConstruction = kwargs.get("efConstruction", 80)#80
|
||||
if "ivf" in index_factory.lower(): # enable reconstruction of vectors for inverted index
|
||||
@ -185,7 +195,7 @@ class FAISSDocumentStore(SQLDocumentStore):
|
||||
|
||||
logger.info(f"HNSW params: n_links: {n_links}, efSearch: {index.hnsw.efSearch}, efConstruction: {index.hnsw.efConstruction}")
|
||||
else:
|
||||
index = faiss.index_factory(vector_dim, index_factory, metric_type)
|
||||
index = faiss.index_factory(embedding_dim, index_factory, metric_type)
|
||||
return index
|
||||
|
||||
def write_documents(self, documents: Union[List[dict], List[Document]], index: Optional[str] = None,
|
||||
@ -217,7 +227,7 @@ class FAISSDocumentStore(SQLDocumentStore):
|
||||
|
||||
if not self.faiss_indexes.get(index):
|
||||
self.faiss_indexes[index] = self._create_new_index(
|
||||
vector_dim=self.vector_dim,
|
||||
embedding_dim=self.embedding_dim,
|
||||
index_factory=self.faiss_index_factory_str,
|
||||
metric_type=faiss.METRIC_INNER_PRODUCT,
|
||||
)
|
||||
@ -544,7 +554,7 @@ class FAISSDocumentStore(SQLDocumentStore):
|
||||
:param config_path: Path to save the initial configuration parameters to.
|
||||
Defaults to the same as the file path, save the extension (.json).
|
||||
This file contains all the parameters passed to FAISSDocumentStore()
|
||||
at creation time (for example the SQL path, vector_dim, etc), and will be
|
||||
at creation time (for example the SQL path, embedding_dim, etc), and will be
|
||||
used by the `load` method to restore the index with the appropriate configuration.
|
||||
:return: None
|
||||
"""
|
||||
@ -574,7 +584,7 @@ class FAISSDocumentStore(SQLDocumentStore):
|
||||
|
||||
# Add other init params to override the ones defined in the init params file
|
||||
init_params["faiss_index"] = faiss_index
|
||||
init_params["vector_dim"] = faiss_index.d
|
||||
init_params["embedding_dim"] = faiss_index.d
|
||||
|
||||
return init_params
|
||||
|
||||
|
||||
@ -4,6 +4,7 @@ if TYPE_CHECKING:
|
||||
from haystack.nodes.retriever import BaseRetriever
|
||||
|
||||
import logging
|
||||
import warnings
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from scipy.special import expit
|
||||
@ -41,7 +42,8 @@ class MilvusDocumentStore(SQLDocumentStore):
|
||||
milvus_url: str = "tcp://localhost:19530",
|
||||
connection_pool: str = "SingletonThread",
|
||||
index: str = "document",
|
||||
vector_dim: int = 768,
|
||||
vector_dim: int = None,
|
||||
embedding_dim: int = 768,
|
||||
index_file_size: int = 1024,
|
||||
similarity: str = "dot_product",
|
||||
index_type: IndexType = IndexType.FLAT,
|
||||
@ -62,7 +64,8 @@ class MilvusDocumentStore(SQLDocumentStore):
|
||||
See https://milvus.io/docs/v1.0.0/install_milvus.md for instructions to start a Milvus instance.
|
||||
:param connection_pool: Connection pool type to connect with Milvus server. Default: "SingletonThread".
|
||||
:param index: Index name for text, embedding and metadata (in Milvus terms, this is the "collection name").
|
||||
:param vector_dim: The embedding vector size. Default: 768.
|
||||
:param vector_dim: Deprecated. Use embedding_dim instead.
|
||||
:param embedding_dim: The embedding vector size. Default: 768.
|
||||
:param index_file_size: Specifies the size of each segment file that is stored by Milvus and its default value is 1024 MB.
|
||||
When the size of newly inserted vectors reaches the specified volume, Milvus packs these vectors into a new segment.
|
||||
Milvus creates one index file for each segment. When conducting a vector search, Milvus searches all index files one by one.
|
||||
@ -98,13 +101,20 @@ class MilvusDocumentStore(SQLDocumentStore):
|
||||
# save init parameters to enable export of component config as YAML
|
||||
self.set_config(
|
||||
sql_url=sql_url, milvus_url=milvus_url, connection_pool=connection_pool, index=index, vector_dim=vector_dim,
|
||||
index_file_size=index_file_size, similarity=similarity, index_type=index_type, index_param=index_param,
|
||||
embedding_dim=embedding_dim, index_file_size=index_file_size, similarity=similarity, index_type=index_type, index_param=index_param,
|
||||
search_param=search_param, duplicate_documents=duplicate_documents,
|
||||
return_embedding=return_embedding, embedding_field=embedding_field, progress_bar=progress_bar,
|
||||
)
|
||||
|
||||
self.milvus_server = Milvus(uri=milvus_url, pool=connection_pool)
|
||||
self.vector_dim = vector_dim
|
||||
|
||||
if vector_dim is not None:
|
||||
warnings.warn("The 'vector_dim' parameter is deprecated, "
|
||||
"use 'embedding_dim' instead.", DeprecationWarning, 2)
|
||||
self.embedding_dim = vector_dim
|
||||
else:
|
||||
self.embedding_dim = embedding_dim
|
||||
|
||||
self.index_file_size = index_file_size
|
||||
|
||||
if similarity in ("dot_product", "cosine"):
|
||||
@ -147,7 +157,7 @@ class MilvusDocumentStore(SQLDocumentStore):
|
||||
if not ok:
|
||||
collection_param = {
|
||||
'collection_name': index,
|
||||
'dimension': self.vector_dim,
|
||||
'dimension': self.embedding_dim,
|
||||
'index_file_size': self.index_file_size,
|
||||
'metric_type': self.metric_type
|
||||
}
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import logging
|
||||
import warnings
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -59,7 +60,8 @@ class Milvus2DocumentStore(SQLDocumentStore):
|
||||
port: str = "19530",
|
||||
connection_pool: str = "SingletonThread",
|
||||
index: str = "document",
|
||||
vector_dim: int = 768,
|
||||
vector_dim: int = None,
|
||||
embedding_dim: int = 768,
|
||||
index_file_size: int = 1024,
|
||||
similarity: str = "dot_product",
|
||||
index_type: str = "IVF_FLAT",
|
||||
@ -81,7 +83,8 @@ class Milvus2DocumentStore(SQLDocumentStore):
|
||||
See https://milvus.io/docs/v1.0.0/install_milvus.md for instructions to start a Milvus instance.
|
||||
:param connection_pool: Connection pool type to connect with Milvus server. Default: "SingletonThread".
|
||||
:param index: Index name for text, embedding and metadata (in Milvus terms, this is the "collection name").
|
||||
:param vector_dim: The embedding vector size. Default: 768.
|
||||
:param vector_dim: Deprecated. Use embedding_dim instead.
|
||||
:param embedding_dim: The embedding vector size. Default: 768.
|
||||
:param index_file_size: Specifies the size of each segment file that is stored by Milvus and its default value is 1024 MB.
|
||||
When the size of newly inserted vectors reaches the specified volume, Milvus packs these vectors into a new segment.
|
||||
Milvus creates one index file for each segment. When conducting a vector search, Milvus searches all index files one by one.
|
||||
@ -120,7 +123,7 @@ class Milvus2DocumentStore(SQLDocumentStore):
|
||||
# save init parameters to enable export of component config as YAML
|
||||
self.set_config(
|
||||
sql_url=sql_url, host=host, port=port, connection_pool=connection_pool, index=index, vector_dim=vector_dim,
|
||||
index_file_size=index_file_size, similarity=similarity, index_type=index_type, index_param=index_param,
|
||||
embedding_dim=embedding_dim, index_file_size=index_file_size, similarity=similarity, index_type=index_type, index_param=index_param,
|
||||
search_param=search_param, duplicate_documents=duplicate_documents, id_field=id_field,
|
||||
return_embedding=return_embedding, embedding_field=embedding_field, progress_bar=progress_bar,
|
||||
custom_fields=custom_fields,
|
||||
@ -135,7 +138,13 @@ class Milvus2DocumentStore(SQLDocumentStore):
|
||||
connections.add_connection(default={"host": host, "port": port})
|
||||
connections.connect()
|
||||
|
||||
self.vector_dim = vector_dim
|
||||
if vector_dim is not None:
|
||||
warnings.warn("The 'vector_dim' parameter is deprecated, "
|
||||
"use 'embedding_dim' instead.", DeprecationWarning, 2)
|
||||
self.embedding_dim = vector_dim
|
||||
else:
|
||||
self.embedding_dim = embedding_dim
|
||||
|
||||
self.index_file_size = index_file_size
|
||||
|
||||
if similarity == "dot_product":
|
||||
@ -187,7 +196,7 @@ class Milvus2DocumentStore(SQLDocumentStore):
|
||||
if not has_collection:
|
||||
fields = [
|
||||
FieldSchema(name=self.id_field, dtype=DataType.INT64, is_primary=True, auto_id=True),
|
||||
FieldSchema(name=self.embedding_field, dtype=DataType.FLOAT_VECTOR, dim=self.vector_dim)
|
||||
FieldSchema(name=self.embedding_field, dtype=DataType.FLOAT_VECTOR, dim=self.embedding_dim)
|
||||
]
|
||||
|
||||
for field in custom_fields:
|
||||
|
||||
@ -487,22 +487,22 @@ def document_store_with_docs(request, test_docs_xs):
|
||||
|
||||
@pytest.fixture
|
||||
def document_store(request, test_docs_xs):
|
||||
vector_dim = request.node.get_closest_marker("vector_dim", pytest.mark.vector_dim(768))
|
||||
document_store = get_document_store(request.param, vector_dim.args[0])
|
||||
embedding_dim = request.node.get_closest_marker("embedding_dim", pytest.mark.embedding_dim(768))
|
||||
document_store = get_document_store(request.param, embedding_dim.args[0])
|
||||
yield document_store
|
||||
document_store.delete_documents()
|
||||
|
||||
@pytest.fixture(params=["faiss", "milvus", "weaviate"])
|
||||
def document_store_cosine(request, test_docs_xs):
|
||||
vector_dim = request.node.get_closest_marker("vector_dim", pytest.mark.vector_dim(768))
|
||||
document_store = get_document_store(request.param, vector_dim.args[0], similarity="cosine")
|
||||
embedding_dim = request.node.get_closest_marker("embedding_dim", pytest.mark.embedding_dim(768))
|
||||
document_store = get_document_store(request.param, embedding_dim.args[0], similarity="cosine")
|
||||
yield document_store
|
||||
document_store.delete_documents()
|
||||
|
||||
@pytest.fixture(params=["elasticsearch", "faiss", "memory", "milvus", "weaviate"])
|
||||
def document_store_cosine_small(request, test_docs_xs):
|
||||
vector_dim = request.node.get_closest_marker("vector_dim", pytest.mark.vector_dim(3))
|
||||
document_store = get_document_store(request.param, vector_dim.args[0], similarity="cosine")
|
||||
embedding_dim = request.node.get_closest_marker("embedding_dim", pytest.mark.embedding_dim(3))
|
||||
document_store = get_document_store(request.param, embedding_dim.args[0], similarity="cosine")
|
||||
yield document_store
|
||||
document_store.delete_documents()
|
||||
|
||||
@ -522,7 +522,7 @@ def get_document_store(document_store_type, embedding_dim=768, embedding_field="
|
||||
)
|
||||
elif document_store_type == "faiss":
|
||||
document_store = FAISSDocumentStore(
|
||||
vector_dim=embedding_dim,
|
||||
embedding_dim=embedding_dim,
|
||||
sql_url="sqlite://",
|
||||
return_embedding=True,
|
||||
embedding_field=embedding_field,
|
||||
@ -531,7 +531,7 @@ def get_document_store(document_store_type, embedding_dim=768, embedding_field="
|
||||
)
|
||||
elif document_store_type == "milvus":
|
||||
document_store = MilvusDocumentStore(
|
||||
vector_dim=embedding_dim,
|
||||
embedding_dim=embedding_dim,
|
||||
sql_url="sqlite://",
|
||||
return_embedding=True,
|
||||
embedding_field=embedding_field,
|
||||
|
||||
@ -9,4 +9,4 @@ markers =
|
||||
pipeline: marks tests with pipeline
|
||||
summarizer: marks summarizer tests
|
||||
weaviate: marks tests that require weaviate container
|
||||
vector_dim: marks usage of document store with non-default embedding dimension (e.g @pytest.mark.vector_dim(128))
|
||||
embedding_dim: marks usage of document store with non-default embedding dimension (e.g @pytest.mark.embedding_dim(128))
|
||||
@ -375,7 +375,7 @@ def test_update_embeddings(document_store, retriever):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("retriever", ["table_text_retriever"], indirect=True)
|
||||
@pytest.mark.vector_dim(512)
|
||||
@pytest.mark.embedding_dim(512)
|
||||
def test_update_embeddings_table_text_retriever(document_store, retriever):
|
||||
documents = []
|
||||
for i in range(3):
|
||||
|
||||
@ -70,7 +70,7 @@ def test_generator_pipeline(document_store, retriever, rag_generator):
|
||||
@pytest.mark.generator
|
||||
@pytest.mark.parametrize("document_store", ["memory"], indirect=True)
|
||||
@pytest.mark.parametrize("retriever", ["retribert"], indirect=True)
|
||||
@pytest.mark.vector_dim(128)
|
||||
@pytest.mark.embedding_dim(128)
|
||||
def test_lfqa_pipeline(document_store, retriever, eli5_generator):
|
||||
# reuse existing DOCS but regenerate embeddings with retribert
|
||||
docs: List[Document] = []
|
||||
@ -90,7 +90,7 @@ def test_lfqa_pipeline(document_store, retriever, eli5_generator):
|
||||
@pytest.mark.generator
|
||||
@pytest.mark.parametrize("document_store", ["memory"], indirect=True)
|
||||
@pytest.mark.parametrize("retriever", ["retribert"], indirect=True)
|
||||
@pytest.mark.vector_dim(128)
|
||||
@pytest.mark.embedding_dim(128)
|
||||
def test_lfqa_pipeline_unknown_converter(document_store, retriever):
|
||||
# reuse existing DOCS but regenerate embeddings with retribert
|
||||
docs: List[Document] = []
|
||||
@ -112,7 +112,7 @@ def test_lfqa_pipeline_unknown_converter(document_store, retriever):
|
||||
@pytest.mark.generator
|
||||
@pytest.mark.parametrize("document_store", ["memory"], indirect=True)
|
||||
@pytest.mark.parametrize("retriever", ["retribert"], indirect=True)
|
||||
@pytest.mark.vector_dim(128)
|
||||
@pytest.mark.embedding_dim(128)
|
||||
def test_lfqa_pipeline_invalid_converter(document_store, retriever):
|
||||
# reuse existing DOCS but regenerate embeddings with retribert
|
||||
docs: List[Document] = []
|
||||
|
||||
@ -169,7 +169,7 @@ def test_dpr_embedding(document_store, retriever, docs):
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("retriever", ["retribert"], indirect=True)
|
||||
@pytest.mark.vector_dim(128)
|
||||
@pytest.mark.embedding_dim(128)
|
||||
def test_retribert_embedding(document_store, retriever, docs):
|
||||
if isinstance(document_store, WeaviateDocumentStore):
|
||||
# Weaviate sets the embedding dimension to 768 as soon as it is initialized.
|
||||
@ -197,7 +197,7 @@ def test_retribert_embedding(document_store, retriever, docs):
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("retriever", ["table_text_retriever"], indirect=True)
|
||||
@pytest.mark.parametrize("document_store", ["elasticsearch"], indirect=True)
|
||||
@pytest.mark.vector_dim(512)
|
||||
@pytest.mark.embedding_dim(512)
|
||||
def test_table_text_retriever_embedding(document_store, retriever, docs):
|
||||
|
||||
document_store.return_embedding = True
|
||||
@ -277,7 +277,7 @@ def test_dpr_saving_and_loading(retriever, document_store):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("retriever", ["table_text_retriever"], indirect=True)
|
||||
@pytest.mark.vector_dim(512)
|
||||
@pytest.mark.embedding_dim(512)
|
||||
def test_table_text_retriever_saving_and_loading(retriever, document_store):
|
||||
retriever.save("test_table_text_retriever_save")
|
||||
|
||||
@ -325,7 +325,7 @@ def test_table_text_retriever_saving_and_loading(retriever, document_store):
|
||||
assert loaded_retriever.query_tokenizer.model_max_length == 512
|
||||
|
||||
|
||||
@pytest.mark.vector_dim(128)
|
||||
@pytest.mark.embedding_dim(128)
|
||||
def test_table_text_retriever_training(document_store):
|
||||
retriever = TableTextRetriever(
|
||||
document_store=document_store,
|
||||
|
||||
@ -115,7 +115,7 @@
|
||||
"source": [
|
||||
"from haystack.document_stores import FAISSDocumentStore\n",
|
||||
"\n",
|
||||
"document_store = FAISSDocumentStore(vector_dim=128, faiss_index_factory_str=\"Flat\")"
|
||||
"document_store = FAISSDocumentStore(embedding_dim=128, faiss_index_factory_str=\"Flat\")"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {
|
||||
@ -343,4 +343,4 @@
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
@ -17,7 +17,7 @@ def tutorial12_lfqa():
|
||||
|
||||
from haystack.document_stores.faiss import FAISSDocumentStore
|
||||
|
||||
document_store = FAISSDocumentStore(vector_dim=128, faiss_index_factory_str="Flat")
|
||||
document_store = FAISSDocumentStore(embedding_dim=128, faiss_index_factory_str="Flat")
|
||||
|
||||
"""
|
||||
Cleaning & indexing documents:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user