mirror of
https://github.com/Cinnamon/kotaemon.git
synced 2025-06-26 23:19:56 +00:00

* use tzlocal to get the local time * delete tmp folder * update date_created and date_updated with current timezone * pass precommit * update date_created field default by local time
193 lines
5.7 KiB
Python
193 lines
5.7 KiB
Python
import uuid
|
|
from datetime import datetime
|
|
|
|
import chromadb
|
|
from ktem.index.models import Index
|
|
from sqlalchemy import (
|
|
JSON,
|
|
Column,
|
|
DateTime,
|
|
Integer,
|
|
String,
|
|
UniqueConstraint,
|
|
create_engine,
|
|
select,
|
|
)
|
|
from sqlalchemy.ext.declarative import declarative_base
|
|
from sqlalchemy.ext.mutable import MutableDict
|
|
from sqlalchemy.orm import Session
|
|
from tzlocal import get_localzone
|
|
|
|
|
|
def _init_resource(private: bool = True, id: int = 1):
|
|
"""Init schemas. Hard-code"""
|
|
Base = declarative_base()
|
|
|
|
if private:
|
|
Source = type(
|
|
"Source",
|
|
(Base,),
|
|
{
|
|
"__tablename__": f"index__{id}__source",
|
|
"__table_args__": (
|
|
UniqueConstraint("name", "user", name="_name_user_uc"),
|
|
),
|
|
"id": Column(
|
|
String,
|
|
primary_key=True,
|
|
default=lambda: str(uuid.uuid4()),
|
|
unique=True,
|
|
),
|
|
"name": Column(String),
|
|
"path": Column(String),
|
|
"size": Column(Integer, default=0),
|
|
"date_created": Column(
|
|
DateTime(timezone=True), default=datetime.now(get_localzone())
|
|
),
|
|
"user": Column(Integer, default=1),
|
|
"note": Column(
|
|
MutableDict.as_mutable(JSON), # type: ignore
|
|
default={},
|
|
),
|
|
},
|
|
)
|
|
else:
|
|
Source = type(
|
|
"Source",
|
|
(Base,),
|
|
{
|
|
"__tablename__": f"index__{id}__source",
|
|
"id": Column(
|
|
String,
|
|
primary_key=True,
|
|
default=lambda: str(uuid.uuid4()),
|
|
unique=True,
|
|
),
|
|
"name": Column(String, unique=True),
|
|
"path": Column(String),
|
|
"size": Column(Integer, default=0),
|
|
"date_created": Column(
|
|
DateTime(timezone=True), default=datetime.now(get_localzone())
|
|
),
|
|
"user": Column(Integer, default=1),
|
|
"note": Column(
|
|
MutableDict.as_mutable(JSON), # type: ignore
|
|
default={},
|
|
),
|
|
},
|
|
)
|
|
Index = type(
|
|
"IndexTable",
|
|
(Base,),
|
|
{
|
|
"__tablename__": f"index__{id}__index",
|
|
"id": Column(Integer, primary_key=True, autoincrement=True),
|
|
"source_id": Column(String),
|
|
"target_id": Column(String),
|
|
"relation_type": Column(String),
|
|
"user": Column(Integer, default=1),
|
|
},
|
|
)
|
|
|
|
return {"Source": Source, "Index": Index}
|
|
|
|
|
|
def get_chromadb_collection(
|
|
db_dir: str = "../ktem_app_data/user_data/vectorstore",
|
|
collection_name: str = "index_1",
|
|
):
|
|
"""Extract collection from chromadb"""
|
|
client = chromadb.PersistentClient(path=db_dir)
|
|
collection = client.get_or_create_collection(collection_name)
|
|
|
|
return collection
|
|
|
|
|
|
def update_metadata(metadata, file_id):
|
|
"""Update file_id"""
|
|
metadata["file_id"] = file_id
|
|
return metadata
|
|
|
|
|
|
def migrate_chroma_db(
|
|
chroma_db_dir: str, sqlite_path: str, is_private: bool = True, int_index: int = 1
|
|
):
|
|
chroma_collection_name = f"index_{int_index}"
|
|
|
|
"""Update chromadb with metadata.file_id"""
|
|
engine = create_engine(sqlite_path)
|
|
resource = _init_resource(private=is_private, id=int_index)
|
|
print("Load sqlalchemy engine successfully!")
|
|
|
|
chroma_db_collection = get_chromadb_collection(
|
|
db_dir=chroma_db_dir, collection_name=chroma_collection_name
|
|
)
|
|
print(
|
|
f"Load chromadb collection: {chroma_collection_name}, "
|
|
f"path: {chroma_db_dir} successfully!"
|
|
)
|
|
|
|
# Load docs id of user
|
|
with Session(engine) as session:
|
|
stmt = select(resource["Source"])
|
|
results = session.execute(stmt)
|
|
doc_ids = [r[0].id for r in results.all()]
|
|
print(f"Retrieve n-docs: {len(doc_ids)}")
|
|
print(doc_ids)
|
|
|
|
for doc_id in doc_ids:
|
|
print("-")
|
|
# Find corresponding vector ids
|
|
with Session(engine) as session:
|
|
stmt = select(resource["Index"]).where(
|
|
resource["Index"].relation_type == "vector",
|
|
resource["Index"].source_id.in_([doc_id]),
|
|
)
|
|
results = session.execute(stmt)
|
|
vs_ids = [r[0].target_id for r in results.all()]
|
|
|
|
print(f"Got {len(vs_ids)} vs_ids for doc {doc_id}")
|
|
|
|
# Update file_id
|
|
if len(vs_ids) > 0:
|
|
batch = chroma_db_collection.get(ids=vs_ids, include=["metadatas"])
|
|
batch.update(
|
|
ids=batch["ids"],
|
|
metadatas=[
|
|
update_metadata(metadata, doc_id) for metadata in batch["metadatas"]
|
|
],
|
|
)
|
|
|
|
# Assert file_id. Skip
|
|
print(f"doc-{doc_id} got updated")
|
|
|
|
|
|
def main(chroma_db_dir: str, sqlite_path: str):
|
|
engine = create_engine(sqlite_path)
|
|
|
|
with Session(engine) as session:
|
|
stmt = select(Index)
|
|
|
|
results = session.execute(stmt)
|
|
file_indices = [r[0] for r in results.all()]
|
|
|
|
for file_index in file_indices:
|
|
_id = file_index.id
|
|
_is_private = file_index.config["private"]
|
|
|
|
print(f"Migrating for Index id: {_id}, is_private: {_is_private}")
|
|
|
|
migrate_chroma_db(
|
|
chroma_db_dir=chroma_db_dir,
|
|
sqlite_path=sqlite_path,
|
|
is_private=_is_private,
|
|
int_index=_id,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
chrome_db_dir: str = "./vectorstore/kan_db"
|
|
sqlite_path: str = "sqlite:///../ktem_app_data/user_data/sql.db"
|
|
|
|
main(chrome_db_dir, sqlite_path)
|