refactor: add batch_size to FAISS __init__ (#6401)

* refactor: add batch_size to FAISS __init__

* refactor: add batch_size to FAISS __init__

* add release note to refactor: add batch_size to FAISS __init__

* fix release note

* add batch_size to docstrings

---------

Co-authored-by: anakin87 <stefanofiorucci@gmail.com>
This commit is contained in:
pandasar13 2023-11-23 17:27:24 +01:00 committed by GitHub
parent 4ec6a60a76
commit edb40b6c1b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 21 additions and 5 deletions

View File

@ -57,6 +57,7 @@ class FAISSDocumentStore(SQLDocumentStore):
ef_search: int = 20, ef_search: int = 20,
ef_construction: int = 80, ef_construction: int = 80,
validate_index_sync: bool = True, validate_index_sync: bool = True,
batch_size: int = 10_000,
): ):
""" """
:param sql_url: SQL connection URL for the database. The default value is "sqlite:///faiss_document_store.db"`. It defaults to a local, file-based SQLite DB. For large scale deployment, we recommend Postgres. :param sql_url: SQL connection URL for the database. The default value is "sqlite:///faiss_document_store.db"`. It defaults to a local, file-based SQLite DB. For large scale deployment, we recommend Postgres.
@ -103,6 +104,8 @@ class FAISSDocumentStore(SQLDocumentStore):
:param ef_search: Used only if `index_factory == "HNSW"`. :param ef_search: Used only if `index_factory == "HNSW"`.
:param ef_construction: Used only if `index_factory == "HNSW"`. :param ef_construction: Used only if `index_factory == "HNSW"`.
:param validate_index_sync: Checks if the document count equals the embedding count at initialization time. :param validate_index_sync: Checks if the document count equals the embedding count at initialization time.
:param batch_size: Number of Documents to index at once / Number of queries to execute at once. If you face
memory issues, decrease the batch_size.
""" """
faiss_import.check() faiss_import.check()
# special case if we want to load an existing index from disk # special case if we want to load an existing index from disk
@ -152,6 +155,7 @@ class FAISSDocumentStore(SQLDocumentStore):
self.return_embedding = return_embedding self.return_embedding = return_embedding
self.embedding_field = embedding_field self.embedding_field = embedding_field
self.batch_size = batch_size
self.progress_bar = progress_bar self.progress_bar = progress_bar
@ -216,7 +220,7 @@ class FAISSDocumentStore(SQLDocumentStore):
self, self,
documents: Union[List[dict], List[Document]], documents: Union[List[dict], List[Document]],
index: Optional[str] = None, index: Optional[str] = None,
batch_size: int = 10_000, batch_size: Optional[int] = None,
duplicate_documents: Optional[str] = None, duplicate_documents: Optional[str] = None,
headers: Optional[Dict[str, str]] = None, headers: Optional[Dict[str, str]] = None,
) -> None: ) -> None:
@ -240,6 +244,8 @@ class FAISSDocumentStore(SQLDocumentStore):
raise NotImplementedError("FAISSDocumentStore does not support headers.") raise NotImplementedError("FAISSDocumentStore does not support headers.")
index = index or self.index index = index or self.index
batch_size = batch_size or self.batch_size
duplicate_documents = duplicate_documents or self.duplicate_documents duplicate_documents = duplicate_documents or self.duplicate_documents
assert ( assert (
duplicate_documents in self.duplicate_documents_options duplicate_documents in self.duplicate_documents_options
@ -324,7 +330,7 @@ class FAISSDocumentStore(SQLDocumentStore):
index: Optional[str] = None, index: Optional[str] = None,
update_existing_embeddings: bool = True, update_existing_embeddings: bool = True,
filters: Optional[FilterType] = None, filters: Optional[FilterType] = None,
batch_size: int = 10_000, batch_size: Optional[int] = None,
): ):
""" """
Updates the embeddings in the the document store using the encoding model specified in the retriever. Updates the embeddings in the the document store using the encoding model specified in the retriever.
@ -342,6 +348,7 @@ class FAISSDocumentStore(SQLDocumentStore):
:return: None :return: None
""" """
index = index or self.index index = index or self.index
batch_size = batch_size or self.batch_size
if update_existing_embeddings is True: if update_existing_embeddings is True:
if filters is None: if filters is None:
@ -404,9 +411,10 @@ class FAISSDocumentStore(SQLDocumentStore):
index: Optional[str] = None, index: Optional[str] = None,
filters: Optional[FilterType] = None, filters: Optional[FilterType] = None,
return_embedding: Optional[bool] = None, return_embedding: Optional[bool] = None,
batch_size: int = 10_000, batch_size: Optional[int] = None,
headers: Optional[Dict[str, str]] = None, headers: Optional[Dict[str, str]] = None,
) -> List[Document]: ) -> List[Document]:
batch_size = batch_size or self.batch_size
if headers: if headers:
raise NotImplementedError("FAISSDocumentStore does not support headers.") raise NotImplementedError("FAISSDocumentStore does not support headers.")
@ -421,7 +429,7 @@ class FAISSDocumentStore(SQLDocumentStore):
index: Optional[str] = None, index: Optional[str] = None,
filters: Optional[FilterType] = None, filters: Optional[FilterType] = None,
return_embedding: Optional[bool] = None, return_embedding: Optional[bool] = None,
batch_size: int = 10_000, batch_size: Optional[int] = None,
headers: Optional[Dict[str, str]] = None, headers: Optional[Dict[str, str]] = None,
) -> Generator[Document, None, None]: ) -> Generator[Document, None, None]:
""" """
@ -440,6 +448,7 @@ class FAISSDocumentStore(SQLDocumentStore):
raise NotImplementedError("FAISSDocumentStore does not support headers.") raise NotImplementedError("FAISSDocumentStore does not support headers.")
index = index or self.index index = index or self.index
batch_size = batch_size or self.batch_size
documents = super(FAISSDocumentStore, self).get_all_documents_generator( documents = super(FAISSDocumentStore, self).get_all_documents_generator(
index=index, filters=filters, batch_size=batch_size, return_embedding=False index=index, filters=filters, batch_size=batch_size, return_embedding=False
) )
@ -455,13 +464,15 @@ class FAISSDocumentStore(SQLDocumentStore):
self, self,
ids: List[str], ids: List[str],
index: Optional[str] = None, index: Optional[str] = None,
batch_size: int = 10_000, batch_size: Optional[int] = None,
headers: Optional[Dict[str, str]] = None, headers: Optional[Dict[str, str]] = None,
) -> List[Document]: ) -> List[Document]:
if headers: if headers:
raise NotImplementedError("FAISSDocumentStore does not support headers.") raise NotImplementedError("FAISSDocumentStore does not support headers.")
index = index or self.index index = index or self.index
batch_size = batch_size or self.batch_size
documents = super(FAISSDocumentStore, self).get_documents_by_id(ids=ids, index=index, batch_size=batch_size) documents = super(FAISSDocumentStore, self).get_documents_by_id(ids=ids, index=index, batch_size=batch_size)
if self.return_embedding: if self.return_embedding:
for doc in documents: for doc in documents:

View File

@ -0,0 +1,5 @@
---
enhancements:
- |
Add batch_size to the __init__ method of FAISS Document Store. This works as the default value for all methods of
FAISS Document Store that support batch_size.