Make FAISSDocumentStore work with yaml (#1727)

* add faiss_index_path and faiss_config_path

* Add latest docstring and tutorial changes

* remove duplicate cleaning stuff

* refactoring + test for invalid param combination

* adjust type hints

* Add latest docstring and tutorial changes

* add documentation to @preload_index

* Add latest docstring and tutorial changes

* recursive __init__ instead of decorator

* Add latest docstring and tutorial changes

* validate instead of check

* combine ifs

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
tstadel 2021-11-11 11:02:22 +01:00 committed by GitHub
parent 42c8edca54
commit 158460504b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 185 additions and 38 deletions

View File

@ -1117,7 +1117,7 @@ the vector embeddings are indexed in a FAISS Index.
#### \_\_init\_\_
```python
| __init__(sql_url: str = "sqlite:///faiss_document_store.db", vector_dim: int = 768, faiss_index_factory_str: str = "Flat", faiss_index: Optional["faiss.swigfaiss.Index"] = None, return_embedding: bool = False, index: str = "document", similarity: str = "dot_product", embedding_field: str = "embedding", progress_bar: bool = True, duplicate_documents: str = 'overwrite', **kwargs, ,)
| __init__(sql_url: str = "sqlite:///faiss_document_store.db", vector_dim: int = 768, faiss_index_factory_str: str = "Flat", faiss_index: Optional["faiss.swigfaiss.Index"] = None, return_embedding: bool = False, index: str = "document", similarity: str = "dot_product", embedding_field: str = "embedding", progress_bar: bool = True, duplicate_documents: str = 'overwrite', faiss_index_path: Union[str, Path] = None, faiss_config_path: Union[str, Path] = None, **kwargs, ,)
```
**Arguments**:
@ -1158,6 +1158,10 @@ the vector embeddings are indexed in a FAISS Index.
overwrite: Update any existing documents with the same ID when adding documents.
fail: an error is raised if the document ID of the document being added already
exists.
- `faiss_index_path`: Stored FAISS index file. Can be created via calling `save()`.
If specified no other params besides faiss_config_path must be specified.
- `faiss_config_path`: Stored FAISS initial configuration parameters.
Can be created via calling `save()`
<a name="faiss.FAISSDocumentStore.write_documents"></a>
#### write\_documents
@ -1359,15 +1363,6 @@ Note: In order to have a correct mapping from FAISS to SQL,
- `index_path`: Stored FAISS index file. Can be created via calling `save()`
- `config_path`: Stored FAISS initial configuration parameters.
Can be created via calling `save()`
- `sql_url`: Connection string to the SQL database that contains your docs and metadata.
Overrides the value defined in the `faiss_init_params_path` file, if present
- `index`: Index name to load the FAISS index as. It must match the index name used for
when creating the FAISS index. Overrides the value defined in the
`faiss_init_params_path` file, if present
**Returns**:
the DocumentStore
<a name="milvus"></a>
# Module milvus

View File

@ -18,6 +18,7 @@ import numpy as np
from haystack.schema import Document
from haystack.document_stores.sql import SQLDocumentStore
from haystack.document_stores.base import get_batches_from_generator
from inspect import Signature, signature
logger = logging.getLogger(__name__)
@ -45,6 +46,8 @@ class FAISSDocumentStore(SQLDocumentStore):
embedding_field: str = "embedding",
progress_bar: bool = True,
duplicate_documents: str = 'overwrite',
faiss_index_path: Union[str, Path] = None,
faiss_config_path: Union[str, Path] = None,
**kwargs,
):
"""
@ -84,7 +87,20 @@ class FAISSDocumentStore(SQLDocumentStore):
overwrite: Update any existing documents with the same ID when adding documents.
fail: an error is raised if the document ID of the document being added already
exists.
:param faiss_index_path: Stored FAISS index file. Can be created via calling `save()`.
If specified no other params besides faiss_config_path must be specified.
:param faiss_config_path: Stored FAISS initial configuration parameters.
Can be created via calling `save()`
"""
# special case if we want to load an existing index from disk
# load init params from disk and run init again
if faiss_index_path is not None:
sig = signature(self.__class__.__init__)
self._validate_params_load_from_disk(sig, locals(), kwargs)
init_params = self._load_init_params_from_config(faiss_index_path, faiss_config_path)
self.__class__.__init__(self, **init_params)
return
# save init parameters to enable export of component config as YAML
self.set_config(
sql_url=sql_url,
@ -133,6 +149,30 @@ class FAISSDocumentStore(SQLDocumentStore):
index=index
)
self._validate_index_sync()
def _validate_params_load_from_disk(self, sig: Signature, locals: dict, kwargs: dict):
allowed_params = ["faiss_index_path", "faiss_config_path", "self", "kwargs"]
invalid_param_set = False
for param in sig.parameters.values():
if param.name not in allowed_params and param.default != locals[param.name]:
invalid_param_set = True
break
if invalid_param_set or len(kwargs) > 0:
raise ValueError("if faiss_index_path is passed no other params besides faiss_config_path are allowed.")
def _validate_index_sync(self):
# This check ensures the correct document database was loaded.
# If it fails, make sure you provided the path to the database
# used when creating the original FAISS index
if not self.get_document_count() == self.get_embedding_count():
raise ValueError("The number of documents present in the SQL database does not "
"match the number of embeddings in FAISS. Make sure your FAISS "
"configuration file correctly points to the same database that "
"was used when creating the original index.")
def _create_new_index(self, vector_dim: int, metric_type, index_factory: str = "Flat", **kwargs):
if index_factory == "HNSW":
# faiss index factory doesn't give the same results for HNSW IP, therefore direct init.
@ -489,23 +529,7 @@ class FAISSDocumentStore(SQLDocumentStore):
with open(config_path, 'w') as ipp:
json.dump(self.pipeline_config["params"], ipp)
@classmethod
def load(cls, index_path: Union[str, Path], config_path: Optional[Union[str, Path]] = None):
"""
Load a saved FAISS index from a file and connect to the SQL database.
Note: In order to have a correct mapping from FAISS to SQL,
make sure to use the same SQL DB that you used when calling `save()`.
:param index_path: Stored FAISS index file. Can be created via calling `save()`
:param config_path: Stored FAISS initial configuration parameters.
Can be created via calling `save()`
:param sql_url: Connection string to the SQL database that contains your docs and metadata.
Overrides the value defined in the `faiss_init_params_path` file, if present
:param index: Index name to load the FAISS index as. It must match the index name used for
when creating the FAISS index. Overrides the value defined in the
`faiss_init_params_path` file, if present
:return: the DocumentStore
"""
def _load_init_params_from_config(self, index_path: Union[str, Path], config_path: Optional[Union[str, Path]] = None):
if not config_path:
index_path = Path(index_path)
config_path = index_path.with_suffix(".json")
@ -523,17 +547,19 @@ class FAISSDocumentStore(SQLDocumentStore):
# Add other init params to override the ones defined in the init params file
init_params["faiss_index"] = faiss_index
init_params["vector_dim"]=faiss_index.d
init_params["vector_dim"] = faiss_index.d
document_store = cls(**init_params)
return init_params
# This check ensures the correct document database was loaded.
# If it fails, make sure you provided the path to the database
# used when creating the original FAISS index
if not document_store.get_document_count() == document_store.get_embedding_count():
raise ValueError("The number of documents present in the SQL database does not "
"match the number of embeddings in FAISS. Make sure your FAISS "
"configuration file correctly points to the same database that "
"was used when creating the original index.")
@classmethod
def load(cls, index_path: Union[str, Path], config_path: Optional[Union[str, Path]] = None):
"""
Load a saved FAISS index from a file and connect to the SQL database.
Note: In order to have a correct mapping from FAISS to SQL,
make sure to use the same SQL DB that you used when calling `save()`.
return document_store
:param index_path: Stored FAISS index file. Can be created via calling `save()`
:param config_path: Stored FAISS initial configuration parameters.
Can be created via calling `save()`
"""
return cls(faiss_index_path=index_path, faiss_config_path=config_path)

