Sara Zan 11cf94a965
Pipeline's YAML: syntax validation (#2226)
* Add BasePipeline.validate_config, BasePipeline.validate_yaml, and some new custom exception classes

* Make error composition work properly

* Clarify typing

* Help mypy a bit more

* Update Documentation & Code Style

* Enable autogenerated docs for Milvus1 and 2 separately

* Revert "Enable autogenerated docs for Milvus1 and 2 separately"

This reverts commit 282be4a78a6e95862a9b4c924fc3dea5ca71e28d.

* Update Documentation & Code Style

* Re-enable 'additionalProperties: False'

* Add pipeline.type to JSON Schema, was somehow forgotten

* Disable additionalProperties on the pipeline properties too

* Fix json-schemas for 1.1.0 and 1.2.0 (should not do it again in the future)

* Cal super in PipelineValidationError

* Improve _read_pipeline_config_from_yaml's error handling

* Fix generate_json_schema.py to include document stores

* Fix json schemas (retro-fix 1.1.0 again)

* Improve custom errors printing, add link to docs

* Add function in BaseComponent to list its subclasses in a module

* Make some document stores base classes abstract

* Add marker 'integration' in pytest flags

* Slighly improve validation of pipelines at load

* Adding tests for YAML loading and validation

* Make custom_query Optional for validation issues

* Fix bug in _read_pipeline_config_from_yaml

* Improve error handling in BasePipeline and Pipeline and add DAG check

* Move json schema generation into haystack/nodes/_json_schema.py (useful for tests)

* Simplify errors slightly

* Add some YAML validation tests

* Remove load_from_config from BasePipeline, it was never used anyway

* Improve tests

* Include json-schemas in package

* Fix conftest imports

* Make BasePipeline abstract

* Improve mocking by making the test independent from the YAML version

* Add exportable_to_yaml decorator to forget about set_config on mock nodes

* Fix mypy errors

* Comment out one monkeypatch

* Fix typing again

* Improve error message for validation

* Add required properties to pipelines

* Fix YAML version for REST API YAMLs to 1.2.0

* Fix load_from_yaml call in load_from_deepset_cloud

* fix HaystackError.__getattr__

* Add super().__init__()in most nodes and docstore, comment set_config

* Remove type from REST API pipelines

* Remove useless init from doc2answers

* Call super in Seq3SeqGenerator

* Typo in deepsetcloud.py

* Fix rest api indexing error mismatch and mock version of JSON schema in all tests

* Working on pipeline tests

* Improve errors printing slightly

* Add back test_pipeline.yaml

* _json_schema.py supports different versions with identical schemas

* Add type to 0.7 schema for backwards compatibility

* Fix small bug in _json_schema.py

* Try alternative to generate json schemas on the CI

* Update Documentation & Code Style

* Make linux CI match autoformat CI

* Fix super-init-not-called

* Accidentally committed file

* Update Documentation & Code Style

* fix test_summarizer_translation.py's import

* Mock YAML in a few suites, split and simplify test_pipeline_debug_and_validation.py::test_invalid_run_args

* Fix json schema for ray tests too

* Update Documentation & Code Style

* Reintroduce validation

* Usa unstable version in tests and rest api

* Make unstable support the latest versions

* Update Documentation & Code Style

* Remove needless fixture

* Make type in pipeline optional in the strings validation

* Fix schemas

* Fix string validation for pipeline type

* Improve validate_config_strings

