refactor: Refactor Weaviate tests (#3541)

* refactor tests

* fix job

* revert

* revert

* revert

* use latest weaviate

* fix abstract methods signatures

* pass class_name to all the CRUD methods

* finish moving all the tests

* bump weaviate version

* raise, don't pass
This commit is contained in:
Massimiliano Pippi 2022-11-14 09:57:30 +01:00 committed by GitHub
parent da6b0dc66f
commit 4dfddf0d10
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 251 additions and 220 deletions

View File

@ -325,6 +325,44 @@ jobs:
if: failure() && github.repository_owner == 'deepset-ai' && github.ref == 'refs/heads/main'
integration-tests-weaviate:
name: Integration / Weaviate / ${{ matrix.os }}
needs:
- unit-tests
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest]
runs-on: ${{ matrix.os }}
services:
weaviate:
image: semitechnologies/weaviate:1.16.0
env:
AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED: "true"
PERSISTENCE_DATA_PATH: "/var/lib/weaviate"
ENABLE_EXPERIMENTAL_BM25: "true"
DISK_USE_READONLY_PERCENTAGE: 95
ports:
- 8080:8080
steps:
- uses: actions/checkout@v3
- name: Setup Python
uses: ./.github/actions/python_cache/
- name: Install Haystack
run: pip install -U .[docstores]
- name: Run tests
run: |
pytest --maxfail=5 -m "document_store and integration" test/document_stores/test_weaviate.py
- uses: act10ns/slack@v1
with:
status: ${{ job.status }}
channel: '#haystack'
if: failure() && github.repository_owner == 'deepset-ai' && github.ref == 'refs/heads/main'
#
# TODO: the following steps need to be revisited
#
@ -502,78 +540,6 @@ jobs:
# pytest ${{ env.PYTEST_PARAMS }} -m "milvus and not integration" ${{ env.SUITES_EXCLUDED_FROM_WINDOWS }} test/document_stores/ --document_store_type=milvus
weaviate-tests-linux:
needs: [mypy, pylint, black]
runs-on: ubuntu-latest
if: contains(github.event.pull_request.labels.*.name, 'topic:weaviate') || !github.event.pull_request.draft
steps:
- uses: actions/checkout@v3
- name: Setup Python
uses: ./.github/actions/python_cache/
- name: Setup Weaviate
run: docker run -d -p 8080:8080 --name haystack_test_weaviate --env AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED='true' --env PERSISTENCE_DATA_PATH='/var/lib/weaviate' --env ENABLE_EXPERIMENTAL_BM25='true' --env DISK_USE_READONLY_PERCENTAGE='95' semitechnologies/weaviate:1.14.1
# TODO Let's try to remove this one from the unit tests
- name: Install pdftotext
run: wget --no-check-certificate https://dl.xpdfreader.com/xpdf-tools-linux-4.04.tar.gz && tar -xvf xpdf-tools-linux-4.04.tar.gz && sudo cp xpdf-tools-linux-4.04/bin64/pdftotext /usr/local/bin
- name: Install Haystack
run: pip install .[weaviate]
- name: Run tests
env:
TOKENIZERS_PARALLELISM: 'false'
run: |
pytest ${{ env.PYTEST_PARAMS }} -m "weaviate and not integration" test/document_stores/ --document_store_type=weaviate
- name: Dump docker logs on failure
if: failure()
uses: jwalton/gh-docker-logs@v1
- uses: act10ns/slack@v1
with:
status: ${{ job.status }}
channel: '#haystack'
if: failure() && github.repository_owner == 'deepset-ai' && github.ref == 'refs/heads/main'
# FIXME: seems like we can't run containers on Windows
# weaviate-tests-windows:
# needs:
# - mypy
# - pylint
# runs-on: windows-latest
# if: contains(github.event.pull_request.labels.*.name, 'topic:weaviate') && contains(github.event.pull_request.labels.*.name, 'topic:windows') || !github.event.pull_request.draft
# steps:
# - uses: actions/checkout@v3
# - name: Setup Python
# uses: ./.github/actions/python_cache/
# with:
# prefix: windows
# - name: Setup Weaviate
# run: docker run -d -p 8080:8080 --name haystack_test_weaviate --env AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED='true' --env PERSISTENCE_DATA_PATH='/var/lib/weaviate' --env ENABLE_EXPERIMENTAL_BM25='true' --env DISK_USE_READONLY_PERCENTAGE='95' semitechnologies/weaviate:1.14.1
# - name: Install pdftotext
# run: |
# choco install xpdf-utils
# choco install openjdk11
# refreshenv
# - name: Install Haystack
# run: pip install .[weaviate]
# - name: Run tests
# env:
# TOKENIZERS_PARALLELISM: 'false'
# run: |
# pytest ${{ env.PYTEST_PARAMS }} -m "weaviate and not integration" ${{ env.SUITES_EXCLUDED_FROM_WINDOWS }} test/document_stores/ --document_store_type=weaviate
pinecone-tests-linux:
needs: [mypy, pylint, black]
runs-on: ubuntu-latest

