mirror of
https://github.com/Cinnamon/kotaemon.git
synced 2025-06-26 23:19:56 +00:00
feat: support milvus vector db (#188) #none
Signed-off-by: ChengZi <chen.zhang@zilliz.com>
This commit is contained in:
parent
76f2652d2a
commit
772186b6e5
@ -80,6 +80,7 @@ KH_DOCSTORE = {
|
|||||||
KH_VECTORSTORE = {
|
KH_VECTORSTORE = {
|
||||||
# "__type__": "kotaemon.storages.LanceDBVectorStore",
|
# "__type__": "kotaemon.storages.LanceDBVectorStore",
|
||||||
"__type__": "kotaemon.storages.ChromaVectorStore",
|
"__type__": "kotaemon.storages.ChromaVectorStore",
|
||||||
|
# "__type__": "kotaemon.storages.MilvusVectorStore",
|
||||||
"path": str(KH_USER_DATA_DIR / "vectorstore"),
|
"path": str(KH_USER_DATA_DIR / "vectorstore"),
|
||||||
}
|
}
|
||||||
KH_LLMS = {}
|
KH_LLMS = {}
|
||||||
|
@ -10,6 +10,7 @@ from .vectorstores import (
|
|||||||
ChromaVectorStore,
|
ChromaVectorStore,
|
||||||
InMemoryVectorStore,
|
InMemoryVectorStore,
|
||||||
LanceDBVectorStore,
|
LanceDBVectorStore,
|
||||||
|
MilvusVectorStore,
|
||||||
SimpleFileVectorStore,
|
SimpleFileVectorStore,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -26,4 +27,5 @@ __all__ = [
|
|||||||
"InMemoryVectorStore",
|
"InMemoryVectorStore",
|
||||||
"SimpleFileVectorStore",
|
"SimpleFileVectorStore",
|
||||||
"LanceDBVectorStore",
|
"LanceDBVectorStore",
|
||||||
|
"MilvusVectorStore",
|
||||||
]
|
]
|
||||||
|
@ -2,6 +2,7 @@ from .base import BaseVectorStore
|
|||||||
from .chroma import ChromaVectorStore
|
from .chroma import ChromaVectorStore
|
||||||
from .in_memory import InMemoryVectorStore
|
from .in_memory import InMemoryVectorStore
|
||||||
from .lancedb import LanceDBVectorStore
|
from .lancedb import LanceDBVectorStore
|
||||||
|
from .milvus import MilvusVectorStore
|
||||||
from .simple_file import SimpleFileVectorStore
|
from .simple_file import SimpleFileVectorStore
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -10,4 +11,5 @@ __all__ = [
|
|||||||
"InMemoryVectorStore",
|
"InMemoryVectorStore",
|
||||||
"SimpleFileVectorStore",
|
"SimpleFileVectorStore",
|
||||||
"LanceDBVectorStore",
|
"LanceDBVectorStore",
|
||||||
|
"MilvusVectorStore",
|
||||||
]
|
]
|
||||||
|
100
libs/kotaemon/kotaemon/storages/vectorstores/milvus.py
Normal file
100
libs/kotaemon/kotaemon/storages/vectorstores/milvus.py
Normal file
@ -0,0 +1,100 @@
|
|||||||
|
import os
|
||||||
|
from typing import Any, Optional, Type, cast
|
||||||
|
|
||||||
|
from llama_index.vector_stores.milvus import MilvusVectorStore as LIMilvusVectorStore
|
||||||
|
|
||||||
|
from kotaemon.base import DocumentWithEmbedding
|
||||||
|
|
||||||
|
from .base import LlamaIndexVectorStore
|
||||||
|
|
||||||
|
|
||||||
|
class MilvusVectorStore(LlamaIndexVectorStore):
|
||||||
|
_li_class: Type[LIMilvusVectorStore] = LIMilvusVectorStore
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
uri: str = "./milvus.db", # or "http://localhost:19530"
|
||||||
|
collection_name: str = "default",
|
||||||
|
token: Optional[str] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
):
|
||||||
|
self._uri = uri
|
||||||
|
self._collection_name = collection_name
|
||||||
|
self._token = token
|
||||||
|
self._kwargs = kwargs
|
||||||
|
self._path = kwargs.get("path", None)
|
||||||
|
self._inited = False
|
||||||
|
|
||||||
|
def _lazy_init(self, dim: Optional[int] = None):
|
||||||
|
"""
|
||||||
|
Lazy init the client.
|
||||||
|
Because the LlamaIndex init method requires the dim parameter,
|
||||||
|
we need to try to get the dim from the first embedding.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dim: Dimension of the vectors.
|
||||||
|
"""
|
||||||
|
if not self._inited:
|
||||||
|
if os.path.isdir(self._path) and not self._uri.startswith("http"):
|
||||||
|
uri = os.path.join(self._path, self._uri)
|
||||||
|
else:
|
||||||
|
uri = self._uri
|
||||||
|
super().__init__(
|
||||||
|
uri=uri,
|
||||||
|
token=self._token,
|
||||||
|
collection_name=self._collection_name,
|
||||||
|
dim=dim,
|
||||||
|
**self._kwargs,
|
||||||
|
)
|
||||||
|
self._client = cast(LIMilvusVectorStore, self._client)
|
||||||
|
self._inited = True
|
||||||
|
|
||||||
|
def add(
|
||||||
|
self,
|
||||||
|
embeddings: list[list[float]] | list[DocumentWithEmbedding],
|
||||||
|
metadatas: Optional[list[dict]] = None,
|
||||||
|
ids: Optional[list[str]] = None,
|
||||||
|
):
|
||||||
|
if not self._inited:
|
||||||
|
if isinstance(embeddings[0], list):
|
||||||
|
dim = len(embeddings[0])
|
||||||
|
else:
|
||||||
|
dim = len(embeddings[0].embedding)
|
||||||
|
self._lazy_init(dim)
|
||||||
|
|
||||||
|
return super().add(embeddings=embeddings, metadatas=metadatas, ids=ids)
|
||||||
|
|
||||||
|
def query(
|
||||||
|
self,
|
||||||
|
embedding: list[float],
|
||||||
|
top_k: int = 1,
|
||||||
|
ids: Optional[list[str]] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> tuple[list[list[float]], list[float], list[str]]:
|
||||||
|
self._lazy_init(len(embedding))
|
||||||
|
|
||||||
|
return super().query(embedding=embedding, top_k=top_k, ids=ids, **kwargs)
|
||||||
|
|
||||||
|
def delete(self, ids: list[str], **kwargs):
|
||||||
|
self._lazy_init()
|
||||||
|
super().delete(ids=ids, **kwargs)
|
||||||
|
|
||||||
|
def drop(self):
|
||||||
|
self._client.client.drop_collection(self._collection_name)
|
||||||
|
|
||||||
|
def count(self) -> int:
|
||||||
|
try:
|
||||||
|
self._lazy_init()
|
||||||
|
except: # noqa: E722
|
||||||
|
return 0
|
||||||
|
return self._client.client.query(
|
||||||
|
collection_name=self._collection_name, output_fields=["count(*)"]
|
||||||
|
)[0]["count(*)"]
|
||||||
|
|
||||||
|
def __persist_flow__(self):
|
||||||
|
return {
|
||||||
|
"uri": self._uri,
|
||||||
|
"collection_name": self._collection_name,
|
||||||
|
"token": self._token,
|
||||||
|
**self._kwargs,
|
||||||
|
}
|
@ -34,6 +34,7 @@ dependencies = [
|
|||||||
"llama-index>=0.10.40,<0.11.0",
|
"llama-index>=0.10.40,<0.11.0",
|
||||||
"llama-index-vector-stores-chroma>=0.1.9",
|
"llama-index-vector-stores-chroma>=0.1.9",
|
||||||
"llama-index-vector-stores-lancedb",
|
"llama-index-vector-stores-lancedb",
|
||||||
|
"llama-index-vector-stores-milvus",
|
||||||
"openai>=1.23.6,<2",
|
"openai>=1.23.6,<2",
|
||||||
"openpyxl>=3.1.2,<3.2",
|
"openpyxl>=3.1.2,<3.2",
|
||||||
"pandas>=2.2.2,<2.3",
|
"pandas>=2.2.2,<2.3",
|
||||||
|
@ -5,6 +5,7 @@ from kotaemon.base import DocumentWithEmbedding
|
|||||||
from kotaemon.storages import (
|
from kotaemon.storages import (
|
||||||
ChromaVectorStore,
|
ChromaVectorStore,
|
||||||
InMemoryVectorStore,
|
InMemoryVectorStore,
|
||||||
|
MilvusVectorStore,
|
||||||
SimpleFileVectorStore,
|
SimpleFileVectorStore,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -153,3 +154,97 @@ class TestSimpleFileVectorStore:
|
|||||||
], "load function does not load data completely"
|
], "load function does not load data completely"
|
||||||
|
|
||||||
os.remove(tmp_path / collection_name)
|
os.remove(tmp_path / collection_name)
|
||||||
|
|
||||||
|
|
||||||
|
class TestMilvusVectorStore:
|
||||||
|
def test_add(self, tmp_path):
|
||||||
|
"""Test that the DB add correctly"""
|
||||||
|
db = MilvusVectorStore(
|
||||||
|
path=str(tmp_path),
|
||||||
|
overwrite=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||||
|
metadatas = [{"a": 1, "b": 2}, {"a": 3, "b": 4}]
|
||||||
|
ids = ["1", "2"]
|
||||||
|
|
||||||
|
assert db.count() == 0, "Expected empty collection"
|
||||||
|
output = db.add(embeddings=embeddings, metadatas=metadatas, ids=ids)
|
||||||
|
assert output == ids, "Expected output to be the same as ids"
|
||||||
|
assert db.count() == 2, "Expected 2 added entries"
|
||||||
|
|
||||||
|
def test_add_from_docs(self, tmp_path):
|
||||||
|
db = MilvusVectorStore(
|
||||||
|
path=str(tmp_path),
|
||||||
|
overwrite=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||||
|
metadatas = [{"a": 1, "b": 2}, {"a": 3, "b": 4}]
|
||||||
|
documents = [
|
||||||
|
DocumentWithEmbedding(embedding=embedding, metadata=metadata)
|
||||||
|
for embedding, metadata in zip(embeddings, metadatas)
|
||||||
|
]
|
||||||
|
assert db.count() == 0, "Expected empty collection"
|
||||||
|
output = db.add(documents)
|
||||||
|
assert len(output) == 2, "Expected outputting 2 ids"
|
||||||
|
assert db.count() == 2, "Expected 2 added entries"
|
||||||
|
|
||||||
|
def test_delete(self, tmp_path):
|
||||||
|
db = MilvusVectorStore(
|
||||||
|
path=str(tmp_path),
|
||||||
|
overwrite=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]
|
||||||
|
metadatas = [{"a": 1, "b": 2}, {"a": 3, "b": 4}, {"a": 5, "b": 6}]
|
||||||
|
ids = ["a", "b", "c"]
|
||||||
|
|
||||||
|
db.add(embeddings=embeddings, metadatas=metadatas, ids=ids)
|
||||||
|
assert db.count() == 3, "Expected 3 added entries"
|
||||||
|
db.delete(ids=["a", "b"])
|
||||||
|
assert db.count() == 1, "Expected 1 remaining entry"
|
||||||
|
db.delete(ids=["c"])
|
||||||
|
assert db.count() == 0, "Expected 0 remaining entry"
|
||||||
|
|
||||||
|
def test_query(self, tmp_path):
|
||||||
|
db = MilvusVectorStore(path=str(tmp_path), overwrite=True)
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
embeddings = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]])
|
||||||
|
norms = np.linalg.norm(embeddings, axis=1)
|
||||||
|
normalized_embeddings = (embeddings / norms[:, np.newaxis]).tolist()
|
||||||
|
|
||||||
|
metadatas = [{"a": 1, "b": 2}, {"a": 3, "b": 4}, {"a": 5, "b": 6}]
|
||||||
|
ids = ["a", "b", "c"]
|
||||||
|
|
||||||
|
db.add(embeddings=normalized_embeddings, metadatas=metadatas, ids=ids)
|
||||||
|
|
||||||
|
_, sim, out_ids = db.query(embedding=normalized_embeddings[0], top_k=1)
|
||||||
|
assert sim == [1.0]
|
||||||
|
assert out_ids == ["a"]
|
||||||
|
|
||||||
|
query_embedding = [
|
||||||
|
normalized_embeddings[1][0] + 0.02,
|
||||||
|
normalized_embeddings[1][1] + 0.02,
|
||||||
|
normalized_embeddings[1][2] + 0.02,
|
||||||
|
]
|
||||||
|
_, _, out_ids = db.query(embedding=query_embedding, top_k=1)
|
||||||
|
assert out_ids == ["b"]
|
||||||
|
|
||||||
|
def test_save_load_delete(self, tmp_path):
|
||||||
|
"""Test that save/load func behave correctly."""
|
||||||
|
embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]
|
||||||
|
metadatas = [{"a": 1, "b": 2}, {"a": 3, "b": 4}, {"a": 5, "b": 6}]
|
||||||
|
ids = ["1", "2", "3"]
|
||||||
|
db = MilvusVectorStore(path=str(tmp_path), overwrite=True)
|
||||||
|
db.add(embeddings=embeddings, metadatas=metadatas, ids=ids)
|
||||||
|
|
||||||
|
db2 = MilvusVectorStore(path=str(tmp_path), overrides=False)
|
||||||
|
assert db2.count() == 3, "load function does not load data completely"
|
||||||
|
|
||||||
|
# test delete collection function
|
||||||
|
db2.drop()
|
||||||
|
# reinit the milvus with the same collection name
|
||||||
|
db2 = MilvusVectorStore(path=str(tmp_path), overwrite=False)
|
||||||
|
assert db2.count() == 0, "delete collection function does not work correctly"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user