mirror of
https://github.com/Cinnamon/kotaemon.git
synced 2025-06-26 23:19:56 +00:00
Provide embedding manager (#16)
* Provide the Embedding management UI * Update Fastembed documentation * Add validation when adding / updating embeddings * Stop using the old ktem embeddings manager * Set default local embedding models * Move the local embeddings below in flowsettings * Update flowsettings
This commit is contained in:
parent
ed10020ea3
commit
7b3307e3c4
@ -193,7 +193,7 @@ information panel.
|
||||
You can access users' collections of LLMs and embedding models with:
|
||||
|
||||
```python
|
||||
from ktem.components import embeddings
|
||||
from ktem.embeddings.manager import embeddings
|
||||
from ktem.llms.manager import llms
|
||||
|
||||
|
||||
|
@ -18,10 +18,11 @@ class FastEmbedEmbeddings(BaseEmbeddings):
|
||||
model_name: str = Param(
|
||||
"BAAI/bge-small-en-v1.5",
|
||||
help=(
|
||||
"Model name for fastembed. "
|
||||
"Supported model: "
|
||||
"https://qdrant.github.io/fastembed/examples/Supported_Models/"
|
||||
"Model name for fastembed. Please refer "
|
||||
"[here](https://qdrant.github.io/fastembed/examples/Supported_Models/) "
|
||||
"for the list of supported models."
|
||||
),
|
||||
required=True,
|
||||
)
|
||||
batch_size: int = Param(
|
||||
256,
|
||||
@ -34,7 +35,7 @@ class FastEmbedEmbeddings(BaseEmbeddings):
|
||||
"If > 1, data-parallel encoding will be used. "
|
||||
"If 0, use all available CPUs. "
|
||||
"If None, use default onnxruntime threading. "
|
||||
"Defaults to None"
|
||||
"Defaults to None."
|
||||
),
|
||||
)
|
||||
|
||||
|
@ -57,7 +57,7 @@ if config("AZURE_OPENAI_API_KEY", default="") and config(
|
||||
if config("AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT", default=""):
|
||||
KH_EMBEDDINGS["azure"] = {
|
||||
"spec": {
|
||||
"__type__": "kotaemon.embeddings.LCAzureOpenAIEmbeddings",
|
||||
"__type__": "kotaemon.embeddings.AzureOpenAIEmbeddings",
|
||||
"azure_endpoint": config("AZURE_OPENAI_ENDPOINT", default=""),
|
||||
"api_key": config("AZURE_OPENAI_API_KEY", default=""),
|
||||
"api_version": config("OPENAI_API_VERSION", default="")
|
||||
@ -68,8 +68,6 @@ if config("AZURE_OPENAI_API_KEY", default="") and config(
|
||||
"timeout": 10,
|
||||
},
|
||||
"default": False,
|
||||
"accuracy": 5,
|
||||
"cost": 5,
|
||||
}
|
||||
|
||||
if config("OPENAI_API_KEY", default=""):
|
||||
@ -88,7 +86,7 @@ if config("OPENAI_API_KEY", default=""):
|
||||
if len(KH_EMBEDDINGS) < 1:
|
||||
KH_EMBEDDINGS["openai"] = {
|
||||
"spec": {
|
||||
"__type__": "kotaemon.embeddings.LCOpenAIEmbeddings",
|
||||
"__type__": "kotaemon.embeddings.OpenAIEmbeddings",
|
||||
"base_url": config("OPENAI_API_BASE", default="")
|
||||
or "https://api.openai.com/v1",
|
||||
"api_key": config("OPENAI_API_KEY", default=""),
|
||||
@ -120,6 +118,14 @@ if config("LOCAL_MODEL", default=""):
|
||||
"cost": 0,
|
||||
}
|
||||
|
||||
if len(KH_EMBEDDINGS) < 1:
|
||||
KH_EMBEDDINGS["local-mxbai-large-v1"] = {
|
||||
"spec": {
|
||||
"__type__": "kotaemon.embeddings.FastEmbedEmbeddings",
|
||||
"model_name": "mixedbread-ai/mxbai-embed-large-v1",
|
||||
},
|
||||
"default": True,
|
||||
}
|
||||
|
||||
KH_REASONINGS = ["ktem.reasoning.simple.FullQAPipeline"]
|
||||
KH_VLM_ENDPOINT = "{0}/openai/deployments/{1}/chat/completions?api-version={2}".format(
|
||||
|
@ -182,7 +182,5 @@ class ModelPool:
|
||||
return self._models[self._cost[0]]
|
||||
|
||||
|
||||
llms = ModelPool("LLMs", settings.KH_LLMS)
|
||||
embeddings = ModelPool("Embeddings", settings.KH_EMBEDDINGS)
|
||||
reasonings: dict = {}
|
||||
tools = ModelPool("Tools", {})
|
||||
|
0
libs/ktem/ktem/embeddings/__init__.py
Normal file
0
libs/ktem/ktem/embeddings/__init__.py
Normal file
36
libs/ktem/ktem/embeddings/db.py
Normal file
36
libs/ktem/ktem/embeddings/db.py
Normal file
@ -0,0 +1,36 @@
|
||||
from typing import Type
|
||||
|
||||
from ktem.db.engine import engine
|
||||
from sqlalchemy import JSON, Boolean, Column, String
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
from theflow.settings import settings as flowsettings
|
||||
from theflow.utils.modules import import_dotted_string
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
class BaseEmbeddingTable(Base):
|
||||
"""Base table to store language model"""
|
||||
|
||||
__abstract__ = True
|
||||
|
||||
name = Column(String, primary_key=True, unique=True)
|
||||
spec = Column(JSON, default={})
|
||||
default = Column(Boolean, default=False)
|
||||
|
||||
|
||||
_base_llm: Type[BaseEmbeddingTable] = (
|
||||
import_dotted_string(flowsettings.KH_EMBEDDING_LLM, safe=False)
|
||||
if hasattr(flowsettings, "KH_EMBEDDING_LLM")
|
||||
else BaseEmbeddingTable
|
||||
)
|
||||
|
||||
|
||||
class EmbeddingTable(_base_llm): # type: ignore
|
||||
__tablename__ = "embedding"
|
||||
|
||||
|
||||
if not getattr(flowsettings, "KH_ENABLE_ALEMBIC", False):
|
||||
EmbeddingTable.metadata.create_all(engine)
|
199
libs/ktem/ktem/embeddings/manager.py
Normal file
199
libs/ktem/ktem/embeddings/manager.py
Normal file
@ -0,0 +1,199 @@
|
||||
from typing import Optional, Type
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from theflow.settings import settings as flowsettings
|
||||
from theflow.utils.modules import deserialize
|
||||
|
||||
from kotaemon.embeddings.base import BaseEmbeddings
|
||||
|
||||
from .db import EmbeddingTable, engine
|
||||
|
||||
|
||||
class EmbeddingManager:
|
||||
"""Represent a pool of models"""
|
||||
|
||||
def __init__(self):
|
||||
self._models: dict[str, BaseEmbeddings] = {}
|
||||
self._info: dict[str, dict] = {}
|
||||
self._default: str = ""
|
||||
self._vendors: list[Type] = []
|
||||
|
||||
# populate the pool if empty
|
||||
if hasattr(flowsettings, "KH_EMBEDDINGS"):
|
||||
with Session(engine) as sess:
|
||||
count = sess.query(EmbeddingTable).count()
|
||||
if not count:
|
||||
for name, model in flowsettings.KH_EMBEDDINGS.items():
|
||||
self.add(
|
||||
name=name,
|
||||
spec=model["spec"],
|
||||
default=model.get("default", False),
|
||||
)
|
||||
|
||||
self.load()
|
||||
self.load_vendors()
|
||||
|
||||
def load(self):
|
||||
"""Load the model pool from database"""
|
||||
self._models, self._info, self._defaut = {}, {}, ""
|
||||
with Session(engine) as sess:
|
||||
stmt = select(EmbeddingTable)
|
||||
items = sess.execute(stmt)
|
||||
|
||||
for (item,) in items:
|
||||
self._models[item.name] = deserialize(item.spec, safe=False)
|
||||
self._info[item.name] = {
|
||||
"name": item.name,
|
||||
"spec": item.spec,
|
||||
"default": item.default,
|
||||
}
|
||||
if item.default:
|
||||
self._default = item.name
|
||||
|
||||
def load_vendors(self):
|
||||
from kotaemon.embeddings import (
|
||||
AzureOpenAIEmbeddings,
|
||||
FastEmbedEmbeddings,
|
||||
OpenAIEmbeddings,
|
||||
)
|
||||
|
||||
self._vendors = [AzureOpenAIEmbeddings, OpenAIEmbeddings, FastEmbedEmbeddings]
|
||||
|
||||
def __getitem__(self, key: str) -> BaseEmbeddings:
|
||||
"""Get model by name"""
|
||||
return self._models[key]
|
||||
|
||||
def __contains__(self, key: str) -> bool:
|
||||
"""Check if model exists"""
|
||||
return key in self._models
|
||||
|
||||
def get(
|
||||
self, key: str, default: Optional[BaseEmbeddings] = None
|
||||
) -> Optional[BaseEmbeddings]:
|
||||
"""Get model by name with default value"""
|
||||
return self._models.get(key, default)
|
||||
|
||||
def settings(self) -> dict:
|
||||
"""Present model pools option for gradio"""
|
||||
return {
|
||||
"label": "Embedding",
|
||||
"choices": list(self._models.keys()),
|
||||
"value": self.get_default_name(),
|
||||
}
|
||||
|
||||
def options(self) -> dict:
|
||||
"""Present a dict of models"""
|
||||
return self._models
|
||||
|
||||
def get_random_name(self) -> str:
|
||||
"""Get the name of random model
|
||||
|
||||
Returns:
|
||||
str: random model name in the pool
|
||||
"""
|
||||
import random
|
||||
|
||||
if not self._models:
|
||||
raise ValueError("No models in pool")
|
||||
|
||||
return random.choice(list(self._models.keys()))
|
||||
|
||||
def get_default_name(self) -> str:
|
||||
"""Get the name of default model
|
||||
|
||||
In case there is no default model, choose random model from pool. In
|
||||
case there are multiple default models, choose random from them.
|
||||
|
||||
Returns:
|
||||
str: model name
|
||||
"""
|
||||
if not self._models:
|
||||
raise ValueError("No models in pool")
|
||||
|
||||
if not self._default:
|
||||
return self.get_random_name()
|
||||
|
||||
return self._default
|
||||
|
||||
def get_random(self) -> BaseEmbeddings:
|
||||
"""Get random model"""
|
||||
return self._models[self.get_random_name()]
|
||||
|
||||
def get_default(self) -> BaseEmbeddings:
|
||||
"""Get default model
|
||||
|
||||
In case there is no default model, choose random model from pool. In
|
||||
case there are multiple default models, choose random from them.
|
||||
|
||||
Returns:
|
||||
BaseEmbeddings: model
|
||||
"""
|
||||
return self._models[self.get_default_name()]
|
||||
|
||||
def info(self) -> dict:
|
||||
"""List all models"""
|
||||
return self._info
|
||||
|
||||
def add(self, name: str, spec: dict, default: bool):
|
||||
"""Add a new model to the pool"""
|
||||
if not name:
|
||||
raise ValueError("Name must not be empty")
|
||||
|
||||
try:
|
||||
with Session(engine) as sess:
|
||||
if default:
|
||||
# turn all models to non-default
|
||||
sess.query(EmbeddingTable).update({"default": False})
|
||||
sess.commit()
|
||||
|
||||
item = EmbeddingTable(name=name, spec=spec, default=default)
|
||||
sess.add(item)
|
||||
sess.commit()
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to add model {name}: {e}")
|
||||
|
||||
self.load()
|
||||
|
||||
def delete(self, name: str):
|
||||
"""Delete a model from the pool"""
|
||||
try:
|
||||
with Session(engine) as sess:
|
||||
item = sess.query(EmbeddingTable).filter_by(name=name).first()
|
||||
sess.delete(item)
|
||||
sess.commit()
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to delete model {name}: {e}")
|
||||
|
||||
self.load()
|
||||
|
||||
def update(self, name: str, spec: dict, default: bool):
|
||||
"""Update a model in the pool"""
|
||||
if not name:
|
||||
raise ValueError("Name must not be empty")
|
||||
|
||||
try:
|
||||
with Session(engine) as sess:
|
||||
|
||||
if default:
|
||||
# turn all models to non-default
|
||||
sess.query(EmbeddingTable).update({"default": False})
|
||||
sess.commit()
|
||||
|
||||
item = sess.query(EmbeddingTable).filter_by(name=name).first()
|
||||
if not item:
|
||||
raise ValueError(f"Model {name} not found")
|
||||
item.spec = spec
|
||||
item.default = default
|
||||
sess.commit()
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to update model {name}: {e}")
|
||||
|
||||
self.load()
|
||||
|
||||
def vendors(self) -> dict:
|
||||
"""Return list of vendors"""
|
||||
return {vendor.__qualname__: vendor for vendor in self._vendors}
|
||||
|
||||
|
||||
embeddings = EmbeddingManager()
|
325
libs/ktem/ktem/embeddings/ui.py
Normal file
325
libs/ktem/ktem/embeddings/ui.py
Normal file
@ -0,0 +1,325 @@
|
||||
from copy import deepcopy
|
||||
|
||||
import gradio as gr
|
||||
import pandas as pd
|
||||
import yaml
|
||||
from ktem.app import BasePage
|
||||
|
||||
from .manager import embeddings
|
||||
|
||||
|
||||
def format_description(cls):
|
||||
params = cls.describe()["params"]
|
||||
params_lines = ["| Name | Type | Description |", "| --- | --- | --- |"]
|
||||
for key, value in params.items():
|
||||
if isinstance(value["auto_callback"], str):
|
||||
continue
|
||||
params_lines.append(f"| {key} | {value['type']} | {value['help']} |")
|
||||
return f"{cls.__doc__}\n\n" + "\n".join(params_lines)
|
||||
|
||||
|
||||
class EmbeddingManagement(BasePage):
|
||||
def __init__(self, app):
|
||||
self._app = app
|
||||
self.spec_desc_default = (
|
||||
"# Spec description\n\nSelect a model to view the spec description."
|
||||
)
|
||||
self.on_building_ui()
|
||||
|
||||
def on_building_ui(self):
|
||||
with gr.Tab(label="View"):
|
||||
self.emb_list = gr.DataFrame(
|
||||
headers=["name", "vendor", "default"],
|
||||
interactive=False,
|
||||
)
|
||||
|
||||
with gr.Column(visible=False) as self._selected_panel:
|
||||
self.selected_emb_name = gr.Textbox(value="", visible=False)
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
self.edit_default = gr.Checkbox(
|
||||
label="Set default",
|
||||
info=(
|
||||
"Set this Embedding model as default. This default "
|
||||
"Embedding will be used by other components by default "
|
||||
"if no Embedding is specified for such components."
|
||||
),
|
||||
)
|
||||
self.edit_spec = gr.Textbox(
|
||||
label="Specification",
|
||||
info="Specification of the Embedding model in YAML format",
|
||||
lines=10,
|
||||
)
|
||||
|
||||
with gr.Row(visible=False) as self._selected_panel_btn:
|
||||
with gr.Column():
|
||||
self.btn_edit_save = gr.Button(
|
||||
"Save", min_width=10, variant="primary"
|
||||
)
|
||||
with gr.Column():
|
||||
self.btn_delete = gr.Button(
|
||||
"Delete", min_width=10, variant="stop"
|
||||
)
|
||||
with gr.Row():
|
||||
self.btn_delete_yes = gr.Button(
|
||||
"Confirm Delete",
|
||||
variant="stop",
|
||||
visible=False,
|
||||
min_width=10,
|
||||
)
|
||||
self.btn_delete_no = gr.Button(
|
||||
"Cancel", visible=False, min_width=10
|
||||
)
|
||||
with gr.Column():
|
||||
self.btn_close = gr.Button("Close", min_width=10)
|
||||
|
||||
with gr.Column():
|
||||
self.edit_spec_desc = gr.Markdown("# Spec description")
|
||||
|
||||
with gr.Tab(label="Add"):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=2):
|
||||
self.name = gr.Textbox(
|
||||
label="Name",
|
||||
info=(
|
||||
"Must be unique and non-empty. "
|
||||
"The name will be used to identify the embedding model."
|
||||
),
|
||||
)
|
||||
self.emb_choices = gr.Dropdown(
|
||||
label="Vendors",
|
||||
info=(
|
||||
"Choose the vendor of the Embedding model. Each vendor "
|
||||
"has different specification."
|
||||
),
|
||||
)
|
||||
self.spec = gr.Textbox(
|
||||
label="Specification",
|
||||
info="Specification of the Embedding model in YAML format.",
|
||||
)
|
||||
self.default = gr.Checkbox(
|
||||
label="Set default",
|
||||
info=(
|
||||
"Set this Embedding model as default. This default "
|
||||
"Embedding will be used by other components by default "
|
||||
"if no Embedding is specified for such components."
|
||||
),
|
||||
)
|
||||
self.btn_new = gr.Button("Add", variant="primary")
|
||||
|
||||
with gr.Column(scale=3):
|
||||
self.spec_desc = gr.Markdown(self.spec_desc_default)
|
||||
|
||||
def _on_app_created(self):
|
||||
"""Called when the app is created"""
|
||||
self._app.app.load(
|
||||
self.list_embeddings,
|
||||
inputs=None,
|
||||
outputs=[self.emb_list],
|
||||
)
|
||||
self._app.app.load(
|
||||
lambda: gr.update(choices=list(embeddings.vendors().keys())),
|
||||
outputs=[self.emb_choices],
|
||||
)
|
||||
|
||||
def on_emb_vendor_change(self, vendor):
|
||||
vendor = embeddings.vendors()[vendor]
|
||||
|
||||
required: dict = {}
|
||||
desc = vendor.describe()
|
||||
for key, value in desc["params"].items():
|
||||
if value.get("required", False):
|
||||
required[key] = value.get("default", None)
|
||||
|
||||
return yaml.dump(required), format_description(vendor)
|
||||
|
||||
def on_register_events(self):
|
||||
self.emb_choices.select(
|
||||
self.on_emb_vendor_change,
|
||||
inputs=[self.emb_choices],
|
||||
outputs=[self.spec, self.spec_desc],
|
||||
)
|
||||
self.btn_new.click(
|
||||
self.create_emb,
|
||||
inputs=[self.name, self.emb_choices, self.spec, self.default],
|
||||
outputs=None,
|
||||
).success(self.list_embeddings, inputs=None, outputs=[self.emb_list]).success(
|
||||
lambda: ("", None, "", False, self.spec_desc_default),
|
||||
outputs=[
|
||||
self.name,
|
||||
self.emb_choices,
|
||||
self.spec,
|
||||
self.default,
|
||||
self.spec_desc,
|
||||
],
|
||||
)
|
||||
self.emb_list.select(
|
||||
self.select_emb,
|
||||
inputs=self.emb_list,
|
||||
outputs=[self.selected_emb_name],
|
||||
show_progress="hidden",
|
||||
)
|
||||
self.selected_emb_name.change(
|
||||
self.on_selected_emb_change,
|
||||
inputs=[self.selected_emb_name],
|
||||
outputs=[
|
||||
self._selected_panel,
|
||||
self._selected_panel_btn,
|
||||
# delete section
|
||||
self.btn_delete,
|
||||
self.btn_delete_yes,
|
||||
self.btn_delete_no,
|
||||
# edit section
|
||||
self.edit_spec,
|
||||
self.edit_spec_desc,
|
||||
self.edit_default,
|
||||
],
|
||||
show_progress="hidden",
|
||||
)
|
||||
self.btn_delete.click(
|
||||
self.on_btn_delete_click,
|
||||
inputs=None,
|
||||
outputs=[self.btn_delete, self.btn_delete_yes, self.btn_delete_no],
|
||||
show_progress="hidden",
|
||||
)
|
||||
self.btn_delete_yes.click(
|
||||
self.delete_emb,
|
||||
inputs=[self.selected_emb_name],
|
||||
outputs=[self.selected_emb_name],
|
||||
show_progress="hidden",
|
||||
).then(
|
||||
self.list_embeddings,
|
||||
inputs=None,
|
||||
outputs=[self.emb_list],
|
||||
)
|
||||
self.btn_delete_no.click(
|
||||
lambda: (
|
||||
gr.update(visible=True),
|
||||
gr.update(visible=False),
|
||||
gr.update(visible=False),
|
||||
),
|
||||
inputs=None,
|
||||
outputs=[self.btn_delete, self.btn_delete_yes, self.btn_delete_no],
|
||||
show_progress="hidden",
|
||||
)
|
||||
self.btn_edit_save.click(
|
||||
self.save_emb,
|
||||
inputs=[
|
||||
self.selected_emb_name,
|
||||
self.edit_default,
|
||||
self.edit_spec,
|
||||
],
|
||||
show_progress="hidden",
|
||||
).then(
|
||||
self.list_embeddings,
|
||||
inputs=None,
|
||||
outputs=[self.emb_list],
|
||||
)
|
||||
self.btn_close.click(
|
||||
lambda: "",
|
||||
outputs=[self.selected_emb_name],
|
||||
)
|
||||
|
||||
def create_emb(self, name, choices, spec, default):
|
||||
try:
|
||||
spec = yaml.safe_load(spec)
|
||||
spec["__type__"] = (
|
||||
embeddings.vendors()[choices].__module__
|
||||
+ "."
|
||||
+ embeddings.vendors()[choices].__qualname__
|
||||
)
|
||||
|
||||
embeddings.add(name, spec=spec, default=default)
|
||||
gr.Info(f'Create Embedding model "{name}" successfully')
|
||||
except Exception as e:
|
||||
raise gr.Error(f"Failed to create Embedding model {name}: {e}")
|
||||
|
||||
def list_embeddings(self):
|
||||
"""List the Embedding models"""
|
||||
items = []
|
||||
for item in embeddings.info().values():
|
||||
record = {}
|
||||
record["name"] = item["name"]
|
||||
record["vendor"] = item["spec"].get("__type__", "-").split(".")[-1]
|
||||
record["default"] = item["default"]
|
||||
items.append(record)
|
||||
|
||||
if items:
|
||||
emb_list = pd.DataFrame.from_records(items)
|
||||
else:
|
||||
emb_list = pd.DataFrame.from_records(
|
||||
[{"name": "-", "vendor": "-", "default": "-"}]
|
||||
)
|
||||
|
||||
return emb_list
|
||||
|
||||
def select_emb(self, emb_list, ev: gr.SelectData):
|
||||
if ev.value == "-" and ev.index[0] == 0:
|
||||
gr.Info("No embedding model is loaded. Please add first")
|
||||
return ""
|
||||
|
||||
if not ev.selected:
|
||||
return ""
|
||||
|
||||
return emb_list["name"][ev.index[0]]
|
||||
|
||||
def on_selected_emb_change(self, selected_emb_name):
|
||||
if selected_emb_name == "":
|
||||
_selected_panel = gr.update(visible=False)
|
||||
_selected_panel_btn = gr.update(visible=False)
|
||||
btn_delete = gr.update(visible=True)
|
||||
btn_delete_yes = gr.update(visible=False)
|
||||
btn_delete_no = gr.update(visible=False)
|
||||
edit_spec = gr.update(value="")
|
||||
edit_spec_desc = gr.update(value="")
|
||||
edit_default = gr.update(value=False)
|
||||
else:
|
||||
_selected_panel = gr.update(visible=True)
|
||||
_selected_panel_btn = gr.update(visible=True)
|
||||
btn_delete = gr.update(visible=True)
|
||||
btn_delete_yes = gr.update(visible=False)
|
||||
btn_delete_no = gr.update(visible=False)
|
||||
|
||||
info = deepcopy(embeddings.info()[selected_emb_name])
|
||||
vendor_str = info["spec"].pop("__type__", "-").split(".")[-1]
|
||||
vendor = embeddings.vendors()[vendor_str]
|
||||
|
||||
edit_spec = yaml.dump(info["spec"])
|
||||
edit_spec_desc = format_description(vendor)
|
||||
edit_default = info["default"]
|
||||
|
||||
return (
|
||||
_selected_panel,
|
||||
_selected_panel_btn,
|
||||
btn_delete,
|
||||
btn_delete_yes,
|
||||
btn_delete_no,
|
||||
edit_spec,
|
||||
edit_spec_desc,
|
||||
edit_default,
|
||||
)
|
||||
|
||||
def on_btn_delete_click(self):
|
||||
btn_delete = gr.update(visible=False)
|
||||
btn_delete_yes = gr.update(visible=True)
|
||||
btn_delete_no = gr.update(visible=True)
|
||||
|
||||
return btn_delete, btn_delete_yes, btn_delete_no
|
||||
|
||||
def save_emb(self, selected_emb_name, default, spec):
|
||||
try:
|
||||
spec = yaml.safe_load(spec)
|
||||
spec["__type__"] = embeddings.info()[selected_emb_name]["spec"]["__type__"]
|
||||
embeddings.update(selected_emb_name, spec=spec, default=default)
|
||||
gr.Info(f'Save Embedding model "{selected_emb_name}" successfully')
|
||||
except Exception as e:
|
||||
gr.Error(f'Failed to save Embedding model "{selected_emb_name}": {e}')
|
||||
|
||||
def delete_emb(self, selected_emb_name):
|
||||
try:
|
||||
embeddings.delete(selected_emb_name)
|
||||
except Exception as e:
|
||||
gr.Error(f'Failed to delete Embedding model "{selected_emb_name}": {e}')
|
||||
return selected_emb_name
|
||||
|
||||
return ""
|
@ -236,17 +236,26 @@ class FileIndex(BaseIndex):
|
||||
"""Create the index for the first time
|
||||
|
||||
For the file index, this will:
|
||||
1. Create the index and the source table if not already exists
|
||||
2. Create the vectorstore
|
||||
3. Create the docstore
|
||||
1. Postprocess the config
|
||||
2. Create the index and the source table if not already exists
|
||||
3. Create the vectorstore
|
||||
4. Create the docstore
|
||||
"""
|
||||
file_types_str = self.config.get(
|
||||
"supported_file_types",
|
||||
self.get_admin_settings()["supported_file_types"]["value"],
|
||||
)
|
||||
file_types = [each.strip() for each in file_types_str.split(",")]
|
||||
self.config["supported_file_types"] = file_types
|
||||
# default user's value
|
||||
config = {}
|
||||
for key, value in self.get_admin_settings().items():
|
||||
config[key] = value["value"]
|
||||
|
||||
# user's modification
|
||||
config.update(self.config)
|
||||
|
||||
# clean
|
||||
file_types_str = config["supported_file_types"]
|
||||
file_types = [each.strip() for each in file_types_str.split(",")]
|
||||
config["supported_file_types"] = file_types
|
||||
self.config = config
|
||||
|
||||
# create the resources
|
||||
self._resources["Source"].metadata.create_all(engine) # type: ignore
|
||||
self._resources["Index"].metadata.create_all(engine) # type: ignore
|
||||
self._fs_path.mkdir(parents=True, exist_ok=True)
|
||||
@ -285,7 +294,7 @@ class FileIndex(BaseIndex):
|
||||
|
||||
@classmethod
|
||||
def get_admin_settings(cls):
|
||||
from ktem.components import embeddings
|
||||
from ktem.embeddings.manager import embeddings
|
||||
|
||||
embedding_default = embeddings.get_default_name()
|
||||
embedding_choices = list(embeddings.options().keys())
|
||||
|
@ -10,8 +10,9 @@ from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import gradio as gr
|
||||
from ktem.components import embeddings, filestorage_path
|
||||
from ktem.components import filestorage_path
|
||||
from ktem.db.models import engine
|
||||
from ktem.embeddings.manager import embeddings
|
||||
from llama_index.vector_stores import (
|
||||
FilterCondition,
|
||||
FilterOperator,
|
||||
@ -68,9 +69,7 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
|
||||
for surrounding tables (e.g. within the page)
|
||||
"""
|
||||
|
||||
vector_retrieval: VectorRetrieval = VectorRetrieval.withx(
|
||||
embedding=embeddings.get_default(),
|
||||
)
|
||||
vector_retrieval: VectorRetrieval = VectorRetrieval.withx()
|
||||
reranker: BaseReranking
|
||||
get_extra_table: bool = False
|
||||
|
||||
@ -226,6 +225,7 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
|
||||
if not user_settings["use_reranking"]:
|
||||
retriever.reranker = None # type: ignore
|
||||
|
||||
retriever.vector_retrieval.embedding = embeddings[index_settings["embedding"]]
|
||||
kwargs = {
|
||||
".top_k": int(user_settings["num_retrieval"]),
|
||||
".mmr": user_settings["mmr"],
|
||||
@ -248,9 +248,7 @@ class IndexDocumentPipeline(BaseFileIndexIndexing):
|
||||
file_ingestor: ingestor to ingest the documents
|
||||
"""
|
||||
|
||||
indexing_vector_pipeline: VectorIndexing = VectorIndexing.withx(
|
||||
embedding=embeddings.get_default(),
|
||||
)
|
||||
indexing_vector_pipeline: VectorIndexing = VectorIndexing.withx()
|
||||
file_ingestor: DocumentIngestor = DocumentIngestor.withx()
|
||||
|
||||
def run(
|
||||
@ -438,6 +436,8 @@ class IndexDocumentPipeline(BaseFileIndexIndexing):
|
||||
if chunk_overlap:
|
||||
obj.file_ingestor.text_splitter.chunk_overlap = chunk_overlap
|
||||
|
||||
obj.indexing_vector_pipeline.embedding = embeddings[index_settings["embedding"]]
|
||||
|
||||
return obj
|
||||
|
||||
def set_resources(self, resources: dict):
|
||||
|
@ -432,7 +432,7 @@ class FileIndexPage(BasePage):
|
||||
"name": each[0].name,
|
||||
"size": each[0].size,
|
||||
"text_length": each[0].text_length,
|
||||
"date_created": each[0].date_created,
|
||||
"date_created": each[0].date_created.strftime("%Y-%m-%d %H:%M:%S"),
|
||||
}
|
||||
for each in session.execute(statement).all()
|
||||
]
|
||||
|
@ -1,6 +1,7 @@
|
||||
import gradio as gr
|
||||
from ktem.app import BasePage
|
||||
from ktem.db.models import User, engine
|
||||
from ktem.embeddings.ui import EmbeddingManagement
|
||||
from ktem.llms.ui import LLMManagement
|
||||
from sqlmodel import Session, select
|
||||
|
||||
@ -17,9 +18,12 @@ class AdminPage(BasePage):
|
||||
with gr.Tab("User Management", visible=False) as self.user_management_tab:
|
||||
self.user_management = UserManagement(self._app)
|
||||
|
||||
with gr.Tab("LLM Management") as self.llm_management_tab:
|
||||
with gr.Tab("LLMs") as self.llm_management_tab:
|
||||
self.llm_management = LLMManagement(self._app)
|
||||
|
||||
with gr.Tab("Embeddings") as self.llm_management_tab:
|
||||
self.emb_management = EmbeddingManagement(self._app)
|
||||
|
||||
def on_subscribe_public_events(self):
|
||||
if self._app.f_user_management:
|
||||
self._app.subscribe_event(
|
||||
|
Loading…
x
Reference in New Issue
Block a user