View File

@ -0,0 +1,31 @@
version: '0.7'
components:
- name: DPRRetriever
type: DensePassageRetriever
params:
document_store: NewFAISSDocumentStore
- name: PDFConverter
type: PDFToTextConverter
params:
remove_numeric_tables: false
- name: Preprocessor
type: PreProcessor
params:
clean_whitespace: true
- name: NewFAISSDocumentStore
type: FAISSDocumentStore
pipelines:
- name: indexing_pipeline
type: Pipeline
nodes:
- name: PDFConverter
inputs: [File]
- name: Preprocessor
inputs: [PDFConverter]
- name: DPRRetriever
inputs: [Preprocessor]
- name: NewFAISSDocumentStore
inputs: [DPRRetriever]

View File

@ -0,0 +1,19 @@
version: '0.7'
components:
- name: DPRRetriever
type: DensePassageRetriever
params:
document_store: ExistingFAISSDocumentStore
- name: ExistingFAISSDocumentStore
type: FAISSDocumentStore
params:
faiss_index_path: 'existing_faiss_document_store'
pipelines:
- name: query_pipeline
type: Pipeline
nodes:
- name: DPRRetriever
inputs: [Query]

View File

@ -51,6 +51,16 @@ def test_faiss_index_save_and_load(tmp_path):
# Check if the init parameters are kept
assert not new_document_store.progress_bar
# test loading the index via init
new_document_store = FAISSDocumentStore(faiss_index_path=tmp_path / "haystack_test_faiss")
# check faiss index is restored
assert new_document_store.faiss_indexes[document_store.index].ntotal == len(DOCUMENTS)
# check if documents are restored
assert len(new_document_store.get_all_documents()) == len(DOCUMENTS)
# Check if the init parameters are kept
assert not new_document_store.progress_bar
@pytest.mark.skipif(sys.platform in ['win32', 'cygwin'], reason="Test with tmp_path not working on windows runner")
def test_faiss_index_save_and_load_custom_path(tmp_path):
@ -80,6 +90,31 @@ def test_faiss_index_save_and_load_custom_path(tmp_path):
# Check if the init parameters are kept
assert not new_document_store.progress_bar
# test loading the index via init
new_document_store = FAISSDocumentStore(faiss_index_path=tmp_path / "haystack_test_faiss", faiss_config_path=tmp_path / "custom_path.json")
# check faiss index is restored
assert new_document_store.faiss_indexes[document_store.index].ntotal == len(DOCUMENTS)
# check if documents are restored
assert len(new_document_store.get_all_documents()) == len(DOCUMENTS)
# Check if the init parameters are kept
assert not new_document_store.progress_bar
@pytest.mark.skipif(sys.platform in ['win32', 'cygwin'], reason="Test with tmp_path not working on windows runner")
def test_faiss_index_mutual_exclusive_args(tmp_path):
with pytest.raises(ValueError):
FAISSDocumentStore(
sql_url=f"sqlite:////{tmp_path/'haystack_test.db'}",
faiss_index_path=f"{tmp_path/'haystack_test'}"
)
with pytest.raises(ValueError):
FAISSDocumentStore(
f"sqlite:////{tmp_path/'haystack_test.db'}",
faiss_index_path=f"{tmp_path/'haystack_test'}"
)
@pytest.mark.parametrize("document_store", ["faiss"], indirect=True)
@pytest.mark.parametrize("index_buffer_size", [10_000, 2])

