Integrate Weaviate as another DocumentStore (#1064)

* Annotation Tool: data is not persisted when using local version #853

* First version of weaviate

* First version of weaviate

* First version of weaviate

* Updated comments

* Updated comments

* ran query, get and write tests

* update embeddings, dynamic schema and filters implemented

* Initial set of tests and fixes

* Tests added for update_embeddings and delete documents

* introduced duplicate documents fix

* fixed mypy errors

* Added Weaviate to requirements

* Fix the weaviate docker env variables

* Fixing test dependencies for now

* Created weaviate test marker and fixed query

* Update docstring

* Add documentation

* Bump up weaviate version

* Bump up weaviate version in documentation

* Bump up weaviate version in documentation

* Updgrade weaviate version

Co-authored-by: Malte Pietsch <malte.pietsch@deepset.ai>
This commit is contained in:
venuraja79 2021-06-10 13:13:53 +05:30 committed by GitHub
parent db17d73a82
commit 49886f88f0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 1126 additions and 3 deletions

View File

@ -76,6 +76,9 @@ jobs:
- name: Run Milvus
run: docker run -d -p 19530:19530 -p 19121:19121 milvusdb/milvus:1.1.0-cpu-d050721-5e559c
- name: Run Weaviate
run: docker run -d -p 8080:8080 --name haystack_test_weaviate --env AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED='true' --env PERSISTENCE_DATA_PATH='/var/lib/weaviate' semitechnologies/weaviate:1.4.0
- name: Run GraphDB
run: docker run -d -p 7200:7200 --name haystack_test_graphdb deepset/graphdb-free:9.4.1-adoptopenjdk11

View File

@ -116,6 +116,27 @@ from haystack.document_store import SQLDocumentStore
document_store = SQLDocumentStore()
```
</div>
</div>
<div class="tab">
<input type="radio" id="tab-1-6" name="tab-group-1">
<label class="labelouter" for="tab-1-6">Weaviate</label>
<div class="tabcontent">
The `WeaviateDocumentStore` requires a running Weaviate Server.
You can start a basic instance like this (see Weaviate docs for details):
```
docker run -d -p 8080:8080 --env AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED='true' --env PERSISTENCE_DATA_PATH='/var/lib/weaviate' semitechnologies/weaviate:1.4.0
```
Afterwards, you can use it in Haystack:
```python
from haystack.document_store import WeaviateDocumentStore
document_store = WeaviateDocumentStore()
```
</div>
</div>
@ -264,6 +285,24 @@ The Document Stores have different characteristics. You should choose one depend
</div>
</div>
<div class="tab">
<input type="radio" id="tab-2-6" name="tab-group-2">
<label class="labelouter" for="tab-2-6">Weaviate</label>
<div class="tabcontent">
**Pros:**
- Simple vector search
- Stores everything in one place: documents, meta data and vectors - so less network overhead when scaling this up
- Allows combination of vector search and scalar filtering, i.e. you can filter for a certain tag and do dense retrieval on that subset
**Cons:**
- Less options for ANN algorithms than FAISS or Milvus
- No BM25 / Tf-idf retrieval
</div>
</div>
</div>
<div class="recommendation">
@ -276,4 +315,4 @@ The Document Stores have different characteristics. You should choose one depend
**Vector Specialist:** Use the `MilvusDocumentStore`, if you want to focus on dense retrieval and possibly deal with larger datasets
</div>
</div>

View File

@ -0,0 +1,716 @@
import logging
from typing import Any, Dict, Generator, List, Optional, Union
import numpy as np
from tqdm import tqdm
from haystack import Document
from haystack.document_store.base import BaseDocumentStore
from haystack.utils import get_batches_from_generator
from weaviate import client, auth, AuthClientPassword
from weaviate import ObjectsBatchRequest
logger = logging.getLogger(__name__)
class WeaviateDocumentStore(BaseDocumentStore):
"""
Weaviate is a cloud-native, modular, real-time vector search engine built to scale your machine learning models.
(See https://www.semi.technology/developers/weaviate/current/index.html#what-is-weaviate)
Some of the key differences in contrast to FAISS & Milvus:
1. Stores everything in one place: documents, meta data and vectors - so less network overhead when scaling this up
2. Allows combination of vector search and scalar filtering, i.e. you can filter for a certain tag and do dense retrieval on that subset
3. Has less variety of ANN algorithms, as of now only HNSW.
Weaviate python client is used to connect to the server, more details are here
https://weaviate-python-client.readthedocs.io/en/docs/weaviate.html
Usage:
1. Start a Weaviate server (see https://www.semi.technology/developers/weaviate/current/getting-started/installation.html)
2. Init a WeaviateDocumentStore in Haystack
"""
def __init__(
self,
host: Union[str, List[str]] = "http://localhost",
port: Union[int, List[int]] = 8080,
timeout_config: tuple = (5, 15),
username: str = None,
password: str = None,
index: str = "Document",
embedding_dim: int = 768,
text_field: str = "text",
name_field: str = "name",
faq_question_field = "question",
similarity: str = "dot_product",
index_type: str = "hnsw",
custom_schema: Optional[dict] = None,
return_embedding: bool = False,
embedding_field: str = "embedding",
progress_bar: bool = True,
duplicate_documents: str = 'overwrite',
**kwargs,
):
"""
:param host: Weaviate server connection URL for storing and processing documents and vectors.
For more details, refer "https://www.semi.technology/developers/weaviate/current/getting-started/installation.html"
:param port: port of Weaviate instance
:param timeout_config: Weaviate Timeout config as a tuple of (retries, time out seconds).
:param username: username (standard authentication via http_auth)
:param password: password (standard authentication via http_auth)
:param index: Index name for document text, embedding and metadata (in Weaviate terminology, this is a "Class" in Weaviate schema).
:param embedding_dim: The embedding vector size. Default: 768.
:param text_field: Name of field that might contain the answer and will therefore be passed to the Reader Model (e.g. "full_text").
If no Reader is used (e.g. in FAQ-Style QA) the plain content of this field will just be returned.
:param name_field: Name of field that contains the title of the the doc
:param faq_question_field: Name of field containing the question in case of FAQ-Style QA
:param similarity: The similarity function used to compare document vectors. 'dot_product' is the default.
:param index_type: Index type of any vector object defined in weaviate schema. The vector index type is pluggable.
Currently, HSNW is only supported.
See: https://www.semi.technology/developers/weaviate/current/more-resources/performance.html
:param custom_schema: Allows to create custom schema in Weaviate, for more details
See https://www.semi.technology/developers/weaviate/current/data-schema/schema-configuration.html
:param module_name : Vectorization module to convert data into vectors. Default is "text2vec-trasnformers"
For more details, See https://www.semi.technology/developers/weaviate/current/modules/
:param return_embedding: To return document embedding.
:param embedding_field: Name of field containing an embedding vector.
:param progress_bar: Whether to show a tqdm progress bar or not.
Can be helpful to disable in production deployments to keep the logs clean.
:param duplicate_documents:Handle duplicates document based on parameter options.
Parameter options : ( 'skip','overwrite','fail')
skip: Ignore the duplicates documents
overwrite: Update any existing documents with the same ID when adding documents.
fail: an error is raised if the document ID of the document being added already exists.
"""
# save init parameters to enable export of component config as YAML
self.set_config(
host=host, port=port, timeout_config=timeout_config, username=username, password=password,
index=index, embedding_dim=embedding_dim, text_field=text_field, name_field=name_field,
faq_question_field=faq_question_field, similarity=similarity, index_type=index_type,
custom_schema=custom_schema,return_embedding=return_embedding, embedding_field=embedding_field,
progress_bar=progress_bar, duplicate_documents=duplicate_documents
)
# Connect to Weaviate server using python binding
weaviate_url =f"{host}:{port}"
if username and password:
secret = AuthClientPassword(username, password)
self.weaviate_client = client.Client(url=weaviate_url,
auth_client_secret=secret,
timeout_config=timeout_config)
else:
self.weaviate_client = client.Client(url=weaviate_url,
timeout_config=timeout_config)
# Test Weaviate connection
try:
status = self.weaviate_client.is_ready()
if not status:
raise ConnectionError(
f"Initial connection to Weaviate failed. Make sure you run Weaviate instance "
f"at `{weaviate_url}` and that it has finished the initial ramp up (can take > 30s)."
)
except Exception:
raise ConnectionError(
f"Initial connection to Weaviate failed. Make sure you run Weaviate instance "
f"at `{weaviate_url}` and that it has finished the initial ramp up (can take > 30s)."
)
self.index = index
self.embedding_dim = embedding_dim
self.text_field = text_field
self.name_field = name_field
self.faq_question_field = faq_question_field
self.similarity = similarity
self.index_type = index_type
self.custom_schema = custom_schema
self.return_embedding = return_embedding
self.embedding_field = embedding_field
self.progress_bar = progress_bar
self.duplicate_documents = duplicate_documents
self._create_schema_and_index_if_not_exist(self.index)
def _create_schema_and_index_if_not_exist(
self,
index: Optional[str] = None,
):
"""Create a new index (schema/class in Weaviate) for storing documents in case if an index (schema) with the name doesn't exist already."""
index = index or self.index
if self.custom_schema:
schema = self.custom_schema
else:
schema = {
"classes": [
{
"class": index,
"description": "Haystack index, it's a class in Weaviate",
"invertedIndexConfig": {
"cleanupIntervalSeconds": 60
},
"vectorizer": "none",
"properties": [
{
"dataType": [
"string"
],
"description": "Name Field",
"name": self.name_field
},
{
"dataType": [
"string"
],
"description": "Question Field",
"name": self.faq_question_field
},
{
"dataType": [
"text"
],
"description": "Document Text",
"name": self.text_field
},
],
}
]
}
if not self.weaviate_client.schema.contains(schema):
self.weaviate_client.schema.create(schema)
def _convert_weaviate_result_to_document(
self,
result: dict,
return_embedding: bool
) -> Document:
"""
Convert weaviate result dict into haystack document object. This is more involved because
weaviate search result dict varies between get and query interfaces.
Weaviate get methods return the data items in properties key, whereas the query doesn't.
"""
score = None
probability = None
text = ""
question = None
id = result.get("id")
embedding = result.get("vector")
# If properties key is present, get all the document fields from it.
# otherwise, a direct lookup in result root dict
props = result.get("properties")
if not props:
props = result
if props.get(self.text_field) is not None:
text = str(props.get(self.text_field))
if props.get(self.faq_question_field) is not None:
question = props.get(self.faq_question_field)
# Weaviate creates "_additional" key for semantic search
if "_additional" in props:
if "certainty" in props["_additional"]:
score = props["_additional"]['certainty']
probability = score
if "id" in props["_additional"]:
id = props["_additional"]['id']
if "vector" in props["_additional"]:
embedding = props["_additional"]['vector']
props.pop("_additional", None)
# We put all additional data of the doc into meta_data and return it in the API
meta_data = {k:v for k,v in props.items() if k not in (self.text_field, self.faq_question_field, self.embedding_field)}
if return_embedding and embedding:
embedding = np.asarray(embedding, dtype=np.float32)
document = Document(
id=id,
text=text,
meta=meta_data,
score=score,
probability=probability,
question=question,
embedding=embedding,
)
return document
def _create_document_field_map(self) -> Dict:
return {
self.text_field: "text",
self.embedding_field: "embedding",
self.faq_question_field if self.faq_question_field else "question": "question"
}
def get_document_by_id(self, id: str, index: Optional[str] = None) -> Optional[Document]:
"""Fetch a document by specifying its text id string"""
# Sample result dict from a get method
'''{'class': 'Document',
'creationTimeUnix': 1621075584724,
'id': '1bad51b7-bd77-485d-8871-21c50fab248f',
'properties': {'meta': "{'key1':'value1'}",
'name': 'name_5',
'text': 'text_5'},
'vector': []}'''
index = index or self.index
document = None
result = self.weaviate_client.data_object.get_by_id(id, with_vector=True)
if result:
document = self._convert_weaviate_result_to_document(result, return_embedding=True)
return document
def get_documents_by_id(self, ids: List[str], index: Optional[str] = None,
batch_size: int = 10_000) -> List[Document]:
"""Fetch documents by specifying a list of text id strings"""
index = index or self.index
documents = []
#TODO: better implementation with multiple where filters instead of chatty call below?
for id in ids:
result = self.weaviate_client.data_object.get_by_id(id, with_vector=True)
if result:
document = self._convert_weaviate_result_to_document(result, return_embedding=True)
documents.append(document)
return documents
def _get_current_properties(self, index: Optional[str] = None) -> List[str]:
"""Get all the existing properties in the schema"""
index = index or self.index
cur_properties = []
for class_item in self.weaviate_client.schema.get()['classes']:
if class_item['class'] == index:
cur_properties = [item['name'] for item in class_item['properties']]
return cur_properties
def _build_filter_clause(self, filters:Dict[str, List[str]]) -> dict:
"""Transform Haystack filter conditions to Weaviate where filter clauses"""
weaviate_filters = []
weaviate_filter = {}
for key, values in filters.items():
for value in values:
weaviate_filter = {
"path": [key],
"operator": "Equal",
"valueString": value
}
weaviate_filters.append(weaviate_filter)
if len(weaviate_filters) > 1:
filter_dict = {
"operator": "Or",
"operands": weaviate_filters
}
return filter_dict
else:
return weaviate_filter
def _update_schema(self, new_prop:str, index: Optional[str] = None):
"""Updates the schema with a new property"""
index = index or self.index
property_dict = {
"dataType": [
"string"
],
"description": f"dynamic property {new_prop}",
"name": new_prop
}
self.weaviate_client.schema.property.create(index, property_dict)
def _check_document(self, cur_props: List[str], doc: dict) -> List[str]:
"""Find the properties in the document that don't exist in the existing schema"""
return [item for item in doc.keys() if item not in cur_props]
def write_documents(
self, documents: Union[List[dict], List[Document]], index: Optional[str] = None,
batch_size: int = 10_000, duplicate_documents: Optional[str] = None):
"""
Add new documents to the DocumentStore.
:param documents: List of `Dicts` or List of `Documents`. Passing an Embedding/Vector is mandatory in case weaviate is not
configured with a module. If a module is configured, the embedding is automatically generated by Weaviate.
:param index: index name for storing the docs and metadata
:param batch_size: When working with large number of documents, batching can help reduce memory footprint.
:param duplicate_documents: Handle duplicates document based on parameter options.
Parameter options : ( 'skip','overwrite','fail')
skip: Ignore the duplicates documents
overwrite: Update any existing documents with the same ID when adding documents.
fail: an error is raised if the document ID of the document being added already
exists.
:raises DuplicateDocumentError: Exception trigger on duplicate document
:return: None
"""
index = index or self.index
self._create_schema_and_index_if_not_exist(index)
field_map = self._create_document_field_map()
duplicate_documents = duplicate_documents or self.duplicate_documents
assert duplicate_documents in self.duplicate_documents_options, \
f"duplicate_documents parameter must be {', '.join(self.duplicate_documents_options)}"
if len(documents) == 0:
logger.warning("Calling DocumentStore.write_documents() with empty list")
return
# Auto schema feature https://github.com/semi-technologies/weaviate/issues/1539
# Get and cache current properties in the schema
current_properties = self._get_current_properties(index)
document_objects = [Document.from_dict(d, field_map=field_map) if isinstance(d, dict) else d for d in documents]
document_objects = self._handle_duplicate_documents(document_objects, duplicate_documents)
batched_documents = get_batches_from_generator(document_objects, batch_size)
with tqdm(total=len(document_objects), disable=not self.progress_bar) as progress_bar:
for document_batch in batched_documents:
docs_batch = ObjectsBatchRequest()
for idx, doc in enumerate(document_batch):
_doc = {
**doc.to_dict(field_map=self._create_document_field_map())
}
_ = _doc.pop("score", None)
_ = _doc.pop("probability", None)
# In order to have a flat structure in elastic + similar behaviour to the other DocumentStores,
# we "unnest" all value within "meta"
if "meta" in _doc.keys():
for k, v in _doc["meta"].items():
_doc[k] = v
_doc.pop("meta")
doc_id = str(_doc.pop("id"))
vector = _doc.pop(self.embedding_field)
if _doc.get(self.faq_question_field) is None:
_doc.pop(self.faq_question_field)
# Check if additional properties are in the document, if so,
# append the schema with all the additional properties
missing_props = self._check_document(current_properties, _doc)
if missing_props:
for property in missing_props:
self._update_schema(property, index)
current_properties.append(property)
docs_batch.add(_doc, class_name=index, uuid=doc_id, vector=vector)
# Ingest a batch of documents
results = self.weaviate_client.batch.create(docs_batch)
# Weaviate returns errors for every failed document in the batch
if results is not None:
for result in results:
if 'result' in result and 'errors' in result['result'] \
and 'error' in result['result']['errors']:
for message in result['result']['errors']['error']:
logger.error(f"{message['message']}")
progress_bar.update(batch_size)
progress_bar.close()
def update_document_meta(self, id: str, meta: Dict[str, str]):
"""
Update the metadata dictionary of a document by specifying its string id
"""
self.weaviate_client.data_object.update(meta, class_name=self.index, uuid=id)
def get_document_count(self, filters: Optional[Dict[str, List[str]]] = None, index: Optional[str] = None) -> int:
"""
Return the number of documents in the document store.
"""
index = index or self.index
doc_count = 0
if filters:
filter_dict = self._build_filter_clause(filters=filters)
result = self.weaviate_client.query.aggregate(index) \
.with_fields("meta { count }") \
.with_where(filter_dict)\
.do()
else:
result = self.weaviate_client.query.aggregate(index)\
.with_fields("meta { count }")\
.do()
if "data" in result:
if "Aggregate" in result.get('data'):
doc_count = result.get('data').get('Aggregate').get(index)[0]['meta']['count']
return doc_count
def get_all_documents(
self,
index: Optional[str] = None,
filters: Optional[Dict[str, List[str]]] = None,
return_embedding: Optional[bool] = None,
batch_size: int = 10_000,
) -> List[Document]:
"""
Get documents from the document store.
:param index: Name of the index to get the documents from. If None, the
DocumentStore's default index (self.index) will be used.
:param filters: Optional filters to narrow down the documents to return.
Example: {"name": ["some", "more"], "category": ["only_one"]}
:param return_embedding: Whether to return the document embeddings.
:param batch_size: When working with large number of documents, batching can help reduce memory footprint.
"""
index = index or self.index
result = self.get_all_documents_generator(
index=index, filters=filters, return_embedding=return_embedding, batch_size=batch_size
)
documents = list(result)
return documents
def _get_all_documents_in_index(
self,
index: Optional[str],
filters: Optional[Dict[str, List[str]]] = None,
batch_size: int = 10_000,
only_documents_without_embedding: bool = False,
) -> Generator[dict, None, None]:
"""
Return all documents in a specific index in the document store
"""
index = index or self.index
# Build the properties to retrieve from Weaviate
properties = self._get_current_properties(index)
properties.append("_additional {id, certainty, vector}")
if filters:
filter_dict = self._build_filter_clause(filters=filters)
result = self.weaviate_client.query.get(class_name=index, properties=properties)\
.with_where(filter_dict)\
.do()
else:
result = self.weaviate_client.query.get(class_name=index, properties=properties)\
.do()
all_docs = {}
if result and "data" in result and "Get" in result.get("data"):
if result.get("data").get("Get").get(index):
all_docs = result.get("data").get("Get").get(index)
yield from all_docs
def get_all_documents_generator(
self,
index: Optional[str] = None,
filters: Optional[Dict[str, List[str]]] = None,
return_embedding: Optional[bool] = None,
batch_size: int = 10_000,
) -> Generator[Document, None, None]:
"""
Get documents from the document store. Under-the-hood, documents are fetched in batches from the
document store and yielded as individual documents. This method can be used to iteratively process
a large number of documents without having to load all documents in memory.
:param index: Name of the index to get the documents from. If None, the
DocumentStore's default index (self.index) will be used.
:param filters: Optional filters to narrow down the documents to return.
Example: {"name": ["some", "more"], "category": ["only_one"]}
:param return_embedding: Whether to return the document embeddings.
:param batch_size: When working with large number of documents, batching can help reduce memory footprint.
"""
if index is None:
index = self.index
if return_embedding is None:
return_embedding = self.return_embedding
results = self._get_all_documents_in_index(index=index, filters=filters, batch_size=batch_size)
for result in results:
document = self._convert_weaviate_result_to_document(result, return_embedding=return_embedding)
yield document
def query(
self,
query: Optional[str] = None,
filters: Optional[Dict[str, List[str]]] = None,
top_k: int = 10,
custom_query: Optional[str] = None,
index: Optional[str] = None,
) -> List[Document]:
"""
Scan through documents in DocumentStore and return a small number documents
that are most relevant to the query as defined by Weaviate semantic search.
:param query: The query
:param filters: A dictionary where the keys specify a metadata field and the value is a list of accepted values for that field
:param top_k: How many documents to return per query.
:param custom_query: Custom query that will executed using query.raw method, for more details refer
https://www.semi.technology/developers/weaviate/current/graphql-references/filters.html
:param index: The name of the index in the DocumentStore from which to retrieve documents
"""
index = index or self.index
# Build the properties to retrieve from Weaviate
properties = self._get_current_properties(index)
properties.append("_additional {id, certainty, vector}")
if custom_query:
query_output = self.weaviate_client.query.raw(custom_query)
elif filters:
filter_dict = self._build_filter_clause(filters)
query_output = self.weaviate_client.query\
.get(class_name=index, properties=properties)\
.with_where(filter_dict)\
.with_limit(top_k)\
.do()
else:
raise NotImplementedError("Weaviate does not support inverted index text query. However, "
"it allows to search by filters example : {'text': 'some text'} or "
"use a custom GraphQL query in text format!")
results = []
if query_output and "data" in query_output and "Get" in query_output.get("data"):
if query_output.get("data").get("Get").get(index):
results = query_output.get("data").get("Get").get(index)
documents = []
for result in results:
doc = self._convert_weaviate_result_to_document(result, return_embedding=True)
documents.append(doc)
return documents
def query_by_embedding(self,
query_emb: np.ndarray,
filters: Optional[dict] = None,
top_k: int = 10,
index: Optional[str] = None,
return_embedding: Optional[bool] = None) -> List[Document]:
"""
Find the document that is most similar to the provided `query_emb` by using a vector similarity metric.
:param query_emb: Embedding of the query (e.g. gathered from DPR)
:param filters: Optional filters to narrow down the search space.
Example: {"name": ["some", "more"], "category": ["only_one"]}
:param top_k: How many documents to return
:param index: index name for storing the docs and metadata
:param return_embedding: To return document embedding
:return:
"""
if return_embedding is None:
return_embedding = self.return_embedding
index = index or self.index
# Build the properties to retrieve from Weaviate
properties = self._get_current_properties(index)
properties.append("_additional {id, certainty, vector}")
query_emb = query_emb.reshape(1, -1).astype(np.float32)
query_string = {
"vector" : query_emb
}
if filters:
filter_dict = self._build_filter_clause(filters)
query_output = self.weaviate_client.query\
.get(class_name=index, properties=properties)\
.with_where(filter_dict)\
.with_near_vector(query_string)\
.with_limit(top_k)\
.do()
else:
query_output = self.weaviate_client.query\
.get(class_name=index, properties=properties)\
.with_near_vector(query_string)\
.with_limit(top_k)\
.do()
results = []
if query_output and "data" in query_output and "Get" in query_output.get("data"):
if query_output.get("data").get("Get").get(index):
results = query_output.get("data").get("Get").get(index)
documents = []
for result in results:
doc = self._convert_weaviate_result_to_document(result, return_embedding=return_embedding)
documents.append(doc)
return documents
def update_embeddings(
self,
retriever,
index: Optional[str] = None,
filters: Optional[Dict[str, List[str]]] = None,
update_existing_embeddings: bool = True,
batch_size: int = 10_000
):
"""
Updates the embeddings in the the document store using the encoding model specified in the retriever.
This can be useful if want to change the embeddings for your documents (e.g. after changing the retriever config).
:param retriever: Retriever to use to update the embeddings.
:param index: Index name to update
:param update_existing_embeddings: Weaviate mandates an embedding while creating the document itself.
This option must be always true for weaviate and it will update the embeddings for all the documents.
:param filters: Optional filters to narrow down the documents for which embeddings are to be updated.
Example: {"name": ["some", "more"], "category": ["only_one"]}
:param batch_size: When working with large number of documents, batching can help reduce memory footprint.
:return: None
"""
if index is None:
index = self.index
if not self.embedding_field:
raise RuntimeError("Specify the arg `embedding_field` when initializing WeaviateDocumentStore()")
if update_existing_embeddings:
logger.info(f"Updating embeddings for all {self.get_document_count(index=index)} docs ...")
else:
raise RuntimeError("All the documents in Weaviate store have an embedding by default. Only update is allowed!")
result = self._get_all_documents_in_index(
index=index,
filters=filters,
batch_size=batch_size,
)
for result_batch in get_batches_from_generator(result, batch_size):
document_batch = [self._convert_weaviate_result_to_document(hit, return_embedding=False) for hit in result_batch]
embeddings = retriever.embed_passages(document_batch) # type: ignore
assert len(document_batch) == len(embeddings)
if embeddings[0].shape[0] != self.embedding_dim:
raise RuntimeError(f"Embedding dim. of model ({embeddings[0].shape[0]})"
f" doesn't match embedding dim. in DocumentStore ({self.embedding_dim})."
"Specify the arg `embedding_dim` when initializing WeaviateDocumentStore()")
for doc, emb in zip(document_batch, embeddings):
# This doc processing will not required once weaviate's update
# method works. To be improved.
_doc = {
**doc.to_dict(field_map=self._create_document_field_map())
}
_ = _doc.pop("score", None)
_ = _doc.pop("probability", None)
if "meta" in _doc.keys():
for k, v in _doc["meta"].items():
_doc[k] = v
_doc.pop("meta")
doc_id = str(_doc.pop("id"))
_ = _doc.pop(self.embedding_field)
keys_to_remove = [k for k,v in _doc.items() if v is None]
for key in keys_to_remove:
_doc.pop(key)
# TODO: Weaviate's update throws an error while passing a vector now, have to improve this later
self.weaviate_client.data_object.replace(_doc, class_name=index, uuid=doc_id, vector=emb)
def delete_all_documents(self, index: Optional[str] = None, filters: Optional[Dict[str, List[str]]] = None):
"""
Delete documents in an index. All documents are deleted if no filters are passed.
:param index: Index name to delete the document from.
:param filters: Optional filters to narrow down the documents to be deleted.
:return: None
"""
index = index or self.index
if filters:
docs_to_delete = self.get_all_documents(index, filters=filters)
for doc in docs_to_delete:
self.weaviate_client.data_object.delete(doc.id)
else:
self.weaviate_client.schema.delete_class(index)
self._create_schema_and_index_if_not_exist(index)

View File

@ -30,4 +30,5 @@ pymilvus
#selenium
#webdriver-manager
SPARQLWrapper
mmh3
mmh3
weaviate-client

View File

@ -9,6 +9,9 @@ from elasticsearch import Elasticsearch
from haystack.knowledge_graph.graphdb import GraphDBKnowledgeGraph
from milvus import Milvus
import weaviate
from haystack.document_store.weaviate import WeaviateDocumentStore
from haystack.document_store.milvus import MilvusDocumentStore
from haystack.generator.transformers import RAGenerator, RAGeneratorType
@ -61,6 +64,8 @@ def pytest_collection_modifyitems(items):
item.add_marker(pytest.mark.pipeline)
elif "slow" in item.nodeid:
item.add_marker(pytest.mark.slow)
elif "weaviate" in item.nodeid:
item.add_marker(pytest.mark.weaviate)
@pytest.fixture(scope="session")
@ -98,6 +103,27 @@ def milvus_fixture():
'milvusdb/milvus:0.10.5-cpu-d010621-4eda95'], shell=True)
time.sleep(40)
@pytest.fixture(scope="session")
def weaviate_fixture():
# test if a Weaviate server is already running. If not, start Weaviate docker container locally.
# Make sure you have given > 6GB memory to docker engine
try:
weaviate_server = weaviate.Client(url='http://localhost:8080', timeout_config=(5, 15))
weaviate_server.is_ready()
except:
print("Starting Weaviate servers ...")
status = subprocess.run(
['docker rm haystack_test_weaviate'],
shell=True
)
status = subprocess.run(
['docker run -d --name haystack_test_weaviate -p 8080:8080 semitechnologies/weaviate:1.4.0'],
shell=True
)
if status.returncode:
raise Exception(
"Failed to launch Weaviate. Please check docker container logs.")
time.sleep(60)
@pytest.fixture(scope="session")
def graphdb_fixture():
@ -310,7 +336,6 @@ def document_store_with_docs(request, test_docs_xs):
yield document_store
document_store.delete_all_documents()
@pytest.fixture(params=["elasticsearch", "faiss", "memory", "sql", "milvus"])
def document_store(request, test_docs_xs):
document_store = get_document_store(request.param)
@ -352,6 +377,14 @@ def get_document_store(document_store_type, embedding_field="embedding"):
if collection.startswith("haystack_test"):
document_store.milvus_server.drop_collection(collection)
return document_store
elif document_store_type == "weaviate":
document_store = WeaviateDocumentStore(
weaviate_url="http://localhost:8080",
index="Haystacktest"
)
document_store.weaviate_client.schema.delete_all()
document_store._create_schema_and_index_if_not_exist()
return document_store
else:
raise Exception(f"No document store fixture for '{document_store_type}'")

View File

@ -8,3 +8,4 @@ markers =
generator: marks generator tests (deselect with '-m "not generator"')
pipeline: marks tests with pipeline
summarizer: marks summarizer tests
weaviate: marks tests that require weaviate container

330
test/test_weaviate.py Normal file
View File

@ -0,0 +1,330 @@
import numpy as np
import pytest
from haystack import Document
from conftest import get_document_store
import uuid
embedding_dim = 768
def get_uuid():
return str(uuid.uuid4())
DOCUMENTS = [
{"text": "text1", "id":get_uuid(), "key": "a", "embedding": np.random.rand(embedding_dim).astype(np.float32)},
{"text": "text2", "id":get_uuid(), "key": "b", "embedding": np.random.rand(embedding_dim).astype(np.float32)},
{"text": "text3", "id":get_uuid(), "key": "b", "embedding": np.random.rand(embedding_dim).astype(np.float32)},
{"text": "text4", "id":get_uuid(), "key": "b", "embedding": np.random.rand(embedding_dim).astype(np.float32)},
{"text": "text5", "id":get_uuid(), "key": "b", "embedding": np.random.rand(embedding_dim).astype(np.float32)},
]
DOCUMENTS_XS = [
# current "dict" format for a document
{"text": "My name is Carla and I live in Berlin", "id":get_uuid(), "meta": {"metafield": "test1", "name": "filename1"}, "embedding": np.random.rand(embedding_dim).astype(np.float32)},
# meta_field at the top level for backward compatibility
{"text": "My name is Paul and I live in New York", "id":get_uuid(), "metafield": "test2", "name": "filename2", "embedding": np.random.rand(embedding_dim).astype(np.float32)},
# Document object for a doc
Document(text="My name is Christelle and I live in Paris", id=get_uuid(), meta={"metafield": "test3", "name": "filename3"}, embedding=np.random.rand(embedding_dim).astype(np.float32))
]
@pytest.fixture(params=["weaviate"])
def document_store_with_docs(request):
document_store = get_document_store(request.param)
document_store.write_documents(DOCUMENTS_XS)
yield document_store
document_store.delete_all_documents()
@pytest.fixture(params=["weaviate"])
def document_store(request):
document_store = get_document_store(request.param)
yield document_store
document_store.delete_all_documents()
@pytest.mark.weaviate
@pytest.mark.parametrize("document_store_with_docs", ["weaviate"], indirect=True)
def test_get_all_documents_without_filters(document_store_with_docs):
documents = document_store_with_docs.get_all_documents()
assert all(isinstance(d, Document) for d in documents)
assert len(documents) == 3
assert {d.meta["name"] for d in documents} == {"filename1", "filename2", "filename3"}
assert {d.meta["metafield"] for d in documents} == {"test1", "test2", "test3"}
@pytest.mark.weaviate
def test_get_all_documents_with_correct_filters(document_store_with_docs):
documents = document_store_with_docs.get_all_documents(filters={"metafield": ["test2"]})
assert len(documents) == 1
assert documents[0].meta["name"] == "filename2"
documents = document_store_with_docs.get_all_documents(filters={"metafield": ["test1", "test3"]})
assert len(documents) == 2
assert {d.meta["name"] for d in documents} == {"filename1", "filename3"}
assert {d.meta["metafield"] for d in documents} == {"test1", "test3"}
@pytest.mark.weaviate
def test_get_all_documents_with_incorrect_filter_name(document_store_with_docs):
documents = document_store_with_docs.get_all_documents(filters={"incorrectmetafield": ["test2"]})
assert len(documents) == 0
@pytest.mark.weaviate
def test_get_all_documents_with_incorrect_filter_value(document_store_with_docs):
documents = document_store_with_docs.get_all_documents(filters={"metafield": ["incorrect_value"]})
assert len(documents) == 0
@pytest.mark.weaviate
def test_get_documents_by_id(document_store_with_docs):
documents = document_store_with_docs.get_all_documents()
doc = document_store_with_docs.get_document_by_id(documents[0].id)
assert doc.id == documents[0].id
assert doc.text == documents[0].text
@pytest.mark.weaviate
@pytest.mark.parametrize("document_store", ["weaviate"], indirect=True)
def test_get_document_count(document_store):
document_store.write_documents(DOCUMENTS)
assert document_store.get_document_count() == 5
assert document_store.get_document_count(filters={"key": ["a"]}) == 1
assert document_store.get_document_count(filters={"key": ["b"]}) == 4
@pytest.mark.weaviate
@pytest.mark.parametrize("document_store", ["weaviate"], indirect=True)
@pytest.mark.parametrize("batch_size", [2])
def test_weaviate_write_docs(document_store, batch_size):
# Write in small batches
for i in range(0, len(DOCUMENTS), batch_size):
document_store.write_documents(DOCUMENTS[i: i + batch_size])
documents_indexed = document_store.get_all_documents()
assert len(documents_indexed) == len(DOCUMENTS)
documents_indexed = document_store.get_all_documents(batch_size=batch_size)
assert len(documents_indexed) == len(DOCUMENTS)
@pytest.mark.weaviate
@pytest.mark.parametrize("document_store", ["weaviate"], indirect=True)
def test_get_all_document_filter_duplicate_value(document_store):
documents = [
Document(
text="Doc1",
meta={"fone": "f0"},
id = get_uuid(),
embedding= np.random.rand(embedding_dim).astype(np.float32)
),
Document(
text="Doc1",
meta={"fone": "f1", "metaid": "0"},
id = get_uuid(),
embedding = np.random.rand(embedding_dim).astype(np.float32)
),
Document(
text="Doc2",
meta={"fthree": "f0"},
id = get_uuid(),
embedding=np.random.rand(embedding_dim).astype(np.float32)
)
]
document_store.write_documents(documents)
documents = document_store.get_all_documents(filters={"fone": ["f1"]})
assert documents[0].text == "Doc1"
assert len(documents) == 1
assert {d.meta["metaid"] for d in documents} == {"0"}
@pytest.mark.weaviate
@pytest.mark.parametrize("document_store", ["weaviate"], indirect=True)
def test_get_all_documents_generator(document_store):
document_store.write_documents(DOCUMENTS)
assert len(list(document_store.get_all_documents_generator(batch_size=2))) == 5
@pytest.mark.weaviate
@pytest.mark.parametrize("document_store", ["weaviate"], indirect=True)
def test_write_with_duplicate_doc_ids(document_store):
id = get_uuid()
documents = [
Document(
text="Doc1",
id=id,
embedding=np.random.rand(embedding_dim).astype(np.float32)
),
Document(
text="Doc2",
id=id,
embedding=np.random.rand(embedding_dim).astype(np.float32)
)
]
document_store.write_documents(documents, duplicate_documents="skip")
with pytest.raises(Exception):
document_store.write_documents(documents, duplicate_documents="fail")
@pytest.mark.weaviate
@pytest.mark.parametrize("document_store", ["weaviate"], indirect=True)
@pytest.mark.parametrize("update_existing_documents", [True, False])
def test_update_existing_documents(document_store, update_existing_documents):
id = uuid.uuid4()
original_docs = [
{"text": "text1_orig", "id": id, "metafieldforcount": "a", "embedding": np.random.rand(embedding_dim).astype(np.float32)},
]
updated_docs = [
{"text": "text1_new", "id": id, "metafieldforcount": "a", "embedding": np.random.rand(embedding_dim).astype(np.float32)},
]
document_store.update_existing_documents = update_existing_documents
document_store.write_documents(original_docs)
assert document_store.get_document_count() == 1
if update_existing_documents:
document_store.write_documents(updated_docs, duplicate_documents="overwrite")
else:
with pytest.raises(Exception):
document_store.write_documents(updated_docs, duplicate_documents="fail")
stored_docs = document_store.get_all_documents()
assert len(stored_docs) == 1
if update_existing_documents:
assert stored_docs[0].text == updated_docs[0]["text"]
else:
assert stored_docs[0].text == original_docs[0]["text"]
@pytest.mark.weaviate
@pytest.mark.parametrize("document_store", ["weaviate"], indirect=True)
def test_write_document_meta(document_store):
uid1 = get_uuid()
uid2 = get_uuid()
uid3 = get_uuid()
uid4 = get_uuid()
documents = [
{"text": "dict_without_meta", "id": uid1, "embedding": np.random.rand(embedding_dim).astype(np.float32)},
{"text": "dict_with_meta", "metafield": "test2", "name": "filename2", "id": uid2, "embedding": np.random.rand(embedding_dim).astype(np.float32)},
Document(text="document_object_without_meta", id=uid3, embedding=np.random.rand(embedding_dim).astype(np.float32)),
Document(text="document_object_with_meta", meta={"metafield": "test4", "name": "filename3"}, id=uid4, embedding=np.random.rand(embedding_dim).astype(np.float32)),
]
document_store.write_documents(documents)
documents_in_store = document_store.get_all_documents()
assert len(documents_in_store) == 4
assert not document_store.get_document_by_id(uid1).meta
assert document_store.get_document_by_id(uid2).meta["metafield"] == "test2"
assert not document_store.get_document_by_id(uid3).meta
assert document_store.get_document_by_id(uid4).meta["metafield"] == "test4"
@pytest.mark.weaviate
@pytest.mark.parametrize("document_store", ["weaviate"], indirect=True)
def test_write_document_index(document_store):
documents = [
{"text": "text1", "id": uuid.uuid4(), "embedding": np.random.rand(embedding_dim).astype(np.float32)},
{"text": "text2", "id": uuid.uuid4(), "embedding": np.random.rand(embedding_dim).astype(np.float32)},
]
document_store.write_documents([documents[0]], index="Haystackone")
assert len(document_store.get_all_documents(index="Haystackone")) == 1
document_store.write_documents([documents[1]], index="Haystacktwo")
assert len(document_store.get_all_documents(index="Haystacktwo")) == 1
assert len(document_store.get_all_documents(index="Haystackone")) == 1
assert len(document_store.get_all_documents()) == 0
@pytest.mark.weaviate
@pytest.mark.parametrize("retriever", ["dpr", "embedding"], indirect=True)
@pytest.mark.parametrize("document_store", ["weaviate"], indirect=True)
def test_update_embeddings(document_store, retriever):
documents = []
for i in range(6):
documents.append({"text": f"text_{i}", "id": str(uuid.uuid4()), "metafield": f"value_{i}", "embedding": np.random.rand(embedding_dim).astype(np.float32)})
documents.append({"text": "text_0", "id": str(uuid.uuid4()), "metafield": "value_0", "embedding": np.random.rand(embedding_dim).astype(np.float32)})
document_store.write_documents(documents, index="HaystackTestOne")
document_store.update_embeddings(retriever, index="HaystackTestOne", batch_size=3)
documents = document_store.get_all_documents(index="HaystackTestOne", return_embedding=True)
assert len(documents) == 7
for doc in documents:
assert type(doc.embedding) is np.ndarray
documents = document_store.get_all_documents(
index="HaystackTestOne",
filters={"metafield": ["value_0"]},
return_embedding=True,
)
assert len(documents) == 2
for doc in documents:
assert doc.meta["metafield"] == "value_0"
np.testing.assert_array_almost_equal(documents[0].embedding, documents[1].embedding, decimal=4)
documents = document_store.get_all_documents(
index="HaystackTestOne",
filters={"metafield": ["value_1", "value_5"]},
return_embedding=True,
)
np.testing.assert_raises(
AssertionError,
np.testing.assert_array_equal,
documents[0].embedding,
documents[1].embedding
)
doc = {"text": "text_7", "id": str(uuid.uuid4()), "metafield": "value_7",
"embedding": retriever.embed_queries(texts=["a random string"])[0]}
document_store.write_documents([doc], index="HaystackTestOne")
doc_before_update = document_store.get_all_documents(index="HaystackTestOne", filters={"metafield": ["value_7"]})[0]
embedding_before_update = doc_before_update.embedding
document_store.update_embeddings(
retriever, index="HaystackTestOne", batch_size=3, filters={"metafield": ["value_0", "value_1"]}
)
doc_after_update = document_store.get_all_documents(index="HaystackTestOne", filters={"metafield": ["value_7"]})[0]
embedding_after_update = doc_after_update.embedding
np.testing.assert_array_equal(embedding_before_update, embedding_after_update)
# test update all embeddings
document_store.update_embeddings(retriever, index="HaystackTestOne", batch_size=3, update_existing_embeddings=True)
assert document_store.get_document_count(index="HaystackTestOne") == 8
doc_after_update = document_store.get_all_documents(index="HaystackTestOne", filters={"metafield": ["value_7"]})[0]
embedding_after_update = doc_after_update.embedding
np.testing.assert_raises(AssertionError, np.testing.assert_array_equal, embedding_before_update, embedding_after_update)
@pytest.mark.weaviate
@pytest.mark.parametrize("document_store_with_docs", ["weaviate"], indirect=True)
def test_query_by_embedding(document_store_with_docs):
docs = document_store_with_docs.query_by_embedding(np.random.rand(embedding_dim).astype(np.float32))
assert len(docs) == 3
docs = document_store_with_docs.query_by_embedding(np.random.rand(embedding_dim).astype(np.float32),
top_k=1)
assert len(docs) == 1
docs = document_store_with_docs.query_by_embedding(np.random.rand(embedding_dim).astype(np.float32),
filters = {"name": ['filename2']})
assert len(docs) == 1
@pytest.mark.weaviate
@pytest.mark.parametrize("document_store_with_docs", ["weaviate"], indirect=True)
def test_query(document_store_with_docs):
query_text = 'My name is Carla and I live in Berlin'
with pytest.raises(Exception):
docs = document_store_with_docs.query(query_text)
docs = document_store_with_docs.query(filters = {"name": ['filename2']})
assert len(docs) == 1
docs = document_store_with_docs.query(filters={"text":[query_text.lower()]})
assert len(docs) == 1
docs = document_store_with_docs.query(filters={"text":['live']})
assert len(docs) == 3
@pytest.mark.weaviate
@pytest.mark.parametrize("document_store_with_docs", ["weaviate"], indirect=True)
def test_delete_all_documents(document_store_with_docs):
assert len(document_store_with_docs.get_all_documents()) == 3
document_store_with_docs.delete_all_documents()
documents = document_store_with_docs.get_all_documents()
assert len(documents) == 0
@pytest.mark.weaviate
@pytest.mark.parametrize("document_store_with_docs", ["weaviate"], indirect=True)
def test_delete_documents_with_filters(document_store_with_docs):
document_store_with_docs.delete_all_documents(filters={"metafield": ["test1", "test2"]})
documents = document_store_with_docs.get_all_documents()
assert len(documents) == 1
assert documents[0].meta["metafield"] == "test3"