Provide embedding manager (#16)

* Provide the Embedding management UI

* Update Fastembed documentation

* Add validation when adding / updating embeddings

* Stop using the old ktem embeddings manager

* Set default local embedding models

* Move the local embeddings below in flowsettings

* Update flowsettings
This commit is contained in:
Duc Nguyen (john) 2024-04-10 15:11:44 +07:00 committed by GitHub
parent ed10020ea3
commit 7b3307e3c4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 608 additions and 30 deletions

View File

@ -193,7 +193,7 @@ information panel.
You can access users' collections of LLMs and embedding models with:
```python
from ktem.components import embeddings
from ktem.embeddings.manager import embeddings
from ktem.llms.manager import llms

View File

@ -18,10 +18,11 @@ class FastEmbedEmbeddings(BaseEmbeddings):
model_name: str = Param(
"BAAI/bge-small-en-v1.5",
help=(
"Model name for fastembed. "
"Supported model: "
"https://qdrant.github.io/fastembed/examples/Supported_Models/"
"Model name for fastembed. Please refer "
"[here](https://qdrant.github.io/fastembed/examples/Supported_Models/) "
"for the list of supported models."
),
required=True,
)
batch_size: int = Param(
256,
@ -34,7 +35,7 @@ class FastEmbedEmbeddings(BaseEmbeddings):
"If > 1, data-parallel encoding will be used. "
"If 0, use all available CPUs. "
"If None, use default onnxruntime threading. "
"Defaults to None"
"Defaults to None."
),
)

View File

@ -57,7 +57,7 @@ if config("AZURE_OPENAI_API_KEY", default="") and config(
if config("AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT", default=""):
KH_EMBEDDINGS["azure"] = {
"spec": {
"__type__": "kotaemon.embeddings.LCAzureOpenAIEmbeddings",
"__type__": "kotaemon.embeddings.AzureOpenAIEmbeddings",
"azure_endpoint": config("AZURE_OPENAI_ENDPOINT", default=""),
"api_key": config("AZURE_OPENAI_API_KEY", default=""),
"api_version": config("OPENAI_API_VERSION", default="")
@ -68,8 +68,6 @@ if config("AZURE_OPENAI_API_KEY", default="") and config(
"timeout": 10,
},
"default": False,
"accuracy": 5,
"cost": 5,
}
if config("OPENAI_API_KEY", default=""):
@ -88,7 +86,7 @@ if config("OPENAI_API_KEY", default=""):
if len(KH_EMBEDDINGS) < 1:
KH_EMBEDDINGS["openai"] = {
"spec": {
"__type__": "kotaemon.embeddings.LCOpenAIEmbeddings",
"__type__": "kotaemon.embeddings.OpenAIEmbeddings",
"base_url": config("OPENAI_API_BASE", default="")
or "https://api.openai.com/v1",
"api_key": config("OPENAI_API_KEY", default=""),
@ -120,6 +118,14 @@ if config("LOCAL_MODEL", default=""):
"cost": 0,
}
if len(KH_EMBEDDINGS) < 1:
KH_EMBEDDINGS["local-mxbai-large-v1"] = {
"spec": {
"__type__": "kotaemon.embeddings.FastEmbedEmbeddings",
"model_name": "mixedbread-ai/mxbai-embed-large-v1",
},
"default": True,
}
KH_REASONINGS = ["ktem.reasoning.simple.FullQAPipeline"]
KH_VLM_ENDPOINT = "{0}/openai/deployments/{1}/chat/completions?api-version={2}".format(

View File

@ -182,7 +182,5 @@ class ModelPool:
return self._models[self._cost[0]]
llms = ModelPool("LLMs", settings.KH_LLMS)
embeddings = ModelPool("Embeddings", settings.KH_EMBEDDINGS)
reasonings: dict = {}
tools = ModelPool("Tools", {})

View File

View File

@ -0,0 +1,36 @@
from typing import Type
from ktem.db.engine import engine
from sqlalchemy import JSON, Boolean, Column, String
from sqlalchemy.orm import DeclarativeBase
from theflow.settings import settings as flowsettings
from theflow.utils.modules import import_dotted_string
class Base(DeclarativeBase):
pass
class BaseEmbeddingTable(Base):
"""Base table to store language model"""
__abstract__ = True
name = Column(String, primary_key=True, unique=True)
spec = Column(JSON, default={})
default = Column(Boolean, default=False)
_base_llm: Type[BaseEmbeddingTable] = (
import_dotted_string(flowsettings.KH_EMBEDDING_LLM, safe=False)
if hasattr(flowsettings, "KH_EMBEDDING_LLM")
else BaseEmbeddingTable
)
class EmbeddingTable(_base_llm): # type: ignore
__tablename__ = "embedding"
if not getattr(flowsettings, "KH_ENABLE_ALEMBIC", False):
EmbeddingTable.metadata.create_all(engine)

View File

@ -0,0 +1,199 @@
from typing import Optional, Type
from sqlalchemy import select
from sqlalchemy.orm import Session
from theflow.settings import settings as flowsettings
from theflow.utils.modules import deserialize
from kotaemon.embeddings.base import BaseEmbeddings
from .db import EmbeddingTable, engine
class EmbeddingManager:
"""Represent a pool of models"""
def __init__(self):
self._models: dict[str, BaseEmbeddings] = {}
self._info: dict[str, dict] = {}
self._default: str = ""
self._vendors: list[Type] = []
# populate the pool if empty
if hasattr(flowsettings, "KH_EMBEDDINGS"):
with Session(engine) as sess:
count = sess.query(EmbeddingTable).count()
if not count:
for name, model in flowsettings.KH_EMBEDDINGS.items():
self.add(
name=name,
spec=model["spec"],
default=model.get("default", False),
)
self.load()
self.load_vendors()
def load(self):
"""Load the model pool from database"""
self._models, self._info, self._defaut = {}, {}, ""
with Session(engine) as sess:
stmt = select(EmbeddingTable)
items = sess.execute(stmt)
for (item,) in items:
self._models[item.name] = deserialize(item.spec, safe=False)
self._info[item.name] = {
"name": item.name,
"spec": item.spec,
"default": item.default,
}
if item.default:
self._default = item.name
def load_vendors(self):
from kotaemon.embeddings import (
AzureOpenAIEmbeddings,
FastEmbedEmbeddings,
OpenAIEmbeddings,
)
self._vendors = [AzureOpenAIEmbeddings, OpenAIEmbeddings, FastEmbedEmbeddings]
def __getitem__(self, key: str) -> BaseEmbeddings:
"""Get model by name"""
return self._models[key]
def __contains__(self, key: str) -> bool:
"""Check if model exists"""
return key in self._models
def get(
self, key: str, default: Optional[BaseEmbeddings] = None
) -> Optional[BaseEmbeddings]:
"""Get model by name with default value"""
return self._models.get(key, default)
def settings(self) -> dict:
"""Present model pools option for gradio"""
return {
"label": "Embedding",
"choices": list(self._models.keys()),
"value": self.get_default_name(),
}
def options(self) -> dict:
"""Present a dict of models"""
return self._models
def get_random_name(self) -> str:
"""Get the name of random model
Returns:
str: random model name in the pool
"""
import random
if not self._models:
raise ValueError("No models in pool")
return random.choice(list(self._models.keys()))
def get_default_name(self) -> str:
"""Get the name of default model
In case there is no default model, choose random model from pool. In
case there are multiple default models, choose random from them.
Returns:
str: model name
"""
if not self._models:
raise ValueError("No models in pool")
if not self._default:
return self.get_random_name()
return self._default
def get_random(self) -> BaseEmbeddings:
"""Get random model"""
return self._models[self.get_random_name()]
def get_default(self) -> BaseEmbeddings:
"""Get default model
In case there is no default model, choose random model from pool. In
case there are multiple default models, choose random from them.
Returns:
BaseEmbeddings: model
"""
return self._models[self.get_default_name()]
def info(self) -> dict:
"""List all models"""
return self._info
def add(self, name: str, spec: dict, default: bool):
"""Add a new model to the pool"""
if not name:
raise ValueError("Name must not be empty")
try:
with Session(engine) as sess:
if default:
# turn all models to non-default
sess.query(EmbeddingTable).update({"default": False})
sess.commit()
item = EmbeddingTable(name=name, spec=spec, default=default)
sess.add(item)
sess.commit()
except Exception as e:
raise ValueError(f"Failed to add model {name}: {e}")
self.load()
def delete(self, name: str):
"""Delete a model from the pool"""
try:
with Session(engine) as sess:
item = sess.query(EmbeddingTable).filter_by(name=name).first()
sess.delete(item)
sess.commit()
except Exception as e:
raise ValueError(f"Failed to delete model {name}: {e}")
self.load()
def update(self, name: str, spec: dict, default: bool):
"""Update a model in the pool"""
if not name:
raise ValueError("Name must not be empty")
try:
with Session(engine) as sess:
if default:
# turn all models to non-default
sess.query(EmbeddingTable).update({"default": False})
sess.commit()
item = sess.query(EmbeddingTable).filter_by(name=name).first()
if not item:
raise ValueError(f"Model {name} not found")
item.spec = spec
item.default = default
sess.commit()
except Exception as e:
raise ValueError(f"Failed to update model {name}: {e}")
self.load()
def vendors(self) -> dict:
"""Return list of vendors"""
return {vendor.__qualname__: vendor for vendor in self._vendors}
embeddings = EmbeddingManager()

View File

@ -0,0 +1,325 @@
from copy import deepcopy
import gradio as gr
import pandas as pd
import yaml
from ktem.app import BasePage
from .manager import embeddings
def format_description(cls):
params = cls.describe()["params"]
params_lines = ["| Name | Type | Description |", "| --- | --- | --- |"]
for key, value in params.items():
if isinstance(value["auto_callback"], str):
continue
params_lines.append(f"| {key} | {value['type']} | {value['help']} |")
return f"{cls.__doc__}\n\n" + "\n".join(params_lines)
class EmbeddingManagement(BasePage):
def __init__(self, app):
self._app = app
self.spec_desc_default = (
"# Spec description\n\nSelect a model to view the spec description."
)
self.on_building_ui()
def on_building_ui(self):
with gr.Tab(label="View"):
self.emb_list = gr.DataFrame(
headers=["name", "vendor", "default"],
interactive=False,
)
with gr.Column(visible=False) as self._selected_panel:
self.selected_emb_name = gr.Textbox(value="", visible=False)
with gr.Row():
with gr.Column():
self.edit_default = gr.Checkbox(
label="Set default",
info=(
"Set this Embedding model as default. This default "
"Embedding will be used by other components by default "
"if no Embedding is specified for such components."
),
)
self.edit_spec = gr.Textbox(
label="Specification",
info="Specification of the Embedding model in YAML format",
lines=10,
)
with gr.Row(visible=False) as self._selected_panel_btn:
with gr.Column():
self.btn_edit_save = gr.Button(
"Save", min_width=10, variant="primary"
)
with gr.Column():
self.btn_delete = gr.Button(
"Delete", min_width=10, variant="stop"
)
with gr.Row():
self.btn_delete_yes = gr.Button(
"Confirm Delete",
variant="stop",
visible=False,
min_width=10,
)
self.btn_delete_no = gr.Button(
"Cancel", visible=False, min_width=10
)
with gr.Column():
self.btn_close = gr.Button("Close", min_width=10)
with gr.Column():
self.edit_spec_desc = gr.Markdown("# Spec description")
with gr.Tab(label="Add"):
with gr.Row():
with gr.Column(scale=2):
self.name = gr.Textbox(
label="Name",
info=(
"Must be unique and non-empty. "
"The name will be used to identify the embedding model."
),
)
self.emb_choices = gr.Dropdown(
label="Vendors",
info=(
"Choose the vendor of the Embedding model. Each vendor "
"has different specification."
),
)
self.spec = gr.Textbox(
label="Specification",
info="Specification of the Embedding model in YAML format.",
)
self.default = gr.Checkbox(
label="Set default",
info=(
"Set this Embedding model as default. This default "
"Embedding will be used by other components by default "
"if no Embedding is specified for such components."
),
)
self.btn_new = gr.Button("Add", variant="primary")
with gr.Column(scale=3):
self.spec_desc = gr.Markdown(self.spec_desc_default)
def _on_app_created(self):
"""Called when the app is created"""
self._app.app.load(
self.list_embeddings,
inputs=None,
outputs=[self.emb_list],
)
self._app.app.load(
lambda: gr.update(choices=list(embeddings.vendors().keys())),
outputs=[self.emb_choices],
)
def on_emb_vendor_change(self, vendor):
vendor = embeddings.vendors()[vendor]
required: dict = {}
desc = vendor.describe()
for key, value in desc["params"].items():
if value.get("required", False):
required[key] = value.get("default", None)
return yaml.dump(required), format_description(vendor)
def on_register_events(self):
self.emb_choices.select(
self.on_emb_vendor_change,
inputs=[self.emb_choices],
outputs=[self.spec, self.spec_desc],
)
self.btn_new.click(
self.create_emb,
inputs=[self.name, self.emb_choices, self.spec, self.default],
outputs=None,
).success(self.list_embeddings, inputs=None, outputs=[self.emb_list]).success(
lambda: ("", None, "", False, self.spec_desc_default),
outputs=[
self.name,
self.emb_choices,
self.spec,
self.default,
self.spec_desc,
],
)
self.emb_list.select(
self.select_emb,
inputs=self.emb_list,
outputs=[self.selected_emb_name],
show_progress="hidden",
)
self.selected_emb_name.change(
self.on_selected_emb_change,
inputs=[self.selected_emb_name],
outputs=[
self._selected_panel,
self._selected_panel_btn,
# delete section
self.btn_delete,
self.btn_delete_yes,
self.btn_delete_no,
# edit section
self.edit_spec,
self.edit_spec_desc,
self.edit_default,
],
show_progress="hidden",
)
self.btn_delete.click(
self.on_btn_delete_click,
inputs=None,
outputs=[self.btn_delete, self.btn_delete_yes, self.btn_delete_no],
show_progress="hidden",
)
self.btn_delete_yes.click(
self.delete_emb,
inputs=[self.selected_emb_name],
outputs=[self.selected_emb_name],
show_progress="hidden",
).then(
self.list_embeddings,
inputs=None,
outputs=[self.emb_list],
)
self.btn_delete_no.click(
lambda: (
gr.update(visible=True),
gr.update(visible=False),
gr.update(visible=False),
),
inputs=None,
outputs=[self.btn_delete, self.btn_delete_yes, self.btn_delete_no],
show_progress="hidden",
)
self.btn_edit_save.click(
self.save_emb,
inputs=[
self.selected_emb_name,
self.edit_default,
self.edit_spec,
],
show_progress="hidden",
).then(
self.list_embeddings,
inputs=None,
outputs=[self.emb_list],
)
self.btn_close.click(
lambda: "",
outputs=[self.selected_emb_name],
)
def create_emb(self, name, choices, spec, default):
try:
spec = yaml.safe_load(spec)
spec["__type__"] = (
embeddings.vendors()[choices].__module__
+ "."
+ embeddings.vendors()[choices].__qualname__
)
embeddings.add(name, spec=spec, default=default)
gr.Info(f'Create Embedding model "{name}" successfully')
except Exception as e:
raise gr.Error(f"Failed to create Embedding model {name}: {e}")
def list_embeddings(self):
"""List the Embedding models"""
items = []
for item in embeddings.info().values():
record = {}
record["name"] = item["name"]
record["vendor"] = item["spec"].get("__type__", "-").split(".")[-1]
record["default"] = item["default"]
items.append(record)
if items:
emb_list = pd.DataFrame.from_records(items)
else:
emb_list = pd.DataFrame.from_records(
[{"name": "-", "vendor": "-", "default": "-"}]
)
return emb_list
def select_emb(self, emb_list, ev: gr.SelectData):
if ev.value == "-" and ev.index[0] == 0:
gr.Info("No embedding model is loaded. Please add first")
return ""
if not ev.selected:
return ""
return emb_list["name"][ev.index[0]]
def on_selected_emb_change(self, selected_emb_name):
if selected_emb_name == "":
_selected_panel = gr.update(visible=False)
_selected_panel_btn = gr.update(visible=False)
btn_delete = gr.update(visible=True)
btn_delete_yes = gr.update(visible=False)
btn_delete_no = gr.update(visible=False)
edit_spec = gr.update(value="")
edit_spec_desc = gr.update(value="")
edit_default = gr.update(value=False)
else:
_selected_panel = gr.update(visible=True)
_selected_panel_btn = gr.update(visible=True)
btn_delete = gr.update(visible=True)
btn_delete_yes = gr.update(visible=False)
btn_delete_no = gr.update(visible=False)
info = deepcopy(embeddings.info()[selected_emb_name])
vendor_str = info["spec"].pop("__type__", "-").split(".")[-1]
vendor = embeddings.vendors()[vendor_str]
edit_spec = yaml.dump(info["spec"])
edit_spec_desc = format_description(vendor)
edit_default = info["default"]
return (
_selected_panel,
_selected_panel_btn,
btn_delete,
btn_delete_yes,
btn_delete_no,
edit_spec,
edit_spec_desc,
edit_default,
)
def on_btn_delete_click(self):
btn_delete = gr.update(visible=False)
btn_delete_yes = gr.update(visible=True)
btn_delete_no = gr.update(visible=True)
return btn_delete, btn_delete_yes, btn_delete_no
def save_emb(self, selected_emb_name, default, spec):
try:
spec = yaml.safe_load(spec)
spec["__type__"] = embeddings.info()[selected_emb_name]["spec"]["__type__"]
embeddings.update(selected_emb_name, spec=spec, default=default)
gr.Info(f'Save Embedding model "{selected_emb_name}" successfully')
except Exception as e:
gr.Error(f'Failed to save Embedding model "{selected_emb_name}": {e}')
def delete_emb(self, selected_emb_name):
try:
embeddings.delete(selected_emb_name)
except Exception as e:
gr.Error(f'Failed to delete Embedding model "{selected_emb_name}": {e}')
return selected_emb_name
return ""

View File

@ -236,17 +236,26 @@ class FileIndex(BaseIndex):
"""Create the index for the first time
For the file index, this will:
1. Create the index and the source table if not already exists
2. Create the vectorstore
3. Create the docstore
1. Postprocess the config
2. Create the index and the source table if not already exists
3. Create the vectorstore
4. Create the docstore
"""
file_types_str = self.config.get(
"supported_file_types",
self.get_admin_settings()["supported_file_types"]["value"],
)
file_types = [each.strip() for each in file_types_str.split(",")]
self.config["supported_file_types"] = file_types
# default user's value
config = {}
for key, value in self.get_admin_settings().items():
config[key] = value["value"]
# user's modification
config.update(self.config)
# clean
file_types_str = config["supported_file_types"]
file_types = [each.strip() for each in file_types_str.split(",")]
config["supported_file_types"] = file_types
self.config = config
# create the resources
self._resources["Source"].metadata.create_all(engine) # type: ignore
self._resources["Index"].metadata.create_all(engine) # type: ignore
self._fs_path.mkdir(parents=True, exist_ok=True)
@ -285,7 +294,7 @@ class FileIndex(BaseIndex):
@classmethod
def get_admin_settings(cls):
from ktem.components import embeddings
from ktem.embeddings.manager import embeddings
embedding_default = embeddings.get_default_name()
embedding_choices = list(embeddings.options().keys())

View File

@ -10,8 +10,9 @@ from pathlib import Path
from typing import Optional
import gradio as gr
from ktem.components import embeddings, filestorage_path
from ktem.components import filestorage_path
from ktem.db.models import engine
from ktem.embeddings.manager import embeddings
from llama_index.vector_stores import (
FilterCondition,
FilterOperator,
@ -68,9 +69,7 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
for surrounding tables (e.g. within the page)
"""
vector_retrieval: VectorRetrieval = VectorRetrieval.withx(
embedding=embeddings.get_default(),
)
vector_retrieval: VectorRetrieval = VectorRetrieval.withx()
reranker: BaseReranking
get_extra_table: bool = False
@ -226,6 +225,7 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
if not user_settings["use_reranking"]:
retriever.reranker = None # type: ignore
retriever.vector_retrieval.embedding = embeddings[index_settings["embedding"]]
kwargs = {
".top_k": int(user_settings["num_retrieval"]),
".mmr": user_settings["mmr"],
@ -248,9 +248,7 @@ class IndexDocumentPipeline(BaseFileIndexIndexing):
file_ingestor: ingestor to ingest the documents
"""
indexing_vector_pipeline: VectorIndexing = VectorIndexing.withx(
embedding=embeddings.get_default(),
)
indexing_vector_pipeline: VectorIndexing = VectorIndexing.withx()
file_ingestor: DocumentIngestor = DocumentIngestor.withx()
def run(
@ -438,6 +436,8 @@ class IndexDocumentPipeline(BaseFileIndexIndexing):
if chunk_overlap:
obj.file_ingestor.text_splitter.chunk_overlap = chunk_overlap
obj.indexing_vector_pipeline.embedding = embeddings[index_settings["embedding"]]
return obj
def set_resources(self, resources: dict):

View File

@ -432,7 +432,7 @@ class FileIndexPage(BasePage):
"name": each[0].name,
"size": each[0].size,
"text_length": each[0].text_length,
"date_created": each[0].date_created,
"date_created": each[0].date_created.strftime("%Y-%m-%d %H:%M:%S"),
}
for each in session.execute(statement).all()
]

View File

@ -1,6 +1,7 @@
import gradio as gr
from ktem.app import BasePage
from ktem.db.models import User, engine
from ktem.embeddings.ui import EmbeddingManagement
from ktem.llms.ui import LLMManagement
from sqlmodel import Session, select
@ -17,9 +18,12 @@ class AdminPage(BasePage):
with gr.Tab("User Management", visible=False) as self.user_management_tab:
self.user_management = UserManagement(self._app)
with gr.Tab("LLM Management") as self.llm_management_tab:
with gr.Tab("LLMs") as self.llm_management_tab:
self.llm_management = LLMManagement(self._app)
with gr.Tab("Embeddings") as self.llm_management_tab:
self.emb_management = EmbeddingManagement(self._app)
def on_subscribe_public_events(self):
if self._app.f_user_management:
self._app.subscribe_event(