* Remove type from test p[ipelines

* Update Documentation & Code Style

* Fix test_pipeline

* Removing more type from pipelines

* Temporary CI patc

* Fix issue with exportable_to_yaml never invoking the wrapped init

* rm stray file

* pipeline tests are green again

* Linux CI now needs .[all] to generate the schema

* Bugfixes, pipeline tests seems to be green

* Typo in version after merge

* Implement missing methods in Weaviate

* Trying to avoid FAISS tests from running in the Milvus1 test suite

* Fix some stray test paths and faiss index dumping

* Fix pytest markers list

* Temporarily disable cache to be able to see tests failures

* Fix pyproject.toml syntax

* Use only tmp_path

* Fix preprocessor signature after merge

* Fix faiss bug

* Fix Ray test

* Fix documentation issue by removing quotes from faiss type

* Update Documentation & Code Style

* use document properly in preprocessor tests

* Update Documentation & Code Style

* make preprocessor capable of handling documents

* import document

* Revert support for documents in preprocessor, do later

* Fix bug in _json_schema.py that was breaking validation

* re-enable cache

* Update Documentation & Code Style

* Simplify calling _json_schema.py from the CI

* Remove redundant ABC inheritance

* Ensure exportable_to_yaml works only on implementations

* Rename subclass to class_ in Meta

* Make run() and get_config() abstract in BasePipeline

* Revert unintended change in preprocessor

* Move outgoing_edges_input_node check inside try block

* Rename VALID_CODE_GEN_INPUT_REGEX into VALID_INPUT_REGEX

* Add check for a RecursionError on validate_config_strings

* Address usages of _pipeline_config in data silo and elasticsearch

* Rename _pipeline_config into _init_parameters

* Fix pytest marker and remove unused imports

* Remove most redundant ABCs

* Rename _init_parameters into _component_configuration

* Remove set_config and type from _component_configuration's dict

* Remove last instances of set_config and replace with super().__init__()

* Implement __init_subclass__ approach

* Simplify checks on the existence of _component_configuration

* Fix faiss issue

* Dynamic generation of node schemas & weed out old schemas

* Add debatable test

* Add docstring to debatable test

* Positive diff between schemas implemented

* Improve diff printing

* Rename REST API YAML files to trigger IDE validation

* Fix typing issues

* Fix more typing

* Typo in YAML filename

* Remove needless type:ignore

* Add tests

* Fix tests & validation feedback for accessory classes in custom nodes

* Refactor RAGeneratorType out

* Fix broken import in conftest

* Improve source error handling

* Remove unused import in test_eval.py breaking tests

* Fix changed error message in tests matches too

* Normalize generate_openapi_specs.py and generate_json_schema.py in the actions

* Fix path to generate_openapi_specs.py in autoformat.yml

* Update Documentation & Code Style

* Add test for FAISSDocumentStore-like situations (superclass with init params)

* Update Documentation & Code Style

* Fix indentation

* Remove commented set_config

* Store model_name_or_path in FARMReader to use in DistillationDataSilo

* Rename _component_configuration into _component_config

* Update Documentation & Code Style

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
2022-03-15 11:17:26 +01:00

649 lines
31 KiB
Python

from typing import TYPE_CHECKING, Any, Union, List, Optional, Dict, Generator
if TYPE_CHECKING:
from haystack.nodes.retriever import BaseRetriever
import json
import logging
import warnings
import numpy as np
from copy import deepcopy
from pathlib import Path
from tqdm.auto import tqdm
from inspect import Signature, signature
try:
import faiss
from haystack.document_stores.sql import (
SQLDocumentStore,
) # its deps are optional, but get installed with the `faiss` extra
except (ImportError, ModuleNotFoundError) as ie:
from haystack.utils.import_utils import _optional_component_not_installed
_optional_component_not_installed(__name__, "faiss", ie)
from haystack.schema import Document
from haystack.document_stores.base import get_batches_from_generator
logger = logging.getLogger(__name__)
class FAISSDocumentStore(SQLDocumentStore):
"""
Document store for very large scale embedding based dense retrievers like the DPR.
It implements the FAISS library(https://github.com/facebookresearch/faiss)
to perform similarity search on vectors.
The document text and meta-data (for filtering) are stored using the SQLDocumentStore, while
the vector embeddings are indexed in a FAISS Index.
"""
def __init__(
self,
sql_url: str = "sqlite:///faiss_document_store.db",
vector_dim: int = None,
embedding_dim: int = 768,
faiss_index_factory_str: str = "Flat",
faiss_index: Optional[faiss.swigfaiss.Index] = None,
return_embedding: bool = False,
index: str = "document",
similarity: str = "dot_product",
embedding_field: str = "embedding",
progress_bar: bool = True,
duplicate_documents: str = "overwrite",
faiss_index_path: Union[str, Path] = None,
faiss_config_path: Union[str, Path] = None,
isolation_level: str = None,
**kwargs,
):
"""
:param sql_url: SQL connection URL for database. It defaults to local file based SQLite DB. For large scale
deployment, Postgres is recommended.
:param vector_dim: Deprecated. Use embedding_dim instead.
:param embedding_dim: The embedding vector size. Default: 768.
:param faiss_index_factory_str: Create a new FAISS index of the specified type.
The type is determined from the given string following the conventions
of the original FAISS index factory.
Recommended options:
- "Flat" (default): Best accuracy (= exact). Becomes slow and RAM intense for > 1 Mio docs.
- "HNSW": Graph-based heuristic. If not further specified,
we use the following config:
HNSW64, efConstruction=80 and efSearch=20
- "IVFx,Flat": Inverted Index. Replace x with the number of centroids aka nlist.
Rule of thumb: nlist = 10 * sqrt (num_docs) is a good starting point.
For more details see:
- Overview of indices https://github.com/facebookresearch/faiss/wiki/Faiss-indexes
- Guideline for choosing an index https://github.com/facebookresearch/faiss/wiki/Guidelines-to-choose-an-index
- FAISS Index factory https://github.com/facebookresearch/faiss/wiki/The-index-factory
Benchmarks: XXX
:param faiss_index: Pass an existing FAISS Index, i.e. an empty one that you configured manually
or one with docs that you used in Haystack before and want to load again.
:param return_embedding: To return document embedding. Unlike other document stores, FAISS will return normalized embeddings
:param index: Name of index in document store to use.
:param similarity: The similarity function used to compare document vectors. 'dot_product' is the default since it is
more performant with DPR embeddings. 'cosine' is recommended if you are using a Sentence-Transformer model.
In both cases, the returned values in Document.score are normalized to be in range [0,1]:
For `dot_product`: expit(np.asarray(raw_score / 100))
FOr `cosine`: (raw_score + 1) / 2
:param embedding_field: Name of field containing an embedding vector.
:param progress_bar: Whether to show a tqdm progress bar or not.
Can be helpful to disable in production deployments to keep the logs clean.
:param duplicate_documents: Handle duplicates document based on parameter options.
Parameter options : ( 'skip','overwrite','fail')
skip: Ignore the duplicates documents
overwrite: Update any existing documents with the same ID when adding documents.
fail: an error is raised if the document ID of the document being added already
exists.
:param faiss_index_path: Stored FAISS index file. Can be created via calling `save()`.
If specified no other params besides faiss_config_path must be specified.
:param faiss_config_path: Stored FAISS initial configuration parameters.
Can be created via calling `save()`
:param isolation_level: see SQLAlchemy's `isolation_level` parameter for `create_engine()` (https://docs.sqlalchemy.org/en/14/core/engines.html#sqlalchemy.create_engine.params.isolation_level)
"""
# special case if we want to load an existing index from disk
# load init params from disk and run init again
if faiss_index_path is not None:
sig = signature(self.__class__.__init__)
self._validate_params_load_from_disk(sig, locals(), kwargs)
init_params = self._load_init_params_from_config(faiss_index_path, faiss_config_path)
self.__class__.__init__(self, **init_params) # pylint: disable=non-parent-init-called
return
if similarity in ("dot_product", "cosine"):
self.similarity = similarity
self.metric_type = faiss.METRIC_INNER_PRODUCT
elif similarity == "l2":
self.similarity = similarity
self.metric_type = faiss.METRIC_L2
else:
raise ValueError(
"The FAISS document store can currently only support dot_product, cosine and l2 similarity. "
"Please set similarity to one of the above."
)
if vector_dim is not None:
warnings.warn(
"The 'vector_dim' parameter is deprecated, " "use 'embedding_dim' instead.", DeprecationWarning, 2
)
self.embedding_dim = vector_dim
else:
self.embedding_dim = embedding_dim
self.faiss_index_factory_str = faiss_index_factory_str
self.faiss_indexes: Dict[str, faiss.swigfaiss.Index] = {}
if faiss_index:
self.faiss_indexes[index] = faiss_index
else:
self.faiss_indexes[index] = self._create_new_index(
embedding_dim=self.embedding_dim,
index_factory=faiss_index_factory_str,
metric_type=self.metric_type,
**kwargs,
)
self.return_embedding = return_embedding
self.embedding_field = embedding_field
self.progress_bar = progress_bar
super().__init__(
url=sql_url, index=index, duplicate_documents=duplicate_documents, isolation_level=isolation_level
)
self._validate_index_sync()
def _validate_params_load_from_disk(self, sig: Signature, locals: dict, kwargs: dict):
allowed_params = ["faiss_index_path", "faiss_config_path", "self", "kwargs"]
invalid_param_set = False
for param in sig.parameters.values():
if param.name not in allowed_params and param.default != locals[param.name]:
invalid_param_set = True
break
if invalid_param_set or len(kwargs) > 0:
raise ValueError("if faiss_index_path is passed no other params besides faiss_config_path are allowed.")
def _validate_index_sync(self):
# This check ensures the correct document database was loaded.
# If it fails, make sure you provided the path to the database
# used when creating the original FAISS index
if not self.get_document_count() == self.get_embedding_count():
raise ValueError(
"The number of documents present in the SQL database does not "
"match the number of embeddings in FAISS. Make sure your FAISS "
"configuration file correctly points to the same database that "
"was used when creating the original index."
)
def _create_new_index(self, embedding_dim: int, metric_type, index_factory: str = "Flat", **kwargs):
if index_factory == "HNSW":
# faiss index factory doesn't give the same results for HNSW IP, therefore direct init.
# defaults here are similar to DPR codebase (good accuracy, but very high RAM consumption)
n_links = kwargs.get("n_links", 64)
index = faiss.IndexHNSWFlat(embedding_dim, n_links, metric_type)
index.hnsw.efSearch = kwargs.get("efSearch", 20) # 20
index.hnsw.efConstruction = kwargs.get("efConstruction", 80) # 80
if "ivf" in index_factory.lower(): # enable reconstruction of vectors for inverted index
self.faiss_indexes[index].set_direct_map_type(faiss.DirectMap.Hashtable)
logger.info(
f"HNSW params: n_links: {n_links}, efSearch: {index.hnsw.efSearch}, efConstruction: {index.hnsw.efConstruction}"
)
else:
index = faiss.index_factory(embedding_dim, index_factory, metric_type)
return index
def write_documents(
self,
documents: Union[List[dict], List[Document]],
index: Optional[str] = None,
batch_size: int = 10_000,
duplicate_documents: Optional[str] = None,
headers: Optional[Dict[str, str]] = None,
) -> None:
"""
Add new documents to the DocumentStore.
:param documents: List of `Dicts` or List of `Documents`. If they already contain the embeddings, we'll index
them right away in FAISS. If not, you can later call update_embeddings() to create & index them.
:param index: (SQL) index name for storing the docs and metadata
:param batch_size: When working with large number of documents, batching can help reduce memory footprint.
:param duplicate_documents: Handle duplicates document based on parameter options.
Parameter options : ( 'skip','overwrite','fail')
skip: Ignore the duplicates documents
overwrite: Update any existing documents with the same ID when adding documents.
fail: an error is raised if the document ID of the document being added already
exists.
:raises DuplicateDocumentError: Exception trigger on duplicate document
:return: None
"""
if headers:
raise NotImplementedError("FAISSDocumentStore does not support headers.")
index = index or self.index
duplicate_documents = duplicate_documents or self.duplicate_documents
assert (
duplicate_documents in self.duplicate_documents_options
), f"duplicate_documents parameter must be {', '.join(self.duplicate_documents_options)}"
if not self.faiss_indexes.get(index):
self.faiss_indexes[index] = self._create_new_index(
embedding_dim=self.embedding_dim,
index_factory=self.faiss_index_factory_str,
metric_type=faiss.METRIC_INNER_PRODUCT,
)
field_map = self._create_document_field_map()
document_objects = [Document.from_dict(d, field_map=field_map) if isinstance(d, dict) else d for d in documents]
document_objects = self._handle_duplicate_documents(
documents=document_objects, index=index, duplicate_documents=duplicate_documents
)
if len(document_objects) > 0:
add_vectors = False if document_objects[0].embedding is None else True
if self.duplicate_documents == "overwrite" and add_vectors:
logger.warning(
"You have to provide `duplicate_documents = 'overwrite'` arg and "
"`FAISSDocumentStore` does not support update in existing `faiss_index`.\n"
"Please call `update_embeddings` method to repopulate `faiss_index`"
)
vector_id = self.faiss_indexes[index].ntotal
with tqdm(
total=len(document_objects), disable=not self.progress_bar, position=0, desc="Writing Documents"
) as progress_bar:
for i in range(0, len(document_objects), batch_size):
if add_vectors:
embeddings = [doc.embedding for doc in document_objects[i : i + batch_size]]
embeddings_to_index = np.array(embeddings, dtype="float32")
if self.similarity == "cosine":
self.normalize_embedding(embeddings_to_index)
self.faiss_indexes[index].add(embeddings_to_index)
docs_to_write_in_sql = []
for doc in document_objects[i : i + batch_size]:
meta = doc.meta
if add_vectors:
meta["vector_id"] = vector_id
vector_id += 1
docs_to_write_in_sql.append(doc)
super(FAISSDocumentStore, self).write_documents(
docs_to_write_in_sql,
index=index,
duplicate_documents=duplicate_documents,
batch_size=batch_size,
)
progress_bar.update(batch_size)
progress_bar.close()
def _create_document_field_map(self) -> Dict:
return {self.index: self.embedding_field}
def update_embeddings(
self,
retriever: "BaseRetriever",
index: Optional[str] = None,
update_existing_embeddings: bool = True,
filters: Optional[Dict[str, Any]] = None, # TODO: Adapt type once we allow extended filters in FAISSDocStore
batch_size: int = 10_000,
):
"""
Updates the embeddings in the the document store using the encoding model specified in the retriever.
This can be useful if want to add or change the embeddings for your documents (e.g. after changing the retriever config).
:param retriever: Retriever to use to get embeddings for text
:param index: Index name for which embeddings are to be updated. If set to None, the default self.index is used.
:param update_existing_embeddings: Whether to update existing embeddings of the documents. If set to False,
only documents without embeddings are processed. This mode can be used for
incremental updating of embeddings, wherein, only newly indexed documents
get processed.
:param filters: Optional filters to narrow down the documents for which embeddings are to be updated.
Example: {"name": ["some", "more"], "category": ["only_one"]}
:param batch_size: When working with large number of documents, batching can help reduce memory footprint.
:return: None
"""
index = index or self.index
if update_existing_embeddings is True:
if filters is None:
self.faiss_indexes[index].reset()
self.reset_vector_ids(index)
else:
raise Exception("update_existing_embeddings=True is not supported with filters.")
if not self.faiss_indexes.get(index):
raise ValueError("Couldn't find a FAISS index. Try to init the FAISSDocumentStore() again ...")
document_count = self.get_document_count(index=index)
if document_count == 0:
logger.warning("Calling DocumentStore.update_embeddings() on an empty index")
return
logger.info(f"Updating embeddings for {document_count} docs...")
vector_id = sum([index.ntotal for index in self.faiss_indexes.values()])
result = self._query(
index=index,
vector_ids=None,
batch_size=batch_size,
filters=filters,
only_documents_without_embedding=not update_existing_embeddings,
)
batched_documents = get_batches_from_generator(result, batch_size)
with tqdm(
total=document_count, disable=not self.progress_bar, position=0, unit=" docs", desc="Updating Embedding"
) as progress_bar:
for document_batch in batched_documents:
embeddings = retriever.embed_documents(document_batch) # type: ignore
assert len(document_batch) == len(embeddings)
embeddings_to_index = np.array(embeddings, dtype="float32")
if self.similarity == "cosine":
self.normalize_embedding(embeddings_to_index)
self.faiss_indexes[index].add(embeddings_to_index)
vector_id_map = {}
for doc in document_batch:
vector_id_map[str(doc.id)] = str(vector_id)
vector_id += 1
self.update_vector_ids(vector_id_map, index=index)
progress_bar.set_description_str("Documents Processed")
progress_bar.update(batch_size)
def get_all_documents(
self,
index: Optional[str] = None,
filters: Optional[Dict[str, Any]] = None, # TODO: Adapt type once we allow extended filters in FAISSDocStore
return_embedding: Optional[bool] = None,
batch_size: int = 10_000,
headers: Optional[Dict[str, str]] = None,
) -> List[Document]:
if headers:
raise NotImplementedError("FAISSDocumentStore does not support headers.")
result = self.get_all_documents_generator(
index=index, filters=filters, return_embedding=return_embedding, batch_size=batch_size
)
documents = list(result)
return documents
def get_all_documents_generator(
self,
index: Optional[str] = None,
filters: Optional[Dict[str, Any]] = None, # TODO: Adapt type once we allow extended filters in FAISSDocStore
return_embedding: Optional[bool] = None,
batch_size: int = 10_000,
headers: Optional[Dict[str, str]] = None,
) -> Generator[Document, None, None]:
"""
Get all documents from the document store. Under-the-hood, documents are fetched in batches from the
document store and yielded as individual documents. This method can be used to iteratively process
a large number of documents without having to load all documents in memory.
:param index: Name of the index to get the documents from. If None, the
DocumentStore's default index (self.index) will be used.
:param filters: Optional filters to narrow down the documents to return.
Example: {"name": ["some", "more"], "category": ["only_one"]}
:param return_embedding: Whether to return the document embeddings. Unlike other document stores, FAISS will return normalized embeddings
:param batch_size: When working with large number of documents, batching can help reduce memory footprint.
"""
if headers:
raise NotImplementedError("FAISSDocumentStore does not support headers.")
index = index or self.index
documents = super(FAISSDocumentStore, self).get_all_documents_generator(
index=index, filters=filters, batch_size=batch_size, return_embedding=False
)
if return_embedding is None:
return_embedding = self.return_embedding
for doc in documents:
if return_embedding:
if doc.meta and doc.meta.get("vector_id") is not None:
doc.embedding = self.faiss_indexes[index].reconstruct(int(doc.meta["vector_id"]))
yield doc
def get_documents_by_id(
self,
ids: List[str],
index: Optional[str] = None,
batch_size: int = 10_000,
headers: Optional[Dict[str, str]] = None,
) -> List[Document]:
if headers:
raise NotImplementedError("FAISSDocumentStore does not support headers.")
index = index or self.index
documents = super(FAISSDocumentStore, self).get_documents_by_id(ids=ids, index=index, batch_size=batch_size)
if self.return_embedding:
for doc in documents:
if doc.meta and doc.meta.get("vector_id") is not None:
doc.embedding = self.faiss_indexes[index].reconstruct(int(doc.meta["vector_id"]))
return documents
def get_embedding_count(self, index: Optional[str] = None, filters: Optional[Dict[str, Any]] = None) -> int:
"""
Return the count of embeddings in the document store.
"""
if filters:
raise Exception("filters are not supported for get_embedding_count in FAISSDocumentStore")
index = index or self.index
return self.faiss_indexes[index].ntotal
def train_index(
self,
documents: Optional[Union[List[dict], List[Document]]],
embeddings: Optional[np.ndarray] = None,
index: Optional[str] = None,
):
"""
Some FAISS indices (e.g. IVF) require initial "training" on a sample of vectors before you can add your final vectors.
The train vectors should come from the same distribution as your final ones.
You can pass either documents (incl. embeddings) or just the plain embeddings that the index shall be trained on.
:param documents: Documents (incl. the embeddings)
:param embeddings: Plain embeddings
:param index: Name of the index to train. If None, the DocumentStore's default index (self.index) will be used.
:return: None
"""
index = index or self.index
if embeddings and documents:
raise ValueError("Either pass `documents` or `embeddings`. You passed both.")
if documents:
document_objects = [Document.from_dict(d) if isinstance(d, dict) else d for d in documents]
doc_embeddings = [doc.embedding for doc in document_objects]
embeddings_for_train = np.array(doc_embeddings, dtype="float32")
self.faiss_indexes[index].train(embeddings_for_train)
if embeddings:
self.faiss_indexes[index].train(embeddings)
def delete_all_documents(
self,
index: Optional[str] = None,
filters: Optional[Dict[str, Any]] = None, # TODO: Adapt type once we allow extended filters in FAISSDocStore
headers: Optional[Dict[str, str]] = None,
):
"""
Delete all documents from the document store.
"""
if headers:
raise NotImplementedError("FAISSDocumentStore does not support headers.")
logger.warning(
"""DEPRECATION WARNINGS:
1. delete_all_documents() method is deprecated, please use delete_documents method
For more details, please refer to the issue: https://github.com/deepset-ai/haystack/issues/1045
"""
)
self.delete_documents(index, None, filters)
def delete_documents(
self,
index: Optional[str] = None,
ids: Optional[List[str]] = None,
filters: Optional[Dict[str, Any]] = None, # TODO: Adapt type once we allow extended filters in FAISSDocStore
headers: Optional[Dict[str, str]] = None,
):
"""
Delete documents from the document store. All documents are deleted if no filters are passed.
:param index: Index name to delete the documents from. If None, the
DocumentStore's default index (self.index) will be used.
:param ids: Optional list of IDs to narrow down the documents to be deleted.
:param filters: Optional filters to narrow down the documents to be deleted.
Example filters: {"name": ["some", "more"], "category": ["only_one"]}.
If filters are provided along with a list of IDs, this method deletes the
intersection of the two query results (documents that match the filters and
have their ID in the list).
:return: None
"""
if headers:
raise NotImplementedError("FAISSDocumentStore does not support headers.")
index = index or self.index
if index in self.faiss_indexes.keys():
if not filters and not ids:
self.faiss_indexes[index].reset()
else:
affected_docs = self.get_all_documents(filters=filters)
if ids:
affected_docs = [doc for doc in affected_docs if doc.id in ids]
doc_ids = [
doc.meta.get("vector_id")
for doc in affected_docs
if doc.meta and doc.meta.get("vector_id") is not None
]
self.faiss_indexes[index].remove_ids(np.array(doc_ids, dtype="int64"))
super().delete_documents(index=index, ids=ids, filters=filters)
def query_by_embedding(
self,
query_emb: np.ndarray,
filters: Optional[Dict[str, Any]] = None, # TODO: Adapt type once we allow extended filters in FAISSDocStore
top_k: int = 10,
index: Optional[str] = None,
return_embedding: Optional[bool] = None,
headers: Optional[Dict[str, str]] = None,
) -> List[Document]:
"""
Find the document that is most similar to the provided `query_emb` by using a vector similarity metric.
:param query_emb: Embedding of the query (e.g. gathered from DPR)
:param filters: Optional filters to narrow down the search space.
Example: {"name": ["some", "more"], "category": ["only_one"]}
:param top_k: How many documents to return
:param index: Index name to query the document from.
:param return_embedding: To return document embedding. Unlike other document stores, FAISS will return normalized embeddings
:return:
"""
if headers:
raise NotImplementedError("FAISSDocumentStore does not support headers.")
if filters:
logger.warning("Query filters are not implemented for the FAISSDocumentStore.")
index = index or self.index
if not self.faiss_indexes.get(index):
raise Exception(f"Index named '{index}' does not exists. Use 'update_embeddings()' to create an index.")
if return_embedding is None:
return_embedding = self.return_embedding
query_emb = query_emb.reshape(1, -1).astype(np.float32)
if self.similarity == "cosine":
self.normalize_embedding(query_emb)
score_matrix, vector_id_matrix = self.faiss_indexes[index].search(query_emb, top_k)
vector_ids_for_query = [str(vector_id) for vector_id in vector_id_matrix[0] if vector_id != -1]
documents = self.get_documents_by_vector_ids(vector_ids_for_query, index=index)
# assign query score to each document
scores_for_vector_ids: Dict[str, float] = {
str(v_id): s for v_id, s in zip(vector_id_matrix[0], score_matrix[0])
}
for doc in documents:
raw_score = scores_for_vector_ids[doc.meta["vector_id"]]
doc.score = self.finalize_raw_score(raw_score, self.similarity)
if return_embedding is True:
doc.embedding = self.faiss_indexes[index].reconstruct(int(doc.meta["vector_id"]))
return documents
def save(self, index_path: Union[str, Path], config_path: Optional[Union[str, Path]] = None):
"""
Save FAISS Index to the specified file.
:param index_path: Path to save the FAISS index to.
:param config_path: Path to save the initial configuration parameters to.
Defaults to the same as the file path, save the extension (.json).
This file contains all the parameters passed to FAISSDocumentStore()
at creation time (for example the SQL path, embedding_dim, etc), and will be
used by the `load` method to restore the index with the appropriate configuration.
:return: None
"""
if not config_path:
index_path = Path(index_path)
config_path = index_path.with_suffix(".json")
faiss.write_index(self.faiss_indexes[self.index], str(index_path))
config_to_save = deepcopy(self._component_config["params"])
keys_to_remove = ["faiss_index", "faiss_index_path"]
for key in keys_to_remove:
if key in config_to_save.keys():
del config_to_save[key]
with open(config_path, "w") as ipp:
json.dump(config_to_save, ipp, default=str)
def _load_init_params_from_config(
self, index_path: Union[str, Path], config_path: Optional[Union[str, Path]] = None
):
if not config_path:
index_path = Path(index_path)
config_path = index_path.with_suffix(".json")
init_params: dict = {}
try:
with open(config_path, "r") as ipp:
init_params = json.load(ipp)
except OSError as e:
raise ValueError(
f"Can't open FAISS configuration file `{config_path}`. "
"Make sure the file exists and the you have the correct permissions "
"to access it."
) from e
faiss_index = faiss.read_index(str(index_path))
# Add other init params to override the ones defined in the init params file
init_params["faiss_index"] = faiss_index
init_params["embedding_dim"] = faiss_index.d
return init_params
@classmethod
def load(cls, index_path: Union[str, Path], config_path: Optional[Union[str, Path]] = None):
"""
Load a saved FAISS index from a file and connect to the SQL database.
Note: In order to have a correct mapping from FAISS to SQL,
make sure to use the same SQL DB that you used when calling `save()`.
:param index_path: Stored FAISS index file. Can be created via calling `save()`
:param config_path: Stored FAISS initial configuration parameters.
Can be created via calling `save()`
"""
return cls(faiss_index_path=index_path, faiss_config_path=config_path)