feat: Add IVF and Product Quantization support for OpenSearchDocumentStore (#3850)

* Add IVF and Product Quantization support for OpenSearchDocumentStore

* Remove unused import statement

* Fix mypy

* Adapt doc strings and error messages to account for PQ

* Adapt validation of indices

* Adapt existing tests

* Fix pylint

* Add tests

* Update lg

* Adapt based on PR review comments

* Fix Pylint

* Adapt based on PR review

* Add request_timeout

* Adapt based on PR review

* Adapt based on PR review

* Adapt tests

* Pin tenacity

* Unpin tenacity

* Adapt based on PR comments

* Add match to tests

---------

Co-authored-by: agnieszka-m <amarzec13@gmail.com>
This commit is contained in:
bogdankostic 2023-02-17 10:28:36 +01:00 committed by GitHub
parent 8370715e7c
commit 7eeb3e07bf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 811 additions and 60 deletions

View File

@ -51,7 +51,6 @@ class ElasticsearchDocumentStore(SearchEngineDocumentStore):
timeout: int = 30,
return_embedding: bool = False,
duplicate_documents: str = "overwrite",
index_type: str = "flat",
scroll: str = "1d",
skip_missing_embeddings: bool = True,
synonyms: Optional[List] = None,
@ -113,8 +112,6 @@ class ElasticsearchDocumentStore(SearchEngineDocumentStore):
overwrite: Update any existing documents with the same ID when adding documents.
fail: an error is raised if the document ID of the document being added already
exists.
:param index_type: The type of index to be created. Choose from 'flat' and 'hnsw'. Currently the
ElasticsearchDocumentStore does not support HNSW but OpenDistroElasticsearchDocumentStore does.
:param scroll: Determines how long the current index is fixed, e.g. during updating all documents with embeddings.
Defaults to "1d" and should not be larger than this. Can also be in minutes "5m" or hours "15h"
For details, see https://www.elastic.co/guide/en/elasticsearch/reference/current/scroll-api.html
@ -132,13 +129,6 @@ class ElasticsearchDocumentStore(SearchEngineDocumentStore):
:param use_system_proxy: Whether to use system proxy.
"""
# hnsw is only supported in OpensearchDocumentStore
if index_type == "hnsw":
raise DocumentStoreError(
"The HNSW algorithm for approximate nearest neighbours calculation is currently not available in the ElasticSearchDocumentStore. "
"Try the OpenSearchDocumentStore instead."
)
# Base constructor might need the client to be ready, create it first
client = self._init_elastic_client(
host=host,
@ -173,7 +163,6 @@ class ElasticsearchDocumentStore(SearchEngineDocumentStore):
similarity=similarity,
return_embedding=return_embedding,
duplicate_documents=duplicate_documents,
index_type=index_type,
scroll=scroll,
skip_missing_embeddings=skip_missing_embeddings,
synonyms=synonyms,

View File

@ -459,7 +459,7 @@ class FAISSDocumentStore(SQLDocumentStore):
def train_index(
self,
documents: Optional[Union[List[dict], List[Document]]],
documents: Optional[Union[List[dict], List[Document]]] = None,
embeddings: Optional[np.ndarray] = None,
index: Optional[str] = None,
):
@ -474,15 +474,20 @@ class FAISSDocumentStore(SQLDocumentStore):
:return: None
"""
index = index or self.index
if embeddings and documents:
if isinstance(embeddings, np.ndarray) and documents:
raise ValueError("Either pass `documents` or `embeddings`. You passed both.")
if documents:
document_objects = [Document.from_dict(d) if isinstance(d, dict) else d for d in documents]
doc_embeddings = [doc.embedding for doc in document_objects]
doc_embeddings = [doc.embedding for doc in document_objects if doc.embedding is not None]
embeddings_for_train = np.array(doc_embeddings, dtype="float32")
self.faiss_indexes[index].train(embeddings_for_train)
if embeddings:
elif isinstance(embeddings, np.ndarray):
self.faiss_indexes[index].train(embeddings)
else:
logger.warning(
"When calling `train_index`, you must provide either Documents or embeddings. Because none of these values was provided, no training will be performed. "
)
def delete_all_documents(
self,

View File

@ -4,6 +4,7 @@ import logging
import numpy as np
from tqdm.auto import tqdm
from tenacity import retry, wait_exponential, retry_if_not_result
try:
from opensearchpy import OpenSearch, Urllib3HttpConnection, RequestsHttpConnection, NotFoundError, RequestError
@ -33,6 +34,8 @@ SIMILARITY_SPACE_TYPE_MAPPINGS = {
class OpenSearchDocumentStore(SearchEngineDocumentStore):
valid_index_types = ["flat", "hnsw", "ivf", "ivf_pq"]
def __init__(
self,
scheme: str = "https", # Mind this different default param
@ -69,6 +72,8 @@ class OpenSearchDocumentStore(SearchEngineDocumentStore):
synonym_type: str = "synonym",
use_system_proxy: bool = False,
knn_engine: str = "nmslib",
knn_parameters: Optional[Dict] = None,
ivf_train_size: Optional[int] = None,
):
"""
Document Store using OpenSearch (https://opensearch.org/). It is compatible with the Amazon OpenSearch Service.
@ -119,11 +124,15 @@ class OpenSearchDocumentStore(SearchEngineDocumentStore):
overwrite: Update any existing documents with the same ID when adding documents.
fail: an error is raised if the document ID of the document being added already
exists.
:param index_type: The type of index to be created. Choose from 'flat' and 'hnsw'.
As OpenSearch currently does not support all similarity functions (e.g. dot_product) in exact vector similarity calculations,
we don't make use of exact vector similarity when index_type='flat'. Instead we use the same approximate vector similarity calculations like in 'hnsw', but further optimized for accuracy.
Exact vector similarity is only used as fallback when there's a mismatch between certain requested and indexed similarity types.
In these cases however, a warning will be displayed. See similarity param for more information.
:param index_type: The type of index you want to create. Choose from 'flat', 'hnsw', 'ivf', or 'ivf_pq'.
'ivf_pq' is an IVF index optimized for memory through product quantization.
('ivf' and 'ivf_pq' are only available with 'faiss' as knn_engine.)
If index_type='flat', we use OpenSearch's default index settings (which is an hnsw index
optimized for accuracy and memory footprint), since OpenSearch does not require a special
index for exact vector similarity calculations. Note that OpenSearchDocumentStore will only
perform exact vector calculations if the selected knn_engine supports it (currently only
knn_engine='score_script'). For the other knn_engines we use hnsw, as this usually achieves
the best balance between nearly as good accuracy and latency.
:param scroll: Determines how long the current index is fixed, e.g. during updating all documents with embeddings.
Defaults to "1d" and should not be larger than this. Can also be in minutes "5m" or hours "15h"
For details, see https://www.elastic.co/guide/en/elasticsearch/reference/current/scroll-api.html
@ -140,6 +149,22 @@ class OpenSearchDocumentStore(SearchEngineDocumentStore):
More info at https://www.elastic.co/guide/en/elasticsearch/reference/current/analysis-synonym-graph-tokenfilter.html
:param knn_engine: The engine you want to use for the nearest neighbor search by OpenSearch's KNN plug-in. Possible values: "nmslib", "faiss" or "score_script". Defaults to "nmslib".
For more information, see [k-NN Index](https://opensearch.org/docs/latest/search-plugins/knn/knn-index/).
:param knn_parameters: Custom parameters for the KNN engine. Parameter names depend on the index type you use.
Configurable parameters for indices of type...
- `hnsw`: `"ef_construction"`, `"ef_search"`, `"m"`
- `ivf`: `"nlist"`, `"nprobes"`
- `ivf_pq`: `"nlist"`, `"nprobes"`, `"m"`, `"code_size"`
If you don't specify any parameters, the OpenSearch's default values are used.
(With the exception of index_type='hnsw', where we use values other than OpenSearch's
default ones to achieve comparability throughout DocumentStores in Haystack.)
For more information on configuration of knn indices, see
[OpenSearch Documentation](https://opensearch.org/docs/latest/search-plugins/knn/knn-index/#method-definitions).
:param ivf_train_size: Number of embeddings to use for training the IVF index. Training starts automatically
once the number of indexed embeddings exceeds ivf_train_size. If `None`, the minimum
number of embeddings recommended for training by FAISS is used (depends on the desired
index type and knn parameters). If `0`, training doesn't happen automatically but needs
to be triggered manually via the `train_index` method.
Default: `None`
"""
# These parameters aren't used by Opensearch at the moment but could be in the future, see
# https://github.com/opensearch-project/security/issues/1504. Let's not deprecate them for
@ -178,7 +203,23 @@ class OpenSearchDocumentStore(SearchEngineDocumentStore):
if knn_engine not in {"nmslib", "faiss", "score_script"}:
raise ValueError(f"knn_engine must be either 'nmslib', 'faiss' or 'score_script' but was {knn_engine}")
if index_type in self.valid_index_types:
if index_type in ["ivf", "ivf_pq"] and knn_engine != "faiss":
raise DocumentStoreError("Use 'faiss' as knn_engine when using 'ivf' as index_type.")
self.index_type = index_type
else:
raise DocumentStoreError(
f"Invalid value for index_type in constructor. Choose one of these values: {self.valid_index_types}."
)
self.knn_engine = knn_engine
self.knn_parameters = {} if knn_parameters is None else knn_parameters
if ivf_train_size is not None:
if ivf_train_size <= 0:
raise DocumentStoreError("`ivf_train_on_write_size` must be None or a positive integer.")
self.ivf_train_size = ivf_train_size
elif self.index_type in ["ivf", "ivf_pq"]:
self.ivf_train_size = self._recommended_ivf_train_size()
self.space_type = SIMILARITY_SPACE_TYPE_MAPPINGS[knn_engine][similarity]
super().__init__(
client=client,
@ -198,7 +239,6 @@ class OpenSearchDocumentStore(SearchEngineDocumentStore):
similarity=similarity,
return_embedding=return_embedding,
duplicate_documents=duplicate_documents,
index_type=index_type,
scroll=scroll,
skip_missing_embeddings=skip_missing_embeddings,
synonyms=synonyms,
@ -315,6 +355,9 @@ class OpenSearchDocumentStore(SearchEngineDocumentStore):
:raises DuplicateDocumentError: Exception trigger on duplicate document
:return: None
"""
if index is None:
index = self.index
if self.knn_engine == "faiss" and self.similarity == "cosine":
field_map = self._create_document_field_map()
documents = [Document.from_dict(d, field_map=field_map) if isinstance(d, dict) else d for d in documents]
@ -331,6 +374,16 @@ class OpenSearchDocumentStore(SearchEngineDocumentStore):
headers=headers,
)
# Train IVF index if number of embeddings exceeds ivf_train_size
if (
self.index_type in ["ivf", "ivf_pq"]
and not index.startswith(".")
and not self._ivf_model_exists(index=index)
):
if self.get_embedding_count(index=index, headers=headers) >= self.ivf_train_size:
train_docs = self.get_all_documents(index=index, return_embedding=True, headers=headers)
self._train_ivf_index(index=index, documents=train_docs, headers=headers)
def _embed_documents(self, documents: List[Document], retriever: DenseRetriever) -> np.ndarray:
"""
Embed a list of documents using a Retriever.
@ -438,6 +491,9 @@ class OpenSearchDocumentStore(SearchEngineDocumentStore):
if return_embedding is None:
return_embedding = self.return_embedding
if self.index_type in ["ivf", "ivf_pq"] and not self._ivf_model_exists(index=index):
self._ivf_index_not_trained_error(index=index, headers=headers)
if not self.embedding_field:
raise DocumentStoreError("Please set a valid `embedding_field` for OpenSearchDocumentStore")
body = self._construct_dense_query_body(
@ -451,8 +507,117 @@ class OpenSearchDocumentStore(SearchEngineDocumentStore):
self._convert_es_hit_to_document(hit, adapt_score_for_embedding=True, scale_score=scale_score)
for hit in result
]
if self.index_type == "hnsw":
ef_search = self._get_ef_search_value()
if top_k > ef_search:
logger.warning(
"top_k (%i) is greater than ef_search (%i). "
"We recommend setting ef_search >= top_k for optimal performance.",
top_k,
ef_search,
)
return documents
def query_by_embedding_batch(
self,
query_embs: Union[List[np.ndarray], np.ndarray],
filters: Optional[Union[FilterType, List[Optional[FilterType]]]] = None,
top_k: int = 10,
index: Optional[str] = None,
return_embedding: Optional[bool] = None,
headers: Optional[Dict[str, str]] = None,
scale_score: bool = True,
) -> List[List[Document]]:
"""
Find the documents that are most similar to the provided `query_embs` by using a vector similarity metric.
:param query_embs: Embeddings of the queries (e.g. gathered from DPR).
Can be a list of one-dimensional numpy arrays or a two-dimensional numpy array.
:param filters: Optional filters to narrow down the search space to documents whose metadata fulfill certain
conditions.
Filters are defined as nested dictionaries. The keys of the dictionaries can be a logical
operator (`"$and"`, `"$or"`, `"$not"`), a comparison operator (`"$eq"`, `"$in"`, `"$gt"`,
`"$gte"`, `"$lt"`, `"$lte"`) or a metadata field name.
Logical operator keys take a dictionary of metadata field names and/or logical operators as
value. Metadata field names take a dictionary of comparison operators as value. Comparison
operator keys take a single value or (in case of `"$in"`) a list of values as value.
If no logical operator is provided, `"$and"` is used as default operation. If no comparison
operator is provided, `"$eq"` (or `"$in"` if the comparison value is a list) is used as default
operation.
__Example__:
```python
filters = {
"$and": {
"type": {"$eq": "article"},
"date": {"$gte": "2015-01-01", "$lt": "2021-01-01"},
"rating": {"$gte": 3},
"$or": {
"genre": {"$in": ["economy", "politics"]},
"publisher": {"$eq": "nytimes"}
}
}
}
# or simpler using default operators
filters = {
"type": "article",
"date": {"$gte": "2015-01-01", "$lt": "2021-01-01"},
"rating": {"$gte": 3},
"$or": {
"genre": ["economy", "politics"],
"publisher": "nytimes"
}
}
```
To use the same logical operator multiple times on the same level, logical operators take
optionally a list of dictionaries as value.
__Example__:
```python
filters = {
"$or": [
{
"$and": {
"Type": "News Paper",
"Date": {
"$lt": "2019-01-01"
}
}
},
{
"$and": {
"Type": "Blog Post",
"Date": {
"$gte": "2019-01-01"
}
}
}
]
}
```
:param top_k: How many documents to return
:param index: Index name for storing the docs and metadata
:param return_embedding: To return document embedding
:param headers: Custom HTTP headers to pass to elasticsearch client (e.g. {'Authorization': 'Basic YWRtaW46cm9vdA=='})
Check out https://www.elastic.co/guide/en/elasticsearch/reference/current/http-clients.html for more information.
:param scale_score: Whether to scale the similarity score to the unit interval (range of [0,1]).
If true (default) similarity scores (e.g. cosine or dot_product) which naturally have a different value range will be scaled to a range of [0,1], where 1 means extremely relevant.
Otherwise raw similarity scores (e.g. cosine or dot_product) will be used.
:return:
"""
if index is None:
index = self.index
if self.index_type in ["ivf", "ivf_pq"] and not self._ivf_model_exists(index=index):
self._ivf_index_not_trained_error(index=index, headers=headers)
return super().query_by_embedding_batch(
query_embs, filters, top_k, index, return_embedding, headers, scale_score
)
def _construct_dense_query_body(
self, query_emb: np.ndarray, return_embedding: bool, filters: Optional[FilterType] = None, top_k: int = 10
):
@ -509,8 +674,11 @@ class OpenSearchDocumentStore(SearchEngineDocumentStore):
index_definition["settings"]["index"] = {"knn": True} # TODO: option to turn off for script scoring
# global ef_search setting affects only nmslib, for faiss it is set in the field mapping
if self.knn_engine == "nmslib" and self.index_type == "hnsw":
index_definition["settings"]["index"]["knn.algo_param.ef_search"] = 20
index_definition["mappings"]["properties"][self.embedding_field] = self._get_embedding_field_mapping()
ef_search = self._get_ef_search_value()
index_definition["settings"]["index"]["knn.algo_param.ef_search"] = ef_search
index_definition["mappings"]["properties"][self.embedding_field] = self._get_embedding_field_mapping(
index=index_name
)
try:
self.client.indices.create(index=index_name, body=index_definition, headers=headers)
@ -522,6 +690,68 @@ class OpenSearchDocumentStore(SearchEngineDocumentStore):
if not self._index_exists(index_name, headers=headers):
raise e
def train_index(
self,
documents: Optional[Union[List[dict], List[Document]]] = None,
embeddings: Optional[np.ndarray] = None,
index: Optional[str] = None,
headers: Optional[Dict[str, str]] = None,
):
"""
Trains an IVF index on the provided Documents or embeddings if the index hasn't been trained yet.
The train vectors should come from the same distribution as your final vectors.
You can pass either Documents (including embeddings) or just plain embeddings you want to train the index on.
:param documents: Documents (including the embeddings) you want to train the index on.
:param embeddings: Plain embeddings you want to train the index on.
:param index: Name of the index to train. If `None`, the DocumentStore's default index (self.index) is used.
:param headers: Custom HTTP headers to pass to the OpenSearch client (for example {'Authorization': 'Basic YWRtaW46cm9vdA=='}).
For more information, see [HTTP/REST clients and security](https://www.elastic.co/guide/en/elasticsearch/reference/current/http-clients.html).
"""
if self.index_type not in ["ivf", "ivf_pq"]:
raise DocumentStoreError(
"You can only train an index if you set `index_type` to 'ivf' or 'ivf_pq' in your DocumentStore. "
"Other index types don't require training."
)
if index is None:
index = self.index
if isinstance(embeddings, np.ndarray) and documents:
raise ValueError("Pass either `documents` or `embeddings`. You passed both.")
if documents:
document_objects = [Document.from_dict(d) if isinstance(d, dict) else d for d in documents]
document_objects = [doc for doc in document_objects if doc.embedding is not None]
self._train_ivf_index(index=index, documents=document_objects, headers=headers)
elif isinstance(embeddings, np.ndarray):
document_objects = [
Document(content=f"Embedding {i}", embedding=embedding) for i, embedding in enumerate(embeddings)
]
self._train_ivf_index(index=index, documents=document_objects, headers=headers)
else:
logger.warning(
"When calling `train_index`, you must provide either Documents or embeddings. "
"Because none of these values was provided, the index won't be trained."
)
def delete_index(self, index: str):
"""
Delete an existing search index. The index together with all data will be removed.
If the index is of type `"ivf"` or `"ivf_pq"`, this method also deletes the corresponding IVF and PQ model.
:param index: The name of the index to delete.
:return: None
"""
# Check if index uses an IVF model and delete it
index_mapping = self.client.indices.get(index)[index]["mappings"]["properties"]
if self.embedding_field in index_mapping and "model_id" in index_mapping[self.embedding_field]:
model_id = index_mapping[self.embedding_field]["model_id"]
self.client.transport.perform_request("DELETE", f"/_plugins/_knn/models/{model_id}")
super().delete_index(index)
def _validate_and_adjust_document_index(self, index_name: str, headers: Optional[Dict[str, str]] = None):
"""
Validates an existing document index. If there's no embedding field, we'll add it.
@ -565,7 +795,7 @@ class OpenSearchDocumentStore(SearchEngineDocumentStore):
if existing_embedding_field is None:
# create embedding field
mappings["properties"][self.embedding_field] = self._get_embedding_field_mapping()
mappings["properties"][self.embedding_field] = self._get_embedding_field_mapping(index=index_name)
self.client.indices.put_mapping(index=index_id, body=mappings, headers=headers)
else:
# check type of existing embedding field
@ -579,14 +809,16 @@ class OpenSearchDocumentStore(SearchEngineDocumentStore):
)
# Check if existing embedding field fits desired knn settings
if self.knn_engine != "score_script":
training_required = self.index_type in ["ivf", "ivf_pq"] and "model_id" not in existing_embedding_field
if self.knn_engine != "score_script" and not training_required:
self._validate_approximate_knn_settings(existing_embedding_field, index_settings, index_id)
# Adjust global ef_search setting (nmslib only). If not set, default is 512.
# Adjust global ef_search setting (nmslib only).
if self.knn_engine == "nmslib":
ef_search = index_settings.get("knn.algo_param", {}).get("ef_search", 512)
if self.index_type == "hnsw" and ef_search != 20:
body = {"knn.algo_param.ef_search": 20}
desired_ef_search = self._get_ef_search_value()
if self.index_type == "hnsw" and ef_search != desired_ef_search:
body = {"knn.algo_param.ef_search": desired_ef_search}
self.client.indices.put_settings(index=index_id, body=body, headers=headers)
logger.info("Set ef_search to 20 for hnsw index '%s'.", index_id)
elif self.index_type == "flat" and ef_search != 512:
@ -603,19 +835,13 @@ class OpenSearchDocumentStore(SearchEngineDocumentStore):
If settings are not specified we infer the same default values as https://opensearch.org/docs/latest/search-plugins/knn/knn-index/
"""
method = existing_embedding_field.get("method", {})
embedding_field_space_type = method.get("space_type", "l2")
embedding_field_knn_engine = method.get("engine", "nmslib")
embedding_field_method_name = method.get("name", "hnsw")
parameters = method.get("parameters", {})
embedding_field_ef_construction = parameters.get("ef_construction", 512)
embedding_field_m = parameters.get("m", 16)
# ef_search is configured in the index settings and not in the mapping for nmslib
if embedding_field_knn_engine == "nmslib":
embedding_field_ef_search = index_settings.get("knn.algo_param", {}).get("ef_search", 512)
if "model_id" in existing_embedding_field:
embedding_field_knn_engine = "faiss"
else:
embedding_field_ef_search = parameters.get("ef_search", 512)
embedding_field_knn_engine = method.get("engine", "nmslib")
embedding_field_space_type = method.get("space_type", "l2")
# Validate knn engine
if embedding_field_knn_engine != self.knn_engine:
raise DocumentStoreError(
f"Existing embedding field '{self.embedding_field}' of OpenSearch index '{index_id}' has knn_engine "
@ -626,6 +852,7 @@ class OpenSearchDocumentStore(SearchEngineDocumentStore):
f" - Overwrite the existing index by setting `recreate_index=True`. Note that you'll lose all existing data."
)
# Validate space type
if embedding_field_space_type != self.space_type:
supported_similaries = [
k
@ -646,7 +873,32 @@ class OpenSearchDocumentStore(SearchEngineDocumentStore):
f" - Overwrite the existing index by setting `recreate_index=True`. Note that you'll lose all existing data."
)
# Validate HNSW indices
if self.index_type in ["flat", "hnsw"]:
self._validate_hnsw_settings(existing_embedding_field, index_settings, index_id)
# Validate IVF indices
elif self.index_type in ["ivf", "ivf_pq"]:
self._validate_ivf_settings(existing_embedding_field, index_settings, index_id)
else:
raise DocumentStoreError("Unknown index_type. Must be one of 'flat', 'hnsw', 'ivf', or 'ivf_pq'.")
def _validate_hnsw_settings(
self, existing_embedding_field: Dict[str, Any], index_settings: Dict[str, Any], index_id: str
):
method = existing_embedding_field.get("method", {})
parameters = method.get("parameters", {})
embedding_field_method_name = method.get("name", "hnsw")
embedding_field_ef_construction = parameters.get("ef_construction", 512)
embedding_field_m = parameters.get("m", 16)
embedding_field_knn_engine = method.get("engine", "nmslib")
# ef_search is configured in the index settings and not in the mapping for nmslib
if embedding_field_knn_engine == "nmslib":
embedding_field_ef_search = index_settings.get("knn.algo_param", {}).get("ef_search", 512)
else:
embedding_field_ef_search = parameters.get("ef_search", 512)
# Check method params according to requested index_type
# Indices of type "flat" that don't use "score_script" as knn_engine use an HNSW index optimized for accuracy
if self.index_type == "flat":
self._assert_embedding_param(
name="method.name", actual=embedding_field_method_name, expected="hnsw", index_id=index_id
@ -659,17 +911,54 @@ class OpenSearchDocumentStore(SearchEngineDocumentStore):
self._assert_embedding_param(
name="ef_search", actual=embedding_field_ef_search, expected=512, index_id=index_id
)
if self.index_type == "hnsw":
elif self.index_type == "hnsw":
expected_ef_construction = self.knn_parameters.get("ef_construction", 80)
expected_m = self.knn_parameters.get("m", 64)
expected_ef_search = self.knn_parameters.get("ef_search", 20)
self._assert_embedding_param(
name="method.name", actual=embedding_field_method_name, expected="hnsw", index_id=index_id
)
self._assert_embedding_param(
name="ef_construction", actual=embedding_field_ef_construction, expected=80, index_id=index_id
name="ef_construction",
actual=embedding_field_ef_construction,
expected=expected_ef_construction,
index_id=index_id,
)
self._assert_embedding_param(name="m", actual=embedding_field_m, expected=64, index_id=index_id)
self._assert_embedding_param(name="m", actual=embedding_field_m, expected=expected_m, index_id=index_id)
if self.knn_engine == "faiss":
self._assert_embedding_param(
name="ef_search", actual=embedding_field_ef_search, expected=20, index_id=index_id
name="ef_search", actual=embedding_field_ef_search, expected=expected_ef_search, index_id=index_id
)
def _validate_ivf_settings(
self, existing_embedding_field: Dict[str, Any], index_settings: Dict[str, Any], index_id: str
):
# Index is not trained yet and should therefore be an HNSW index with default settings until index is trained
if "model_id" in existing_embedding_field:
model_endpoint = f"/_plugins/_knn/models/{existing_embedding_field['model_id']}"
response = self.client.transport.perform_request("GET", url=model_endpoint)
model_settings_list = [setting.split(":") for setting in response["description"].split()]
model_settings = {k: (int(v) if v.isnumeric() else v) for k, v in model_settings_list}
embedding_field_nlist = model_settings.get("nlist")
embedding_field_nprobes = model_settings.get("nprobes")
expected_nlist = self.knn_parameters.get("nlist", 4)
expected_nprobes = self.knn_parameters.get("nprobes", 1)
self._assert_embedding_param(
name="nlist", actual=embedding_field_nlist, expected=expected_nlist, index_id=index_id
)
self._assert_embedding_param(
name="nprobes", actual=embedding_field_nprobes, expected=expected_nprobes, index_id=index_id
)
if self.index_type == "ivf_pq":
embedding_field_m = model_settings.get("m")
embedding_field_code_size = model_settings.get("code_size")
expected_m = self.knn_parameters.get("m", 1)
expected_code_size = self.knn_parameters.get("code_size", 8)
self._assert_embedding_param(name="m", actual=embedding_field_m, expected=expected_m, index_id=index_id)
self._assert_embedding_param(
name="code_size", actual=embedding_field_code_size, expected=expected_code_size, index_id=index_id
)
def _assert_embedding_param(self, name: str, actual: Any, expected: Any, index_id: str) -> None:
@ -690,6 +979,7 @@ class OpenSearchDocumentStore(SearchEngineDocumentStore):
space_type: Optional[str] = None,
index_type: Optional[str] = None,
embedding_dim: Optional[int] = None,
index: Optional[str] = None,
) -> Dict[str, Any]:
if space_type is None:
space_type = self.space_type
@ -699,27 +989,78 @@ class OpenSearchDocumentStore(SearchEngineDocumentStore):
index_type = self.index_type
if embedding_dim is None:
embedding_dim = self.embedding_dim
if index is None:
index = self.index
embeddings_field_mapping = {"type": "knn_vector", "dimension": embedding_dim}
if knn_engine != "score_script":
method: dict = {"space_type": space_type, "name": "hnsw", "engine": knn_engine}
method: dict = {"space_type": space_type, "engine": knn_engine}
ef_construction = (
80 if "ef_construction" not in self.knn_parameters else self.knn_parameters["ef_construction"]
)
ef_search = self._get_ef_search_value()
m = 64 if "m" not in self.knn_parameters else self.knn_parameters["m"]
if index_type == "flat":
# We're using HNSW with knn_engines nmslib and faiss as they do not support exact knn.
method["name"] = "hnsw"
# use default parameters from https://opensearch.org/docs/1.2/search-plugins/knn/knn-index/
# we need to set them explicitly as aws managed instances starting from version 1.2 do not support empty parameters
method["parameters"] = {"ef_construction": 512, "m": 16}
elif index_type == "hnsw":
method["parameters"] = {"ef_construction": 80, "m": 64}
method["name"] = "hnsw"
method["parameters"] = {"ef_construction": ef_construction, "m": m}
# for nmslib this is a global index setting
if knn_engine == "faiss":
method["parameters"]["ef_search"] = 20
method["parameters"]["ef_search"] = ef_search
elif index_type in ["ivf", "ivf_pq"]:
if knn_engine != "faiss":
raise DocumentStoreError("To use 'ivf' or 'ivf_pq as index_type, set knn_engine to 'faiss'.")
# Check if IVF model already exists
if self._ivf_model_exists(index):
logger.info("Using existing IVF model '%s-ivf' for index '%s'.", index, index)
embeddings_field_mapping = {"type": "knn_vector", "model_id": f"{index}-ivf"}
method = {}
else:
# IVF indices require training before they can be initialized. Setting index_type to HNSW until
# index is trained
logger.info("Using index of type 'flat' for index '%s' until IVF model is trained.", index)
method = {}
else:
logger.error("Set index_type to either 'flat' or 'hnsw'")
logger.error("Set index_type to either 'flat', 'hnsw', 'ivf', or 'ivf_pq'.")
method["name"] = "hnsw"
embeddings_field_mapping["method"] = method
if method:
embeddings_field_mapping["method"] = method
return embeddings_field_mapping
def _ivf_model_exists(self, index: str) -> bool:
if self._index_exists(".opensearch-knn-models"):
response = self.client.transport.perform_request("GET", "/_plugins/_knn/models/_search")
existing_ivf_models = set(
model["_source"]["model_id"]
for model in response["hits"]["hits"]
if model["_source"]["state"] != "failed"
)
else:
existing_ivf_models = set()
return f"{index}-ivf" in existing_ivf_models
def _ivf_index_not_trained_error(self, index: str, headers: Optional[Dict[str, str]] = None):
add_num_of_embs = ""
if self.ivf_train_size != 0:
embs_to_add = self.get_embedding_count(index=index, headers=headers) - self.ivf_train_size
add_num_of_embs = (
f"or add at least {embs_to_add} more embeddings to automatically start the " f"training process "
)
raise DocumentStoreError(
f"Index of type '{self.index_type}' is not trained yet. Train the index manually using "
f"`train_index` {add_num_of_embs}before querying it."
)
def _create_label_index(self, index_name: str, headers: Optional[Dict[str, str]] = None):
mapping = {
"mappings": {
@ -798,6 +1139,156 @@ class OpenSearchDocumentStore(SearchEngineDocumentStore):
return score
def _train_ivf_index(
self, index: Optional[str], documents: List[Document], headers: Optional[Dict[str, str]] = None
):
"""
If the provided index is not an IVF index yet, this method trains it on the provided Documents
and converts the index to an IVF index.
"""
if index is None:
index = self.index
# Check if IVF index is already trained by checking if embedding mapping contains a model_id field
if "model_id" in self.client.indices.get(index)[index]["mappings"]["properties"][self.embedding_field]:
logger.info("IVF index '%s' is already trained. Skipping training.", index)
# IVF model is not trained yet -> train it and convert HNSW index to IVF index
else:
nlist = self.knn_parameters.get("nlist", 4)
nprobes = self.knn_parameters.get("nprobes", 1)
recommended_train_size = self._recommended_ivf_train_size()
documents = [doc for doc in documents if doc.embedding is not None]
if len(documents) < nlist:
raise DocumentStoreError(
f"IVF training requires the number of training samples to be greater than or "
f"equal to `nlist`. Number of provided training samples is `{len(documents)}` "
f"and nlist is `{nlist}`."
)
if len(documents) < recommended_train_size:
logger.warning(
"Consider increasing the number of training samples to at least "
"`%i` to get a reliable %s index.",
recommended_train_size,
self.index_type,
)
# Create temporary index containing training embeddings
self._create_document_index(index_name=f".{index}_ivf_training", headers=headers)
self.write_documents(documents=documents, index=f".{index}_ivf_training", headers=headers)
settings = f"index_type:{self.index_type} nlist:{nlist} nprobes:{nprobes}"
training_req_body: Dict = {
"training_index": f".{index}_ivf_training",
"training_field": self.embedding_field,
"dimension": self.embedding_dim,
"method": {
"name": "ivf",
"engine": "faiss",
"space_type": self.space_type,
"parameters": {"nlist": nlist, "nprobes": nprobes},
},
}
# Add product quantization
if self.index_type == "ivf_pq":
m = self.knn_parameters.get("m", 1)
code_size = self.knn_parameters.get("code_size", 8)
if code_size > 8:
raise DocumentStoreError(
f"code_size parameter for product quantization must be less than or equal to 8. "
f"Provided code_size is `{code_size}`."
)
# see FAISS doc for details: https://github.com/facebookresearch/faiss/wiki/FAQ#can-i-ignore-warning-clustering-xxx-points-to-yyy-centroids
n_clusters = 2**code_size
if len(documents) < n_clusters:
raise DocumentStoreError(
f"PQ training requires the number of training samples to be greater than or "
f"equal to the number of clusters. Number of provided training samples is `{len(documents)}` "
f"and the number of clusters is `{n_clusters}`."
)
encoder = {"name": "pq", "parameters": {"m": m, "code_size": code_size}}
settings += f" m:{m} code_size:{code_size}"
training_req_body["method"]["parameters"]["encoder"] = encoder
training_req_body["description"] = settings
logger.info("Training IVF index '%s' using {len(documents)} embeddings.", index)
train_endpoint = f"/_plugins/_knn/models/{index}-ivf/_train"
response = self.client.transport.perform_request(
"POST", url=train_endpoint, headers=headers, body=training_req_body
)
ivf_model = response["model_id"]
# Wait until model training is finished, _knn_model_trained uses a retry decorator
if self._knn_model_trained(ivf_model, headers=headers):
logger.info("Training of IVF index '%s' finished.", index)
# Delete temporary training index
self.client.indices.delete(index=f".{index}_ivf_training", headers=headers)
# Clone original index to temporary one
self.client.indices.add_block(index=index, block="read_only")
self.client.indices.clone(
index=index, target=f".{index}_temp", body={"settings": {"index": {"blocks": {"read_only": False}}}}
)
self.client.indices.put_settings(index=index, body={"index": {"blocks": {"read_only": False}}})
self.client.indices.delete(index=index)
# Reindex original index to newly created IVF index
self._create_document_index(index_name=index, headers=headers)
self.client.reindex(
body={"source": {"index": f".{index}_temp"}, "dest": {"index": index}},
params={"request_timeout": 24 * 60 * 60},
)
self.client.indices.delete(index=f".{index}_temp")
def _recommended_ivf_train_size(self) -> int:
"""
Calculates the minumum recommended number of training samples for IVF training as suggested in FAISS docs.
https://github.com/facebookresearch/faiss/wiki/FAQ#can-i-ignore-warning-clustering-xxx-points-to-yyy-centroids
"""
min_points_per_cluster = 39
if self.index_type == "ivf":
n_clusters = self.knn_parameters.get("nlist", 4)
return n_clusters * min_points_per_cluster
elif self.index_type == "ivf_pq":
n_clusters = 2 ** self.knn_parameters.get("code_size", 8)
return n_clusters * min_points_per_cluster
else:
raise DocumentStoreError(f"Invalid index type '{self.index_type}'.")
@retry(retry=retry_if_not_result(bool), wait=wait_exponential(min=1, max=10))
def _knn_model_trained(self, model_name: str, headers: Optional[Dict[str, str]] = None) -> bool:
model_state_endpoint = f"/_plugins/_knn/models/{model_name}"
response = self.client.transport.perform_request("GET", url=model_state_endpoint, headers=headers)
model_state = response["state"]
if model_state == "created":
return True
elif model_state == "failed":
error_message = response["error"]
raise DocumentStoreError(f"Failed to train the KNN model. Error message: {error_message}")
return False
def _get_ef_search_value(self) -> int:
ef_search = 20 if "ef_search" not in self.knn_parameters else self.knn_parameters["ef_search"]
return ef_search
def _delete_index(self, index: str):
if self._index_exists(index):
self.client.indices.delete(index=index, ignore=[400, 404])
self._delete_ivf_model(index)
logger.info("Index '%s' deleted.", index)
def _delete_ivf_model(self, index: str):
"""
If index is an index of type 'ivf' or 'ivf_pq', this method deletes the corresponding IVF model.
"""
if self._index_exists(".opensearch-knn-models"):
response = self.client.transport.perform_request("GET", "/_plugins/_knn/models/_search")
existing_ivf_models = set(model["_source"]["model_id"] for model in response["hits"]["hits"])
if f"{index}-ivf" in existing_ivf_models:
self.client.transport.perform_request("DELETE", f"/_plugins/_knn/models/{index}-ivf")
def clone_embedding_field(
self,
new_embedding_field: str,

View File

@ -66,7 +66,6 @@ class SearchEngineDocumentStore(KeywordDocumentStore):
similarity: str = "dot_product",
return_embedding: bool = False,
duplicate_documents: str = "overwrite",
index_type: str = "flat",
scroll: str = "1d",
skip_missing_embeddings: bool = True,
synonyms: Optional[List] = None,
@ -105,10 +104,6 @@ class SearchEngineDocumentStore(KeywordDocumentStore):
raise DocumentStoreError(
f"Invalid value {similarity} for similarity, choose between 'cosine', 'l2' and 'dot_product'"
)
if index_type in ["flat", "hnsw"]:
self.index_type = index_type
else:
raise Exception("Invalid value for index_type in constructor. Choose between 'flat' and 'hnsw'")
self._init_indices(
index=index, label_index=label_index, create_index=create_index, recreate_index=recreate_index

View File

@ -66,6 +66,7 @@ dependencies = [
"azure-ai-formrecognizer>=3.2.0b2", # forms reader
# audio's espnet-model-zoo requires huggingface-hub version <0.8 while we need >=0.5 to be able to use create_repo in FARMReader
"huggingface-hub>=0.5.0",
"tenacity", # retry decorator
# Preprocessing
"more_itertools", # for windowing

View File

@ -136,6 +136,35 @@ class TestFAISSDocumentStore(DocumentStoreBaseTestAbstract):
# Check that get_embedding_count works as expected
assert document_store.get_embedding_count() == len(documents_with_embeddings)
@pytest.mark.integration
def test_train_index_from_docs(self, documents_with_embeddings, tmp_path):
document_store = FAISSDocumentStore(
sql_url=f"sqlite:///{tmp_path}/test_faiss_retrieving.db",
faiss_index_factory_str="IVF1,Flat",
isolation_level="AUTOCOMMIT",
return_embedding=True,
)
document_store.delete_all_documents(index=document_store.index)
assert not document_store.faiss_indexes[document_store.index].is_trained
document_store.train_index(documents_with_embeddings)
assert document_store.faiss_indexes[document_store.index].is_trained
@pytest.mark.integration
def test_train_index_from_embeddings(self, documents_with_embeddings, tmp_path):
document_store = FAISSDocumentStore(
sql_url=f"sqlite:///{tmp_path}/test_faiss_retrieving.db",
faiss_index_factory_str="IVF1,Flat",
isolation_level="AUTOCOMMIT",
return_embedding=True,
)
document_store.delete_all_documents(index=document_store.index)
embeddings = np.array([doc.embedding for doc in documents_with_embeddings])
assert not document_store.faiss_indexes[document_store.index].is_trained
document_store.train_index(embeddings=embeddings)
assert document_store.faiss_indexes[document_store.index].is_trained
@pytest.mark.integration
def test_write_docs_different_indexes(self, ds, documents_with_embeddings):
docs_a = documents_with_embeddings[:2]

View File

@ -127,8 +127,11 @@ class TestOpenSearchDocumentStore(DocumentStoreBaseTestAbstract, SearchEngineDoc
OpenSearchDocumentStore(index="nmslib_index", create_index=True)
@pytest.mark.integration
def test___init___faiss(self):
OpenSearchDocumentStore(index="faiss_index", create_index=True, knn_engine="faiss")
@pytest.mark.parametrize("index_type", ["flat", "hnsw", "ivf", "ivf_pq"])
def test___init___faiss(self, index_type):
OpenSearchDocumentStore(
index=f"faiss_index_{index_type}", recreate_index=True, knn_engine="faiss", index_type=index_type
)
@pytest.mark.integration
def test___init___score_script(self):
@ -185,6 +188,107 @@ class TestOpenSearchDocumentStore(DocumentStoreBaseTestAbstract, SearchEngineDoc
for result in results:
assert len(result) == 3
@pytest.mark.integration
@pytest.mark.parametrize("index_type", ["ivf", "ivf_pq"])
def test_train_index_from_documents(self, ds: OpenSearchDocumentStore, documents, index_type):
# Create another document store on top of the previous one
ds = OpenSearchDocumentStore(
index=ds.index,
label_index=ds.label_index,
recreate_index=True,
knn_engine="faiss",
index_type=index_type,
knn_parameters={"code_size": 2},
)
# Check that IVF indices use score_script before training
emb_field_settings = ds.client.indices.get(ds.index)[ds.index]["mappings"]["properties"][ds.embedding_field]
assert emb_field_settings == {"type": "knn_vector", "dimension": 768}
ds.train_index(documents)
# Check that embedding_field_settings have been updated
emb_field_settings = ds.client.indices.get(ds.index)[ds.index]["mappings"]["properties"][ds.embedding_field]
assert emb_field_settings == {"type": "knn_vector", "model_id": f"{ds.index}-ivf"}
# Check that model uses expected parameters
expected_model_settigns = {"index_type": index_type, "nlist": 4, "nprobes": 1}
if index_type == "ivf_pq":
expected_model_settigns["code_size"] = 2
expected_model_settigns["m"] = 1
model_endpoint = f"/_plugins/_knn/models/{ds.index}-ivf"
response = ds.client.transport.perform_request("GET", url=model_endpoint)
model_settings_list = [setting.split(":") for setting in response["description"].split()]
model_settings = {k: (int(v) if v.isnumeric() else v) for k, v in model_settings_list}
assert model_settings == expected_model_settigns
@pytest.mark.integration
@pytest.mark.parametrize("index_type", ["ivf", "ivf_pq"])
def test_train_index_from_embeddings(self, ds: OpenSearchDocumentStore, documents, index_type):
# Create another document store on top of the previous one
ds = OpenSearchDocumentStore(
index=ds.index,
label_index=ds.label_index,
recreate_index=True,
knn_engine="faiss",
index_type=index_type,
knn_parameters={"code_size": 2},
)
# Check that IVF indices use HNSW with default settings before training
emb_field_settings = ds.client.indices.get(ds.index)[ds.index]["mappings"]["properties"][ds.embedding_field]
assert emb_field_settings == {"type": "knn_vector", "dimension": 768}
embeddings = np.array([doc.embedding for doc in documents if doc.embedding is not None])
ds.train_index(embeddings=embeddings)
# Check that embedding_field_settings have been updated
emb_field_settings = ds.client.indices.get(ds.index)[ds.index]["mappings"]["properties"][ds.embedding_field]
assert emb_field_settings == {"type": "knn_vector", "model_id": f"{ds.index}-ivf"}
# Check that model uses expected parameters
expected_model_settigns = {"index_type": index_type, "nlist": 4, "nprobes": 1}
if index_type == "ivf_pq":
expected_model_settigns["code_size"] = 2
expected_model_settigns["m"] = 1
model_endpoint = f"/_plugins/_knn/models/{ds.index}-ivf"
response = ds.client.transport.perform_request("GET", url=model_endpoint)
model_settings_list = [setting.split(":") for setting in response["description"].split()]
model_settings = {k: (int(v) if v.isnumeric() else v) for k, v in model_settings_list}
assert model_settings == expected_model_settigns
@pytest.mark.integration
@pytest.mark.parametrize("index_type", ["ivf", "ivf_pq"])
def test_train_index_with_write_documents(self, ds: OpenSearchDocumentStore, documents, index_type):
# Create another document store on top of the previous one
ds = OpenSearchDocumentStore(
index=ds.index,
label_index=ds.label_index,
recreate_index=True,
knn_engine="faiss",
index_type=index_type,
knn_parameters={"code_size": 2},
ivf_train_size=6,
)
# Check that IVF indices use HNSW with default settings before training
emb_field_settings = ds.client.indices.get(ds.index)[ds.index]["mappings"]["properties"][ds.embedding_field]
assert emb_field_settings == {"type": "knn_vector", "dimension": 768}
ds.write_documents(documents)
# Check that embedding_field_settings have been updated
emb_field_settings = ds.client.indices.get(ds.index)[ds.index]["mappings"]["properties"][ds.embedding_field]
assert emb_field_settings == {"type": "knn_vector", "model_id": f"{ds.index}-ivf"}
# Check that model uses expected parameters
expected_model_settigns = {"index_type": index_type, "nlist": 4, "nprobes": 1}
if index_type == "ivf_pq":
expected_model_settigns["code_size"] = 2
expected_model_settigns["m"] = 1
model_endpoint = f"/_plugins/_knn/models/{ds.index}-ivf"
response = ds.client.transport.perform_request("GET", url=model_endpoint)
model_settings_list = [setting.split(":") for setting in response["description"].split()]
model_settings = {k: (int(v) if v.isnumeric() else v) for k, v in model_settings_list}
assert model_settings == expected_model_settigns
# Unit tests
@pytest.mark.unit
@ -294,6 +398,20 @@ class TestOpenSearchDocumentStore(DocumentStoreBaseTestAbstract, SearchEngineDoc
with pytest.raises(DocumentStoreError):
mocked_document_store.query_by_embedding(self.query_emb)
@pytest.mark.unit
def test_query_by_embedding_raises_if_ivf_untrained(self, mocked_document_store):
mocked_document_store.index_type = "ivf"
mocked_document_store.ivf_train_size = 10
with pytest.raises(DocumentStoreError, match="Index of type 'ivf' is not trained yet."):
mocked_document_store.query_by_embedding(self.query_emb)
@pytest.mark.unit
def test_query_by_embedding_batch_if_ivf_untrained(self, mocked_document_store):
mocked_document_store.index_type = "ivf"
mocked_document_store.ivf_train_size = 10
with pytest.raises(DocumentStoreError, match="Index of type 'ivf' is not trained yet."):
mocked_document_store.query_by_embedding_batch([self.query_emb])
@pytest.mark.unit
def test_query_by_embedding_filters(self, mocked_document_store):
assert mocked_document_store.knn_engine != "score_script"
@ -649,8 +767,10 @@ class TestOpenSearchDocumentStore(DocumentStoreBaseTestAbstract, SearchEngineDoc
@pytest.mark.unit
def test__init_indices_creates_index_if_exists_and_recreate_index(self, mocked_document_store):
# delete_index askes twice + one check for each index creation
mocked_document_store.client.indices.exists.side_effect = [True, True, False, False]
# delete_index asks four times: one check for doc index, one check for label index
# + one check for both if ivf model exists
# create_index asks two times: one for doc index, one for label index
mocked_document_store.client.indices.exists.side_effect = [True, False, True, False, False, False]
mocked_document_store._init_indices(self.index_name, "label_index", create_index=True, recreate_index=True)
mocked_document_store.client.indices.delete.assert_called()
@ -824,7 +944,7 @@ class TestOpenSearchDocumentStore(DocumentStoreBaseTestAbstract, SearchEngineDoc
}
@pytest.mark.unit
def test__get_embedding_field_mapping_hnsw(self, mocked_document_store):
def test__get_embedding_field_mapping_default_hnsw(self, mocked_document_store):
mocked_document_store.index_type = "hnsw"
assert mocked_document_store._get_embedding_field_mapping() == {
@ -839,7 +959,7 @@ class TestOpenSearchDocumentStore(DocumentStoreBaseTestAbstract, SearchEngineDoc
}
@pytest.mark.unit
def test__get_embedding_field_mapping_hnsw_faiss(self, mocked_document_store):
def test__get_embedding_field_mapping_default_hnsw_faiss(self, mocked_document_store):
mocked_document_store.index_type = "hnsw"
mocked_document_store.knn_engine = "faiss"
@ -854,6 +974,127 @@ class TestOpenSearchDocumentStore(DocumentStoreBaseTestAbstract, SearchEngineDoc
},
}
@pytest.mark.unit
def test__get_embedding_field_mapping_custom_hnsw(self, mocked_document_store):
mocked_document_store.index_type = "hnsw"
mocked_document_store.knn_parameters = {"ef_construction": 1, "m": 2}
assert mocked_document_store._get_embedding_field_mapping() == {
"type": "knn_vector",
"dimension": 768,
"method": {
"space_type": "innerproduct",
"engine": "nmslib",
"name": "hnsw",
"parameters": {"ef_construction": 1, "m": 2},
},
}
@pytest.mark.unit
def test__get_embedding_field_mapping_custom_hnsw_faiss(self, mocked_document_store):
mocked_document_store.index_type = "hnsw"
mocked_document_store.knn_engine = "faiss"
mocked_document_store.knn_parameters = {"ef_construction": 1, "m": 2, "ef_search": 3}
assert mocked_document_store._get_embedding_field_mapping() == {
"type": "knn_vector",
"dimension": 768,
"method": {
"space_type": "innerproduct",
"engine": "faiss",
"name": "hnsw",
"parameters": {"ef_construction": 1, "m": 2, "ef_search": 3},
},
}
@pytest.mark.unit
def test__get_embedding_field_mapping_ivf(self, mocked_document_store):
mocked_document_store.index_type = "ivf"
mocked_document_store.knn_engine = "faiss"
mocked_document_store.client.indices.exists.return_value = False
# Before training, IVF indices use HNSW with default settings
assert mocked_document_store._get_embedding_field_mapping() == {"type": "knn_vector", "dimension": 768}
# Assume we have trained the index
mocked_document_store.client.indices.exists.return_value = True
mocked_document_store.client.transport.perform_request.return_value = {
"took": 4,
"timed_out": False,
"_shards": {"total": 1, "successful": 1, "skipped": 0, "failed": 0},
"hits": {
"total": {"value": 1, "relation": "eq"},
"max_score": 1.0,
"hits": [
{
"_index": ".opensearch-knn-models",
"_type": "_doc",
"_id": "document-ivf",
"_score": 1.0,
"_source": {
"model_blob": "<SOME MODEL BLOB>",
"engine": "faiss",
"space_type": "innerproduct",
"description": "index_type:ivf nlist:4 nprobes:1",
"model_id": f"{mocked_document_store.index}-ivf",
"state": "created",
"error": "",
"dimension": 768,
"timestamp": "2023-01-25T16:04:21.284398Z",
},
}
],
},
}
assert mocked_document_store._get_embedding_field_mapping() == {
"type": "knn_vector",
"model_id": f"{mocked_document_store.index}-ivf",
}
@pytest.mark.unit
def test__get_embedding_field_mapping_ivfpq(self, mocked_document_store):
mocked_document_store.index_type = "ivf_pq"
mocked_document_store.knn_engine = "faiss"
mocked_document_store.client.indices.exists.return_value = False
# Before training, IVF indices use HNSW with default settings
assert mocked_document_store._get_embedding_field_mapping() == {"type": "knn_vector", "dimension": 768}
# Assume we have trained the index
mocked_document_store.client.indices.exists.return_value = True
mocked_document_store.client.transport.perform_request.return_value = {
"took": 4,
"timed_out": False,
"_shards": {"total": 1, "successful": 1, "skipped": 0, "failed": 0},
"hits": {
"total": {"value": 1, "relation": "eq"},
"max_score": 1.0,
"hits": [
{
"_index": ".opensearch-knn-models",
"_type": "_doc",
"_id": "document-ivf",
"_score": 1.0,
"_source": {
"model_blob": "<SOME MODEL BLOB>",
"engine": "faiss",
"space_type": "innerproduct",
"description": "index_type:ivf_pq nlist:4 nprobes:1 m:1 code_size:8",
"model_id": f"{mocked_document_store.index}-ivf",
"state": "created",
"error": "",
"dimension": 768,
"timestamp": "2023-01-25T16:04:21.284398Z",
},
}
],
},
}
assert mocked_document_store._get_embedding_field_mapping() == {
"type": "knn_vector",
"model_id": f"{mocked_document_store.index}-ivf",
}
@pytest.mark.unit
def test__get_embedding_field_mapping_wrong(self, mocked_document_store, caplog):
mocked_document_store.index_type = "foo"
@ -861,7 +1102,7 @@ class TestOpenSearchDocumentStore(DocumentStoreBaseTestAbstract, SearchEngineDoc
with caplog.at_level(logging.ERROR, logger="haystack.document_stores.opensearch"):
retval = mocked_document_store._get_embedding_field_mapping()
assert "Set index_type to either 'flat' or 'hnsw'" in caplog.text
assert "Set index_type to either 'flat', 'hnsw', 'ivf', or 'ivf_pq'" in caplog.text
assert retval == {
"type": "knn_vector",
"dimension": 768,