mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-13 08:33:57 +00:00
* Add BasePipeline.validate_config, BasePipeline.validate_yaml, and some new custom exception classes * Make error composition work properly * Clarify typing * Help mypy a bit more * Update Documentation & Code Style * Enable autogenerated docs for Milvus1 and 2 separately * Revert "Enable autogenerated docs for Milvus1 and 2 separately" This reverts commit 282be4a78a6e95862a9b4c924fc3dea5ca71e28d. * Update Documentation & Code Style * Re-enable 'additionalProperties: False' * Add pipeline.type to JSON Schema, was somehow forgotten * Disable additionalProperties on the pipeline properties too * Fix json-schemas for 1.1.0 and 1.2.0 (should not do it again in the future) * Cal super in PipelineValidationError * Improve _read_pipeline_config_from_yaml's error handling * Fix generate_json_schema.py to include document stores * Fix json schemas (retro-fix 1.1.0 again) * Improve custom errors printing, add link to docs * Add function in BaseComponent to list its subclasses in a module * Make some document stores base classes abstract * Add marker 'integration' in pytest flags * Slighly improve validation of pipelines at load * Adding tests for YAML loading and validation * Make custom_query Optional for validation issues * Fix bug in _read_pipeline_config_from_yaml * Improve error handling in BasePipeline and Pipeline and add DAG check * Move json schema generation into haystack/nodes/_json_schema.py (useful for tests) * Simplify errors slightly * Add some YAML validation tests * Remove load_from_config from BasePipeline, it was never used anyway * Improve tests * Include json-schemas in package * Fix conftest imports * Make BasePipeline abstract * Improve mocking by making the test independent from the YAML version * Add exportable_to_yaml decorator to forget about set_config on mock nodes * Fix mypy errors * Comment out one monkeypatch * Fix typing again * Improve error message for validation * Add required properties to pipelines * Fix YAML version for REST API YAMLs to 1.2.0 * Fix load_from_yaml call in load_from_deepset_cloud * fix HaystackError.__getattr__ * Add super().__init__()in most nodes and docstore, comment set_config * Remove type from REST API pipelines * Remove useless init from doc2answers * Call super in Seq3SeqGenerator * Typo in deepsetcloud.py * Fix rest api indexing error mismatch and mock version of JSON schema in all tests * Working on pipeline tests * Improve errors printing slightly * Add back test_pipeline.yaml * _json_schema.py supports different versions with identical schemas * Add type to 0.7 schema for backwards compatibility * Fix small bug in _json_schema.py * Try alternative to generate json schemas on the CI * Update Documentation & Code Style * Make linux CI match autoformat CI * Fix super-init-not-called * Accidentally committed file * Update Documentation & Code Style * fix test_summarizer_translation.py's import * Mock YAML in a few suites, split and simplify test_pipeline_debug_and_validation.py::test_invalid_run_args * Fix json schema for ray tests too * Update Documentation & Code Style * Reintroduce validation * Usa unstable version in tests and rest api * Make unstable support the latest versions * Update Documentation & Code Style * Remove needless fixture * Make type in pipeline optional in the strings validation * Fix schemas * Fix string validation for pipeline type * Improve validate_config_strings * Remove type from test p[ipelines * Update Documentation & Code Style * Fix test_pipeline * Removing more type from pipelines * Temporary CI patc * Fix issue with exportable_to_yaml never invoking the wrapped init * rm stray file * pipeline tests are green again * Linux CI now needs .[all] to generate the schema * Bugfixes, pipeline tests seems to be green * Typo in version after merge * Implement missing methods in Weaviate * Trying to avoid FAISS tests from running in the Milvus1 test suite * Fix some stray test paths and faiss index dumping * Fix pytest markers list * Temporarily disable cache to be able to see tests failures * Fix pyproject.toml syntax * Use only tmp_path * Fix preprocessor signature after merge * Fix faiss bug * Fix Ray test * Fix documentation issue by removing quotes from faiss type * Update Documentation & Code Style * use document properly in preprocessor tests * Update Documentation & Code Style * make preprocessor capable of handling documents * import document * Revert support for documents in preprocessor, do later * Fix bug in _json_schema.py that was breaking validation * re-enable cache * Update Documentation & Code Style * Simplify calling _json_schema.py from the CI * Remove redundant ABC inheritance * Ensure exportable_to_yaml works only on implementations * Rename subclass to class_ in Meta * Make run() and get_config() abstract in BasePipeline * Revert unintended change in preprocessor * Move outgoing_edges_input_node check inside try block * Rename VALID_CODE_GEN_INPUT_REGEX into VALID_INPUT_REGEX * Add check for a RecursionError on validate_config_strings * Address usages of _pipeline_config in data silo and elasticsearch * Rename _pipeline_config into _init_parameters * Fix pytest marker and remove unused imports * Remove most redundant ABCs * Rename _init_parameters into _component_configuration * Remove set_config and type from _component_configuration's dict * Remove last instances of set_config and replace with super().__init__() * Implement __init_subclass__ approach * Simplify checks on the existence of _component_configuration * Fix faiss issue * Dynamic generation of node schemas & weed out old schemas * Add debatable test * Add docstring to debatable test * Positive diff between schemas implemented * Improve diff printing * Rename REST API YAML files to trigger IDE validation * Fix typing issues * Fix more typing * Typo in YAML filename * Remove needless type:ignore * Add tests * Fix tests & validation feedback for accessory classes in custom nodes * Refactor RAGeneratorType out * Fix broken import in conftest * Improve source error handling * Remove unused import in test_eval.py breaking tests * Fix changed error message in tests matches too * Normalize generate_openapi_specs.py and generate_json_schema.py in the actions * Fix path to generate_openapi_specs.py in autoformat.yml * Update Documentation & Code Style * Add test for FAISSDocumentStore-like situations (superclass with init params) * Update Documentation & Code Style * Fix indentation * Remove commented set_config * Store model_name_or_path in FARMReader to use in DistillationDataSilo * Rename _component_configuration into _component_config * Update Documentation & Code Style Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
649 lines
31 KiB
Python
649 lines
31 KiB
Python
from typing import TYPE_CHECKING, Any, Union, List, Optional, Dict, Generator
|
|
|
|
if TYPE_CHECKING:
|
|
from haystack.nodes.retriever import BaseRetriever
|
|
|
|
import json
|
|
import logging
|
|
import warnings
|
|
import numpy as np
|
|
from copy import deepcopy
|
|
from pathlib import Path
|
|
from tqdm.auto import tqdm
|
|
from inspect import Signature, signature
|
|
|
|
try:
|
|
import faiss
|
|
from haystack.document_stores.sql import (
|
|
SQLDocumentStore,
|
|
) # its deps are optional, but get installed with the `faiss` extra
|
|
except (ImportError, ModuleNotFoundError) as ie:
|
|
from haystack.utils.import_utils import _optional_component_not_installed
|
|
|
|
_optional_component_not_installed(__name__, "faiss", ie)
|
|
|
|
from haystack.schema import Document
|
|
from haystack.document_stores.base import get_batches_from_generator
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class FAISSDocumentStore(SQLDocumentStore):
|
|
"""
|
|
Document store for very large scale embedding based dense retrievers like the DPR.
|
|
|
|
It implements the FAISS library(https://github.com/facebookresearch/faiss)
|
|
to perform similarity search on vectors.
|
|
|
|
The document text and meta-data (for filtering) are stored using the SQLDocumentStore, while
|
|
the vector embeddings are indexed in a FAISS Index.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
sql_url: str = "sqlite:///faiss_document_store.db",
|
|
vector_dim: int = None,
|
|
embedding_dim: int = 768,
|
|
faiss_index_factory_str: str = "Flat",
|
|
faiss_index: Optional[faiss.swigfaiss.Index] = None,
|
|
return_embedding: bool = False,
|
|
index: str = "document",
|
|
similarity: str = "dot_product",
|
|
embedding_field: str = "embedding",
|
|
progress_bar: bool = True,
|
|
duplicate_documents: str = "overwrite",
|
|
faiss_index_path: Union[str, Path] = None,
|
|
faiss_config_path: Union[str, Path] = None,
|
|
isolation_level: str = None,
|
|
**kwargs,
|
|
):
|
|
"""
|
|
:param sql_url: SQL connection URL for database. It defaults to local file based SQLite DB. For large scale
|
|
deployment, Postgres is recommended.
|
|
:param vector_dim: Deprecated. Use embedding_dim instead.
|
|
:param embedding_dim: The embedding vector size. Default: 768.
|
|
:param faiss_index_factory_str: Create a new FAISS index of the specified type.
|
|
The type is determined from the given string following the conventions
|
|
of the original FAISS index factory.
|
|
Recommended options:
|
|
- "Flat" (default): Best accuracy (= exact). Becomes slow and RAM intense for > 1 Mio docs.
|
|
- "HNSW": Graph-based heuristic. If not further specified,
|
|
we use the following config:
|
|
HNSW64, efConstruction=80 and efSearch=20
|
|
- "IVFx,Flat": Inverted Index. Replace x with the number of centroids aka nlist.
|
|
Rule of thumb: nlist = 10 * sqrt (num_docs) is a good starting point.
|
|
For more details see:
|
|
- Overview of indices https://github.com/facebookresearch/faiss/wiki/Faiss-indexes
|
|
- Guideline for choosing an index https://github.com/facebookresearch/faiss/wiki/Guidelines-to-choose-an-index
|
|
- FAISS Index factory https://github.com/facebookresearch/faiss/wiki/The-index-factory
|
|
Benchmarks: XXX
|
|
:param faiss_index: Pass an existing FAISS Index, i.e. an empty one that you configured manually
|
|
or one with docs that you used in Haystack before and want to load again.
|
|
:param return_embedding: To return document embedding. Unlike other document stores, FAISS will return normalized embeddings
|
|
:param index: Name of index in document store to use.
|
|
:param similarity: The similarity function used to compare document vectors. 'dot_product' is the default since it is
|
|
more performant with DPR embeddings. 'cosine' is recommended if you are using a Sentence-Transformer model.
|
|
In both cases, the returned values in Document.score are normalized to be in range [0,1]:
|
|
For `dot_product`: expit(np.asarray(raw_score / 100))
|
|
FOr `cosine`: (raw_score + 1) / 2
|
|
:param embedding_field: Name of field containing an embedding vector.
|
|
:param progress_bar: Whether to show a tqdm progress bar or not.
|
|
Can be helpful to disable in production deployments to keep the logs clean.
|
|
:param duplicate_documents: Handle duplicates document based on parameter options.
|
|
Parameter options : ( 'skip','overwrite','fail')
|
|
skip: Ignore the duplicates documents
|
|
overwrite: Update any existing documents with the same ID when adding documents.
|
|
fail: an error is raised if the document ID of the document being added already
|
|
exists.
|
|
:param faiss_index_path: Stored FAISS index file. Can be created via calling `save()`.
|
|
If specified no other params besides faiss_config_path must be specified.
|
|
:param faiss_config_path: Stored FAISS initial configuration parameters.
|
|
Can be created via calling `save()`
|
|
:param isolation_level: see SQLAlchemy's `isolation_level` parameter for `create_engine()` (https://docs.sqlalchemy.org/en/14/core/engines.html#sqlalchemy.create_engine.params.isolation_level)
|
|
"""
|
|
# special case if we want to load an existing index from disk
|
|
# load init params from disk and run init again
|
|
if faiss_index_path is not None:
|
|
sig = signature(self.__class__.__init__)
|
|
self._validate_params_load_from_disk(sig, locals(), kwargs)
|
|
init_params = self._load_init_params_from_config(faiss_index_path, faiss_config_path)
|
|
self.__class__.__init__(self, **init_params) # pylint: disable=non-parent-init-called
|
|
return
|
|
|
|
if similarity in ("dot_product", "cosine"):
|
|
self.similarity = similarity
|
|
self.metric_type = faiss.METRIC_INNER_PRODUCT
|
|
elif similarity == "l2":
|
|
self.similarity = similarity
|
|
self.metric_type = faiss.METRIC_L2
|
|
else:
|
|
raise ValueError(
|
|
"The FAISS document store can currently only support dot_product, cosine and l2 similarity. "
|
|
"Please set similarity to one of the above."
|
|
)
|
|
|
|
if vector_dim is not None:
|
|
warnings.warn(
|
|
"The 'vector_dim' parameter is deprecated, " "use 'embedding_dim' instead.", DeprecationWarning, 2
|
|
)
|
|
self.embedding_dim = vector_dim
|
|
else:
|
|
self.embedding_dim = embedding_dim
|
|
|
|
self.faiss_index_factory_str = faiss_index_factory_str
|
|
self.faiss_indexes: Dict[str, faiss.swigfaiss.Index] = {}
|
|
if faiss_index:
|
|
self.faiss_indexes[index] = faiss_index
|
|
else:
|
|
self.faiss_indexes[index] = self._create_new_index(
|
|
embedding_dim=self.embedding_dim,
|
|
index_factory=faiss_index_factory_str,
|
|
metric_type=self.metric_type,
|
|
**kwargs,
|
|
)
|
|
|
|
self.return_embedding = return_embedding
|
|
self.embedding_field = embedding_field
|
|
|
|
self.progress_bar = progress_bar
|
|
|
|
super().__init__(
|
|
url=sql_url, index=index, duplicate_documents=duplicate_documents, isolation_level=isolation_level
|
|
)
|
|
|
|
self._validate_index_sync()
|
|
|
|
def _validate_params_load_from_disk(self, sig: Signature, locals: dict, kwargs: dict):
|
|
allowed_params = ["faiss_index_path", "faiss_config_path", "self", "kwargs"]
|
|
invalid_param_set = False
|
|
|
|
for param in sig.parameters.values():
|
|
if param.name not in allowed_params and param.default != locals[param.name]:
|
|
invalid_param_set = True
|
|
break
|
|
|
|
if invalid_param_set or len(kwargs) > 0:
|
|
raise ValueError("if faiss_index_path is passed no other params besides faiss_config_path are allowed.")
|
|
|
|
def _validate_index_sync(self):
|
|
# This check ensures the correct document database was loaded.
|
|
# If it fails, make sure you provided the path to the database
|
|
# used when creating the original FAISS index
|
|
if not self.get_document_count() == self.get_embedding_count():
|
|
raise ValueError(
|
|
"The number of documents present in the SQL database does not "
|
|
"match the number of embeddings in FAISS. Make sure your FAISS "
|
|
"configuration file correctly points to the same database that "
|
|
"was used when creating the original index."
|
|
)
|
|
|
|
def _create_new_index(self, embedding_dim: int, metric_type, index_factory: str = "Flat", **kwargs):
|
|
if index_factory == "HNSW":
|
|
# faiss index factory doesn't give the same results for HNSW IP, therefore direct init.
|
|
# defaults here are similar to DPR codebase (good accuracy, but very high RAM consumption)
|
|
n_links = kwargs.get("n_links", 64)
|
|
index = faiss.IndexHNSWFlat(embedding_dim, n_links, metric_type)
|
|
index.hnsw.efSearch = kwargs.get("efSearch", 20) # 20
|
|
index.hnsw.efConstruction = kwargs.get("efConstruction", 80) # 80
|
|
if "ivf" in index_factory.lower(): # enable reconstruction of vectors for inverted index
|
|
self.faiss_indexes[index].set_direct_map_type(faiss.DirectMap.Hashtable)
|
|
|
|
logger.info(
|
|
f"HNSW params: n_links: {n_links}, efSearch: {index.hnsw.efSearch}, efConstruction: {index.hnsw.efConstruction}"
|
|
)
|
|
else:
|
|
index = faiss.index_factory(embedding_dim, index_factory, metric_type)
|
|
return index
|
|
|
|
def write_documents(
|
|
self,
|
|
documents: Union[List[dict], List[Document]],
|
|
index: Optional[str] = None,
|
|
batch_size: int = 10_000,
|
|
duplicate_documents: Optional[str] = None,
|
|
headers: Optional[Dict[str, str]] = None,
|
|
) -> None:
|
|
"""
|
|
Add new documents to the DocumentStore.
|
|
|
|
:param documents: List of `Dicts` or List of `Documents`. If they already contain the embeddings, we'll index
|
|
them right away in FAISS. If not, you can later call update_embeddings() to create & index them.
|
|
:param index: (SQL) index name for storing the docs and metadata
|
|
:param batch_size: When working with large number of documents, batching can help reduce memory footprint.
|
|
:param duplicate_documents: Handle duplicates document based on parameter options.
|
|
Parameter options : ( 'skip','overwrite','fail')
|
|
skip: Ignore the duplicates documents
|
|
overwrite: Update any existing documents with the same ID when adding documents.
|
|
fail: an error is raised if the document ID of the document being added already
|
|
exists.
|
|
:raises DuplicateDocumentError: Exception trigger on duplicate document
|
|
:return: None
|
|
"""
|
|
if headers:
|
|
raise NotImplementedError("FAISSDocumentStore does not support headers.")
|
|
|
|
index = index or self.index
|
|
duplicate_documents = duplicate_documents or self.duplicate_documents
|
|
assert (
|
|
duplicate_documents in self.duplicate_documents_options
|
|
), f"duplicate_documents parameter must be {', '.join(self.duplicate_documents_options)}"
|
|
|
|
if not self.faiss_indexes.get(index):
|
|
self.faiss_indexes[index] = self._create_new_index(
|
|
embedding_dim=self.embedding_dim,
|
|
index_factory=self.faiss_index_factory_str,
|
|
metric_type=faiss.METRIC_INNER_PRODUCT,
|
|
)
|
|
|
|
field_map = self._create_document_field_map()
|
|
document_objects = [Document.from_dict(d, field_map=field_map) if isinstance(d, dict) else d for d in documents]
|
|
document_objects = self._handle_duplicate_documents(
|
|
documents=document_objects, index=index, duplicate_documents=duplicate_documents
|
|
)
|
|
if len(document_objects) > 0:
|
|
add_vectors = False if document_objects[0].embedding is None else True
|
|
|
|
if self.duplicate_documents == "overwrite" and add_vectors:
|
|
logger.warning(
|
|
"You have to provide `duplicate_documents = 'overwrite'` arg and "
|
|
"`FAISSDocumentStore` does not support update in existing `faiss_index`.\n"
|
|
"Please call `update_embeddings` method to repopulate `faiss_index`"
|
|
)
|
|
|
|
vector_id = self.faiss_indexes[index].ntotal
|
|
with tqdm(
|
|
total=len(document_objects), disable=not self.progress_bar, position=0, desc="Writing Documents"
|
|
) as progress_bar:
|
|
for i in range(0, len(document_objects), batch_size):
|
|
if add_vectors:
|
|
embeddings = [doc.embedding for doc in document_objects[i : i + batch_size]]
|
|
embeddings_to_index = np.array(embeddings, dtype="float32")
|
|
|
|
if self.similarity == "cosine":
|
|
self.normalize_embedding(embeddings_to_index)
|
|
|
|
self.faiss_indexes[index].add(embeddings_to_index)
|
|
|
|
docs_to_write_in_sql = []
|
|
for doc in document_objects[i : i + batch_size]:
|
|
meta = doc.meta
|
|
if add_vectors:
|
|
meta["vector_id"] = vector_id
|
|
vector_id += 1
|
|
docs_to_write_in_sql.append(doc)
|
|
|
|
super(FAISSDocumentStore, self).write_documents(
|
|
docs_to_write_in_sql,
|
|
index=index,
|
|
duplicate_documents=duplicate_documents,
|
|
batch_size=batch_size,
|
|
)
|
|
progress_bar.update(batch_size)
|
|
progress_bar.close()
|
|
|
|
def _create_document_field_map(self) -> Dict:
|
|
return {self.index: self.embedding_field}
|
|
|
|
def update_embeddings(
|
|
self,
|
|
retriever: "BaseRetriever",
|
|
index: Optional[str] = None,
|
|
update_existing_embeddings: bool = True,
|
|
filters: Optional[Dict[str, Any]] = None, # TODO: Adapt type once we allow extended filters in FAISSDocStore
|
|
batch_size: int = 10_000,
|
|
):
|
|
"""
|
|
Updates the embeddings in the the document store using the encoding model specified in the retriever.
|
|
This can be useful if want to add or change the embeddings for your documents (e.g. after changing the retriever config).
|
|
|
|
:param retriever: Retriever to use to get embeddings for text
|
|
:param index: Index name for which embeddings are to be updated. If set to None, the default self.index is used.
|
|
:param update_existing_embeddings: Whether to update existing embeddings of the documents. If set to False,
|
|
only documents without embeddings are processed. This mode can be used for
|
|
incremental updating of embeddings, wherein, only newly indexed documents
|
|
get processed.
|
|
:param filters: Optional filters to narrow down the documents for which embeddings are to be updated.
|
|
Example: {"name": ["some", "more"], "category": ["only_one"]}
|
|
:param batch_size: When working with large number of documents, batching can help reduce memory footprint.
|
|
:return: None
|
|
"""
|
|
index = index or self.index
|
|
|
|
if update_existing_embeddings is True:
|
|
if filters is None:
|
|
self.faiss_indexes[index].reset()
|
|
self.reset_vector_ids(index)
|
|
else:
|
|
raise Exception("update_existing_embeddings=True is not supported with filters.")
|
|
|
|
if not self.faiss_indexes.get(index):
|
|
raise ValueError("Couldn't find a FAISS index. Try to init the FAISSDocumentStore() again ...")
|
|
|
|
document_count = self.get_document_count(index=index)
|
|
if document_count == 0:
|
|
logger.warning("Calling DocumentStore.update_embeddings() on an empty index")
|
|
return
|
|
|
|
logger.info(f"Updating embeddings for {document_count} docs...")
|
|
vector_id = sum([index.ntotal for index in self.faiss_indexes.values()])
|
|
|
|
result = self._query(
|
|
index=index,
|
|
vector_ids=None,
|
|
batch_size=batch_size,
|
|
filters=filters,
|
|
only_documents_without_embedding=not update_existing_embeddings,
|
|
)
|
|
batched_documents = get_batches_from_generator(result, batch_size)
|
|
with tqdm(
|
|
total=document_count, disable=not self.progress_bar, position=0, unit=" docs", desc="Updating Embedding"
|
|
) as progress_bar:
|
|
for document_batch in batched_documents:
|
|
embeddings = retriever.embed_documents(document_batch) # type: ignore
|
|
assert len(document_batch) == len(embeddings)
|
|
|
|
embeddings_to_index = np.array(embeddings, dtype="float32")
|
|
|
|
if self.similarity == "cosine":
|
|
self.normalize_embedding(embeddings_to_index)
|
|
|
|
self.faiss_indexes[index].add(embeddings_to_index)
|
|
|
|
vector_id_map = {}
|
|
for doc in document_batch:
|
|
vector_id_map[str(doc.id)] = str(vector_id)
|
|
vector_id += 1
|
|
self.update_vector_ids(vector_id_map, index=index)
|
|
progress_bar.set_description_str("Documents Processed")
|
|
progress_bar.update(batch_size)
|
|
|
|
def get_all_documents(
|
|
self,
|
|
index: Optional[str] = None,
|
|
filters: Optional[Dict[str, Any]] = None, # TODO: Adapt type once we allow extended filters in FAISSDocStore
|
|
return_embedding: Optional[bool] = None,
|
|
batch_size: int = 10_000,
|
|
headers: Optional[Dict[str, str]] = None,
|
|
) -> List[Document]:
|
|
if headers:
|
|
raise NotImplementedError("FAISSDocumentStore does not support headers.")
|
|
|
|
result = self.get_all_documents_generator(
|
|
index=index, filters=filters, return_embedding=return_embedding, batch_size=batch_size
|
|
)
|
|
documents = list(result)
|
|
return documents
|
|
|
|
def get_all_documents_generator(
|
|
self,
|
|
index: Optional[str] = None,
|
|
filters: Optional[Dict[str, Any]] = None, # TODO: Adapt type once we allow extended filters in FAISSDocStore
|
|
return_embedding: Optional[bool] = None,
|
|
batch_size: int = 10_000,
|
|
headers: Optional[Dict[str, str]] = None,
|
|
) -> Generator[Document, None, None]:
|
|
"""
|
|
Get all documents from the document store. Under-the-hood, documents are fetched in batches from the
|
|
document store and yielded as individual documents. This method can be used to iteratively process
|
|
a large number of documents without having to load all documents in memory.
|
|
|
|
:param index: Name of the index to get the documents from. If None, the
|
|
DocumentStore's default index (self.index) will be used.
|
|
:param filters: Optional filters to narrow down the documents to return.
|
|
Example: {"name": ["some", "more"], "category": ["only_one"]}
|
|
:param return_embedding: Whether to return the document embeddings. Unlike other document stores, FAISS will return normalized embeddings
|
|
:param batch_size: When working with large number of documents, batching can help reduce memory footprint.
|
|
"""
|
|
if headers:
|
|
raise NotImplementedError("FAISSDocumentStore does not support headers.")
|
|
|
|
index = index or self.index
|
|
documents = super(FAISSDocumentStore, self).get_all_documents_generator(
|
|
index=index, filters=filters, batch_size=batch_size, return_embedding=False
|
|
)
|
|
if return_embedding is None:
|
|
return_embedding = self.return_embedding
|
|
|
|
for doc in documents:
|
|
if return_embedding:
|
|
if doc.meta and doc.meta.get("vector_id") is not None:
|
|
doc.embedding = self.faiss_indexes[index].reconstruct(int(doc.meta["vector_id"]))
|
|
yield doc
|
|
|
|
def get_documents_by_id(
|
|
self,
|
|
ids: List[str],
|
|
index: Optional[str] = None,
|
|
batch_size: int = 10_000,
|
|
headers: Optional[Dict[str, str]] = None,
|
|
) -> List[Document]:
|
|
if headers:
|
|
raise NotImplementedError("FAISSDocumentStore does not support headers.")
|
|
|
|
index = index or self.index
|
|
documents = super(FAISSDocumentStore, self).get_documents_by_id(ids=ids, index=index, batch_size=batch_size)
|
|
if self.return_embedding:
|
|
for doc in documents:
|
|
if doc.meta and doc.meta.get("vector_id") is not None:
|
|
doc.embedding = self.faiss_indexes[index].reconstruct(int(doc.meta["vector_id"]))
|
|
return documents
|
|
|
|
def get_embedding_count(self, index: Optional[str] = None, filters: Optional[Dict[str, Any]] = None) -> int:
|
|
"""
|
|
Return the count of embeddings in the document store.
|
|
"""
|
|
if filters:
|
|
raise Exception("filters are not supported for get_embedding_count in FAISSDocumentStore")
|
|
index = index or self.index
|
|
return self.faiss_indexes[index].ntotal
|
|
|
|
def train_index(
|
|
self,
|
|
documents: Optional[Union[List[dict], List[Document]]],
|
|
embeddings: Optional[np.ndarray] = None,
|
|
index: Optional[str] = None,
|
|
):
|
|
"""
|
|
Some FAISS indices (e.g. IVF) require initial "training" on a sample of vectors before you can add your final vectors.
|
|
The train vectors should come from the same distribution as your final ones.
|
|
You can pass either documents (incl. embeddings) or just the plain embeddings that the index shall be trained on.
|
|
|
|
:param documents: Documents (incl. the embeddings)
|
|
:param embeddings: Plain embeddings
|
|
:param index: Name of the index to train. If None, the DocumentStore's default index (self.index) will be used.
|
|
:return: None
|
|
"""
|
|
index = index or self.index
|
|
if embeddings and documents:
|
|
raise ValueError("Either pass `documents` or `embeddings`. You passed both.")
|
|
if documents:
|
|
document_objects = [Document.from_dict(d) if isinstance(d, dict) else d for d in documents]
|
|
doc_embeddings = [doc.embedding for doc in document_objects]
|
|
embeddings_for_train = np.array(doc_embeddings, dtype="float32")
|
|
self.faiss_indexes[index].train(embeddings_for_train)
|
|
if embeddings:
|
|
self.faiss_indexes[index].train(embeddings)
|
|
|
|
def delete_all_documents(
|
|
self,
|
|
index: Optional[str] = None,
|
|
filters: Optional[Dict[str, Any]] = None, # TODO: Adapt type once we allow extended filters in FAISSDocStore
|
|
headers: Optional[Dict[str, str]] = None,
|
|
):
|
|
"""
|
|
Delete all documents from the document store.
|
|
"""
|
|
if headers:
|
|
raise NotImplementedError("FAISSDocumentStore does not support headers.")
|
|
|
|
logger.warning(
|
|
"""DEPRECATION WARNINGS:
|
|
1. delete_all_documents() method is deprecated, please use delete_documents method
|
|
For more details, please refer to the issue: https://github.com/deepset-ai/haystack/issues/1045
|
|
"""
|
|
)
|
|
self.delete_documents(index, None, filters)
|
|
|
|
def delete_documents(
|
|
self,
|
|
index: Optional[str] = None,
|
|
ids: Optional[List[str]] = None,
|
|
filters: Optional[Dict[str, Any]] = None, # TODO: Adapt type once we allow extended filters in FAISSDocStore
|
|
headers: Optional[Dict[str, str]] = None,
|
|
):
|
|
"""
|
|
Delete documents from the document store. All documents are deleted if no filters are passed.
|
|
|
|
:param index: Index name to delete the documents from. If None, the
|
|
DocumentStore's default index (self.index) will be used.
|
|
:param ids: Optional list of IDs to narrow down the documents to be deleted.
|
|
:param filters: Optional filters to narrow down the documents to be deleted.
|
|
Example filters: {"name": ["some", "more"], "category": ["only_one"]}.
|
|
If filters are provided along with a list of IDs, this method deletes the
|
|
intersection of the two query results (documents that match the filters and
|
|
have their ID in the list).
|
|
:return: None
|
|
"""
|
|
if headers:
|
|
raise NotImplementedError("FAISSDocumentStore does not support headers.")
|
|
|
|
index = index or self.index
|
|
if index in self.faiss_indexes.keys():
|
|
if not filters and not ids:
|
|
self.faiss_indexes[index].reset()
|
|
else:
|
|
affected_docs = self.get_all_documents(filters=filters)
|
|
if ids:
|
|
affected_docs = [doc for doc in affected_docs if doc.id in ids]
|
|
doc_ids = [
|
|
doc.meta.get("vector_id")
|
|
for doc in affected_docs
|
|
if doc.meta and doc.meta.get("vector_id") is not None
|
|
]
|
|
self.faiss_indexes[index].remove_ids(np.array(doc_ids, dtype="int64"))
|
|
|
|
super().delete_documents(index=index, ids=ids, filters=filters)
|
|
|
|
def query_by_embedding(
|
|
self,
|
|
query_emb: np.ndarray,
|
|
filters: Optional[Dict[str, Any]] = None, # TODO: Adapt type once we allow extended filters in FAISSDocStore
|
|
top_k: int = 10,
|
|
index: Optional[str] = None,
|
|
return_embedding: Optional[bool] = None,
|
|
headers: Optional[Dict[str, str]] = None,
|
|
) -> List[Document]:
|
|
"""
|
|
Find the document that is most similar to the provided `query_emb` by using a vector similarity metric.
|
|
|
|
:param query_emb: Embedding of the query (e.g. gathered from DPR)
|
|
:param filters: Optional filters to narrow down the search space.
|
|
Example: {"name": ["some", "more"], "category": ["only_one"]}
|
|
:param top_k: How many documents to return
|
|
:param index: Index name to query the document from.
|
|
:param return_embedding: To return document embedding. Unlike other document stores, FAISS will return normalized embeddings
|
|
:return:
|
|
"""
|
|
if headers:
|
|
raise NotImplementedError("FAISSDocumentStore does not support headers.")
|
|
|
|
if filters:
|
|
logger.warning("Query filters are not implemented for the FAISSDocumentStore.")
|
|
|
|
index = index or self.index
|
|
if not self.faiss_indexes.get(index):
|
|
raise Exception(f"Index named '{index}' does not exists. Use 'update_embeddings()' to create an index.")
|
|
|
|
if return_embedding is None:
|
|
return_embedding = self.return_embedding
|
|
|
|
query_emb = query_emb.reshape(1, -1).astype(np.float32)
|
|
|
|
if self.similarity == "cosine":
|
|
self.normalize_embedding(query_emb)
|
|
|
|
score_matrix, vector_id_matrix = self.faiss_indexes[index].search(query_emb, top_k)
|
|
vector_ids_for_query = [str(vector_id) for vector_id in vector_id_matrix[0] if vector_id != -1]
|
|
|
|
documents = self.get_documents_by_vector_ids(vector_ids_for_query, index=index)
|
|
|
|
# assign query score to each document
|
|
scores_for_vector_ids: Dict[str, float] = {
|
|
str(v_id): s for v_id, s in zip(vector_id_matrix[0], score_matrix[0])
|
|
}
|
|
for doc in documents:
|
|
raw_score = scores_for_vector_ids[doc.meta["vector_id"]]
|
|
doc.score = self.finalize_raw_score(raw_score, self.similarity)
|
|
|
|
if return_embedding is True:
|
|
doc.embedding = self.faiss_indexes[index].reconstruct(int(doc.meta["vector_id"]))
|
|
|
|
return documents
|
|
|
|
def save(self, index_path: Union[str, Path], config_path: Optional[Union[str, Path]] = None):
|
|
"""
|
|
Save FAISS Index to the specified file.
|
|
|
|
:param index_path: Path to save the FAISS index to.
|
|
:param config_path: Path to save the initial configuration parameters to.
|
|
Defaults to the same as the file path, save the extension (.json).
|
|
This file contains all the parameters passed to FAISSDocumentStore()
|
|
at creation time (for example the SQL path, embedding_dim, etc), and will be
|
|
used by the `load` method to restore the index with the appropriate configuration.
|
|
:return: None
|
|
"""
|
|
if not config_path:
|
|
index_path = Path(index_path)
|
|
config_path = index_path.with_suffix(".json")
|
|
|
|
faiss.write_index(self.faiss_indexes[self.index], str(index_path))
|
|
|
|
config_to_save = deepcopy(self._component_config["params"])
|
|
keys_to_remove = ["faiss_index", "faiss_index_path"]
|
|
for key in keys_to_remove:
|
|
if key in config_to_save.keys():
|
|
del config_to_save[key]
|
|
|
|
with open(config_path, "w") as ipp:
|
|
json.dump(config_to_save, ipp, default=str)
|
|
|
|
def _load_init_params_from_config(
|
|
self, index_path: Union[str, Path], config_path: Optional[Union[str, Path]] = None
|
|
):
|
|
if not config_path:
|
|
index_path = Path(index_path)
|
|
config_path = index_path.with_suffix(".json")
|
|
|
|
init_params: dict = {}
|
|
try:
|
|
with open(config_path, "r") as ipp:
|
|
init_params = json.load(ipp)
|
|
except OSError as e:
|
|
raise ValueError(
|
|
f"Can't open FAISS configuration file `{config_path}`. "
|
|
"Make sure the file exists and the you have the correct permissions "
|
|
"to access it."
|
|
) from e
|
|
|
|
faiss_index = faiss.read_index(str(index_path))
|
|
|
|
# Add other init params to override the ones defined in the init params file
|
|
init_params["faiss_index"] = faiss_index
|
|
init_params["embedding_dim"] = faiss_index.d
|
|
|
|
return init_params
|
|
|
|
@classmethod
|
|
def load(cls, index_path: Union[str, Path], config_path: Optional[Union[str, Path]] = None):
|
|
"""
|
|
Load a saved FAISS index from a file and connect to the SQL database.
|
|
Note: In order to have a correct mapping from FAISS to SQL,
|
|
make sure to use the same SQL DB that you used when calling `save()`.
|
|
|
|
:param index_path: Stored FAISS index file. Can be created via calling `save()`
|
|
:param config_path: Stored FAISS initial configuration parameters.
|
|
Can be created via calling `save()`
|
|
"""
|
|
return cls(faiss_index_path=index_path, faiss_config_path=config_path)
|