kotaemon/scripts/migrate/migrate_chroma_db.py
cin-cris 5b828c213c
fix: fix Application UI using UTC time (#472) bump:patch
* use tzlocal to get the local time

* delete tmp folder

* update date_created and date_updated with current timezone

* pass precommit

* update date_created field default by local time
2024-11-11 16:51:38 +07:00

193 lines
5.7 KiB
Python

import uuid
from datetime import datetime
import chromadb
from ktem.index.models import Index
from sqlalchemy import (
JSON,
Column,
DateTime,
Integer,
String,
UniqueConstraint,
create_engine,
select,
)
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.ext.mutable import MutableDict
from sqlalchemy.orm import Session
from tzlocal import get_localzone
def _init_resource(private: bool = True, id: int = 1):
"""Init schemas. Hard-code"""
Base = declarative_base()
if private:
Source = type(
"Source",
(Base,),
{
"__tablename__": f"index__{id}__source",
"__table_args__": (
UniqueConstraint("name", "user", name="_name_user_uc"),
),
"id": Column(
String,
primary_key=True,
default=lambda: str(uuid.uuid4()),
unique=True,
),
"name": Column(String),
"path": Column(String),
"size": Column(Integer, default=0),
"date_created": Column(
DateTime(timezone=True), default=datetime.now(get_localzone())
),
"user": Column(Integer, default=1),
"note": Column(
MutableDict.as_mutable(JSON), # type: ignore
default={},
),
},
)
else:
Source = type(
"Source",
(Base,),
{
"__tablename__": f"index__{id}__source",
"id": Column(
String,
primary_key=True,
default=lambda: str(uuid.uuid4()),
unique=True,
),
"name": Column(String, unique=True),
"path": Column(String),
"size": Column(Integer, default=0),
"date_created": Column(
DateTime(timezone=True), default=datetime.now(get_localzone())
),
"user": Column(Integer, default=1),
"note": Column(
MutableDict.as_mutable(JSON), # type: ignore
default={},
),
},
)
Index = type(
"IndexTable",
(Base,),
{
"__tablename__": f"index__{id}__index",
"id": Column(Integer, primary_key=True, autoincrement=True),
"source_id": Column(String),
"target_id": Column(String),
"relation_type": Column(String),
"user": Column(Integer, default=1),
},
)
return {"Source": Source, "Index": Index}
def get_chromadb_collection(
db_dir: str = "../ktem_app_data/user_data/vectorstore",
collection_name: str = "index_1",
):
"""Extract collection from chromadb"""
client = chromadb.PersistentClient(path=db_dir)
collection = client.get_or_create_collection(collection_name)
return collection
def update_metadata(metadata, file_id):
"""Update file_id"""
metadata["file_id"] = file_id
return metadata
def migrate_chroma_db(
chroma_db_dir: str, sqlite_path: str, is_private: bool = True, int_index: int = 1
):
chroma_collection_name = f"index_{int_index}"
"""Update chromadb with metadata.file_id"""
engine = create_engine(sqlite_path)
resource = _init_resource(private=is_private, id=int_index)
print("Load sqlalchemy engine successfully!")
chroma_db_collection = get_chromadb_collection(
db_dir=chroma_db_dir, collection_name=chroma_collection_name
)
print(
f"Load chromadb collection: {chroma_collection_name}, "
f"path: {chroma_db_dir} successfully!"
)
# Load docs id of user
with Session(engine) as session:
stmt = select(resource["Source"])
results = session.execute(stmt)
doc_ids = [r[0].id for r in results.all()]
print(f"Retrieve n-docs: {len(doc_ids)}")
print(doc_ids)
for doc_id in doc_ids:
print("-")
# Find corresponding vector ids
with Session(engine) as session:
stmt = select(resource["Index"]).where(
resource["Index"].relation_type == "vector",
resource["Index"].source_id.in_([doc_id]),
)
results = session.execute(stmt)
vs_ids = [r[0].target_id for r in results.all()]
print(f"Got {len(vs_ids)} vs_ids for doc {doc_id}")
# Update file_id
if len(vs_ids) > 0:
batch = chroma_db_collection.get(ids=vs_ids, include=["metadatas"])
batch.update(
ids=batch["ids"],
metadatas=[
update_metadata(metadata, doc_id) for metadata in batch["metadatas"]
],
)
# Assert file_id. Skip
print(f"doc-{doc_id} got updated")
def main(chroma_db_dir: str, sqlite_path: str):
engine = create_engine(sqlite_path)
with Session(engine) as session:
stmt = select(Index)
results = session.execute(stmt)
file_indices = [r[0] for r in results.all()]
for file_index in file_indices:
_id = file_index.id
_is_private = file_index.config["private"]
print(f"Migrating for Index id: {_id}, is_private: {_is_private}")
migrate_chroma_db(
chroma_db_dir=chroma_db_dir,
sqlite_path=sqlite_path,
is_private=_is_private,
int_index=_id,
)
if __name__ == "__main__":
chrome_db_dir: str = "./vectorstore/kan_db"
sqlite_path: str = "sqlite:///../ktem_app_data/user_data/sql.db"
main(chrome_db_dir, sqlite_path)