diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py index 06cc56975..1774bc285 100644 --- a/rag/llm/embedding_model.py +++ b/rag/llm/embedding_model.py @@ -111,11 +111,12 @@ class OpenAIEmbed(Base): return np.array(res.data[0].embedding), res.usage.total_tokens -class AzureEmbed(Base): +class AzureEmbed(OpenAIEmbed): def __init__(self, key, model_name, **kwargs): self.client = AzureOpenAI(api_key=key, azure_endpoint=kwargs["base_url"], api_version="2024-02-01") self.model_name = model_name + class BaiChuanEmbed(OpenAIEmbed): def __init__(self, key, model_name='Baichuan-Text-Embedding',