mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-08-31 03:46:30 +00:00
Integrate Weaviate as another DocumentStore (#1064)
* Annotation Tool: data is not persisted when using local version #853 * First version of weaviate * First version of weaviate * First version of weaviate * Updated comments * Updated comments * ran query, get and write tests * update embeddings, dynamic schema and filters implemented * Initial set of tests and fixes * Tests added for update_embeddings and delete documents * introduced duplicate documents fix * fixed mypy errors * Added Weaviate to requirements * Fix the weaviate docker env variables * Fixing test dependencies for now * Created weaviate test marker and fixed query * Update docstring * Add documentation * Bump up weaviate version * Bump up weaviate version in documentation * Bump up weaviate version in documentation * Updgrade weaviate version Co-authored-by: Malte Pietsch <malte.pietsch@deepset.ai>
This commit is contained in:
parent
db17d73a82
commit
49886f88f0
3
.github/workflows/ci.yml
vendored
3
.github/workflows/ci.yml
vendored
@ -76,6 +76,9 @@ jobs:
|
||||
- name: Run Milvus
|
||||
run: docker run -d -p 19530:19530 -p 19121:19121 milvusdb/milvus:1.1.0-cpu-d050721-5e559c
|
||||
|
||||
- name: Run Weaviate
|
||||
run: docker run -d -p 8080:8080 --name haystack_test_weaviate --env AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED='true' --env PERSISTENCE_DATA_PATH='/var/lib/weaviate' semitechnologies/weaviate:1.4.0
|
||||
|
||||
- name: Run GraphDB
|
||||
run: docker run -d -p 7200:7200 --name haystack_test_graphdb deepset/graphdb-free:9.4.1-adoptopenjdk11
|
||||
|
||||
|
@ -116,6 +116,27 @@ from haystack.document_store import SQLDocumentStore
|
||||
document_store = SQLDocumentStore()
|
||||
```
|
||||
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="tab">
|
||||
<input type="radio" id="tab-1-6" name="tab-group-1">
|
||||
<label class="labelouter" for="tab-1-6">Weaviate</label>
|
||||
<div class="tabcontent">
|
||||
|
||||
The `WeaviateDocumentStore` requires a running Weaviate Server.
|
||||
You can start a basic instance like this (see Weaviate docs for details):
|
||||
```
|
||||
docker run -d -p 8080:8080 --env AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED='true' --env PERSISTENCE_DATA_PATH='/var/lib/weaviate' semitechnologies/weaviate:1.4.0
|
||||
```
|
||||
|
||||
Afterwards, you can use it in Haystack:
|
||||
```python
|
||||
from haystack.document_store import WeaviateDocumentStore
|
||||
|
||||
document_store = WeaviateDocumentStore()
|
||||
```
|
||||
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@ -264,6 +285,24 @@ The Document Stores have different characteristics. You should choose one depend
|
||||
</div>
|
||||
</div>
|
||||
|
||||
|
||||
<div class="tab">
|
||||
<input type="radio" id="tab-2-6" name="tab-group-2">
|
||||
<label class="labelouter" for="tab-2-6">Weaviate</label>
|
||||
<div class="tabcontent">
|
||||
|
||||
**Pros:**
|
||||
- Simple vector search
|
||||
- Stores everything in one place: documents, meta data and vectors - so less network overhead when scaling this up
|
||||
- Allows combination of vector search and scalar filtering, i.e. you can filter for a certain tag and do dense retrieval on that subset
|
||||
|
||||
**Cons:**
|
||||
- Less options for ANN algorithms than FAISS or Milvus
|
||||
- No BM25 / Tf-idf retrieval
|
||||
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
|
||||
<div class="recommendation">
|
||||
@ -276,4 +315,4 @@ The Document Stores have different characteristics. You should choose one depend
|
||||
|
||||
**Vector Specialist:** Use the `MilvusDocumentStore`, if you want to focus on dense retrieval and possibly deal with larger datasets
|
||||
|
||||
</div>
|
||||
</div>
|
||||
|
716
haystack/document_store/weaviate.py
Normal file
716
haystack/document_store/weaviate.py
Normal file
@ -0,0 +1,716 @@
|
||||
import logging
|
||||
from typing import Any, Dict, Generator, List, Optional, Union
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
from haystack import Document
|
||||
from haystack.document_store.base import BaseDocumentStore
|
||||
from haystack.utils import get_batches_from_generator
|
||||
|
||||
from weaviate import client, auth, AuthClientPassword
|
||||
from weaviate import ObjectsBatchRequest
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class WeaviateDocumentStore(BaseDocumentStore):
|
||||
"""
|
||||
|
||||
Weaviate is a cloud-native, modular, real-time vector search engine built to scale your machine learning models.
|
||||
(See https://www.semi.technology/developers/weaviate/current/index.html#what-is-weaviate)
|
||||
|
||||
Some of the key differences in contrast to FAISS & Milvus:
|
||||
1. Stores everything in one place: documents, meta data and vectors - so less network overhead when scaling this up
|
||||
2. Allows combination of vector search and scalar filtering, i.e. you can filter for a certain tag and do dense retrieval on that subset
|
||||
3. Has less variety of ANN algorithms, as of now only HNSW.
|
||||
|
||||
Weaviate python client is used to connect to the server, more details are here
|
||||
https://weaviate-python-client.readthedocs.io/en/docs/weaviate.html
|
||||
|
||||
Usage:
|
||||
1. Start a Weaviate server (see https://www.semi.technology/developers/weaviate/current/getting-started/installation.html)
|
||||
2. Init a WeaviateDocumentStore in Haystack
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: Union[str, List[str]] = "http://localhost",
|
||||
port: Union[int, List[int]] = 8080,
|
||||
timeout_config: tuple = (5, 15),
|
||||
username: str = None,
|
||||
password: str = None,
|
||||
index: str = "Document",
|
||||
embedding_dim: int = 768,
|
||||
text_field: str = "text",
|
||||
name_field: str = "name",
|
||||
faq_question_field = "question",
|
||||
similarity: str = "dot_product",
|
||||
index_type: str = "hnsw",
|
||||
custom_schema: Optional[dict] = None,
|
||||
return_embedding: bool = False,
|
||||
embedding_field: str = "embedding",
|
||||
progress_bar: bool = True,
|
||||
duplicate_documents: str = 'overwrite',
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
:param host: Weaviate server connection URL for storing and processing documents and vectors.
|
||||
For more details, refer "https://www.semi.technology/developers/weaviate/current/getting-started/installation.html"
|
||||
:param port: port of Weaviate instance
|
||||
:param timeout_config: Weaviate Timeout config as a tuple of (retries, time out seconds).
|
||||
:param username: username (standard authentication via http_auth)
|
||||
:param password: password (standard authentication via http_auth)
|
||||
:param index: Index name for document text, embedding and metadata (in Weaviate terminology, this is a "Class" in Weaviate schema).
|
||||
:param embedding_dim: The embedding vector size. Default: 768.
|
||||
:param text_field: Name of field that might contain the answer and will therefore be passed to the Reader Model (e.g. "full_text").
|
||||
If no Reader is used (e.g. in FAQ-Style QA) the plain content of this field will just be returned.
|
||||
:param name_field: Name of field that contains the title of the the doc
|
||||
:param faq_question_field: Name of field containing the question in case of FAQ-Style QA
|
||||
:param similarity: The similarity function used to compare document vectors. 'dot_product' is the default.
|
||||
:param index_type: Index type of any vector object defined in weaviate schema. The vector index type is pluggable.
|
||||
Currently, HSNW is only supported.
|
||||
See: https://www.semi.technology/developers/weaviate/current/more-resources/performance.html
|
||||
:param custom_schema: Allows to create custom schema in Weaviate, for more details
|
||||
See https://www.semi.technology/developers/weaviate/current/data-schema/schema-configuration.html
|
||||
:param module_name : Vectorization module to convert data into vectors. Default is "text2vec-trasnformers"
|
||||
For more details, See https://www.semi.technology/developers/weaviate/current/modules/
|
||||
:param return_embedding: To return document embedding.
|
||||
:param embedding_field: Name of field containing an embedding vector.
|
||||
:param progress_bar: Whether to show a tqdm progress bar or not.
|
||||
Can be helpful to disable in production deployments to keep the logs clean.
|
||||
:param duplicate_documents:Handle duplicates document based on parameter options.
|
||||
Parameter options : ( 'skip','overwrite','fail')
|
||||
skip: Ignore the duplicates documents
|
||||
overwrite: Update any existing documents with the same ID when adding documents.
|
||||
fail: an error is raised if the document ID of the document being added already exists.
|
||||
"""
|
||||
|
||||
# save init parameters to enable export of component config as YAML
|
||||
self.set_config(
|
||||
host=host, port=port, timeout_config=timeout_config, username=username, password=password,
|
||||
index=index, embedding_dim=embedding_dim, text_field=text_field, name_field=name_field,
|
||||
faq_question_field=faq_question_field, similarity=similarity, index_type=index_type,
|
||||
custom_schema=custom_schema,return_embedding=return_embedding, embedding_field=embedding_field,
|
||||
progress_bar=progress_bar, duplicate_documents=duplicate_documents
|
||||
)
|
||||
|
||||
# Connect to Weaviate server using python binding
|
||||
weaviate_url =f"{host}:{port}"
|
||||
if username and password:
|
||||
secret = AuthClientPassword(username, password)
|
||||
self.weaviate_client = client.Client(url=weaviate_url,
|
||||
auth_client_secret=secret,
|
||||
timeout_config=timeout_config)
|
||||
else:
|
||||
self.weaviate_client = client.Client(url=weaviate_url,
|
||||
timeout_config=timeout_config)
|
||||
|
||||
# Test Weaviate connection
|
||||
try:
|
||||
status = self.weaviate_client.is_ready()
|
||||
if not status:
|
||||
raise ConnectionError(
|
||||
f"Initial connection to Weaviate failed. Make sure you run Weaviate instance "
|
||||
f"at `{weaviate_url}` and that it has finished the initial ramp up (can take > 30s)."
|
||||
)
|
||||
except Exception:
|
||||
raise ConnectionError(
|
||||
f"Initial connection to Weaviate failed. Make sure you run Weaviate instance "
|
||||
f"at `{weaviate_url}` and that it has finished the initial ramp up (can take > 30s)."
|
||||
)
|
||||
self.index = index
|
||||
self.embedding_dim = embedding_dim
|
||||
self.text_field = text_field
|
||||
self.name_field = name_field
|
||||
self.faq_question_field = faq_question_field
|
||||
self.similarity = similarity
|
||||
self.index_type = index_type
|
||||
self.custom_schema = custom_schema
|
||||
self.return_embedding = return_embedding
|
||||
self.embedding_field = embedding_field
|
||||
self.progress_bar = progress_bar
|
||||
self.duplicate_documents = duplicate_documents
|
||||
|
||||
self._create_schema_and_index_if_not_exist(self.index)
|
||||
|
||||
def _create_schema_and_index_if_not_exist(
|
||||
self,
|
||||
index: Optional[str] = None,
|
||||
):
|
||||
"""Create a new index (schema/class in Weaviate) for storing documents in case if an index (schema) with the name doesn't exist already."""
|
||||
index = index or self.index
|
||||
|
||||
if self.custom_schema:
|
||||
schema = self.custom_schema
|
||||
else:
|
||||
schema = {
|
||||
"classes": [
|
||||
{
|
||||
"class": index,
|
||||
"description": "Haystack index, it's a class in Weaviate",
|
||||
"invertedIndexConfig": {
|
||||
"cleanupIntervalSeconds": 60
|
||||
},
|
||||
"vectorizer": "none",
|
||||
"properties": [
|
||||
{
|
||||
"dataType": [
|
||||
"string"
|
||||
],
|
||||
"description": "Name Field",
|
||||
"name": self.name_field
|
||||
},
|
||||
{
|
||||
"dataType": [
|
||||
"string"
|
||||
],
|
||||
"description": "Question Field",
|
||||
"name": self.faq_question_field
|
||||
},
|
||||
{
|
||||
"dataType": [
|
||||
"text"
|
||||
],
|
||||
"description": "Document Text",
|
||||
"name": self.text_field
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
}
|
||||
if not self.weaviate_client.schema.contains(schema):
|
||||
self.weaviate_client.schema.create(schema)
|
||||
|
||||
def _convert_weaviate_result_to_document(
|
||||
self,
|
||||
result: dict,
|
||||
return_embedding: bool
|
||||
) -> Document:
|
||||
"""
|
||||
Convert weaviate result dict into haystack document object. This is more involved because
|
||||
weaviate search result dict varies between get and query interfaces.
|
||||
Weaviate get methods return the data items in properties key, whereas the query doesn't.
|
||||
"""
|
||||
score = None
|
||||
probability = None
|
||||
text = ""
|
||||
question = None
|
||||
|
||||
id = result.get("id")
|
||||
embedding = result.get("vector")
|
||||
|
||||
# If properties key is present, get all the document fields from it.
|
||||
# otherwise, a direct lookup in result root dict
|
||||
props = result.get("properties")
|
||||
if not props:
|
||||
props = result
|
||||
|
||||
if props.get(self.text_field) is not None:
|
||||
text = str(props.get(self.text_field))
|
||||
|
||||
if props.get(self.faq_question_field) is not None:
|
||||
question = props.get(self.faq_question_field)
|
||||
|
||||
# Weaviate creates "_additional" key for semantic search
|
||||
if "_additional" in props:
|
||||
if "certainty" in props["_additional"]:
|
||||
score = props["_additional"]['certainty']
|
||||
probability = score
|
||||
if "id" in props["_additional"]:
|
||||
id = props["_additional"]['id']
|
||||
if "vector" in props["_additional"]:
|
||||
embedding = props["_additional"]['vector']
|
||||
props.pop("_additional", None)
|
||||
|
||||
# We put all additional data of the doc into meta_data and return it in the API
|
||||
meta_data = {k:v for k,v in props.items() if k not in (self.text_field, self.faq_question_field, self.embedding_field)}
|
||||
|
||||
if return_embedding and embedding:
|
||||
embedding = np.asarray(embedding, dtype=np.float32)
|
||||
|
||||
document = Document(
|
||||
id=id,
|
||||
text=text,
|
||||
meta=meta_data,
|
||||
score=score,
|
||||
probability=probability,
|
||||
question=question,
|
||||
embedding=embedding,
|
||||
)
|
||||
return document
|
||||
|
||||
def _create_document_field_map(self) -> Dict:
|
||||
return {
|
||||
self.text_field: "text",
|
||||
self.embedding_field: "embedding",
|
||||
self.faq_question_field if self.faq_question_field else "question": "question"
|
||||
}
|
||||
|
||||
def get_document_by_id(self, id: str, index: Optional[str] = None) -> Optional[Document]:
|
||||
"""Fetch a document by specifying its text id string"""
|
||||
# Sample result dict from a get method
|
||||
'''{'class': 'Document',
|
||||
'creationTimeUnix': 1621075584724,
|
||||
'id': '1bad51b7-bd77-485d-8871-21c50fab248f',
|
||||
'properties': {'meta': "{'key1':'value1'}",
|
||||
'name': 'name_5',
|
||||
'text': 'text_5'},
|
||||
'vector': []}'''
|
||||
index = index or self.index
|
||||
document = None
|
||||
result = self.weaviate_client.data_object.get_by_id(id, with_vector=True)
|
||||
if result:
|
||||
document = self._convert_weaviate_result_to_document(result, return_embedding=True)
|
||||
return document
|
||||
|
||||
def get_documents_by_id(self, ids: List[str], index: Optional[str] = None,
|
||||
batch_size: int = 10_000) -> List[Document]:
|
||||
"""Fetch documents by specifying a list of text id strings"""
|
||||
index = index or self.index
|
||||
documents = []
|
||||
#TODO: better implementation with multiple where filters instead of chatty call below?
|
||||
for id in ids:
|
||||
result = self.weaviate_client.data_object.get_by_id(id, with_vector=True)
|
||||
if result:
|
||||
document = self._convert_weaviate_result_to_document(result, return_embedding=True)
|
||||
documents.append(document)
|
||||
return documents
|
||||
|
||||
def _get_current_properties(self, index: Optional[str] = None) -> List[str]:
|
||||
"""Get all the existing properties in the schema"""
|
||||
index = index or self.index
|
||||
cur_properties = []
|
||||
for class_item in self.weaviate_client.schema.get()['classes']:
|
||||
if class_item['class'] == index:
|
||||
cur_properties = [item['name'] for item in class_item['properties']]
|
||||
|
||||
return cur_properties
|
||||
|
||||
def _build_filter_clause(self, filters:Dict[str, List[str]]) -> dict:
|
||||
"""Transform Haystack filter conditions to Weaviate where filter clauses"""
|
||||
weaviate_filters = []
|
||||
weaviate_filter = {}
|
||||
for key, values in filters.items():
|
||||
for value in values:
|
||||
weaviate_filter = {
|
||||
"path": [key],
|
||||
"operator": "Equal",
|
||||
"valueString": value
|
||||
}
|
||||
weaviate_filters.append(weaviate_filter)
|
||||
if len(weaviate_filters) > 1:
|
||||
filter_dict = {
|
||||
"operator": "Or",
|
||||
"operands": weaviate_filters
|
||||
}
|
||||
return filter_dict
|
||||
else:
|
||||
return weaviate_filter
|
||||
|
||||
def _update_schema(self, new_prop:str, index: Optional[str] = None):
|
||||
"""Updates the schema with a new property"""
|
||||
index = index or self.index
|
||||
property_dict = {
|
||||
"dataType": [
|
||||
"string"
|
||||
],
|
||||
"description": f"dynamic property {new_prop}",
|
||||
"name": new_prop
|
||||
}
|
||||
self.weaviate_client.schema.property.create(index, property_dict)
|
||||
|
||||
def _check_document(self, cur_props: List[str], doc: dict) -> List[str]:
|
||||
"""Find the properties in the document that don't exist in the existing schema"""
|
||||
return [item for item in doc.keys() if item not in cur_props]
|
||||
|
||||
def write_documents(
|
||||
self, documents: Union[List[dict], List[Document]], index: Optional[str] = None,
|
||||
batch_size: int = 10_000, duplicate_documents: Optional[str] = None):
|
||||
"""
|
||||
Add new documents to the DocumentStore.
|
||||
|
||||
:param documents: List of `Dicts` or List of `Documents`. Passing an Embedding/Vector is mandatory in case weaviate is not
|
||||
configured with a module. If a module is configured, the embedding is automatically generated by Weaviate.
|
||||
:param index: index name for storing the docs and metadata
|
||||
:param batch_size: When working with large number of documents, batching can help reduce memory footprint.
|
||||
:param duplicate_documents: Handle duplicates document based on parameter options.
|
||||
Parameter options : ( 'skip','overwrite','fail')
|
||||
skip: Ignore the duplicates documents
|
||||
overwrite: Update any existing documents with the same ID when adding documents.
|
||||
fail: an error is raised if the document ID of the document being added already
|
||||
exists.
|
||||
:raises DuplicateDocumentError: Exception trigger on duplicate document
|
||||
:return: None
|
||||
"""
|
||||
index = index or self.index
|
||||
self._create_schema_and_index_if_not_exist(index)
|
||||
field_map = self._create_document_field_map()
|
||||
|
||||
duplicate_documents = duplicate_documents or self.duplicate_documents
|
||||
assert duplicate_documents in self.duplicate_documents_options, \
|
||||
f"duplicate_documents parameter must be {', '.join(self.duplicate_documents_options)}"
|
||||
|
||||
if len(documents) == 0:
|
||||
logger.warning("Calling DocumentStore.write_documents() with empty list")
|
||||
return
|
||||
|
||||
# Auto schema feature https://github.com/semi-technologies/weaviate/issues/1539
|
||||
# Get and cache current properties in the schema
|
||||
current_properties = self._get_current_properties(index)
|
||||
|
||||
document_objects = [Document.from_dict(d, field_map=field_map) if isinstance(d, dict) else d for d in documents]
|
||||
document_objects = self._handle_duplicate_documents(document_objects, duplicate_documents)
|
||||
|
||||
batched_documents = get_batches_from_generator(document_objects, batch_size)
|
||||
with tqdm(total=len(document_objects), disable=not self.progress_bar) as progress_bar:
|
||||
for document_batch in batched_documents:
|
||||
docs_batch = ObjectsBatchRequest()
|
||||
for idx, doc in enumerate(document_batch):
|
||||
_doc = {
|
||||
**doc.to_dict(field_map=self._create_document_field_map())
|
||||
}
|
||||
_ = _doc.pop("score", None)
|
||||
_ = _doc.pop("probability", None)
|
||||
|
||||
# In order to have a flat structure in elastic + similar behaviour to the other DocumentStores,
|
||||
# we "unnest" all value within "meta"
|
||||
if "meta" in _doc.keys():
|
||||
for k, v in _doc["meta"].items():
|
||||
_doc[k] = v
|
||||
_doc.pop("meta")
|
||||
|
||||
doc_id = str(_doc.pop("id"))
|
||||
vector = _doc.pop(self.embedding_field)
|
||||
if _doc.get(self.faq_question_field) is None:
|
||||
_doc.pop(self.faq_question_field)
|
||||
|
||||
# Check if additional properties are in the document, if so,
|
||||
# append the schema with all the additional properties
|
||||
missing_props = self._check_document(current_properties, _doc)
|
||||
if missing_props:
|
||||
for property in missing_props:
|
||||
self._update_schema(property, index)
|
||||
current_properties.append(property)
|
||||
|
||||
docs_batch.add(_doc, class_name=index, uuid=doc_id, vector=vector)
|
||||
|
||||
# Ingest a batch of documents
|
||||
results = self.weaviate_client.batch.create(docs_batch)
|
||||
# Weaviate returns errors for every failed document in the batch
|
||||
if results is not None:
|
||||
for result in results:
|
||||
if 'result' in result and 'errors' in result['result'] \
|
||||
and 'error' in result['result']['errors']:
|
||||
for message in result['result']['errors']['error']:
|
||||
logger.error(f"{message['message']}")
|
||||
progress_bar.update(batch_size)
|
||||
progress_bar.close()
|
||||
|
||||
def update_document_meta(self, id: str, meta: Dict[str, str]):
|
||||
"""
|
||||
Update the metadata dictionary of a document by specifying its string id
|
||||
"""
|
||||
self.weaviate_client.data_object.update(meta, class_name=self.index, uuid=id)
|
||||
|
||||
def get_document_count(self, filters: Optional[Dict[str, List[str]]] = None, index: Optional[str] = None) -> int:
|
||||
"""
|
||||
Return the number of documents in the document store.
|
||||
"""
|
||||
index = index or self.index
|
||||
doc_count = 0
|
||||
if filters:
|
||||
filter_dict = self._build_filter_clause(filters=filters)
|
||||
result = self.weaviate_client.query.aggregate(index) \
|
||||
.with_fields("meta { count }") \
|
||||
.with_where(filter_dict)\
|
||||
.do()
|
||||
else:
|
||||
result = self.weaviate_client.query.aggregate(index)\
|
||||
.with_fields("meta { count }")\
|
||||
.do()
|
||||
|
||||
if "data" in result:
|
||||
if "Aggregate" in result.get('data'):
|
||||
doc_count = result.get('data').get('Aggregate').get(index)[0]['meta']['count']
|
||||
|
||||
return doc_count
|
||||
|
||||
def get_all_documents(
|
||||
self,
|
||||
index: Optional[str] = None,
|
||||
filters: Optional[Dict[str, List[str]]] = None,
|
||||
return_embedding: Optional[bool] = None,
|
||||
batch_size: int = 10_000,
|
||||
) -> List[Document]:
|
||||
"""
|
||||
Get documents from the document store.
|
||||
|
||||
:param index: Name of the index to get the documents from. If None, the
|
||||
DocumentStore's default index (self.index) will be used.
|
||||
:param filters: Optional filters to narrow down the documents to return.
|
||||
Example: {"name": ["some", "more"], "category": ["only_one"]}
|
||||
:param return_embedding: Whether to return the document embeddings.
|
||||
:param batch_size: When working with large number of documents, batching can help reduce memory footprint.
|
||||
"""
|
||||
index = index or self.index
|
||||
result = self.get_all_documents_generator(
|
||||
index=index, filters=filters, return_embedding=return_embedding, batch_size=batch_size
|
||||
)
|
||||
documents = list(result)
|
||||
return documents
|
||||
|
||||
def _get_all_documents_in_index(
|
||||
self,
|
||||
index: Optional[str],
|
||||
filters: Optional[Dict[str, List[str]]] = None,
|
||||
batch_size: int = 10_000,
|
||||
only_documents_without_embedding: bool = False,
|
||||
) -> Generator[dict, None, None]:
|
||||
"""
|
||||
Return all documents in a specific index in the document store
|
||||
"""
|
||||
index = index or self.index
|
||||
|
||||
# Build the properties to retrieve from Weaviate
|
||||
properties = self._get_current_properties(index)
|
||||
properties.append("_additional {id, certainty, vector}")
|
||||
|
||||
if filters:
|
||||
filter_dict = self._build_filter_clause(filters=filters)
|
||||
result = self.weaviate_client.query.get(class_name=index, properties=properties)\
|
||||
.with_where(filter_dict)\
|
||||
.do()
|
||||
else:
|
||||
result = self.weaviate_client.query.get(class_name=index, properties=properties)\
|
||||
.do()
|
||||
|
||||
all_docs = {}
|
||||
if result and "data" in result and "Get" in result.get("data"):
|
||||
if result.get("data").get("Get").get(index):
|
||||
all_docs = result.get("data").get("Get").get(index)
|
||||
|
||||
yield from all_docs
|
||||
|
||||
def get_all_documents_generator(
|
||||
self,
|
||||
index: Optional[str] = None,
|
||||
filters: Optional[Dict[str, List[str]]] = None,
|
||||
return_embedding: Optional[bool] = None,
|
||||
batch_size: int = 10_000,
|
||||
) -> Generator[Document, None, None]:
|
||||
"""
|
||||
Get documents from the document store. Under-the-hood, documents are fetched in batches from the
|
||||
document store and yielded as individual documents. This method can be used to iteratively process
|
||||
a large number of documents without having to load all documents in memory.
|
||||
|
||||
:param index: Name of the index to get the documents from. If None, the
|
||||
DocumentStore's default index (self.index) will be used.
|
||||
:param filters: Optional filters to narrow down the documents to return.
|
||||
Example: {"name": ["some", "more"], "category": ["only_one"]}
|
||||
:param return_embedding: Whether to return the document embeddings.
|
||||
:param batch_size: When working with large number of documents, batching can help reduce memory footprint.
|
||||
"""
|
||||
|
||||
if index is None:
|
||||
index = self.index
|
||||
|
||||
if return_embedding is None:
|
||||
return_embedding = self.return_embedding
|
||||
|
||||
results = self._get_all_documents_in_index(index=index, filters=filters, batch_size=batch_size)
|
||||
for result in results:
|
||||
document = self._convert_weaviate_result_to_document(result, return_embedding=return_embedding)
|
||||
yield document
|
||||
|
||||
def query(
|
||||
self,
|
||||
query: Optional[str] = None,
|
||||
filters: Optional[Dict[str, List[str]]] = None,
|
||||
top_k: int = 10,
|
||||
custom_query: Optional[str] = None,
|
||||
index: Optional[str] = None,
|
||||
) -> List[Document]:
|
||||
"""
|
||||
Scan through documents in DocumentStore and return a small number documents
|
||||
that are most relevant to the query as defined by Weaviate semantic search.
|
||||
|
||||
:param query: The query
|
||||
:param filters: A dictionary where the keys specify a metadata field and the value is a list of accepted values for that field
|
||||
:param top_k: How many documents to return per query.
|
||||
:param custom_query: Custom query that will executed using query.raw method, for more details refer
|
||||
https://www.semi.technology/developers/weaviate/current/graphql-references/filters.html
|
||||
:param index: The name of the index in the DocumentStore from which to retrieve documents
|
||||
"""
|
||||
index = index or self.index
|
||||
|
||||
# Build the properties to retrieve from Weaviate
|
||||
properties = self._get_current_properties(index)
|
||||
properties.append("_additional {id, certainty, vector}")
|
||||
|
||||
if custom_query:
|
||||
query_output = self.weaviate_client.query.raw(custom_query)
|
||||
elif filters:
|
||||
filter_dict = self._build_filter_clause(filters)
|
||||
query_output = self.weaviate_client.query\
|
||||
.get(class_name=index, properties=properties)\
|
||||
.with_where(filter_dict)\
|
||||
.with_limit(top_k)\
|
||||
.do()
|
||||
else:
|
||||
raise NotImplementedError("Weaviate does not support inverted index text query. However, "
|
||||
"it allows to search by filters example : {'text': 'some text'} or "
|
||||
"use a custom GraphQL query in text format!")
|
||||
|
||||
results = []
|
||||
if query_output and "data" in query_output and "Get" in query_output.get("data"):
|
||||
if query_output.get("data").get("Get").get(index):
|
||||
results = query_output.get("data").get("Get").get(index)
|
||||
|
||||
documents = []
|
||||
for result in results:
|
||||
doc = self._convert_weaviate_result_to_document(result, return_embedding=True)
|
||||
documents.append(doc)
|
||||
|
||||
return documents
|
||||
|
||||
def query_by_embedding(self,
|
||||
query_emb: np.ndarray,
|
||||
filters: Optional[dict] = None,
|
||||
top_k: int = 10,
|
||||
index: Optional[str] = None,
|
||||
return_embedding: Optional[bool] = None) -> List[Document]:
|
||||
"""
|
||||
Find the document that is most similar to the provided `query_emb` by using a vector similarity metric.
|
||||
|
||||
:param query_emb: Embedding of the query (e.g. gathered from DPR)
|
||||
:param filters: Optional filters to narrow down the search space.
|
||||
Example: {"name": ["some", "more"], "category": ["only_one"]}
|
||||
:param top_k: How many documents to return
|
||||
:param index: index name for storing the docs and metadata
|
||||
:param return_embedding: To return document embedding
|
||||
:return:
|
||||
"""
|
||||
if return_embedding is None:
|
||||
return_embedding = self.return_embedding
|
||||
index = index or self.index
|
||||
|
||||
# Build the properties to retrieve from Weaviate
|
||||
properties = self._get_current_properties(index)
|
||||
properties.append("_additional {id, certainty, vector}")
|
||||
|
||||
query_emb = query_emb.reshape(1, -1).astype(np.float32)
|
||||
query_string = {
|
||||
"vector" : query_emb
|
||||
}
|
||||
if filters:
|
||||
filter_dict = self._build_filter_clause(filters)
|
||||
query_output = self.weaviate_client.query\
|
||||
.get(class_name=index, properties=properties)\
|
||||
.with_where(filter_dict)\
|
||||
.with_near_vector(query_string)\
|
||||
.with_limit(top_k)\
|
||||
.do()
|
||||
else:
|
||||
query_output = self.weaviate_client.query\
|
||||
.get(class_name=index, properties=properties)\
|
||||
.with_near_vector(query_string)\
|
||||
.with_limit(top_k)\
|
||||
.do()
|
||||
|
||||
results = []
|
||||
if query_output and "data" in query_output and "Get" in query_output.get("data"):
|
||||
if query_output.get("data").get("Get").get(index):
|
||||
results = query_output.get("data").get("Get").get(index)
|
||||
|
||||
documents = []
|
||||
for result in results:
|
||||
doc = self._convert_weaviate_result_to_document(result, return_embedding=return_embedding)
|
||||
documents.append(doc)
|
||||
|
||||
return documents
|
||||
|
||||
def update_embeddings(
|
||||
self,
|
||||
retriever,
|
||||
index: Optional[str] = None,
|
||||
filters: Optional[Dict[str, List[str]]] = None,
|
||||
update_existing_embeddings: bool = True,
|
||||
batch_size: int = 10_000
|
||||
):
|
||||
"""
|
||||
Updates the embeddings in the the document store using the encoding model specified in the retriever.
|
||||
This can be useful if want to change the embeddings for your documents (e.g. after changing the retriever config).
|
||||
|
||||
:param retriever: Retriever to use to update the embeddings.
|
||||
:param index: Index name to update
|
||||
:param update_existing_embeddings: Weaviate mandates an embedding while creating the document itself.
|
||||
This option must be always true for weaviate and it will update the embeddings for all the documents.
|
||||
:param filters: Optional filters to narrow down the documents for which embeddings are to be updated.
|
||||
Example: {"name": ["some", "more"], "category": ["only_one"]}
|
||||
:param batch_size: When working with large number of documents, batching can help reduce memory footprint.
|
||||
:return: None
|
||||
"""
|
||||
if index is None:
|
||||
index = self.index
|
||||
|
||||
if not self.embedding_field:
|
||||
raise RuntimeError("Specify the arg `embedding_field` when initializing WeaviateDocumentStore()")
|
||||
|
||||
if update_existing_embeddings:
|
||||
logger.info(f"Updating embeddings for all {self.get_document_count(index=index)} docs ...")
|
||||
else:
|
||||
raise RuntimeError("All the documents in Weaviate store have an embedding by default. Only update is allowed!")
|
||||
|
||||
result = self._get_all_documents_in_index(
|
||||
index=index,
|
||||
filters=filters,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
|
||||
for result_batch in get_batches_from_generator(result, batch_size):
|
||||
document_batch = [self._convert_weaviate_result_to_document(hit, return_embedding=False) for hit in result_batch]
|
||||
embeddings = retriever.embed_passages(document_batch) # type: ignore
|
||||
assert len(document_batch) == len(embeddings)
|
||||
|
||||
if embeddings[0].shape[0] != self.embedding_dim:
|
||||
raise RuntimeError(f"Embedding dim. of model ({embeddings[0].shape[0]})"
|
||||
f" doesn't match embedding dim. in DocumentStore ({self.embedding_dim})."
|
||||
"Specify the arg `embedding_dim` when initializing WeaviateDocumentStore()")
|
||||
for doc, emb in zip(document_batch, embeddings):
|
||||
# This doc processing will not required once weaviate's update
|
||||
# method works. To be improved.
|
||||
_doc = {
|
||||
**doc.to_dict(field_map=self._create_document_field_map())
|
||||
}
|
||||
_ = _doc.pop("score", None)
|
||||
_ = _doc.pop("probability", None)
|
||||
|
||||
if "meta" in _doc.keys():
|
||||
for k, v in _doc["meta"].items():
|
||||
_doc[k] = v
|
||||
_doc.pop("meta")
|
||||
|
||||
doc_id = str(_doc.pop("id"))
|
||||
_ = _doc.pop(self.embedding_field)
|
||||
keys_to_remove = [k for k,v in _doc.items() if v is None]
|
||||
for key in keys_to_remove:
|
||||
_doc.pop(key)
|
||||
|
||||
# TODO: Weaviate's update throws an error while passing a vector now, have to improve this later
|
||||
self.weaviate_client.data_object.replace(_doc, class_name=index, uuid=doc_id, vector=emb)
|
||||
|
||||
def delete_all_documents(self, index: Optional[str] = None, filters: Optional[Dict[str, List[str]]] = None):
|
||||
"""
|
||||
Delete documents in an index. All documents are deleted if no filters are passed.
|
||||
|
||||
:param index: Index name to delete the document from.
|
||||
:param filters: Optional filters to narrow down the documents to be deleted.
|
||||
:return: None
|
||||
"""
|
||||
index = index or self.index
|
||||
if filters:
|
||||
docs_to_delete = self.get_all_documents(index, filters=filters)
|
||||
for doc in docs_to_delete:
|
||||
self.weaviate_client.data_object.delete(doc.id)
|
||||
else:
|
||||
self.weaviate_client.schema.delete_class(index)
|
||||
self._create_schema_and_index_if_not_exist(index)
|
@ -30,4 +30,5 @@ pymilvus
|
||||
#selenium
|
||||
#webdriver-manager
|
||||
SPARQLWrapper
|
||||
mmh3
|
||||
mmh3
|
||||
weaviate-client
|
@ -9,6 +9,9 @@ from elasticsearch import Elasticsearch
|
||||
from haystack.knowledge_graph.graphdb import GraphDBKnowledgeGraph
|
||||
from milvus import Milvus
|
||||
|
||||
import weaviate
|
||||
from haystack.document_store.weaviate import WeaviateDocumentStore
|
||||
|
||||
from haystack.document_store.milvus import MilvusDocumentStore
|
||||
from haystack.generator.transformers import RAGenerator, RAGeneratorType
|
||||
|
||||
@ -61,6 +64,8 @@ def pytest_collection_modifyitems(items):
|
||||
item.add_marker(pytest.mark.pipeline)
|
||||
elif "slow" in item.nodeid:
|
||||
item.add_marker(pytest.mark.slow)
|
||||
elif "weaviate" in item.nodeid:
|
||||
item.add_marker(pytest.mark.weaviate)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
@ -98,6 +103,27 @@ def milvus_fixture():
|
||||
'milvusdb/milvus:0.10.5-cpu-d010621-4eda95'], shell=True)
|
||||
time.sleep(40)
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def weaviate_fixture():
|
||||
# test if a Weaviate server is already running. If not, start Weaviate docker container locally.
|
||||
# Make sure you have given > 6GB memory to docker engine
|
||||
try:
|
||||
weaviate_server = weaviate.Client(url='http://localhost:8080', timeout_config=(5, 15))
|
||||
weaviate_server.is_ready()
|
||||
except:
|
||||
print("Starting Weaviate servers ...")
|
||||
status = subprocess.run(
|
||||
['docker rm haystack_test_weaviate'],
|
||||
shell=True
|
||||
)
|
||||
status = subprocess.run(
|
||||
['docker run -d --name haystack_test_weaviate -p 8080:8080 semitechnologies/weaviate:1.4.0'],
|
||||
shell=True
|
||||
)
|
||||
if status.returncode:
|
||||
raise Exception(
|
||||
"Failed to launch Weaviate. Please check docker container logs.")
|
||||
time.sleep(60)
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def graphdb_fixture():
|
||||
@ -310,7 +336,6 @@ def document_store_with_docs(request, test_docs_xs):
|
||||
yield document_store
|
||||
document_store.delete_all_documents()
|
||||
|
||||
|
||||
@pytest.fixture(params=["elasticsearch", "faiss", "memory", "sql", "milvus"])
|
||||
def document_store(request, test_docs_xs):
|
||||
document_store = get_document_store(request.param)
|
||||
@ -352,6 +377,14 @@ def get_document_store(document_store_type, embedding_field="embedding"):
|
||||
if collection.startswith("haystack_test"):
|
||||
document_store.milvus_server.drop_collection(collection)
|
||||
return document_store
|
||||
elif document_store_type == "weaviate":
|
||||
document_store = WeaviateDocumentStore(
|
||||
weaviate_url="http://localhost:8080",
|
||||
index="Haystacktest"
|
||||
)
|
||||
document_store.weaviate_client.schema.delete_all()
|
||||
document_store._create_schema_and_index_if_not_exist()
|
||||
return document_store
|
||||
else:
|
||||
raise Exception(f"No document store fixture for '{document_store_type}'")
|
||||
|
||||
|
@ -8,3 +8,4 @@ markers =
|
||||
generator: marks generator tests (deselect with '-m "not generator"')
|
||||
pipeline: marks tests with pipeline
|
||||
summarizer: marks summarizer tests
|
||||
weaviate: marks tests that require weaviate container
|
||||
|
330
test/test_weaviate.py
Normal file
330
test/test_weaviate.py
Normal file
@ -0,0 +1,330 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
from haystack import Document
|
||||
from conftest import get_document_store
|
||||
import uuid
|
||||
|
||||
embedding_dim = 768
|
||||
|
||||
def get_uuid():
|
||||
return str(uuid.uuid4())
|
||||
|
||||
DOCUMENTS = [
|
||||
{"text": "text1", "id":get_uuid(), "key": "a", "embedding": np.random.rand(embedding_dim).astype(np.float32)},
|
||||
{"text": "text2", "id":get_uuid(), "key": "b", "embedding": np.random.rand(embedding_dim).astype(np.float32)},
|
||||
{"text": "text3", "id":get_uuid(), "key": "b", "embedding": np.random.rand(embedding_dim).astype(np.float32)},
|
||||
{"text": "text4", "id":get_uuid(), "key": "b", "embedding": np.random.rand(embedding_dim).astype(np.float32)},
|
||||
{"text": "text5", "id":get_uuid(), "key": "b", "embedding": np.random.rand(embedding_dim).astype(np.float32)},
|
||||
]
|
||||
|
||||
DOCUMENTS_XS = [
|
||||
# current "dict" format for a document
|
||||
{"text": "My name is Carla and I live in Berlin", "id":get_uuid(), "meta": {"metafield": "test1", "name": "filename1"}, "embedding": np.random.rand(embedding_dim).astype(np.float32)},
|
||||
# meta_field at the top level for backward compatibility
|
||||
{"text": "My name is Paul and I live in New York", "id":get_uuid(), "metafield": "test2", "name": "filename2", "embedding": np.random.rand(embedding_dim).astype(np.float32)},
|
||||
# Document object for a doc
|
||||
Document(text="My name is Christelle and I live in Paris", id=get_uuid(), meta={"metafield": "test3", "name": "filename3"}, embedding=np.random.rand(embedding_dim).astype(np.float32))
|
||||
]
|
||||
|
||||
@pytest.fixture(params=["weaviate"])
|
||||
def document_store_with_docs(request):
|
||||
document_store = get_document_store(request.param)
|
||||
document_store.write_documents(DOCUMENTS_XS)
|
||||
yield document_store
|
||||
document_store.delete_all_documents()
|
||||
|
||||
@pytest.fixture(params=["weaviate"])
|
||||
def document_store(request):
|
||||
document_store = get_document_store(request.param)
|
||||
yield document_store
|
||||
document_store.delete_all_documents()
|
||||
|
||||
@pytest.mark.weaviate
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["weaviate"], indirect=True)
|
||||
def test_get_all_documents_without_filters(document_store_with_docs):
|
||||
documents = document_store_with_docs.get_all_documents()
|
||||
assert all(isinstance(d, Document) for d in documents)
|
||||
assert len(documents) == 3
|
||||
assert {d.meta["name"] for d in documents} == {"filename1", "filename2", "filename3"}
|
||||
assert {d.meta["metafield"] for d in documents} == {"test1", "test2", "test3"}
|
||||
|
||||
@pytest.mark.weaviate
|
||||
def test_get_all_documents_with_correct_filters(document_store_with_docs):
|
||||
documents = document_store_with_docs.get_all_documents(filters={"metafield": ["test2"]})
|
||||
assert len(documents) == 1
|
||||
assert documents[0].meta["name"] == "filename2"
|
||||
|
||||
documents = document_store_with_docs.get_all_documents(filters={"metafield": ["test1", "test3"]})
|
||||
assert len(documents) == 2
|
||||
assert {d.meta["name"] for d in documents} == {"filename1", "filename3"}
|
||||
assert {d.meta["metafield"] for d in documents} == {"test1", "test3"}
|
||||
|
||||
@pytest.mark.weaviate
|
||||
def test_get_all_documents_with_incorrect_filter_name(document_store_with_docs):
|
||||
documents = document_store_with_docs.get_all_documents(filters={"incorrectmetafield": ["test2"]})
|
||||
assert len(documents) == 0
|
||||
|
||||
@pytest.mark.weaviate
|
||||
def test_get_all_documents_with_incorrect_filter_value(document_store_with_docs):
|
||||
documents = document_store_with_docs.get_all_documents(filters={"metafield": ["incorrect_value"]})
|
||||
assert len(documents) == 0
|
||||
|
||||
@pytest.mark.weaviate
|
||||
def test_get_documents_by_id(document_store_with_docs):
|
||||
documents = document_store_with_docs.get_all_documents()
|
||||
doc = document_store_with_docs.get_document_by_id(documents[0].id)
|
||||
assert doc.id == documents[0].id
|
||||
assert doc.text == documents[0].text
|
||||
|
||||
@pytest.mark.weaviate
|
||||
@pytest.mark.parametrize("document_store", ["weaviate"], indirect=True)
|
||||
def test_get_document_count(document_store):
|
||||
document_store.write_documents(DOCUMENTS)
|
||||
assert document_store.get_document_count() == 5
|
||||
assert document_store.get_document_count(filters={"key": ["a"]}) == 1
|
||||
assert document_store.get_document_count(filters={"key": ["b"]}) == 4
|
||||
|
||||
@pytest.mark.weaviate
|
||||
@pytest.mark.parametrize("document_store", ["weaviate"], indirect=True)
|
||||
@pytest.mark.parametrize("batch_size", [2])
|
||||
def test_weaviate_write_docs(document_store, batch_size):
|
||||
# Write in small batches
|
||||
for i in range(0, len(DOCUMENTS), batch_size):
|
||||
document_store.write_documents(DOCUMENTS[i: i + batch_size])
|
||||
|
||||
documents_indexed = document_store.get_all_documents()
|
||||
assert len(documents_indexed) == len(DOCUMENTS)
|
||||
|
||||
documents_indexed = document_store.get_all_documents(batch_size=batch_size)
|
||||
assert len(documents_indexed) == len(DOCUMENTS)
|
||||
|
||||
@pytest.mark.weaviate
|
||||
@pytest.mark.parametrize("document_store", ["weaviate"], indirect=True)
|
||||
def test_get_all_document_filter_duplicate_value(document_store):
|
||||
documents = [
|
||||
Document(
|
||||
text="Doc1",
|
||||
meta={"fone": "f0"},
|
||||
id = get_uuid(),
|
||||
embedding= np.random.rand(embedding_dim).astype(np.float32)
|
||||
),
|
||||
Document(
|
||||
text="Doc1",
|
||||
meta={"fone": "f1", "metaid": "0"},
|
||||
id = get_uuid(),
|
||||
embedding = np.random.rand(embedding_dim).astype(np.float32)
|
||||
),
|
||||
Document(
|
||||
text="Doc2",
|
||||
meta={"fthree": "f0"},
|
||||
id = get_uuid(),
|
||||
embedding=np.random.rand(embedding_dim).astype(np.float32)
|
||||
)
|
||||
]
|
||||
document_store.write_documents(documents)
|
||||
documents = document_store.get_all_documents(filters={"fone": ["f1"]})
|
||||
assert documents[0].text == "Doc1"
|
||||
assert len(documents) == 1
|
||||
assert {d.meta["metaid"] for d in documents} == {"0"}
|
||||
|
||||
@pytest.mark.weaviate
|
||||
@pytest.mark.parametrize("document_store", ["weaviate"], indirect=True)
|
||||
def test_get_all_documents_generator(document_store):
|
||||
document_store.write_documents(DOCUMENTS)
|
||||
assert len(list(document_store.get_all_documents_generator(batch_size=2))) == 5
|
||||
|
||||
@pytest.mark.weaviate
|
||||
@pytest.mark.parametrize("document_store", ["weaviate"], indirect=True)
|
||||
def test_write_with_duplicate_doc_ids(document_store):
|
||||
id = get_uuid()
|
||||
documents = [
|
||||
Document(
|
||||
text="Doc1",
|
||||
id=id,
|
||||
embedding=np.random.rand(embedding_dim).astype(np.float32)
|
||||
),
|
||||
Document(
|
||||
text="Doc2",
|
||||
id=id,
|
||||
embedding=np.random.rand(embedding_dim).astype(np.float32)
|
||||
)
|
||||
]
|
||||
document_store.write_documents(documents, duplicate_documents="skip")
|
||||
with pytest.raises(Exception):
|
||||
document_store.write_documents(documents, duplicate_documents="fail")
|
||||
|
||||
@pytest.mark.weaviate
|
||||
@pytest.mark.parametrize("document_store", ["weaviate"], indirect=True)
|
||||
@pytest.mark.parametrize("update_existing_documents", [True, False])
|
||||
def test_update_existing_documents(document_store, update_existing_documents):
|
||||
id = uuid.uuid4()
|
||||
original_docs = [
|
||||
{"text": "text1_orig", "id": id, "metafieldforcount": "a", "embedding": np.random.rand(embedding_dim).astype(np.float32)},
|
||||
]
|
||||
|
||||
updated_docs = [
|
||||
{"text": "text1_new", "id": id, "metafieldforcount": "a", "embedding": np.random.rand(embedding_dim).astype(np.float32)},
|
||||
]
|
||||
|
||||
document_store.update_existing_documents = update_existing_documents
|
||||
document_store.write_documents(original_docs)
|
||||
assert document_store.get_document_count() == 1
|
||||
|
||||
if update_existing_documents:
|
||||
document_store.write_documents(updated_docs, duplicate_documents="overwrite")
|
||||
else:
|
||||
with pytest.raises(Exception):
|
||||
document_store.write_documents(updated_docs, duplicate_documents="fail")
|
||||
|
||||
stored_docs = document_store.get_all_documents()
|
||||
assert len(stored_docs) == 1
|
||||
if update_existing_documents:
|
||||
assert stored_docs[0].text == updated_docs[0]["text"]
|
||||
else:
|
||||
assert stored_docs[0].text == original_docs[0]["text"]
|
||||
|
||||
@pytest.mark.weaviate
|
||||
@pytest.mark.parametrize("document_store", ["weaviate"], indirect=True)
|
||||
def test_write_document_meta(document_store):
|
||||
uid1 = get_uuid()
|
||||
uid2 = get_uuid()
|
||||
uid3 = get_uuid()
|
||||
uid4 = get_uuid()
|
||||
documents = [
|
||||
{"text": "dict_without_meta", "id": uid1, "embedding": np.random.rand(embedding_dim).astype(np.float32)},
|
||||
{"text": "dict_with_meta", "metafield": "test2", "name": "filename2", "id": uid2, "embedding": np.random.rand(embedding_dim).astype(np.float32)},
|
||||
Document(text="document_object_without_meta", id=uid3, embedding=np.random.rand(embedding_dim).astype(np.float32)),
|
||||
Document(text="document_object_with_meta", meta={"metafield": "test4", "name": "filename3"}, id=uid4, embedding=np.random.rand(embedding_dim).astype(np.float32)),
|
||||
]
|
||||
document_store.write_documents(documents)
|
||||
documents_in_store = document_store.get_all_documents()
|
||||
assert len(documents_in_store) == 4
|
||||
|
||||
assert not document_store.get_document_by_id(uid1).meta
|
||||
assert document_store.get_document_by_id(uid2).meta["metafield"] == "test2"
|
||||
assert not document_store.get_document_by_id(uid3).meta
|
||||
assert document_store.get_document_by_id(uid4).meta["metafield"] == "test4"
|
||||
|
||||
@pytest.mark.weaviate
|
||||
@pytest.mark.parametrize("document_store", ["weaviate"], indirect=True)
|
||||
def test_write_document_index(document_store):
|
||||
documents = [
|
||||
{"text": "text1", "id": uuid.uuid4(), "embedding": np.random.rand(embedding_dim).astype(np.float32)},
|
||||
{"text": "text2", "id": uuid.uuid4(), "embedding": np.random.rand(embedding_dim).astype(np.float32)},
|
||||
]
|
||||
|
||||
document_store.write_documents([documents[0]], index="Haystackone")
|
||||
assert len(document_store.get_all_documents(index="Haystackone")) == 1
|
||||
|
||||
document_store.write_documents([documents[1]], index="Haystacktwo")
|
||||
assert len(document_store.get_all_documents(index="Haystacktwo")) == 1
|
||||
|
||||
assert len(document_store.get_all_documents(index="Haystackone")) == 1
|
||||
assert len(document_store.get_all_documents()) == 0
|
||||
|
||||
@pytest.mark.weaviate
|
||||
@pytest.mark.parametrize("retriever", ["dpr", "embedding"], indirect=True)
|
||||
@pytest.mark.parametrize("document_store", ["weaviate"], indirect=True)
|
||||
def test_update_embeddings(document_store, retriever):
|
||||
documents = []
|
||||
for i in range(6):
|
||||
documents.append({"text": f"text_{i}", "id": str(uuid.uuid4()), "metafield": f"value_{i}", "embedding": np.random.rand(embedding_dim).astype(np.float32)})
|
||||
documents.append({"text": "text_0", "id": str(uuid.uuid4()), "metafield": "value_0", "embedding": np.random.rand(embedding_dim).astype(np.float32)})
|
||||
|
||||
document_store.write_documents(documents, index="HaystackTestOne")
|
||||
document_store.update_embeddings(retriever, index="HaystackTestOne", batch_size=3)
|
||||
documents = document_store.get_all_documents(index="HaystackTestOne", return_embedding=True)
|
||||
assert len(documents) == 7
|
||||
for doc in documents:
|
||||
assert type(doc.embedding) is np.ndarray
|
||||
|
||||
documents = document_store.get_all_documents(
|
||||
index="HaystackTestOne",
|
||||
filters={"metafield": ["value_0"]},
|
||||
return_embedding=True,
|
||||
)
|
||||
assert len(documents) == 2
|
||||
for doc in documents:
|
||||
assert doc.meta["metafield"] == "value_0"
|
||||
np.testing.assert_array_almost_equal(documents[0].embedding, documents[1].embedding, decimal=4)
|
||||
|
||||
documents = document_store.get_all_documents(
|
||||
index="HaystackTestOne",
|
||||
filters={"metafield": ["value_1", "value_5"]},
|
||||
return_embedding=True,
|
||||
)
|
||||
np.testing.assert_raises(
|
||||
AssertionError,
|
||||
np.testing.assert_array_equal,
|
||||
documents[0].embedding,
|
||||
documents[1].embedding
|
||||
)
|
||||
|
||||
doc = {"text": "text_7", "id": str(uuid.uuid4()), "metafield": "value_7",
|
||||
"embedding": retriever.embed_queries(texts=["a random string"])[0]}
|
||||
document_store.write_documents([doc], index="HaystackTestOne")
|
||||
|
||||
doc_before_update = document_store.get_all_documents(index="HaystackTestOne", filters={"metafield": ["value_7"]})[0]
|
||||
embedding_before_update = doc_before_update.embedding
|
||||
|
||||
document_store.update_embeddings(
|
||||
retriever, index="HaystackTestOne", batch_size=3, filters={"metafield": ["value_0", "value_1"]}
|
||||
)
|
||||
doc_after_update = document_store.get_all_documents(index="HaystackTestOne", filters={"metafield": ["value_7"]})[0]
|
||||
embedding_after_update = doc_after_update.embedding
|
||||
np.testing.assert_array_equal(embedding_before_update, embedding_after_update)
|
||||
|
||||
# test update all embeddings
|
||||
document_store.update_embeddings(retriever, index="HaystackTestOne", batch_size=3, update_existing_embeddings=True)
|
||||
assert document_store.get_document_count(index="HaystackTestOne") == 8
|
||||
doc_after_update = document_store.get_all_documents(index="HaystackTestOne", filters={"metafield": ["value_7"]})[0]
|
||||
embedding_after_update = doc_after_update.embedding
|
||||
np.testing.assert_raises(AssertionError, np.testing.assert_array_equal, embedding_before_update, embedding_after_update)
|
||||
|
||||
@pytest.mark.weaviate
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["weaviate"], indirect=True)
|
||||
def test_query_by_embedding(document_store_with_docs):
|
||||
docs = document_store_with_docs.query_by_embedding(np.random.rand(embedding_dim).astype(np.float32))
|
||||
assert len(docs) == 3
|
||||
|
||||
docs = document_store_with_docs.query_by_embedding(np.random.rand(embedding_dim).astype(np.float32),
|
||||
top_k=1)
|
||||
assert len(docs) == 1
|
||||
|
||||
docs = document_store_with_docs.query_by_embedding(np.random.rand(embedding_dim).astype(np.float32),
|
||||
filters = {"name": ['filename2']})
|
||||
assert len(docs) == 1
|
||||
|
||||
@pytest.mark.weaviate
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["weaviate"], indirect=True)
|
||||
def test_query(document_store_with_docs):
|
||||
query_text = 'My name is Carla and I live in Berlin'
|
||||
with pytest.raises(Exception):
|
||||
docs = document_store_with_docs.query(query_text)
|
||||
|
||||
docs = document_store_with_docs.query(filters = {"name": ['filename2']})
|
||||
assert len(docs) == 1
|
||||
|
||||
docs = document_store_with_docs.query(filters={"text":[query_text.lower()]})
|
||||
assert len(docs) == 1
|
||||
|
||||
docs = document_store_with_docs.query(filters={"text":['live']})
|
||||
assert len(docs) == 3
|
||||
|
||||
@pytest.mark.weaviate
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["weaviate"], indirect=True)
|
||||
def test_delete_all_documents(document_store_with_docs):
|
||||
assert len(document_store_with_docs.get_all_documents()) == 3
|
||||
|
||||
document_store_with_docs.delete_all_documents()
|
||||
documents = document_store_with_docs.get_all_documents()
|
||||
assert len(documents) == 0
|
||||
|
||||
@pytest.mark.weaviate
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["weaviate"], indirect=True)
|
||||
def test_delete_documents_with_filters(document_store_with_docs):
|
||||
document_store_with_docs.delete_all_documents(filters={"metafield": ["test1", "test2"]})
|
||||
documents = document_store_with_docs.get_all_documents()
|
||||
assert len(documents) == 1
|
||||
assert documents[0].meta["metafield"] == "test3"
|
||||
|
Loading…
x
Reference in New Issue
Block a user