autogen/test/agentchat/contrib/vectordb/test_pgvectordb.py
Audel Rouhi 6604ca511b
PGVector Support for Custom Connection Object (#2566)
* Added fixes and tests for basic auth format

* User can provide their own connection object. Added test for it.

* Updated instructions on how to use. Fully tested all 3 authentication methods successfully.

* Get password from gitlab secrets.

* Hide passwords.

* Update notebook/agentchat_pgvector_RetrieveChat.ipynb

Co-authored-by: Li Jiang <bnujli@gmail.com>

* Hide passwords.

* Added connection_string test. 3 tests total for auth.

* Fixed quotes on db config params. No other changes found.

* Ran notebook

* Ran pre-commits and updated setup to include psycopg[binary] for windows and mac.

* Corrected list extension.

* Separate connection establishment function. Testing pending.

* Fixed pgvectordb auth

* Update agentchat_pgvector_RetrieveChat.ipynb

Added autocommit=True in example

* Rerun notebook

---------

Co-authored-by: Li Jiang <bnujli@gmail.com>
Co-authored-by: Li Jiang <lijiang1@microsoft.com>
2024-05-24 17:58:56 +00:00

135 lines
5.1 KiB
Python

import os
import sys
import urllib.parse
import pytest
from conftest import reason
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
try:
import pgvector
import psycopg
import sentence_transformers
from autogen.agentchat.contrib.vectordb.pgvectordb import PGVectorDB
except ImportError:
skip = True
else:
skip = False
reason = "do not run on MacOS or windows OR dependency is not installed OR " + reason
@pytest.mark.skipif(
sys.platform in ["darwin", "win32"] or skip,
reason=reason,
)
def test_pgvector():
# test db config
db_config = {
"connection_string": "postgresql://postgres:postgres@localhost:5432/postgres",
}
# test create collection with connection_string authentication
db = PGVectorDB(
connection_string=db_config["connection_string"],
)
collection_name = "test_collection"
collection = db.create_collection(collection_name=collection_name, overwrite=True, get_or_create=True)
assert collection.name == collection_name
# test create collection with conn object authentication
parsed_connection = urllib.parse.urlparse(db_config["connection_string"])
encoded_username = urllib.parse.quote(parsed_connection.username, safe="")
encoded_password = urllib.parse.quote(parsed_connection.password, safe="")
encoded_host = urllib.parse.quote(parsed_connection.hostname, safe="")
encoded_database = urllib.parse.quote(parsed_connection.path[1:], safe="")
connection_string_encoded = (
f"{parsed_connection.scheme}://{encoded_username}:{encoded_password}"
f"@{encoded_host}:{parsed_connection.port}/{encoded_database}"
)
conn = psycopg.connect(conninfo=connection_string_encoded, autocommit=True)
db = PGVectorDB(conn=conn)
collection_name = "test_collection"
collection = db.create_collection(collection_name=collection_name, overwrite=True, get_or_create=True)
assert collection.name == collection_name
# test create collection with basic authentication
db_config = {
"username": "postgres",
"password": os.environ.get("POSTGRES_PASSWORD", default="postgres"),
"host": "localhost",
"port": 5432,
"dbname": "postgres",
}
db = PGVectorDB(
username=db_config["username"],
password=db_config["password"],
port=db_config["port"],
host=db_config["host"],
dbname=db_config["dbname"],
)
collection_name = "test_collection"
collection = db.create_collection(collection_name=collection_name, overwrite=True, get_or_create=True)
assert collection.name == collection_name
# test_delete_collection
db.delete_collection(collection_name)
assert collection.table_exists(table_name=collection_name) is False
# test more create collection
collection = db.create_collection(collection_name, overwrite=False, get_or_create=False)
assert collection.name == collection_name
pytest.raises(ValueError, db.create_collection, collection_name, overwrite=False, get_or_create=False)
collection = db.create_collection(collection_name, overwrite=True, get_or_create=False)
assert collection.name == collection_name
collection = db.create_collection(collection_name, overwrite=False, get_or_create=True)
assert collection.name == collection_name
# test_get_collection
collection = db.get_collection(collection_name)
assert collection.name == collection_name
# test_insert_docs
docs = [{"content": "doc1", "id": "1"}, {"content": "doc2", "id": "2"}, {"content": "doc3", "id": "3"}]
db.insert_docs(docs, collection_name, upsert=False)
res = db.get_collection(collection_name).get(ids=["1", "2"])
final_results = [result.get("content") for result in res]
assert final_results == ["doc1", "doc2"]
# test_update_docs
docs = [{"content": "doc11", "id": "1"}, {"content": "doc2", "id": "2"}, {"content": "doc3", "id": "3"}]
db.update_docs(docs, collection_name)
res = db.get_collection(collection_name).get(["1", "2"])
final_results = [result.get("content") for result in res]
assert final_results == ["doc11", "doc2"]
# test_delete_docs
ids = ["1"]
collection_name = "test_collection"
db.delete_docs(ids, collection_name)
res = db.get_collection(collection_name).get(ids)
final_results = [result.get("content") for result in res]
assert final_results == []
# test_retrieve_docs
queries = ["doc2", "doc3"]
collection_name = "test_collection"
res = db.retrieve_docs(queries, collection_name)
assert [[r[0]["id"] for r in rr] for rr in res] == [["2", "3"], ["3", "2"]]
res = db.retrieve_docs(queries, collection_name, distance_threshold=0.1)
assert [[r[0]["id"] for r in rr] for rr in res] == [["2"], ["3"]]
# test_get_docs_by_ids
res = db.get_docs_by_ids(["1", "2"], collection_name)
assert [r["id"] for r in res] == ["2"] # "1" has been deleted
res = db.get_docs_by_ids(collection_name=collection_name)
assert set([r["id"] for r in res]) == set(["2", "3"]) # All Docs returned
if __name__ == "__main__":
test_pgvector()