fix: add guidance parameters for LC wrapper models (#255)

* fix: add docstring to LC wrapper models

* fix: fix metadata passing with LC embedding wrapper
This commit is contained in:
Tuan Anh Nguyen Dang (Tadashi_Cin) 2024-09-09 14:15:34 +07:00 committed by GitHub
parent ce489725d8
commit 96d2086017
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 54 additions and 18 deletions

View File

@ -208,6 +208,7 @@ KH_EMBEDDINGS["cohere"] = {
"__type__": "kotaemon.embeddings.LCCohereEmbeddings",
"model": "embed-multilingual-v2.0",
"cohere_api_key": "your-key",
"user_agent": "default",
},
"default": False,
}

View File

@ -1,6 +1,6 @@
from typing import Optional
from kotaemon.base import Document, DocumentWithEmbedding
from kotaemon.base import DocumentWithEmbedding, Param
from .base import BaseEmbeddings
@ -19,25 +19,14 @@ class LCEmbeddingMixin:
super().__init__()
def run(self, text):
input_: list[str] = []
if not isinstance(text, list):
text = [text]
for item in text:
if isinstance(item, str):
input_.append(item)
elif isinstance(item, Document):
input_.append(item.text)
else:
raise ValueError(
f"Invalid input type {type(item)}, should be str or Document"
)
input_docs = self.prepare_input(text)
input_ = [doc.text for doc in input_docs]
embeddings = self._obj.embed_documents(input_)
return [
DocumentWithEmbedding(text=each_text, embedding=each_embedding)
for each_text, each_embedding in zip(input_, embeddings)
DocumentWithEmbedding(content=doc, embedding=each_embedding)
for doc, each_embedding in zip(input_docs, embeddings)
]
def __repr__(self):
@ -162,6 +151,20 @@ class LCAzureOpenAIEmbeddings(LCEmbeddingMixin, BaseEmbeddings):
class LCCohereEmbeddings(LCEmbeddingMixin, BaseEmbeddings):
"""Wrapper around Langchain's Cohere embedding, focusing on key parameters"""
cohere_api_key: str = Param(
help="API key (https://dashboard.cohere.com/api-keys)",
default=None,
required=True,
)
model: str = Param(
help="Model name to use (https://docs.cohere.com/docs/models)",
default=None,
required=True,
)
user_agent: str = Param(
help="User agent (leave default)", default="default", required=True
)
def __init__(
self,
model: str = "embed-english-v2.0",
@ -190,6 +193,15 @@ class LCCohereEmbeddings(LCEmbeddingMixin, BaseEmbeddings):
class LCHuggingFaceEmbeddings(LCEmbeddingMixin, BaseEmbeddings):
"""Wrapper around Langchain's HuggingFace embedding, focusing on key parameters"""
model_name: str = Param(
help=(
"Model name to use (https://huggingface.co/models?"
"pipeline_tag=sentence-similarity&sort=trending)"
),
default=None,
required=True,
)
def __init__(
self,
model_name: str = "sentence-transformers/all-mpnet-base-v2",

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import logging
from typing import AsyncGenerator, Iterator
from kotaemon.base import BaseMessage, HumanMessage, LLMInterface
from kotaemon.base import BaseMessage, HumanMessage, LLMInterface, Param
from .base import ChatLLM
@ -224,6 +224,17 @@ class LCAzureChatOpenAI(LCChatMixin, ChatLLM): # type: ignore
class LCAnthropicChat(LCChatMixin, ChatLLM): # type: ignore
api_key: str = Param(
help="API key (https://console.anthropic.com/settings/keys)", required=True
)
model_name: str = Param(
help=(
"Model name to use "
"(https://docs.anthropic.com/en/docs/about-claude/models)"
),
required=True,
)
def __init__(
self,
api_key: str | None = None,
@ -248,6 +259,17 @@ class LCAnthropicChat(LCChatMixin, ChatLLM): # type: ignore
class LCGeminiChat(LCChatMixin, ChatLLM): # type: ignore
api_key: str = Param(
help="API key (https://aistudio.google.com/app/apikey)", required=True
)
model_name: str = Param(
help=(
"Model name to use (https://cloud.google"
".com/vertex-ai/generative-ai/docs/learn/models)"
),
required=True,
)
def __init__(
self,
api_key: str | None = None,

View File

@ -50,6 +50,7 @@ class EmbeddingManager:
}
if item.default:
self._default = item.name
self._models["default"] = self._models[item.name]
def load_vendors(self):
from kotaemon.embeddings import (

View File

@ -344,7 +344,7 @@ class FileIndex(BaseIndex):
def get_admin_settings(cls):
from ktem.embeddings.manager import embedding_models_manager
embedding_default = embedding_models_manager.get_default_name()
embedding_default = "default"
embedding_choices = list(embedding_models_manager.options().keys())
return {