mirror of
https://github.com/microsoft/autogen.git
synced 2025-08-14 11:41:46 +00:00
403 lines
16 KiB
Python
403 lines
16 KiB
Python
![]() |
import logging
|
||
|
import os
|
||
|
import random
|
||
|
from time import monotonic, sleep
|
||
|
from typing import List
|
||
|
|
||
|
import pytest
|
||
|
|
||
|
from autogen.agentchat.contrib.vectordb.base import Document
|
||
|
|
||
|
try:
|
||
|
import pymongo
|
||
|
import sentence_transformers
|
||
|
|
||
|
from autogen.agentchat.contrib.vectordb.mongodb import MongoDBAtlasVectorDB
|
||
|
except ImportError:
|
||
|
# To display warning in pyproject.toml [tool.pytest.ini_options] set log_cli = true
|
||
|
logger = logging.getLogger(__name__)
|
||
|
logger.warning(f"skipping {__name__}. It requires one to pip install pymongo or the extra [retrievechat-mongodb]")
|
||
|
pytest.skip(allow_module_level=True)
|
||
|
|
||
|
from pymongo import MongoClient
|
||
|
from pymongo.collection import Collection
|
||
|
from pymongo.errors import OperationFailure
|
||
|
|
||
|
logger = logging.getLogger(__name__)
|
||
|
|
||
|
MONGODB_URI = os.environ.get("MONGODB_URI", "mongodb://localhost:27017/?directConnection=true")
|
||
|
MONGODB_DATABASE = os.environ.get("DATABASE", "autogen_test_db")
|
||
|
MONGODB_COLLECTION = os.environ.get("MONGODB_COLLECTION", "autogen_test_vectorstore")
|
||
|
MONGODB_INDEX = os.environ.get("MONGODB_INDEX", "vector_index")
|
||
|
|
||
|
RETRIES = 10
|
||
|
DELAY = 2
|
||
|
TIMEOUT = 120.0
|
||
|
|
||
|
|
||
|
def _wait_for_predicate(predicate, err, timeout=TIMEOUT, interval=DELAY):
|
||
|
"""Generic to block until the predicate returns true
|
||
|
|
||
|
Args:
|
||
|
predicate (Callable[, bool]): A function that returns a boolean value
|
||
|
err (str): Error message to raise if nothing occurs
|
||
|
timeout (float, optional): Length of time to wait for predicate. Defaults to TIMEOUT.
|
||
|
interval (float, optional): Interval to check predicate. Defaults to DELAY.
|
||
|
|
||
|
Raises:
|
||
|
TimeoutError: _description_
|
||
|
"""
|
||
|
start = monotonic()
|
||
|
while not predicate():
|
||
|
if monotonic() - start > TIMEOUT:
|
||
|
raise TimeoutError(err)
|
||
|
sleep(DELAY)
|
||
|
|
||
|
|
||
|
def _delete_search_indexes(collection: Collection, wait=True):
|
||
|
"""Deletes all indexes in a collection
|
||
|
|
||
|
Args:
|
||
|
collection (pymongo.Collection): MongoDB Collection Abstraction
|
||
|
"""
|
||
|
for index in collection.list_search_indexes():
|
||
|
try:
|
||
|
collection.drop_search_index(index["name"])
|
||
|
except OperationFailure:
|
||
|
# Delete already issued
|
||
|
pass
|
||
|
if wait:
|
||
|
_wait_for_predicate(lambda: not list(collection.list_search_indexes()), "Not all collections deleted")
|
||
|
|
||
|
|
||
|
def _empty_collections_and_delete_indexes(database, collections=None, wait=True):
|
||
|
"""Empty all collections within the database and remove indexes
|
||
|
|
||
|
Args:
|
||
|
database (pymongo.Database): MongoDB Database Abstraction
|
||
|
"""
|
||
|
for collection_name in collections or database.list_collection_names():
|
||
|
_delete_search_indexes(database[collection_name], wait)
|
||
|
database[collection_name].drop()
|
||
|
|
||
|
|
||
|
@pytest.fixture
|
||
|
def db():
|
||
|
"""VectorDB setup and teardown, including collections and search indexes"""
|
||
|
database = MongoClient(MONGODB_URI)[MONGODB_DATABASE]
|
||
|
_empty_collections_and_delete_indexes(database)
|
||
|
vectorstore = MongoDBAtlasVectorDB(
|
||
|
connection_string=MONGODB_URI,
|
||
|
database_name=MONGODB_DATABASE,
|
||
|
wait_until_index_ready=TIMEOUT,
|
||
|
overwrite=True,
|
||
|
)
|
||
|
yield vectorstore
|
||
|
_empty_collections_and_delete_indexes(database)
|
||
|
|
||
|
|
||
|
@pytest.fixture
|
||
|
def example_documents() -> List[Document]:
|
||
|
"""Note mix of integers and strings as ids"""
|
||
|
return [
|
||
|
Document(id=1, content="Dogs are tough.", metadata={"a": 1}),
|
||
|
Document(id=2, content="Cats have fluff.", metadata={"b": 1}),
|
||
|
Document(id="1", content="What is a sandwich?", metadata={"c": 1}),
|
||
|
Document(id="2", content="A sandwich makes a great lunch.", metadata={"d": 1, "e": 2}),
|
||
|
]
|
||
|
|
||
|
|
||
|
@pytest.fixture
|
||
|
def db_with_indexed_clxn(collection_name):
|
||
|
"""VectorDB with a collection created immediately"""
|
||
|
database = MongoClient(MONGODB_URI)[MONGODB_DATABASE]
|
||
|
_empty_collections_and_delete_indexes(database, [collection_name], wait=True)
|
||
|
vectorstore = MongoDBAtlasVectorDB(
|
||
|
connection_string=MONGODB_URI,
|
||
|
database_name=MONGODB_DATABASE,
|
||
|
wait_until_index_ready=TIMEOUT,
|
||
|
collection_name=collection_name,
|
||
|
overwrite=True,
|
||
|
)
|
||
|
yield vectorstore, vectorstore.db[collection_name]
|
||
|
_empty_collections_and_delete_indexes(database, [collection_name])
|
||
|
|
||
|
|
||
|
_COLLECTION_NAMING_CACHE = []
|
||
|
|
||
|
|
||
|
@pytest.fixture
|
||
|
def collection_name():
|
||
|
collection_id = random.randint(0, 100)
|
||
|
while collection_id in _COLLECTION_NAMING_CACHE:
|
||
|
collection_id = random.randint(0, 100)
|
||
|
_COLLECTION_NAMING_CACHE.append(collection_id)
|
||
|
|
||
|
return f"{MONGODB_COLLECTION}_{collection_id}"
|
||
|
|
||
|
|
||
|
def test_create_collection(db, collection_name):
|
||
|
"""
|
||
|
def create_collection(collection_name: str,
|
||
|
overwrite: bool = False) -> Collection
|
||
|
Create a collection in the vector database.
|
||
|
- Case 1. if the collection does not exist, create the collection.
|
||
|
- Case 2. the collection exists, if overwrite is True, it will overwrite the collection.
|
||
|
- Case 3. the collection exists and overwrite is False return the existing collection.
|
||
|
- Case 4. the collection exists and overwrite is False and get_or_create is False, raise a ValueError
|
||
|
"""
|
||
|
collection_case_1 = db.create_collection(
|
||
|
collection_name=collection_name,
|
||
|
)
|
||
|
assert collection_case_1.name == collection_name
|
||
|
|
||
|
collection_case_2 = db.create_collection(
|
||
|
collection_name=collection_name,
|
||
|
overwrite=True,
|
||
|
)
|
||
|
assert collection_case_2.name == collection_name
|
||
|
|
||
|
collection_case_3 = db.create_collection(
|
||
|
collection_name=collection_name,
|
||
|
)
|
||
|
assert collection_case_3.name == collection_name
|
||
|
|
||
|
with pytest.raises(ValueError):
|
||
|
db.create_collection(collection_name=collection_name, overwrite=False, get_or_create=False)
|
||
|
|
||
|
|
||
|
def test_get_collection(db, collection_name):
|
||
|
with pytest.raises(ValueError):
|
||
|
db.get_collection()
|
||
|
|
||
|
collection_created = db.create_collection(collection_name)
|
||
|
assert isinstance(collection_created, Collection)
|
||
|
assert collection_created.name == collection_name
|
||
|
|
||
|
collection_got = db.get_collection(collection_name)
|
||
|
assert collection_got.name == collection_created.name
|
||
|
assert collection_got.name == db.active_collection.name
|
||
|
|
||
|
|
||
|
def test_delete_collection(db, collection_name):
|
||
|
assert collection_name not in db.list_collections()
|
||
|
collection = db.create_collection(collection_name)
|
||
|
assert collection_name in db.list_collections()
|
||
|
db.delete_collection(collection.name)
|
||
|
assert collection_name not in db.list_collections()
|
||
|
|
||
|
|
||
|
def test_insert_docs(db, collection_name, example_documents):
|
||
|
# Test that there's an active collection
|
||
|
with pytest.raises(ValueError) as exc:
|
||
|
db.insert_docs(example_documents)
|
||
|
assert "No collection is specified" in str(exc.value)
|
||
|
|
||
|
# Test upsert
|
||
|
db.insert_docs(example_documents, collection_name, upsert=True)
|
||
|
|
||
|
# Create a collection
|
||
|
db.delete_collection(collection_name)
|
||
|
collection = db.create_collection(collection_name)
|
||
|
|
||
|
# Insert example documents
|
||
|
db.insert_docs(example_documents, collection_name=collection_name)
|
||
|
found = list(collection.find({}))
|
||
|
assert len(found) == len(example_documents)
|
||
|
# Check that documents have correct fields, including "_id" and "embedding" but not "id"
|
||
|
assert all([set(doc.keys()) == {"_id", "content", "metadata", "embedding"} for doc in found])
|
||
|
# Check ids
|
||
|
assert {doc["_id"] for doc in found} == {1, "1", 2, "2"}
|
||
|
# Check embedding lengths
|
||
|
assert len(found[0]["embedding"]) == 384
|
||
|
|
||
|
|
||
|
def test_update_docs(db_with_indexed_clxn, example_documents):
|
||
|
db, collection = db_with_indexed_clxn
|
||
|
# Use update_docs to insert new documents
|
||
|
db.update_docs(example_documents, collection.name, upsert=True)
|
||
|
# Test that no changes were made to example_documents
|
||
|
assert set(example_documents[0].keys()) == {"id", "content", "metadata"}
|
||
|
assert collection.count_documents({}) == len(example_documents)
|
||
|
found = list(collection.find({}))
|
||
|
# Check that documents have correct fields, including "_id" and "embedding" but not "id"
|
||
|
assert all([set(doc.keys()) == {"_id", "content", "metadata", "embedding"} for doc in found])
|
||
|
assert all([isinstance(doc["embedding"][0], float) for doc in found])
|
||
|
assert all([len(doc["embedding"]) == db.dimensions for doc in found])
|
||
|
# Check ids
|
||
|
assert {doc["_id"] for doc in found} == {1, "1", 2, "2"}
|
||
|
|
||
|
# Update an *existing* Document
|
||
|
updated_doc = Document(id=1, content="Cats are tough.", metadata={"a": 10})
|
||
|
db.update_docs([updated_doc], collection.name)
|
||
|
assert collection.find_one({"_id": 1})["content"] == "Cats are tough."
|
||
|
|
||
|
# Upsert a *new* Document
|
||
|
new_id = 3
|
||
|
new_doc = Document(id=new_id, content="Cats are tough.")
|
||
|
db.update_docs([new_doc], collection.name, upsert=True)
|
||
|
assert collection.find_one({"_id": new_id})["content"] == "Cats are tough."
|
||
|
|
||
|
# Attempting to use update to insert a new doc
|
||
|
# *without* setting upsert set to True
|
||
|
# is a no-op in MongoDB. # TODO Confirm behaviour and autogen's preference.
|
||
|
new_id = 4
|
||
|
new_doc = Document(id=new_id, content="That is NOT a sandwich?")
|
||
|
db.update_docs([new_doc], collection.name)
|
||
|
assert collection.find_one({"_id": new_id}) is None
|
||
|
|
||
|
|
||
|
def test_delete_docs(db_with_indexed_clxn, example_documents):
|
||
|
db, clxn = db_with_indexed_clxn
|
||
|
# Insert example documents
|
||
|
db.insert_docs(example_documents, collection_name=clxn.name)
|
||
|
# Delete the 1s
|
||
|
db.delete_docs(ids=[1, "1"], collection_name=clxn.name)
|
||
|
# Confirm just the 2s remain
|
||
|
assert {2, "2"} == {doc["_id"] for doc in clxn.find({})}
|
||
|
|
||
|
|
||
|
def test_get_docs_by_ids(db_with_indexed_clxn, example_documents):
|
||
|
db, clxn = db_with_indexed_clxn
|
||
|
# Insert example documents
|
||
|
db.insert_docs(example_documents, collection_name=clxn.name)
|
||
|
|
||
|
# Test without setting "include" kwarg
|
||
|
docs = db.get_docs_by_ids(ids=[2, "2"], collection_name=clxn.name)
|
||
|
assert len(docs) == 2
|
||
|
assert all([doc["id"] in [2, "2"] for doc in docs])
|
||
|
assert set(docs[0].keys()) == {"id", "content", "metadata"}
|
||
|
|
||
|
# Test with include
|
||
|
docs = db.get_docs_by_ids(ids=[2], include=["content"], collection_name=clxn.name)
|
||
|
assert len(docs) == 1
|
||
|
assert set(docs[0].keys()) == {"id", "content"}
|
||
|
|
||
|
# Test with empty ids list
|
||
|
docs = db.get_docs_by_ids(ids=[], include=["content"], collection_name=clxn.name)
|
||
|
assert len(docs) == 0
|
||
|
|
||
|
# Test with empty ids list
|
||
|
docs = db.get_docs_by_ids(ids=None, include=["content"], collection_name=clxn.name)
|
||
|
assert len(docs) == 4
|
||
|
|
||
|
|
||
|
def test_retrieve_docs_empty(db_with_indexed_clxn):
|
||
|
db, clxn = db_with_indexed_clxn
|
||
|
assert db.retrieve_docs(queries=["Cats"], collection_name=clxn.name, n_results=2) == []
|
||
|
|
||
|
|
||
|
def test_retrieve_docs_populated_db_empty_query(db_with_indexed_clxn, example_documents):
|
||
|
db, clxn = db_with_indexed_clxn
|
||
|
db.insert_docs(example_documents, collection_name=clxn.name)
|
||
|
# Empty list of queries returns empty list of results
|
||
|
results = db.retrieve_docs(queries=[], collection_name=clxn.name, n_results=2)
|
||
|
assert results == []
|
||
|
|
||
|
|
||
|
def test_retrieve_docs(db_with_indexed_clxn, example_documents):
|
||
|
"""Begin testing Atlas Vector Search
|
||
|
NOTE: Indexing may take some time, so we must be patient on the first query.
|
||
|
We have the wait_until_index_ready flag to ensure index is created and ready
|
||
|
Immediately adding documents and then querying is only standard for testing
|
||
|
"""
|
||
|
db, clxn = db_with_indexed_clxn
|
||
|
# Insert example documents
|
||
|
db.insert_docs(example_documents, collection_name=clxn.name)
|
||
|
|
||
|
n_results = 2 # Number of closest docs to return
|
||
|
|
||
|
def results_ready():
|
||
|
results = db.retrieve_docs(queries=["Cats"], collection_name=clxn.name, n_results=n_results)
|
||
|
return len(results[0]) == n_results
|
||
|
|
||
|
_wait_for_predicate(results_ready, f"Failed to retrieve docs after waiting {TIMEOUT} seconds after each.")
|
||
|
|
||
|
results = db.retrieve_docs(queries=["Cats"], collection_name=clxn.name, n_results=n_results)
|
||
|
assert {doc[0]["id"] for doc in results[0]} == {1, 2}
|
||
|
assert all(["embedding" not in doc[0] for doc in results[0]])
|
||
|
|
||
|
|
||
|
def test_retrieve_docs_with_embedding(db_with_indexed_clxn, example_documents):
|
||
|
"""Begin testing Atlas Vector Search
|
||
|
NOTE: Indexing may take some time, so we must be patient on the first query.
|
||
|
We have the wait_until_index_ready flag to ensure index is created and ready
|
||
|
Immediately adding documents and then querying is only standard for testing
|
||
|
"""
|
||
|
db, clxn = db_with_indexed_clxn
|
||
|
# Insert example documents
|
||
|
db.insert_docs(example_documents, collection_name=clxn.name)
|
||
|
|
||
|
n_results = 2 # Number of closest docs to return
|
||
|
|
||
|
def results_ready():
|
||
|
results = db.retrieve_docs(queries=["Cats"], collection_name=clxn.name, n_results=n_results)
|
||
|
return len(results[0]) == n_results
|
||
|
|
||
|
_wait_for_predicate(results_ready, f"Failed to retrieve docs after waiting {TIMEOUT} seconds after each.")
|
||
|
|
||
|
results = db.retrieve_docs(queries=["Cats"], collection_name=clxn.name, n_results=n_results, include_embedding=True)
|
||
|
assert {doc[0]["id"] for doc in results[0]} == {1, 2}
|
||
|
assert all(["embedding" in doc[0] for doc in results[0]])
|
||
|
|
||
|
|
||
|
def test_retrieve_docs_multiple_queries(db_with_indexed_clxn, example_documents):
|
||
|
db, clxn = db_with_indexed_clxn
|
||
|
# Insert example documents
|
||
|
db.insert_docs(example_documents, collection_name=clxn.name)
|
||
|
n_results = 2 # Number of closest docs to return
|
||
|
|
||
|
queries = ["Some good pets", "What kind of Sandwich?"]
|
||
|
|
||
|
def results_ready():
|
||
|
results = db.retrieve_docs(queries=queries, collection_name=clxn.name, n_results=n_results)
|
||
|
return all([len(res) == n_results for res in results])
|
||
|
|
||
|
_wait_for_predicate(results_ready, f"Failed to retrieve docs after waiting {TIMEOUT} seconds after each.")
|
||
|
|
||
|
results = db.retrieve_docs(queries=queries, collection_name=clxn.name, n_results=2)
|
||
|
|
||
|
assert len(results) == len(queries)
|
||
|
assert all([len(res) == n_results for res in results])
|
||
|
assert {doc[0]["id"] for doc in results[0]} == {1, 2}
|
||
|
assert {doc[0]["id"] for doc in results[1]} == {"1", "2"}
|
||
|
|
||
|
|
||
|
def test_retrieve_docs_with_threshold(db_with_indexed_clxn, example_documents):
|
||
|
db, clxn = db_with_indexed_clxn
|
||
|
# Insert example documents
|
||
|
db.insert_docs(example_documents, collection_name=clxn.name)
|
||
|
|
||
|
n_results = 2 # Number of closest docs to return
|
||
|
queries = ["Cats"]
|
||
|
|
||
|
def results_ready():
|
||
|
results = db.retrieve_docs(queries=queries, collection_name=clxn.name, n_results=n_results)
|
||
|
return len(results[0]) == n_results
|
||
|
|
||
|
_wait_for_predicate(results_ready, f"Failed to retrieve docs after waiting {TIMEOUT} seconds after each.")
|
||
|
|
||
|
# Distance Threshold of .3 means that the score must be .7 or greater
|
||
|
# only one result should be that value
|
||
|
results = db.retrieve_docs(queries=queries, collection_name=clxn.name, n_results=n_results, distance_threshold=0.3)
|
||
|
assert len(results[0]) == 1
|
||
|
assert all([doc[1] >= 0.7 for doc in results[0]])
|
||
|
|
||
|
|
||
|
def test_wait_until_document_ready(collection_name, example_documents):
|
||
|
database = MongoClient(MONGODB_URI)[MONGODB_DATABASE]
|
||
|
_empty_collections_and_delete_indexes(database, [collection_name], wait=True)
|
||
|
try:
|
||
|
vectorstore = MongoDBAtlasVectorDB(
|
||
|
connection_string=MONGODB_URI,
|
||
|
database_name=MONGODB_DATABASE,
|
||
|
wait_until_index_ready=TIMEOUT,
|
||
|
collection_name=collection_name,
|
||
|
overwrite=True,
|
||
|
wait_until_document_ready=TIMEOUT,
|
||
|
)
|
||
|
vectorstore.insert_docs(example_documents)
|
||
|
assert vectorstore.retrieve_docs(queries=["Cats"], n_results=4)
|
||
|
finally:
|
||
|
_empty_collections_and_delete_indexes(database, [collection_name])
|