Lalit Pagaria 2e9f3c1512
Fix update_embeddings function in FAISSDocumentStore and add retriever fixture in tests (#481)
* 1. Prevent update_embeddings function in FAISSDocumentStore to set faiss_index as None when document store does not have any docs.

2. cleaning up tests by adding fixture for retriever.

* TfidfRetriever need document store with documents during initialization as it call fit() function in constructor so fixing it by checking self.paragraphs of None

* Fix naming of retriever's fixture (embedded to embedding and tfid to tfidf)
2020-10-14 16:15:04 +02:00

255 lines
12 KiB
Python

import logging
from sys import platform
from pathlib import Path
from typing import Union, List, Optional, Dict
from tqdm import tqdm
import numpy as np
from haystack import Document
from haystack.document_store.sql import SQLDocumentStore
from haystack.retriever.base import BaseRetriever
if platform != 'win32' and platform != 'cygwin':
import faiss
else:
raise ModuleNotFoundError("FAISSDocumentStore on windows platform is not supported")
logger = logging.getLogger(__name__)
class FAISSDocumentStore(SQLDocumentStore):
"""
Document store for very large scale embedding based dense retrievers like the DPR.
It implements the FAISS library(https://github.com/facebookresearch/faiss)
to perform similarity search on vectors.
The document text and meta-data (for filtering) are stored using the SQLDocumentStore, while
the vector embeddings are indexed in a FAISS Index.
"""
def __init__(
self,
sql_url: str = "sqlite:///",
index_buffer_size: int = 10_000,
vector_dim: int = 768,
faiss_index_factory_str: str = "Flat",
faiss_index: Optional[faiss.swigfaiss.Index] = None,
**kwargs,
):
"""
:param sql_url: SQL connection URL for database. It defaults to local file based SQLite DB. For large scale
deployment, Postgres is recommended.
:param index_buffer_size: When working with large datasets, the ingestion process(FAISS + SQL) can be buffered in
smaller chunks to reduce memory footprint.
:param vector_dim: the embedding vector size.
:param faiss_index_factory_str: Create a new FAISS index of the specified type.
The type is determined from the given string following the conventions
of the original FAISS index factory.
Recommended options:
- "Flat" (default): Best accuracy (= exact). Becomes slow and RAM intense for > 1 Mio docs.
- "HNSW": Graph-based heuristic. If not further specified,
we use a RAM intense, but more accurate config:
HNSW256, efConstruction=256 and efSearch=256
- "IVFx,Flat": Inverted Index. Replace x with the number of centroids aka nlist.
Rule of thumb: nlist = 10 * sqrt (num_docs) is a good starting point.
For more details see:
- Overview of indices https://github.com/facebookresearch/faiss/wiki/Faiss-indexes
- Guideline for choosing an index https://github.com/facebookresearch/faiss/wiki/Guidelines-to-choose-an-index
- FAISS Index factory https://github.com/facebookresearch/faiss/wiki/The-index-factory
Benchmarks: XXX
:param faiss_index: Pass an existing FAISS Index, i.e. an empty one that you configured manually
or one with docs that you used in Haystack before and want to load again.
"""
self.vector_dim = vector_dim
if faiss_index:
self.faiss_index = faiss_index
else:
self.faiss_index = self._create_new_index(vector_dim=self.vector_dim, index_factory=faiss_index_factory_str, **kwargs)
self.index_buffer_size = index_buffer_size
super().__init__(url=sql_url)
def _create_new_index(self, vector_dim: int, index_factory: str = "Flat", metric_type=faiss.METRIC_INNER_PRODUCT, **kwargs):
if index_factory == "HNSW" and metric_type == faiss.METRIC_INNER_PRODUCT:
# faiss index factory doesn't give the same results for HNSW IP, therefore direct init.
# defaults here are similar to DPR codebase (good accuracy, but very high RAM consumption)
n_links = kwargs.get("n_links", 128)
index = faiss.IndexHNSWFlat(vector_dim, n_links, metric_type)
index.hnsw.efSearch = kwargs.get("efSearch", 20)#20
index.hnsw.efConstruction = kwargs.get("efConstruction", 80)#80
logger.info(f"HNSW params: n_links: {n_links}, efSearch: {index.hnsw.efSearch}, efConstruction: {index.hnsw.efConstruction}")
else:
index = faiss.index_factory(vector_dim, index_factory, metric_type)
return index
def write_documents(self, documents: Union[List[dict], List[Document]], index: Optional[str] = None):
"""
Add new documents to the DocumentStore.
:param documents: List of `Dicts` or List of `Documents`. If they already contain the embeddings, we'll index
them right away in FAISS. If not, you can later call update_embeddings() to create & index them.
:param index: (SQL) index name for storing the docs and metadata
:return:
"""
# vector index
if not self.faiss_index:
raise ValueError("Couldn't find a FAISS index. Try to init the FAISSDocumentStore() again ...")
# doc + metadata index
index = index or self.index
document_objects = [Document.from_dict(d) if isinstance(d, dict) else d for d in documents]
add_vectors = False if document_objects[0].embedding is None else True
for i in range(0, len(document_objects), self.index_buffer_size):
vector_id = self.faiss_index.ntotal
if add_vectors:
embeddings = [doc.embedding for doc in document_objects[i: i + self.index_buffer_size]]
embeddings = np.array(embeddings, dtype="float32")
self.faiss_index.add(embeddings)
docs_to_write_in_sql = []
for doc in document_objects[i : i + self.index_buffer_size]:
meta = doc.meta
if add_vectors:
meta["vector_id"] = vector_id
vector_id += 1
docs_to_write_in_sql.append(doc)
super(FAISSDocumentStore, self).write_documents(docs_to_write_in_sql, index=index)
def update_embeddings(self, retriever: BaseRetriever, index: Optional[str] = None):
"""
Updates the embeddings in the the document store using the encoding model specified in the retriever.
This can be useful if want to add or change the embeddings for your documents (e.g. after changing the retriever config).
:param retriever: Retriever to use to get embeddings for text
:param index: (SQL) index name for storing the docs and metadata
:return: None
"""
if not self.faiss_index:
raise ValueError("Couldn't find a FAISS index. Try to init the FAISSDocumentStore() again ...")
index = index or self.index
documents = self.get_all_documents(index=index)
if len(documents) == 0:
logger.warning("Calling DocumentStore.update_embeddings() on an empty index")
return
# To clear out the FAISS index contents and frees all memory immediately that is in use by the index
self.faiss_index.reset()
logger.info(f"Updating embeddings for {len(documents)} docs...")
embeddings = retriever.embed_passages(documents) # type: ignore
assert len(documents) == len(embeddings)
for i, doc in enumerate(documents):
doc.embedding = embeddings[i]
logger.info("Indexing embeddings and updating vectors_ids...")
for i in tqdm(range(0, len(documents), self.index_buffer_size)):
vector_id_map = {}
vector_id = self.faiss_index.ntotal
embeddings = [doc.embedding for doc in documents[i: i + self.index_buffer_size]]
embeddings = np.array(embeddings, dtype="float32")
self.faiss_index.add(embeddings)
for doc in documents[i: i + self.index_buffer_size]:
vector_id_map[doc.id] = vector_id
vector_id += 1
self.update_vector_ids(vector_id_map, index=index)
def train_index(self, documents: Optional[Union[List[dict], List[Document]]], embeddings: Optional[np.array] = None):
"""
Some FAISS indices (e.g. IVF) require initial "training" on a sample of vectors before you can add your final vectors.
The train vectors should come from the same distribution as your final ones.
You can pass either documents (incl. embeddings) or just the plain embeddings that the index shall be trained on.
:param documents: Documents (incl. the embeddings)
:param embeddings: Plain embeddings
:return: None
"""
if embeddings and documents:
raise ValueError("Either pass `documents` or `embeddings`. You passed both.")
if documents:
document_objects = [Document.from_dict(d) if isinstance(d, dict) else d for d in documents]
embeddings = [doc.embedding for doc in document_objects]
embeddings = np.array(embeddings, dtype="float32")
self.faiss_index.train(embeddings)
def delete_all_documents(self, index=None):
index = index or self.index
self.faiss_index.reset()
super().delete_all_documents(index=index)
def query_by_embedding(
self, query_emb: np.array, filters: Optional[dict] = None, top_k: int = 10, index: Optional[str] = None
) -> List[Document]:
"""
Find the document that is most similar to the provided `query_emb` by using a vector similarity metric.
:param query_emb: Embedding of the query (e.g. gathered from DPR)
:param filters: Optional filters to narrow down the search space.
Example: {"name": ["some", "more"], "category": ["only_one"]}
:param top_k: How many documents to return
:param index: (SQL) index name for storing the docs and metadata
:return:
"""
if filters:
raise Exception("Query filters are not implemented for the FAISSDocumentStore.")
if not self.faiss_index:
raise Exception("No index exists. Use 'update_embeddings()` to create an index.")
query_emb = query_emb.reshape(1, -1).astype(np.float32)
score_matrix, vector_id_matrix = self.faiss_index.search(query_emb, top_k)
vector_ids_for_query = [str(vector_id) for vector_id in vector_id_matrix[0] if vector_id != -1]
documents = self.get_documents_by_vector_ids(vector_ids_for_query, index=index)
#assign query score to each document
scores_for_vector_ids: Dict[str, float] = {str(v_id): s for v_id, s in zip(vector_id_matrix[0], score_matrix[0])}
for doc in documents:
doc.score = scores_for_vector_ids[doc.meta["vector_id"]] # type: ignore
doc.probability = (doc.score + 1) / 2
return documents
def save(self, file_path: Union[str, Path]):
"""
Save FAISS Index to the specified file.
:param file_path: Path to save to.
:return: None
"""
faiss.write_index(self.faiss_index, str(file_path))
@classmethod
def load(
cls,
faiss_file_path: Union[str, Path],
sql_url: str,
index_buffer_size: int = 10_000,
):
"""
Load a saved FAISS index from a file and connect to the SQL database.
Note: In order to have a correct mapping from FAISS to SQL,
make sure to use the same SQL DB that you used when calling `save()`.
:param faiss_file_path: Stored FAISS index file. Can be created via calling `save()`
:param sql_url: Connection string to the SQL database that contains your docs and metadata.
:param index_buffer_size: When working with large datasets, the ingestion process(FAISS + SQL) can be buffered in
smaller chunks to reduce memory footprint.
:return:
"""
"""
"""
faiss_index = faiss.read_index(str(faiss_file_path))
return cls(
faiss_index=faiss_index,
sql_url=sql_url,
index_buffer_size=index_buffer_size,
vector_dim=faiss_index.d
)