View File

@ -17,7 +17,7 @@ except (ImportError, ModuleNotFoundError) as ie:
_optional_component_not_installed(__name__, "weaviate", ie)
from haystack.schema import Document
from haystack.schema import Document, Label
from haystack.document_stores import BaseDocumentStore
from haystack.document_stores.base import get_batches_from_generator
from haystack.document_stores.filter_utils import LogicalFilterClause
@ -312,7 +312,7 @@ class WeaviateDocumentStore(BaseDocumentStore):
id = self._sanitize_id(id=id, index=index)
result = None
try:
result = self.weaviate_client.data_object.get_by_id(id, with_vector=True)
result = self.weaviate_client.data_object.get_by_id(id, class_name=index, with_vector=True)
except weaviate.exceptions.UnexpectedStatusCodeException as usce:
logging.debug("Weaviate could not get the document requested: %s", usce)
if result:
@ -339,7 +339,7 @@ class WeaviateDocumentStore(BaseDocumentStore):
id = self._sanitize_id(id=id, index=index)
result = None
try:
result = self.weaviate_client.data_object.get_by_id(id, with_vector=True)
result = self.weaviate_client.data_object.get_by_id(id, class_name=index, with_vector=True)
except weaviate.exceptions.UnexpectedStatusCodeException as usce:
logging.debug("Weaviate could not get the document requested: %s", usce)
if result:
@ -1352,7 +1352,7 @@ class WeaviateDocumentStore(BaseDocumentStore):
if ids and not filters:
for id in ids:
self.weaviate_client.data_object.delete(id)
self.weaviate_client.data_object.delete(id, class_name=index)
else:
# Use filters to restrict list of retrieved documents, before checking these against provided ids
@ -1360,7 +1360,7 @@ class WeaviateDocumentStore(BaseDocumentStore):
if ids:
docs_to_delete = [doc for doc in docs_to_delete if doc.id in ids]
for doc in docs_to_delete:
self.weaviate_client.data_object.delete(doc.id)
self.weaviate_client.data_object.delete(doc.id, class_name=index)
def delete_index(self, index: str):
"""
@ -1382,7 +1382,13 @@ class WeaviateDocumentStore(BaseDocumentStore):
self.weaviate_client.schema.delete_class(index)
logger.info("Index '%s' deleted.", index)
def delete_labels(self):
def delete_labels(
self,
index: Optional[str] = None,
ids: Optional[List[str]] = None,
filters: Optional[Dict[str, Union[Dict, List, str, int, float, bool]]] = None,
headers: Optional[Dict[str, str]] = None,
):
"""
Implemented to respect BaseDocumentStore's contract.
@ -1390,7 +1396,12 @@ class WeaviateDocumentStore(BaseDocumentStore):
"""
raise NotImplementedError("Weaviate does not support labels (yet).")
def get_all_labels(self):
def get_all_labels(
self,
index: Optional[str] = None,
filters: Optional[Dict[str, Union[Dict, List, str, int, float, bool]]] = None,
headers: Optional[Dict[str, str]] = None,
) -> List[Label]:
"""
Implemented to respect BaseDocumentStore's contract.
@ -1398,7 +1409,7 @@ class WeaviateDocumentStore(BaseDocumentStore):
"""
raise NotImplementedError("Weaviate does not support labels (yet).")
def get_label_count(self):
def get_label_count(self, index: Optional[str] = None, headers: Optional[Dict[str, str]] = None) -> int:
"""
Implemented to respect BaseDocumentStore's contract.
@ -1406,10 +1417,15 @@ class WeaviateDocumentStore(BaseDocumentStore):
"""
raise NotImplementedError("Weaviate does not support labels (yet).")
def write_labels(self):
def write_labels(
self,
labels: Union[List[Label], List[dict]],
index: Optional[str] = None,
headers: Optional[Dict[str, str]] = None,
):
"""
Implemented to respect BaseDocumentStore's contract.
Weaviate does not support labels (yet).
"""
pass
raise NotImplementedError("Weaviate does not support labels (yet).")

View File

@ -70,7 +70,7 @@ def launch_weaviate(sleep=15):
logger.debug("Starting Weaviate ...")
status = subprocess.run(
[
f"docker start {WEAVIATE_CONTAINER_NAME} > /dev/null 2>&1 || docker run -d -p 8080:8080 --env AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED='true' --env PERSISTENCE_DATA_PATH='/var/lib/weaviate' --name {WEAVIATE_CONTAINER_NAME} semitechnologies/weaviate:1.14.0"
f"docker start {WEAVIATE_CONTAINER_NAME} > /dev/null 2>&1 || docker run -d -p 8080:8080 --env AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED='true' --env PERSISTENCE_DATA_PATH='/var/lib/weaviate' --name {WEAVIATE_CONTAINER_NAME} semitechnologies/weaviate:latest"
],
shell=True,
)

View File

@ -124,7 +124,7 @@ milvus = [
"farm-haystack[sql,only-milvus]",
]
weaviate = [
"weaviate-client==3.6.0",
"weaviate-client==3.9.0",
]
only-pinecone = [
"pinecone-client>=2.0.11,<3",
@ -314,9 +314,6 @@ disable = [
"simplifiable-if-expression",
"use-list-literal",
# To review later
"cyclic-import",
"import-outside-toplevel",
@ -334,7 +331,7 @@ addopts = "--strict-markers"
markers = [
"unit: unit tests",
"integration: integration tests",
"generator: generator tests",
"summarizer: summarizer tests",
"embedding_dim: uses a document store with non-default embedding dimension (e.g @pytest.mark.embedding_dim(128))",

View File

@ -493,7 +493,7 @@ def weaviate_fixture():
print("Starting Weaviate servers ...")
status = subprocess.run(["docker rm haystack_test_weaviate"], shell=True)
status = subprocess.run(
["docker run -d --name haystack_test_weaviate -p 8080:8080 semitechnologies/weaviate:1.14.1"], shell=True
["docker run -d --name haystack_test_weaviate -p 8080:8080 semitechnologies/weaviate:latest"], shell=True
)
if status.returncode:
raise Exception("Failed to launch Weaviate. Please check docker container logs.")

View File

@ -177,13 +177,13 @@ class DocumentStoreBaseTestAbstract:
def test_comparison_filters(self, ds, documents):
ds.write_documents(documents)
result = ds.get_all_documents(filters={"numbers": {"$gt": 0}})
result = ds.get_all_documents(filters={"numbers": {"$gt": 0.0}})
assert len(result) == 3
result = ds.get_all_documents(filters={"numbers": {"$gte": -2}})
result = ds.get_all_documents(filters={"numbers": {"$gte": -2.0}})
assert len(result) == 6
result = ds.get_all_documents(filters={"numbers": {"$lt": 0}})
result = ds.get_all_documents(filters={"numbers": {"$lt": 0.0}})
assert len(result) == 3
result = ds.get_all_documents(filters={"numbers": {"$lte": 2.0}})
@ -297,7 +297,7 @@ class DocumentStoreBaseTestAbstract:
@pytest.mark.integration
def test_get_document_count(self, ds, documents):
ds.write_documents(documents)
assert ds.get_document_count() == 9
assert ds.get_document_count() == len(documents)
assert ds.get_document_count(filters={"year": ["2020"]}) == 3
assert ds.get_document_count(filters={"month": ["02"]}) == 3

View File

@ -1,3 +1,11 @@
import pytest
from haystack.document_stores.weaviate import WeaviateDocumentStore
from haystack.schema import Document
from .test_base import DocumentStoreBaseTestAbstract
import uuid
from unittest.mock import MagicMock
@ -5,7 +13,6 @@ import numpy as np
import pytest
from haystack.schema import Document
from ..conftest import get_document_store
embedding_dim = 768
@ -14,162 +21,207 @@ def get_uuid():
return str(uuid.uuid4())
DOCUMENTS = [
{"content": "text1", "id": "not a correct uuid", "key": "a"},
{"content": "text2", "id": get_uuid(), "key": "b", "embedding": np.random.rand(embedding_dim).astype(np.float32)},
{"content": "text3", "id": get_uuid(), "key": "b", "embedding": np.random.rand(embedding_dim).astype(np.float32)},
{"content": "text4", "id": get_uuid(), "key": "b", "embedding": np.random.rand(embedding_dim).astype(np.float32)},
{"content": "text5", "id": get_uuid(), "key": "b", "embedding": np.random.rand(embedding_dim).astype(np.float32)},
]
class TestWeaviateDocumentStore(DocumentStoreBaseTestAbstract):
# Constants
DOCUMENTS_XS = [
# current "dict" format for a document
{
"content": "My name is Carla and I live in Berlin",
"id": get_uuid(),
"meta": {"metafield": "test1", "name": "filename1"},
"embedding": np.random.rand(embedding_dim).astype(np.float32),
},
# meta_field at the top level for backward compatibility
{
"content": "My name is Paul and I live in New York",
"id": get_uuid(),
"metafield": "test2",
"name": "filename2",
"embedding": np.random.rand(embedding_dim).astype(np.float32),
},
# Document object for a doc
Document(
content="My name is Christelle and I live in Paris",
id=get_uuid(),
meta={"metafield": "test3", "name": "filename3"},
embedding=np.random.rand(embedding_dim).astype(np.float32),
),
]
index_name = "DocumentsTest"
@pytest.fixture
def ds(self):
return WeaviateDocumentStore(index=self.index_name, recreate_index=True)
@pytest.fixture(params=["weaviate"])
def document_store_with_docs(request, tmp_path):
document_store = get_document_store(request.param, tmp_path=tmp_path)
document_store.write_documents(DOCUMENTS_XS)
yield document_store
document_store.delete_index(document_store.index)
@pytest.fixture(scope="class")
def documents(self):
documents = []
for i in range(3):
documents.append(
Document(
id=get_uuid(),
content=f"A Foo Document {i}",
meta={"name": f"name_{i}", "year": "2020", "month": "01", "numbers": [2.0, 4.0]},
embedding=np.random.rand(768).astype(np.float32),
)
)
documents.append(
Document(
id=get_uuid(),
content=f"A Bar Document {i}",
meta={"name": f"name_{i}", "year": "2021", "month": "02", "numbers": [-2.0, -4.0]},
embedding=np.random.rand(768).astype(np.float32),
)
)
@pytest.fixture(params=["weaviate"])
def document_store(request, tmp_path):
document_store = get_document_store(request.param, tmp_path=tmp_path)
yield document_store
document_store.delete_index(document_store.index)
documents.append(
Document(
id=get_uuid(),
content=f"A Baz Document {i}",
meta={"name": f"name_{i}", "month": "03"},
embedding=np.random.rand(768).astype(np.float32),
)
)
return documents
@pytest.mark.weaviate
@pytest.mark.parametrize("document_store", ["weaviate"], indirect=True)
@pytest.mark.parametrize("batch_size", [2])
def test_weaviate_write_docs(document_store, batch_size):
# Write in small batches
for i in range(0, len(DOCUMENTS), batch_size):
document_store.write_documents(DOCUMENTS[i : i + batch_size])
@pytest.mark.skip(reason="Weaviate does not support labels")
@pytest.mark.integration
def test_write_labels(self):
pass
documents_indexed = document_store.get_all_documents()
assert len(documents_indexed) == len(DOCUMENTS)
@pytest.mark.skip(reason="Weaviate does not support labels")
@pytest.mark.integration
def test_delete_labels(self):
pass
documents_indexed = document_store.get_all_documents(batch_size=batch_size)
assert len(documents_indexed) == len(DOCUMENTS)
@pytest.mark.skip(reason="Weaviate does not support labels")
@pytest.mark.integration
def test_delete_labels_by_id(self):
pass
@pytest.mark.skip(reason="Weaviate does not support labels")
@pytest.mark.integration
def test_delete_labels_by_filter(self):
pass
@pytest.mark.weaviate
@pytest.mark.parametrize("document_store_with_docs", ["weaviate"], indirect=True)
def test_query_by_embedding(document_store_with_docs):
docs = document_store_with_docs.query_by_embedding(np.random.rand(embedding_dim).astype(np.float32))
assert len(docs) == 3
@pytest.mark.skip(reason="Weaviate does not support labels")
@pytest.mark.integration
def test_delete_labels_by_filter_id(self):
pass
docs = document_store_with_docs.query_by_embedding(np.random.rand(embedding_dim).astype(np.float32), top_k=1)
assert len(docs) == 1
@pytest.mark.skip(reason="Weaviate does not support labels")
@pytest.mark.integration
def test_get_label_count(self):
pass
docs = document_store_with_docs.query_by_embedding(
np.random.rand(embedding_dim).astype(np.float32), filters={"name": ["filename2"]}
)
assert len(docs) == 1
@pytest.mark.skip(reason="Weaviate does not support labels")
@pytest.mark.integration
def test_write_labels_duplicate(self):
pass
@pytest.mark.skip(reason="Weaviate does not support labels")
@pytest.mark.integration
def test_write_get_all_labels(self):
pass
@pytest.mark.weaviate
@pytest.mark.parametrize("document_store_with_docs", ["weaviate"], indirect=True)
def test_query(document_store_with_docs):
query_text = "My name is Carla and I live in Berlin"
docs = document_store_with_docs.query(query_text)
assert len(docs) == 3
@pytest.mark.integration
def test_ne_filters(self, ds, documents):
"""
Weaviate doesn't include documents if the field is missing,
so we customize this test
"""
ds.write_documents(documents)
# BM25 retrieval WITH filters is not yet supported as of Weaviate v1.14.1
with pytest.raises(Exception):
docs = document_store_with_docs.query(query_text, filters={"name": ["filename2"]})
result = ds.get_all_documents(filters={"year": {"$ne": "2020"}})
assert len(result) == 3
docs = document_store_with_docs.query(filters={"name": ["filename2"]})
assert len(docs) == 1
@pytest.mark.integration
def test_nin_filters(self, ds, documents):
"""
Weaviate doesn't include documents if the field is missing,
so we customize this test
"""
ds.write_documents(documents)
docs = document_store_with_docs.query(filters={"content": [query_text.lower()]})
assert len(docs) == 1
result = ds.get_all_documents(filters={"year": {"$nin": ["2020", "2021", "n.a."]}})
assert len(result) == 0
docs = document_store_with_docs.query(filters={"content": ["live"]})
assert len(docs) == 3
@pytest.mark.integration
def test_delete_index(self, ds, documents):
"""Contrary to other Document Stores, this doesn't raise if the index is empty"""
ds.write_documents(documents, index="custom_index")
assert ds.get_document_count(index="custom_index") == len(documents)
ds.delete_index(index="custom_index")
assert ds.get_document_count(index="custom_index") == 0
@pytest.mark.integration
def test_query_by_embedding(self, ds, documents):
ds.write_documents(documents)
@pytest.mark.weaviate
def test_get_all_documents_unaffected_by_QUERY_MAXIMUM_RESULTS(document_store_with_docs, monkeypatch):
"""
Ensure `get_all_documents` works no matter the value of QUERY_MAXIMUM_RESULTS
see https://github.com/deepset-ai/haystack/issues/2517
"""
monkeypatch.setattr(document_store_with_docs, "get_document_count", lambda **kwargs: 13_000)
docs = document_store_with_docs.get_all_documents()
assert len(docs) == 3
docs = ds.query_by_embedding(np.random.rand(embedding_dim).astype(np.float32))
assert len(docs) == 9
docs = ds.query_by_embedding(np.random.rand(embedding_dim).astype(np.float32), top_k=1)
assert len(docs) == 1
@pytest.mark.weaviate
@pytest.mark.parametrize("document_store_with_docs", ["weaviate"], indirect=True)
def test_deleting_by_id_or_by_filters(document_store_with_docs):
# This test verifies that deleting an object by its ID does not first require fetching all documents. This fixes
# a bug, as described in https://github.com/deepset-ai/haystack/issues/2898
document_store_with_docs.get_all_documents = MagicMock(wraps=document_store_with_docs.get_all_documents)
docs = ds.query_by_embedding(np.random.rand(embedding_dim).astype(np.float32), filters={"name": ["name_1"]})
assert len(docs) == 3
assert document_store_with_docs.get_document_count() == 3
@pytest.mark.integration
def test_query(self, ds, documents):
ds.write_documents(documents)
# Delete a document by its ID. This should bypass the get_all_documents() call
document_store_with_docs.delete_documents(ids=[DOCUMENTS_XS[0]["id"]])
document_store_with_docs.get_all_documents.assert_not_called()
assert document_store_with_docs.get_document_count() == 2
query_text = "Foo"
docs = ds.query(query_text)
assert len(docs) == 3
document_store_with_docs.get_all_documents.reset_mock()
# Delete a document with filters. Prove that using the filters will go through get_all_documents()
document_store_with_docs.delete_documents(filters={"name": ["filename2"]})
document_store_with_docs.get_all_documents.assert_called()
assert document_store_with_docs.get_document_count() == 1
# BM25 retrieval WITH filters is not yet supported as of Weaviate v1.14.1
with pytest.raises(Exception):
docs = ds.query(query_text, filters={"name": ["filename2"]})
docs = ds.query(filters={"name": ["name_0"]})
assert len(docs) == 3
@pytest.mark.weaviate
@pytest.mark.parametrize("similarity", ["cosine", "l2", "dot_product"])
def test_similarity_existing_index(tmp_path, similarity):
"""Testing non-matching similarity"""
# create the document_store
document_store = get_document_store("weaviate", tmp_path, similarity=similarity, recreate_index=True)
docs = ds.query(filters={"content": [query_text.lower()]})
assert len(docs) == 3
# try to connect to the same document store but using the wrong similarity
non_matching_similarity = "l2" if similarity == "cosine" else "cosine"
with pytest.raises(ValueError, match=r"This index already exists in Weaviate with similarity .*"):
document_store2 = get_document_store(
"weaviate", tmp_path, similarity=non_matching_similarity, recreate_index=False
docs = ds.query(filters={"content": ["baz"]})
assert len(docs) == 3
@pytest.mark.integration
def test_get_all_documents_unaffected_by_QUERY_MAXIMUM_RESULTS(self, ds, documents, monkeypatch):
"""
Ensure `get_all_documents` works no matter the value of QUERY_MAXIMUM_RESULTS
see https://github.com/deepset-ai/haystack/issues/2517
"""
ds.write_documents(documents)
monkeypatch.setattr(ds, "get_document_count", lambda **kwargs: 13_000)
docs = ds.get_all_documents()
assert len(docs) == 9
@pytest.mark.integration
def test_deleting_by_id_or_by_filters(self, ds, documents):
ds.write_documents(documents)
# This test verifies that deleting an object by its ID does not first require fetching all documents. This fixes
# a bug, as described in https://github.com/deepset-ai/haystack/issues/2898
ds.get_all_documents = MagicMock(wraps=ds.get_all_documents)
assert ds.get_document_count() == 9
# Delete a document by its ID. This should bypass the get_all_documents() call
ds.delete_documents(ids=[documents[0].id])
ds.get_all_documents.assert_not_called()
assert ds.get_document_count() == 8
ds.get_all_documents.reset_mock()
# Delete a document with filters. Prove that using the filters will go through get_all_documents()
ds.delete_documents(filters={"name": ["name_0"]})
ds.get_all_documents.assert_called()
assert ds.get_document_count() == 6
@pytest.mark.integration
@pytest.mark.parametrize("similarity", ["cosine", "l2", "dot_product"])
def test_similarity_existing_index(self, similarity):
"""Testing non-matching similarity"""
# create the document_store
document_store = WeaviateDocumentStore(
similarity=similarity, index=f"test_similarity_existing_index_{similarity}", recreate_index=True
)
# try to connect to the same document store but using the wrong similarity
non_matching_similarity = "l2" if similarity == "cosine" else "cosine"
with pytest.raises(ValueError, match=r"This index already exists in Weaviate with similarity .*"):
document_store2 = WeaviateDocumentStore(
similarity=non_matching_similarity,
index=f"test_similarity_existing_index_{similarity}",
recreate_index=False,
)
@pytest.mark.weaviate
@pytest.mark.parametrize("document_store", ["weaviate"], indirect=True)
def test_cant_write_id_in_meta(document_store):
with pytest.raises(ValueError, match='"meta" info contains duplicate key "id"'):
document_store.write_documents([Document(content="test", meta={"id": "test-id"})])
@pytest.mark.integration
def test_cant_write_id_in_meta(self, ds):
with pytest.raises(ValueError, match='"meta" info contains duplicate key "id"'):
ds.write_documents([Document(content="test", meta={"id": "test-id"})])
@pytest.mark.weaviate
@pytest.mark.parametrize("document_store", ["weaviate"], indirect=True)
def test_cant_write_top_level_fields_in_meta(document_store):
with pytest.raises(ValueError, match='"meta" info contains duplicate key "content"'):
document_store.write_documents([Document(content="test", meta={"content": "test-id"})])
@pytest.mark.integration
def test_cant_write_top_level_fields_in_meta(self, ds):
with pytest.raises(ValueError, match='"meta" info contains duplicate key "content"'):
ds.write_documents([Document(content="test", meta={"content": "test-id"})])