View File

@ -1,5 +1,6 @@
from pathlib import Path
import os
import json
import math
import pytest
@ -745,3 +746,43 @@ def test_query_pipeline_with_document_classifier(document_store):
assert prediction["answers"][0].meta["classification"]["label"] == "joy"
assert "_debug" not in prediction.keys()
def test_existing_faiss_document_store():
clean_faiss_document_store()
pipeline = Pipeline.load_from_yaml(
Path(__file__).parent/"samples"/"pipeline"/"test_pipeline_faiss_indexing.yaml", pipeline_name="indexing_pipeline"
)
pipeline.run(
file_paths=Path(__file__).parent/"samples"/"pdf"/"sample_pdf_1.pdf"
)
new_document_store = pipeline.get_document_store()
new_document_store.save('existing_faiss_document_store')
# test correct load of query pipeline from yaml
pipeline = Pipeline.load_from_yaml(
Path(__file__).parent/"samples"/"pipeline"/"test_pipeline_faiss_retrieval.yaml", pipeline_name="query_pipeline"
)
retriever = pipeline.get_node("DPRRetriever")
existing_document_store = retriever.document_store
faiss_index = existing_document_store.faiss_indexes['document']
assert faiss_index.ntotal == 2
prediction = pipeline.run(
query="Who made the PDF specification?", params={"DPRRetriever": {"top_k": 10}}
)
assert prediction["query"] == "Who made the PDF specification?"
assert len(prediction["documents"]) == 2
clean_faiss_document_store()
def clean_faiss_document_store():
if Path('existing_faiss_document_store').exists():
os.remove('existing_faiss_document_store')
if Path('existing_faiss_document_store.json').exists():
os.remove('existing_faiss_document_store.json')
if Path('faiss_document_store.db').exists():
os.remove('faiss_document_store.db')