mirror of
https://github.com/Cinnamon/kotaemon.git
synced 2025-06-26 23:19:56 +00:00
feat: support milvus vector db (#188) #none
Signed-off-by: ChengZi <chen.zhang@zilliz.com>
This commit is contained in:
parent
76f2652d2a
commit
772186b6e5
@ -80,6 +80,7 @@ KH_DOCSTORE = {
|
||||
KH_VECTORSTORE = {
|
||||
# "__type__": "kotaemon.storages.LanceDBVectorStore",
|
||||
"__type__": "kotaemon.storages.ChromaVectorStore",
|
||||
# "__type__": "kotaemon.storages.MilvusVectorStore",
|
||||
"path": str(KH_USER_DATA_DIR / "vectorstore"),
|
||||
}
|
||||
KH_LLMS = {}
|
||||
|
@ -10,6 +10,7 @@ from .vectorstores import (
|
||||
ChromaVectorStore,
|
||||
InMemoryVectorStore,
|
||||
LanceDBVectorStore,
|
||||
MilvusVectorStore,
|
||||
SimpleFileVectorStore,
|
||||
)
|
||||
|
||||
@ -26,4 +27,5 @@ __all__ = [
|
||||
"InMemoryVectorStore",
|
||||
"SimpleFileVectorStore",
|
||||
"LanceDBVectorStore",
|
||||
"MilvusVectorStore",
|
||||
]
|
||||
|
@ -2,6 +2,7 @@ from .base import BaseVectorStore
|
||||
from .chroma import ChromaVectorStore
|
||||
from .in_memory import InMemoryVectorStore
|
||||
from .lancedb import LanceDBVectorStore
|
||||
from .milvus import MilvusVectorStore
|
||||
from .simple_file import SimpleFileVectorStore
|
||||
|
||||
__all__ = [
|
||||
@ -10,4 +11,5 @@ __all__ = [
|
||||
"InMemoryVectorStore",
|
||||
"SimpleFileVectorStore",
|
||||
"LanceDBVectorStore",
|
||||
"MilvusVectorStore",
|
||||
]
|
||||
|
100
libs/kotaemon/kotaemon/storages/vectorstores/milvus.py
Normal file
100
libs/kotaemon/kotaemon/storages/vectorstores/milvus.py
Normal file
@ -0,0 +1,100 @@
|
||||
import os
|
||||
from typing import Any, Optional, Type, cast
|
||||
|
||||
from llama_index.vector_stores.milvus import MilvusVectorStore as LIMilvusVectorStore
|
||||
|
||||
from kotaemon.base import DocumentWithEmbedding
|
||||
|
||||
from .base import LlamaIndexVectorStore
|
||||
|
||||
|
||||
class MilvusVectorStore(LlamaIndexVectorStore):
|
||||
_li_class: Type[LIMilvusVectorStore] = LIMilvusVectorStore
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
uri: str = "./milvus.db", # or "http://localhost:19530"
|
||||
collection_name: str = "default",
|
||||
token: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
self._uri = uri
|
||||
self._collection_name = collection_name
|
||||
self._token = token
|
||||
self._kwargs = kwargs
|
||||
self._path = kwargs.get("path", None)
|
||||
self._inited = False
|
||||
|
||||
def _lazy_init(self, dim: Optional[int] = None):
|
||||
"""
|
||||
Lazy init the client.
|
||||
Because the LlamaIndex init method requires the dim parameter,
|
||||
we need to try to get the dim from the first embedding.
|
||||
|
||||
Args:
|
||||
dim: Dimension of the vectors.
|
||||
"""
|
||||
if not self._inited:
|
||||
if os.path.isdir(self._path) and not self._uri.startswith("http"):
|
||||
uri = os.path.join(self._path, self._uri)
|
||||
else:
|
||||
uri = self._uri
|
||||
super().__init__(
|
||||
uri=uri,
|
||||
token=self._token,
|
||||
collection_name=self._collection_name,
|
||||
dim=dim,
|
||||
**self._kwargs,
|
||||
)
|
||||
self._client = cast(LIMilvusVectorStore, self._client)
|
||||
self._inited = True
|
||||
|
||||
def add(
|
||||
self,
|
||||
embeddings: list[list[float]] | list[DocumentWithEmbedding],
|
||||
metadatas: Optional[list[dict]] = None,
|
||||
ids: Optional[list[str]] = None,
|
||||
):
|
||||
if not self._inited:
|
||||
if isinstance(embeddings[0], list):
|
||||
dim = len(embeddings[0])
|
||||
else:
|
||||
dim = len(embeddings[0].embedding)
|
||||
self._lazy_init(dim)
|
||||
|
||||
return super().add(embeddings=embeddings, metadatas=metadatas, ids=ids)
|
||||
|
||||
def query(
|
||||
self,
|
||||
embedding: list[float],
|
||||
top_k: int = 1,
|
||||
ids: Optional[list[str]] = None,
|
||||
**kwargs,
|
||||
) -> tuple[list[list[float]], list[float], list[str]]:
|
||||
self._lazy_init(len(embedding))
|
||||
|
||||
return super().query(embedding=embedding, top_k=top_k, ids=ids, **kwargs)
|
||||
|
||||
def delete(self, ids: list[str], **kwargs):
|
||||
self._lazy_init()
|
||||
super().delete(ids=ids, **kwargs)
|
||||
|
||||
def drop(self):
|
||||
self._client.client.drop_collection(self._collection_name)
|
||||
|
||||
def count(self) -> int:
|
||||
try:
|
||||
self._lazy_init()
|
||||
except: # noqa: E722
|
||||
return 0
|
||||
return self._client.client.query(
|
||||
collection_name=self._collection_name, output_fields=["count(*)"]
|
||||
)[0]["count(*)"]
|
||||
|
||||
def __persist_flow__(self):
|
||||
return {
|
||||
"uri": self._uri,
|
||||
"collection_name": self._collection_name,
|
||||
"token": self._token,
|
||||
**self._kwargs,
|
||||
}
|
@ -34,6 +34,7 @@ dependencies = [
|
||||
"llama-index>=0.10.40,<0.11.0",
|
||||
"llama-index-vector-stores-chroma>=0.1.9",
|
||||
"llama-index-vector-stores-lancedb",
|
||||
"llama-index-vector-stores-milvus",
|
||||
"openai>=1.23.6,<2",
|
||||
"openpyxl>=3.1.2,<3.2",
|
||||
"pandas>=2.2.2,<2.3",
|
||||
|
@ -5,6 +5,7 @@ from kotaemon.base import DocumentWithEmbedding
|
||||
from kotaemon.storages import (
|
||||
ChromaVectorStore,
|
||||
InMemoryVectorStore,
|
||||
MilvusVectorStore,
|
||||
SimpleFileVectorStore,
|
||||
)
|
||||
|
||||
@ -153,3 +154,97 @@ class TestSimpleFileVectorStore:
|
||||
], "load function does not load data completely"
|
||||
|
||||
os.remove(tmp_path / collection_name)
|
||||
|
||||
|
||||
class TestMilvusVectorStore:
|
||||
def test_add(self, tmp_path):
|
||||
"""Test that the DB add correctly"""
|
||||
db = MilvusVectorStore(
|
||||
path=str(tmp_path),
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||
metadatas = [{"a": 1, "b": 2}, {"a": 3, "b": 4}]
|
||||
ids = ["1", "2"]
|
||||
|
||||
assert db.count() == 0, "Expected empty collection"
|
||||
output = db.add(embeddings=embeddings, metadatas=metadatas, ids=ids)
|
||||
assert output == ids, "Expected output to be the same as ids"
|
||||
assert db.count() == 2, "Expected 2 added entries"
|
||||
|
||||
def test_add_from_docs(self, tmp_path):
|
||||
db = MilvusVectorStore(
|
||||
path=str(tmp_path),
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||
metadatas = [{"a": 1, "b": 2}, {"a": 3, "b": 4}]
|
||||
documents = [
|
||||
DocumentWithEmbedding(embedding=embedding, metadata=metadata)
|
||||
for embedding, metadata in zip(embeddings, metadatas)
|
||||
]
|
||||
assert db.count() == 0, "Expected empty collection"
|
||||
output = db.add(documents)
|
||||
assert len(output) == 2, "Expected outputting 2 ids"
|
||||
assert db.count() == 2, "Expected 2 added entries"
|
||||
|
||||
def test_delete(self, tmp_path):
|
||||
db = MilvusVectorStore(
|
||||
path=str(tmp_path),
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]
|
||||
metadatas = [{"a": 1, "b": 2}, {"a": 3, "b": 4}, {"a": 5, "b": 6}]
|
||||
ids = ["a", "b", "c"]
|
||||
|
||||
db.add(embeddings=embeddings, metadatas=metadatas, ids=ids)
|
||||
assert db.count() == 3, "Expected 3 added entries"
|
||||
db.delete(ids=["a", "b"])
|
||||
assert db.count() == 1, "Expected 1 remaining entry"
|
||||
db.delete(ids=["c"])
|
||||
assert db.count() == 0, "Expected 0 remaining entry"
|
||||
|
||||
def test_query(self, tmp_path):
|
||||
db = MilvusVectorStore(path=str(tmp_path), overwrite=True)
|
||||
import numpy as np
|
||||
|
||||
embeddings = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]])
|
||||
norms = np.linalg.norm(embeddings, axis=1)
|
||||
normalized_embeddings = (embeddings / norms[:, np.newaxis]).tolist()
|
||||
|
||||
metadatas = [{"a": 1, "b": 2}, {"a": 3, "b": 4}, {"a": 5, "b": 6}]
|
||||
ids = ["a", "b", "c"]
|
||||
|
||||
db.add(embeddings=normalized_embeddings, metadatas=metadatas, ids=ids)
|
||||
|
||||
_, sim, out_ids = db.query(embedding=normalized_embeddings[0], top_k=1)
|
||||
assert sim == [1.0]
|
||||
assert out_ids == ["a"]
|
||||
|
||||
query_embedding = [
|
||||
normalized_embeddings[1][0] + 0.02,
|
||||
normalized_embeddings[1][1] + 0.02,
|
||||
normalized_embeddings[1][2] + 0.02,
|
||||
]
|
||||
_, _, out_ids = db.query(embedding=query_embedding, top_k=1)
|
||||
assert out_ids == ["b"]
|
||||
|
||||
def test_save_load_delete(self, tmp_path):
|
||||
"""Test that save/load func behave correctly."""
|
||||
embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]
|
||||
metadatas = [{"a": 1, "b": 2}, {"a": 3, "b": 4}, {"a": 5, "b": 6}]
|
||||
ids = ["1", "2", "3"]
|
||||
db = MilvusVectorStore(path=str(tmp_path), overwrite=True)
|
||||
db.add(embeddings=embeddings, metadatas=metadatas, ids=ids)
|
||||
|
||||
db2 = MilvusVectorStore(path=str(tmp_path), overrides=False)
|
||||
assert db2.count() == 3, "load function does not load data completely"
|
||||
|
||||
# test delete collection function
|
||||
db2.drop()
|
||||
# reinit the milvus with the same collection name
|
||||
db2 = MilvusVectorStore(path=str(tmp_path), overwrite=False)
|
||||
assert db2.count() == 0, "delete collection function does not work correctly"
|
||||
|
Loading…
x
Reference in New Issue
Block a user