mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-28 15:38:36 +00:00
Make FAISSDocumentStore work with yaml (#1727)
* add faiss_index_path and faiss_config_path * Add latest docstring and tutorial changes * remove duplicate cleaning stuff * refactoring + test for invalid param combination * adjust type hints * Add latest docstring and tutorial changes * add documentation to @preload_index * Add latest docstring and tutorial changes * recursive __init__ instead of decorator * Add latest docstring and tutorial changes * validate instead of check * combine ifs Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
parent
42c8edca54
commit
158460504b
@ -1117,7 +1117,7 @@ the vector embeddings are indexed in a FAISS Index.
|
||||
#### \_\_init\_\_
|
||||
|
||||
```python
|
||||
| __init__(sql_url: str = "sqlite:///faiss_document_store.db", vector_dim: int = 768, faiss_index_factory_str: str = "Flat", faiss_index: Optional["faiss.swigfaiss.Index"] = None, return_embedding: bool = False, index: str = "document", similarity: str = "dot_product", embedding_field: str = "embedding", progress_bar: bool = True, duplicate_documents: str = 'overwrite', **kwargs, ,)
|
||||
| __init__(sql_url: str = "sqlite:///faiss_document_store.db", vector_dim: int = 768, faiss_index_factory_str: str = "Flat", faiss_index: Optional["faiss.swigfaiss.Index"] = None, return_embedding: bool = False, index: str = "document", similarity: str = "dot_product", embedding_field: str = "embedding", progress_bar: bool = True, duplicate_documents: str = 'overwrite', faiss_index_path: Union[str, Path] = None, faiss_config_path: Union[str, Path] = None, **kwargs, ,)
|
||||
```
|
||||
|
||||
**Arguments**:
|
||||
@ -1158,6 +1158,10 @@ the vector embeddings are indexed in a FAISS Index.
|
||||
overwrite: Update any existing documents with the same ID when adding documents.
|
||||
fail: an error is raised if the document ID of the document being added already
|
||||
exists.
|
||||
- `faiss_index_path`: Stored FAISS index file. Can be created via calling `save()`.
|
||||
If specified no other params besides faiss_config_path must be specified.
|
||||
- `faiss_config_path`: Stored FAISS initial configuration parameters.
|
||||
Can be created via calling `save()`
|
||||
|
||||
<a name="faiss.FAISSDocumentStore.write_documents"></a>
|
||||
#### write\_documents
|
||||
@ -1359,15 +1363,6 @@ Note: In order to have a correct mapping from FAISS to SQL,
|
||||
- `index_path`: Stored FAISS index file. Can be created via calling `save()`
|
||||
- `config_path`: Stored FAISS initial configuration parameters.
|
||||
Can be created via calling `save()`
|
||||
- `sql_url`: Connection string to the SQL database that contains your docs and metadata.
|
||||
Overrides the value defined in the `faiss_init_params_path` file, if present
|
||||
- `index`: Index name to load the FAISS index as. It must match the index name used for
|
||||
when creating the FAISS index. Overrides the value defined in the
|
||||
`faiss_init_params_path` file, if present
|
||||
|
||||
**Returns**:
|
||||
|
||||
the DocumentStore
|
||||
|
||||
<a name="milvus"></a>
|
||||
# Module milvus
|
||||
|
||||
@ -18,6 +18,7 @@ import numpy as np
|
||||
from haystack.schema import Document
|
||||
from haystack.document_stores.sql import SQLDocumentStore
|
||||
from haystack.document_stores.base import get_batches_from_generator
|
||||
from inspect import Signature, signature
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -45,6 +46,8 @@ class FAISSDocumentStore(SQLDocumentStore):
|
||||
embedding_field: str = "embedding",
|
||||
progress_bar: bool = True,
|
||||
duplicate_documents: str = 'overwrite',
|
||||
faiss_index_path: Union[str, Path] = None,
|
||||
faiss_config_path: Union[str, Path] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@ -84,7 +87,20 @@ class FAISSDocumentStore(SQLDocumentStore):
|
||||
overwrite: Update any existing documents with the same ID when adding documents.
|
||||
fail: an error is raised if the document ID of the document being added already
|
||||
exists.
|
||||
:param faiss_index_path: Stored FAISS index file. Can be created via calling `save()`.
|
||||
If specified no other params besides faiss_config_path must be specified.
|
||||
:param faiss_config_path: Stored FAISS initial configuration parameters.
|
||||
Can be created via calling `save()`
|
||||
"""
|
||||
# special case if we want to load an existing index from disk
|
||||
# load init params from disk and run init again
|
||||
if faiss_index_path is not None:
|
||||
sig = signature(self.__class__.__init__)
|
||||
self._validate_params_load_from_disk(sig, locals(), kwargs)
|
||||
init_params = self._load_init_params_from_config(faiss_index_path, faiss_config_path)
|
||||
self.__class__.__init__(self, **init_params)
|
||||
return
|
||||
|
||||
# save init parameters to enable export of component config as YAML
|
||||
self.set_config(
|
||||
sql_url=sql_url,
|
||||
@ -133,6 +149,30 @@ class FAISSDocumentStore(SQLDocumentStore):
|
||||
index=index
|
||||
)
|
||||
|
||||
self._validate_index_sync()
|
||||
|
||||
def _validate_params_load_from_disk(self, sig: Signature, locals: dict, kwargs: dict):
|
||||
allowed_params = ["faiss_index_path", "faiss_config_path", "self", "kwargs"]
|
||||
invalid_param_set = False
|
||||
|
||||
for param in sig.parameters.values():
|
||||
if param.name not in allowed_params and param.default != locals[param.name]:
|
||||
invalid_param_set = True
|
||||
break
|
||||
|
||||
if invalid_param_set or len(kwargs) > 0:
|
||||
raise ValueError("if faiss_index_path is passed no other params besides faiss_config_path are allowed.")
|
||||
|
||||
def _validate_index_sync(self):
|
||||
# This check ensures the correct document database was loaded.
|
||||
# If it fails, make sure you provided the path to the database
|
||||
# used when creating the original FAISS index
|
||||
if not self.get_document_count() == self.get_embedding_count():
|
||||
raise ValueError("The number of documents present in the SQL database does not "
|
||||
"match the number of embeddings in FAISS. Make sure your FAISS "
|
||||
"configuration file correctly points to the same database that "
|
||||
"was used when creating the original index.")
|
||||
|
||||
def _create_new_index(self, vector_dim: int, metric_type, index_factory: str = "Flat", **kwargs):
|
||||
if index_factory == "HNSW":
|
||||
# faiss index factory doesn't give the same results for HNSW IP, therefore direct init.
|
||||
@ -489,23 +529,7 @@ class FAISSDocumentStore(SQLDocumentStore):
|
||||
with open(config_path, 'w') as ipp:
|
||||
json.dump(self.pipeline_config["params"], ipp)
|
||||
|
||||
@classmethod
|
||||
def load(cls, index_path: Union[str, Path], config_path: Optional[Union[str, Path]] = None):
|
||||
"""
|
||||
Load a saved FAISS index from a file and connect to the SQL database.
|
||||
Note: In order to have a correct mapping from FAISS to SQL,
|
||||
make sure to use the same SQL DB that you used when calling `save()`.
|
||||
|
||||
:param index_path: Stored FAISS index file. Can be created via calling `save()`
|
||||
:param config_path: Stored FAISS initial configuration parameters.
|
||||
Can be created via calling `save()`
|
||||
:param sql_url: Connection string to the SQL database that contains your docs and metadata.
|
||||
Overrides the value defined in the `faiss_init_params_path` file, if present
|
||||
:param index: Index name to load the FAISS index as. It must match the index name used for
|
||||
when creating the FAISS index. Overrides the value defined in the
|
||||
`faiss_init_params_path` file, if present
|
||||
:return: the DocumentStore
|
||||
"""
|
||||
def _load_init_params_from_config(self, index_path: Union[str, Path], config_path: Optional[Union[str, Path]] = None):
|
||||
if not config_path:
|
||||
index_path = Path(index_path)
|
||||
config_path = index_path.with_suffix(".json")
|
||||
@ -523,17 +547,19 @@ class FAISSDocumentStore(SQLDocumentStore):
|
||||
|
||||
# Add other init params to override the ones defined in the init params file
|
||||
init_params["faiss_index"] = faiss_index
|
||||
init_params["vector_dim"]=faiss_index.d
|
||||
init_params["vector_dim"] = faiss_index.d
|
||||
|
||||
document_store = cls(**init_params)
|
||||
return init_params
|
||||
|
||||
# This check ensures the correct document database was loaded.
|
||||
# If it fails, make sure you provided the path to the database
|
||||
# used when creating the original FAISS index
|
||||
if not document_store.get_document_count() == document_store.get_embedding_count():
|
||||
raise ValueError("The number of documents present in the SQL database does not "
|
||||
"match the number of embeddings in FAISS. Make sure your FAISS "
|
||||
"configuration file correctly points to the same database that "
|
||||
"was used when creating the original index.")
|
||||
@classmethod
|
||||
def load(cls, index_path: Union[str, Path], config_path: Optional[Union[str, Path]] = None):
|
||||
"""
|
||||
Load a saved FAISS index from a file and connect to the SQL database.
|
||||
Note: In order to have a correct mapping from FAISS to SQL,
|
||||
make sure to use the same SQL DB that you used when calling `save()`.
|
||||
|
||||
return document_store
|
||||
:param index_path: Stored FAISS index file. Can be created via calling `save()`
|
||||
:param config_path: Stored FAISS initial configuration parameters.
|
||||
Can be created via calling `save()`
|
||||
"""
|
||||
return cls(faiss_index_path=index_path, faiss_config_path=config_path)
|
||||
|
||||
31
test/samples/pipeline/test_pipeline_faiss_indexing.yaml
Normal file
31
test/samples/pipeline/test_pipeline_faiss_indexing.yaml
Normal file
@ -0,0 +1,31 @@
|
||||
version: '0.7'
|
||||
|
||||
components:
|
||||
- name: DPRRetriever
|
||||
type: DensePassageRetriever
|
||||
params:
|
||||
document_store: NewFAISSDocumentStore
|
||||
- name: PDFConverter
|
||||
type: PDFToTextConverter
|
||||
params:
|
||||
remove_numeric_tables: false
|
||||
- name: Preprocessor
|
||||
type: PreProcessor
|
||||
params:
|
||||
clean_whitespace: true
|
||||
- name: NewFAISSDocumentStore
|
||||
type: FAISSDocumentStore
|
||||
|
||||
|
||||
pipelines:
|
||||
- name: indexing_pipeline
|
||||
type: Pipeline
|
||||
nodes:
|
||||
- name: PDFConverter
|
||||
inputs: [File]
|
||||
- name: Preprocessor
|
||||
inputs: [PDFConverter]
|
||||
- name: DPRRetriever
|
||||
inputs: [Preprocessor]
|
||||
- name: NewFAISSDocumentStore
|
||||
inputs: [DPRRetriever]
|
||||
19
test/samples/pipeline/test_pipeline_faiss_retrieval.yaml
Normal file
19
test/samples/pipeline/test_pipeline_faiss_retrieval.yaml
Normal file
@ -0,0 +1,19 @@
|
||||
version: '0.7'
|
||||
|
||||
components:
|
||||
- name: DPRRetriever
|
||||
type: DensePassageRetriever
|
||||
params:
|
||||
document_store: ExistingFAISSDocumentStore
|
||||
- name: ExistingFAISSDocumentStore
|
||||
type: FAISSDocumentStore
|
||||
params:
|
||||
faiss_index_path: 'existing_faiss_document_store'
|
||||
|
||||
|
||||
pipelines:
|
||||
- name: query_pipeline
|
||||
type: Pipeline
|
||||
nodes:
|
||||
- name: DPRRetriever
|
||||
inputs: [Query]
|
||||
@ -51,6 +51,16 @@ def test_faiss_index_save_and_load(tmp_path):
|
||||
# Check if the init parameters are kept
|
||||
assert not new_document_store.progress_bar
|
||||
|
||||
# test loading the index via init
|
||||
new_document_store = FAISSDocumentStore(faiss_index_path=tmp_path / "haystack_test_faiss")
|
||||
|
||||
# check faiss index is restored
|
||||
assert new_document_store.faiss_indexes[document_store.index].ntotal == len(DOCUMENTS)
|
||||
# check if documents are restored
|
||||
assert len(new_document_store.get_all_documents()) == len(DOCUMENTS)
|
||||
# Check if the init parameters are kept
|
||||
assert not new_document_store.progress_bar
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform in ['win32', 'cygwin'], reason="Test with tmp_path not working on windows runner")
|
||||
def test_faiss_index_save_and_load_custom_path(tmp_path):
|
||||
@ -80,6 +90,31 @@ def test_faiss_index_save_and_load_custom_path(tmp_path):
|
||||
# Check if the init parameters are kept
|
||||
assert not new_document_store.progress_bar
|
||||
|
||||
# test loading the index via init
|
||||
new_document_store = FAISSDocumentStore(faiss_index_path=tmp_path / "haystack_test_faiss", faiss_config_path=tmp_path / "custom_path.json")
|
||||
|
||||
# check faiss index is restored
|
||||
assert new_document_store.faiss_indexes[document_store.index].ntotal == len(DOCUMENTS)
|
||||
# check if documents are restored
|
||||
assert len(new_document_store.get_all_documents()) == len(DOCUMENTS)
|
||||
# Check if the init parameters are kept
|
||||
assert not new_document_store.progress_bar
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform in ['win32', 'cygwin'], reason="Test with tmp_path not working on windows runner")
|
||||
def test_faiss_index_mutual_exclusive_args(tmp_path):
|
||||
with pytest.raises(ValueError):
|
||||
FAISSDocumentStore(
|
||||
sql_url=f"sqlite:////{tmp_path/'haystack_test.db'}",
|
||||
faiss_index_path=f"{tmp_path/'haystack_test'}"
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
FAISSDocumentStore(
|
||||
f"sqlite:////{tmp_path/'haystack_test.db'}",
|
||||
faiss_index_path=f"{tmp_path/'haystack_test'}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("document_store", ["faiss"], indirect=True)
|
||||
@pytest.mark.parametrize("index_buffer_size", [10_000, 2])
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from pathlib import Path
|
||||
|
||||
import os
|
||||
import json
|
||||
import math
|
||||
import pytest
|
||||
@ -745,3 +746,43 @@ def test_query_pipeline_with_document_classifier(document_store):
|
||||
assert prediction["answers"][0].meta["classification"]["label"] == "joy"
|
||||
assert "_debug" not in prediction.keys()
|
||||
|
||||
|
||||
def test_existing_faiss_document_store():
|
||||
clean_faiss_document_store()
|
||||
|
||||
pipeline = Pipeline.load_from_yaml(
|
||||
Path(__file__).parent/"samples"/"pipeline"/"test_pipeline_faiss_indexing.yaml", pipeline_name="indexing_pipeline"
|
||||
)
|
||||
pipeline.run(
|
||||
file_paths=Path(__file__).parent/"samples"/"pdf"/"sample_pdf_1.pdf"
|
||||
)
|
||||
|
||||
new_document_store = pipeline.get_document_store()
|
||||
new_document_store.save('existing_faiss_document_store')
|
||||
|
||||
# test correct load of query pipeline from yaml
|
||||
pipeline = Pipeline.load_from_yaml(
|
||||
Path(__file__).parent/"samples"/"pipeline"/"test_pipeline_faiss_retrieval.yaml", pipeline_name="query_pipeline"
|
||||
)
|
||||
|
||||
retriever = pipeline.get_node("DPRRetriever")
|
||||
existing_document_store = retriever.document_store
|
||||
faiss_index = existing_document_store.faiss_indexes['document']
|
||||
assert faiss_index.ntotal == 2
|
||||
|
||||
prediction = pipeline.run(
|
||||
query="Who made the PDF specification?", params={"DPRRetriever": {"top_k": 10}}
|
||||
)
|
||||
|
||||
assert prediction["query"] == "Who made the PDF specification?"
|
||||
assert len(prediction["documents"]) == 2
|
||||
clean_faiss_document_store()
|
||||
|
||||
|
||||
def clean_faiss_document_store():
|
||||
if Path('existing_faiss_document_store').exists():
|
||||
os.remove('existing_faiss_document_store')
|
||||
if Path('existing_faiss_document_store.json').exists():
|
||||
os.remove('existing_faiss_document_store.json')
|
||||
if Path('faiss_document_store.db').exists():
|
||||
os.remove('faiss_document_store.db')
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user