feat: support milvus vector db (#188) #none

Signed-off-by: ChengZi <chen.zhang@zilliz.com>
This commit is contained in:
ChengZi 2024-09-04 21:22:50 +08:00 committed by GitHub
parent 76f2652d2a
commit 772186b6e5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 201 additions and 0 deletions

View File

@ -80,6 +80,7 @@ KH_DOCSTORE = {
KH_VECTORSTORE = {
# "__type__": "kotaemon.storages.LanceDBVectorStore",
"__type__": "kotaemon.storages.ChromaVectorStore",
# "__type__": "kotaemon.storages.MilvusVectorStore",
"path": str(KH_USER_DATA_DIR / "vectorstore"),
}
KH_LLMS = {}

View File

@ -10,6 +10,7 @@ from .vectorstores import (
ChromaVectorStore,
InMemoryVectorStore,
LanceDBVectorStore,
MilvusVectorStore,
SimpleFileVectorStore,
)
@ -26,4 +27,5 @@ __all__ = [
"InMemoryVectorStore",
"SimpleFileVectorStore",
"LanceDBVectorStore",
"MilvusVectorStore",
]

View File

@ -2,6 +2,7 @@ from .base import BaseVectorStore
from .chroma import ChromaVectorStore
from .in_memory import InMemoryVectorStore
from .lancedb import LanceDBVectorStore
from .milvus import MilvusVectorStore
from .simple_file import SimpleFileVectorStore
__all__ = [
@ -10,4 +11,5 @@ __all__ = [
"InMemoryVectorStore",
"SimpleFileVectorStore",
"LanceDBVectorStore",
"MilvusVectorStore",
]

View File

@ -0,0 +1,100 @@
import os
from typing import Any, Optional, Type, cast
from llama_index.vector_stores.milvus import MilvusVectorStore as LIMilvusVectorStore
from kotaemon.base import DocumentWithEmbedding
from .base import LlamaIndexVectorStore
class MilvusVectorStore(LlamaIndexVectorStore):
_li_class: Type[LIMilvusVectorStore] = LIMilvusVectorStore
def __init__(
self,
uri: str = "./milvus.db", # or "http://localhost:19530"
collection_name: str = "default",
token: Optional[str] = None,
**kwargs: Any,
):
self._uri = uri
self._collection_name = collection_name
self._token = token
self._kwargs = kwargs
self._path = kwargs.get("path", None)
self._inited = False
def _lazy_init(self, dim: Optional[int] = None):
"""
Lazy init the client.
Because the LlamaIndex init method requires the dim parameter,
we need to try to get the dim from the first embedding.
Args:
dim: Dimension of the vectors.
"""
if not self._inited:
if os.path.isdir(self._path) and not self._uri.startswith("http"):
uri = os.path.join(self._path, self._uri)
else:
uri = self._uri
super().__init__(
uri=uri,
token=self._token,
collection_name=self._collection_name,
dim=dim,
**self._kwargs,
)
self._client = cast(LIMilvusVectorStore, self._client)
self._inited = True
def add(
self,
embeddings: list[list[float]] | list[DocumentWithEmbedding],
metadatas: Optional[list[dict]] = None,
ids: Optional[list[str]] = None,
):
if not self._inited:
if isinstance(embeddings[0], list):
dim = len(embeddings[0])
else:
dim = len(embeddings[0].embedding)
self._lazy_init(dim)
return super().add(embeddings=embeddings, metadatas=metadatas, ids=ids)
def query(
self,
embedding: list[float],
top_k: int = 1,
ids: Optional[list[str]] = None,
**kwargs,
) -> tuple[list[list[float]], list[float], list[str]]:
self._lazy_init(len(embedding))
return super().query(embedding=embedding, top_k=top_k, ids=ids, **kwargs)
def delete(self, ids: list[str], **kwargs):
self._lazy_init()
super().delete(ids=ids, **kwargs)
def drop(self):
self._client.client.drop_collection(self._collection_name)
def count(self) -> int:
try:
self._lazy_init()
except: # noqa: E722
return 0
return self._client.client.query(
collection_name=self._collection_name, output_fields=["count(*)"]
)[0]["count(*)"]
def __persist_flow__(self):
return {
"uri": self._uri,
"collection_name": self._collection_name,
"token": self._token,
**self._kwargs,
}

View File

@ -34,6 +34,7 @@ dependencies = [
"llama-index>=0.10.40,<0.11.0",
"llama-index-vector-stores-chroma>=0.1.9",
"llama-index-vector-stores-lancedb",
"llama-index-vector-stores-milvus",
"openai>=1.23.6,<2",
"openpyxl>=3.1.2,<3.2",
"pandas>=2.2.2,<2.3",

View File

@ -5,6 +5,7 @@ from kotaemon.base import DocumentWithEmbedding
from kotaemon.storages import (
ChromaVectorStore,
InMemoryVectorStore,
MilvusVectorStore,
SimpleFileVectorStore,
)
@ -153,3 +154,97 @@ class TestSimpleFileVectorStore:
], "load function does not load data completely"
os.remove(tmp_path / collection_name)
class TestMilvusVectorStore:
def test_add(self, tmp_path):
"""Test that the DB add correctly"""
db = MilvusVectorStore(
path=str(tmp_path),
overwrite=True,
)
embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
metadatas = [{"a": 1, "b": 2}, {"a": 3, "b": 4}]
ids = ["1", "2"]
assert db.count() == 0, "Expected empty collection"
output = db.add(embeddings=embeddings, metadatas=metadatas, ids=ids)
assert output == ids, "Expected output to be the same as ids"
assert db.count() == 2, "Expected 2 added entries"
def test_add_from_docs(self, tmp_path):
db = MilvusVectorStore(
path=str(tmp_path),
overwrite=True,
)
embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
metadatas = [{"a": 1, "b": 2}, {"a": 3, "b": 4}]
documents = [
DocumentWithEmbedding(embedding=embedding, metadata=metadata)
for embedding, metadata in zip(embeddings, metadatas)
]
assert db.count() == 0, "Expected empty collection"
output = db.add(documents)
assert len(output) == 2, "Expected outputting 2 ids"
assert db.count() == 2, "Expected 2 added entries"
def test_delete(self, tmp_path):
db = MilvusVectorStore(
path=str(tmp_path),
overwrite=True,
)
embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]
metadatas = [{"a": 1, "b": 2}, {"a": 3, "b": 4}, {"a": 5, "b": 6}]
ids = ["a", "b", "c"]
db.add(embeddings=embeddings, metadatas=metadatas, ids=ids)
assert db.count() == 3, "Expected 3 added entries"
db.delete(ids=["a", "b"])
assert db.count() == 1, "Expected 1 remaining entry"
db.delete(ids=["c"])
assert db.count() == 0, "Expected 0 remaining entry"
def test_query(self, tmp_path):
db = MilvusVectorStore(path=str(tmp_path), overwrite=True)
import numpy as np
embeddings = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]])
norms = np.linalg.norm(embeddings, axis=1)
normalized_embeddings = (embeddings / norms[:, np.newaxis]).tolist()
metadatas = [{"a": 1, "b": 2}, {"a": 3, "b": 4}, {"a": 5, "b": 6}]
ids = ["a", "b", "c"]
db.add(embeddings=normalized_embeddings, metadatas=metadatas, ids=ids)
_, sim, out_ids = db.query(embedding=normalized_embeddings[0], top_k=1)
assert sim == [1.0]
assert out_ids == ["a"]
query_embedding = [
normalized_embeddings[1][0] + 0.02,
normalized_embeddings[1][1] + 0.02,
normalized_embeddings[1][2] + 0.02,
]
_, _, out_ids = db.query(embedding=query_embedding, top_k=1)
assert out_ids == ["b"]
def test_save_load_delete(self, tmp_path):
"""Test that save/load func behave correctly."""
embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]
metadatas = [{"a": 1, "b": 2}, {"a": 3, "b": 4}, {"a": 5, "b": 6}]
ids = ["1", "2", "3"]
db = MilvusVectorStore(path=str(tmp_path), overwrite=True)
db.add(embeddings=embeddings, metadatas=metadatas, ids=ids)
db2 = MilvusVectorStore(path=str(tmp_path), overrides=False)
assert db2.count() == 3, "load function does not load data completely"
# test delete collection function
db2.drop()
# reinit the milvus with the same collection name
db2 = MilvusVectorStore(path=str(tmp_path), overwrite=False)
assert db2.count() == 0, "delete collection function does not work correctly"