2024-01-15 08:46:22 +08:00
|
|
|
#
|
2024-01-19 19:51:57 +08:00
|
|
|
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
2024-01-15 08:46:22 +08:00
|
|
|
#
|
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
|
#
|
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
#
|
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
#
|
2025-07-03 19:05:31 +08:00
|
|
|
import json
|
|
|
|
|
import os
|
2024-09-12 17:51:20 +08:00
|
|
|
import threading
|
2025-07-03 19:05:31 +08:00
|
|
|
from abc import ABC
|
2025-06-03 14:18:40 +08:00
|
|
|
from urllib.parse import urljoin
|
|
|
|
|
|
2025-07-03 19:05:31 +08:00
|
|
|
import dashscope
|
|
|
|
|
import google.generativeai as genai
|
|
|
|
|
import numpy as np
|
2024-05-29 16:50:02 +08:00
|
|
|
import requests
|
2024-04-08 19:20:57 +08:00
|
|
|
from ollama import Client
|
2024-01-15 08:46:22 +08:00
|
|
|
from openai import OpenAI
|
2025-07-03 19:05:31 +08:00
|
|
|
from zhipuai import ZhipuAI
|
2024-09-24 19:22:01 +08:00
|
|
|
|
2025-11-03 20:25:02 +08:00
|
|
|
from common.log_utils import log_exception
|
2025-11-03 08:50:05 +08:00
|
|
|
from common.token_utils import num_tokens_from_string, truncate
|
2025-11-05 11:07:54 +08:00
|
|
|
from common import globals
|
2025-10-23 23:02:27 +08:00
|
|
|
from api import settings
|
|
|
|
|
import logging
|
2024-03-27 11:33:46 +08:00
|
|
|
|
2024-11-25 11:37:56 +08:00
|
|
|
|
2024-01-15 08:46:22 +08:00
|
|
|
class Base(ABC):
|
2025-08-07 08:45:37 +07:00
|
|
|
def __init__(self, key, model_name, **kwargs):
|
|
|
|
|
"""
|
|
|
|
|
Constructor for abstract base class.
|
|
|
|
|
Parameters are accepted for interface consistency but are not stored.
|
|
|
|
|
Subclasses should implement their own initialization as needed.
|
|
|
|
|
"""
|
2024-01-15 08:46:22 +08:00
|
|
|
pass
|
|
|
|
|
|
2024-12-03 16:22:39 +08:00
|
|
|
def encode(self, texts: list):
|
2024-01-15 08:46:22 +08:00
|
|
|
raise NotImplementedError("Please implement encode method!")
|
|
|
|
|
|
2024-01-23 19:45:36 +08:00
|
|
|
def encode_queries(self, text: str):
|
|
|
|
|
raise NotImplementedError("Please implement encode method!")
|
|
|
|
|
|
2025-01-26 13:54:26 +08:00
|
|
|
def total_token_count(self, resp):
|
2025-10-09 12:36:19 +08:00
|
|
|
try:
|
|
|
|
|
return resp.usage.total_tokens
|
|
|
|
|
except Exception:
|
|
|
|
|
pass
|
|
|
|
|
try:
|
|
|
|
|
return resp["usage"]["total_tokens"]
|
|
|
|
|
except Exception:
|
|
|
|
|
pass
|
|
|
|
|
return 0
|
2025-01-26 13:54:26 +08:00
|
|
|
|
2024-01-15 08:46:22 +08:00
|
|
|
|
2025-10-23 23:02:27 +08:00
|
|
|
class BuiltinEmbed(Base):
|
|
|
|
|
_FACTORY_NAME = "Builtin"
|
|
|
|
|
MAX_TOKENS = {"Qwen/Qwen3-Embedding-0.6B": 30000, "BAAI/bge-m3": 8000, "BAAI/bge-small-en-v1.5": 500}
|
2024-05-29 16:50:02 +08:00
|
|
|
_model = None
|
2024-12-19 16:18:18 +08:00
|
|
|
_model_name = ""
|
2025-10-23 23:02:27 +08:00
|
|
|
_max_tokens = 500
|
2024-06-27 14:48:49 +08:00
|
|
|
_model_lock = threading.Lock()
|
2025-01-26 13:54:26 +08:00
|
|
|
|
2024-05-29 16:50:02 +08:00
|
|
|
def __init__(self, key, model_name, **kwargs):
|
2025-11-05 11:07:54 +08:00
|
|
|
logging.info(f"Initialize BuiltinEmbed according to globals.EMBEDDING_CFG: {globals.EMBEDDING_CFG}")
|
|
|
|
|
embedding_cfg = globals.EMBEDDING_CFG
|
2025-10-23 23:02:27 +08:00
|
|
|
if not BuiltinEmbed._model and "tei-" in os.getenv("COMPOSE_PROFILES", ""):
|
|
|
|
|
with BuiltinEmbed._model_lock:
|
|
|
|
|
BuiltinEmbed._model_name = settings.EMBEDDING_MDL
|
|
|
|
|
BuiltinEmbed._max_tokens = BuiltinEmbed.MAX_TOKENS.get(settings.EMBEDDING_MDL, 500)
|
|
|
|
|
BuiltinEmbed._model = HuggingFaceEmbed(embedding_cfg["api_key"], settings.EMBEDDING_MDL, base_url=embedding_cfg["base_url"])
|
|
|
|
|
self._model = BuiltinEmbed._model
|
|
|
|
|
self._model_name = BuiltinEmbed._model_name
|
|
|
|
|
self._max_tokens = BuiltinEmbed._max_tokens
|
2024-01-15 08:46:22 +08:00
|
|
|
|
2024-12-03 16:22:39 +08:00
|
|
|
def encode(self, texts: list):
|
|
|
|
|
batch_size = 16
|
2025-10-31 16:46:20 +08:00
|
|
|
# TEI is able to auto truncate inputs according to https://github.com/huggingface/text-embeddings-inference.
|
2024-01-15 08:46:22 +08:00
|
|
|
token_count = 0
|
2025-07-14 14:02:48 +08:00
|
|
|
ress = None
|
2024-01-15 08:46:22 +08:00
|
|
|
for i in range(0, len(texts), batch_size):
|
2025-10-23 23:02:27 +08:00
|
|
|
embeddings, token_count_delta = self._model.encode(texts[i : i + batch_size])
|
|
|
|
|
token_count += token_count_delta
|
2025-07-14 14:02:48 +08:00
|
|
|
if ress is None:
|
2025-10-23 23:02:27 +08:00
|
|
|
ress = embeddings
|
2025-07-14 14:02:48 +08:00
|
|
|
else:
|
2025-10-23 23:02:27 +08:00
|
|
|
ress = np.concatenate((ress, embeddings), axis=0)
|
2025-07-14 14:02:48 +08:00
|
|
|
return ress, token_count
|
2024-01-15 08:46:22 +08:00
|
|
|
|
2024-01-17 20:20:42 +08:00
|
|
|
def encode_queries(self, text: str):
|
2025-10-23 23:02:27 +08:00
|
|
|
return self._model.encode_queries(text)
|
2024-01-17 20:20:42 +08:00
|
|
|
|
2024-01-15 08:46:22 +08:00
|
|
|
|
|
|
|
|
class OpenAIEmbed(Base):
|
2025-07-03 19:05:31 +08:00
|
|
|
_FACTORY_NAME = "OpenAI"
|
|
|
|
|
|
|
|
|
|
def __init__(self, key, model_name="text-embedding-ada-002", base_url="https://api.openai.com/v1"):
|
2024-04-16 16:42:19 +08:00
|
|
|
if not base_url:
|
|
|
|
|
base_url = "https://api.openai.com/v1"
|
2024-03-28 19:15:16 +08:00
|
|
|
self.client = OpenAI(api_key=key, base_url=base_url)
|
2024-01-15 08:46:22 +08:00
|
|
|
self.model_name = model_name
|
|
|
|
|
|
2024-12-03 16:22:39 +08:00
|
|
|
def encode(self, texts: list):
|
|
|
|
|
# OpenAI requires batch size <=16
|
|
|
|
|
batch_size = 16
|
2024-08-16 09:49:27 +08:00
|
|
|
texts = [truncate(t, 8191) for t in texts]
|
2024-12-03 16:22:39 +08:00
|
|
|
ress = []
|
|
|
|
|
total_tokens = 0
|
|
|
|
|
for i in range(0, len(texts), batch_size):
|
2025-10-09 12:36:19 +08:00
|
|
|
res = self.client.embeddings.create(input=texts[i : i + batch_size], model=self.model_name, encoding_format="float", extra_body={"drop_params": True})
|
2025-06-12 17:53:59 +08:00
|
|
|
try:
|
|
|
|
|
ress.extend([d.embedding for d in res.data])
|
|
|
|
|
total_tokens += self.total_token_count(res)
|
|
|
|
|
except Exception as _e:
|
|
|
|
|
log_exception(_e, res)
|
2024-12-03 16:22:39 +08:00
|
|
|
return np.array(ress), total_tokens
|
2024-01-23 19:45:36 +08:00
|
|
|
|
|
|
|
|
def encode_queries(self, text):
|
2025-10-09 12:36:19 +08:00
|
|
|
res = self.client.embeddings.create(input=[truncate(text, 8191)], model=self.model_name, encoding_format="float",extra_body={"drop_params": True})
|
2025-01-26 13:54:26 +08:00
|
|
|
return np.array(res.data[0].embedding), self.total_token_count(res)
|
2024-01-15 08:46:22 +08:00
|
|
|
|
|
|
|
|
|
2024-07-19 15:50:28 +08:00
|
|
|
class LocalAIEmbed(Base):
|
2025-07-03 19:05:31 +08:00
|
|
|
_FACTORY_NAME = "LocalAI"
|
|
|
|
|
|
2024-07-19 15:50:28 +08:00
|
|
|
def __init__(self, key, model_name, base_url):
|
2024-07-25 10:23:35 +08:00
|
|
|
if not base_url:
|
|
|
|
|
raise ValueError("Local embedding model url cannot be None")
|
2025-06-03 14:18:40 +08:00
|
|
|
base_url = urljoin(base_url, "v1")
|
2024-07-25 10:23:35 +08:00
|
|
|
self.client = OpenAI(api_key="empty", base_url=base_url)
|
2024-07-19 15:50:28 +08:00
|
|
|
self.model_name = model_name.split("___")[0]
|
|
|
|
|
|
2024-12-03 16:22:39 +08:00
|
|
|
def encode(self, texts: list):
|
|
|
|
|
batch_size = 16
|
|
|
|
|
ress = []
|
|
|
|
|
for i in range(0, len(texts), batch_size):
|
2025-07-03 19:05:31 +08:00
|
|
|
res = self.client.embeddings.create(input=texts[i : i + batch_size], model=self.model_name)
|
2025-06-12 17:53:59 +08:00
|
|
|
try:
|
|
|
|
|
ress.extend([d.embedding for d in res.data])
|
|
|
|
|
except Exception as _e:
|
|
|
|
|
log_exception(_e, res)
|
2024-12-03 16:22:39 +08:00
|
|
|
# local embedding for LmStudio donot count tokens
|
|
|
|
|
return np.array(ress), 1024
|
2024-07-19 15:50:28 +08:00
|
|
|
|
|
|
|
|
def encode_queries(self, text):
|
2024-07-25 11:43:43 +08:00
|
|
|
embds, cnt = self.encode([text])
|
|
|
|
|
return np.array(embds[0]), cnt
|
2024-07-25 10:23:35 +08:00
|
|
|
|
2024-07-19 15:50:28 +08:00
|
|
|
|
2024-07-19 09:22:59 +08:00
|
|
|
class AzureEmbed(OpenAIEmbed):
|
2025-07-03 19:05:31 +08:00
|
|
|
_FACTORY_NAME = "Azure-OpenAI"
|
|
|
|
|
|
2024-07-04 09:57:16 +08:00
|
|
|
def __init__(self, key, model_name, **kwargs):
|
2024-10-09 10:34:58 +08:00
|
|
|
from openai.lib.azure import AzureOpenAI
|
2025-07-03 19:05:31 +08:00
|
|
|
|
|
|
|
|
api_key = json.loads(key).get("api_key", "")
|
|
|
|
|
api_version = json.loads(key).get("api_version", "2024-02-01")
|
2024-10-11 11:26:42 +08:00
|
|
|
self.client = AzureOpenAI(api_key=api_key, azure_endpoint=kwargs["base_url"], api_version=api_version)
|
2024-07-04 09:57:16 +08:00
|
|
|
self.model_name = model_name
|
|
|
|
|
|
2024-07-19 09:22:59 +08:00
|
|
|
|
2024-05-28 09:09:37 +08:00
|
|
|
class BaiChuanEmbed(OpenAIEmbed):
|
2025-07-03 19:05:31 +08:00
|
|
|
_FACTORY_NAME = "BaiChuan"
|
|
|
|
|
|
|
|
|
|
def __init__(self, key, model_name="Baichuan-Text-Embedding", base_url="https://api.baichuan-ai.com/v1"):
|
2024-05-28 09:09:37 +08:00
|
|
|
if not base_url:
|
|
|
|
|
base_url = "https://api.baichuan-ai.com/v1"
|
|
|
|
|
super().__init__(key, model_name, base_url)
|
|
|
|
|
|
|
|
|
|
|
2024-01-15 08:46:22 +08:00
|
|
|
class QWenEmbed(Base):
|
2025-07-03 19:05:31 +08:00
|
|
|
_FACTORY_NAME = "Tongyi-Qianwen"
|
|
|
|
|
|
2024-03-28 19:15:16 +08:00
|
|
|
def __init__(self, key, model_name="text_embedding_v2", **kwargs):
|
2024-12-03 16:22:39 +08:00
|
|
|
self.key = key
|
2024-01-15 08:46:22 +08:00
|
|
|
self.model_name = model_name
|
|
|
|
|
|
2024-12-03 16:22:39 +08:00
|
|
|
def encode(self, texts: list):
|
fix: retry embedding with Qwen family models when limits temporarily reached. (#8690)
fix: retry embedding with Qwen family models when limits temporarily
reached.
APIs of Qwen family models are limited by calling rates. When reached,
the "output" attribute of the "resp" will be None, and in turn cause
TypeError when trying to retrieve "embeddings". Since these limits are
almost temporary, I have added a simple retry mechanism to avoid it.
Besides, if retry_max reached, the error can be early raised, instead of
hidden behind "TypeError".
### What problem does this PR solve?
Sometimes Qwen blocks calling due to rate limits, but it will cause the
whole parsing procedure stops when creating knowledge base. In this
situation, resp["output"] will be None, and resp["output"]["embeddings"]
will cause TypeError. Since the limits are temporary, I apply a simple
retry mechanism to solve it.
### Type of change
- [x] Bug Fix (non-breaking change which fixes an issue)
---------
Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
2025-07-07 12:15:52 +08:00
|
|
|
import time
|
2025-07-03 19:05:31 +08:00
|
|
|
|
2025-07-23 18:10:35 +08:00
|
|
|
import dashscope
|
|
|
|
|
|
2024-12-03 16:22:39 +08:00
|
|
|
batch_size = 4
|
2025-06-12 17:53:59 +08:00
|
|
|
res = []
|
|
|
|
|
token_count = 0
|
|
|
|
|
texts = [truncate(t, 2048) for t in texts]
|
|
|
|
|
for i in range(0, len(texts), batch_size):
|
fix: retry embedding with Qwen family models when limits temporarily reached. (#8690)
fix: retry embedding with Qwen family models when limits temporarily
reached.
APIs of Qwen family models are limited by calling rates. When reached,
the "output" attribute of the "resp" will be None, and in turn cause
TypeError when trying to retrieve "embeddings". Since these limits are
almost temporary, I have added a simple retry mechanism to avoid it.
Besides, if retry_max reached, the error can be early raised, instead of
hidden behind "TypeError".
### What problem does this PR solve?
Sometimes Qwen blocks calling due to rate limits, but it will cause the
whole parsing procedure stops when creating knowledge base. In this
situation, resp["output"] will be None, and resp["output"]["embeddings"]
will cause TypeError. Since the limits are temporary, I apply a simple
retry mechanism to solve it.
### Type of change
- [x] Bug Fix (non-breaking change which fixes an issue)
---------
Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
2025-07-07 12:15:52 +08:00
|
|
|
retry_max = 5
|
2025-07-03 19:05:31 +08:00
|
|
|
resp = dashscope.TextEmbedding.call(model=self.model_name, input=texts[i : i + batch_size], api_key=self.key, text_type="document")
|
2025-07-10 10:30:18 +08:00
|
|
|
while (resp["output"] is None or resp["output"].get("embeddings") is None) and retry_max > 0:
|
fix: retry embedding with Qwen family models when limits temporarily reached. (#8690)
fix: retry embedding with Qwen family models when limits temporarily
reached.
APIs of Qwen family models are limited by calling rates. When reached,
the "output" attribute of the "resp" will be None, and in turn cause
TypeError when trying to retrieve "embeddings". Since these limits are
almost temporary, I have added a simple retry mechanism to avoid it.
Besides, if retry_max reached, the error can be early raised, instead of
hidden behind "TypeError".
### What problem does this PR solve?
Sometimes Qwen blocks calling due to rate limits, but it will cause the
whole parsing procedure stops when creating knowledge base. In this
situation, resp["output"] will be None, and resp["output"]["embeddings"]
will cause TypeError. Since the limits are temporary, I apply a simple
retry mechanism to solve it.
### Type of change
- [x] Bug Fix (non-breaking change which fixes an issue)
---------
Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
2025-07-07 12:15:52 +08:00
|
|
|
time.sleep(10)
|
|
|
|
|
resp = dashscope.TextEmbedding.call(model=self.model_name, input=texts[i : i + batch_size], api_key=self.key, text_type="document")
|
|
|
|
|
retry_max -= 1
|
2025-07-10 10:30:18 +08:00
|
|
|
if retry_max == 0 and (resp["output"] is None or resp["output"].get("embeddings") is None):
|
|
|
|
|
if resp.get("message"):
|
|
|
|
|
log_exception(ValueError(f"Retry_max reached, calling embedding model failed: {resp['message']}"))
|
|
|
|
|
else:
|
|
|
|
|
log_exception(ValueError("Retry_max reached, calling embedding model failed"))
|
fix: retry embedding with Qwen family models when limits temporarily reached. (#8690)
fix: retry embedding with Qwen family models when limits temporarily
reached.
APIs of Qwen family models are limited by calling rates. When reached,
the "output" attribute of the "resp" will be None, and in turn cause
TypeError when trying to retrieve "embeddings". Since these limits are
almost temporary, I have added a simple retry mechanism to avoid it.
Besides, if retry_max reached, the error can be early raised, instead of
hidden behind "TypeError".
### What problem does this PR solve?
Sometimes Qwen blocks calling due to rate limits, but it will cause the
whole parsing procedure stops when creating knowledge base. In this
situation, resp["output"] will be None, and resp["output"]["embeddings"]
will cause TypeError. Since the limits are temporary, I apply a simple
retry mechanism to solve it.
### Type of change
- [x] Bug Fix (non-breaking change which fixes an issue)
---------
Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
2025-07-07 12:15:52 +08:00
|
|
|
raise
|
2025-06-12 17:53:59 +08:00
|
|
|
try:
|
2024-05-31 09:46:22 +08:00
|
|
|
embds = [[] for _ in range(len(resp["output"]["embeddings"]))]
|
|
|
|
|
for e in resp["output"]["embeddings"]:
|
|
|
|
|
embds[e["text_index"]] = e["embedding"]
|
|
|
|
|
res.extend(embds)
|
2025-01-26 13:54:26 +08:00
|
|
|
token_count += self.total_token_count(resp)
|
2025-06-12 17:53:59 +08:00
|
|
|
except Exception as _e:
|
|
|
|
|
log_exception(_e, resp)
|
|
|
|
|
raise
|
|
|
|
|
return np.array(res), token_count
|
2024-05-31 09:46:22 +08:00
|
|
|
|
|
|
|
|
def encode_queries(self, text):
|
2025-07-03 19:05:31 +08:00
|
|
|
resp = dashscope.TextEmbedding.call(model=self.model_name, input=text[:2048], api_key=self.key, text_type="query")
|
2024-05-31 09:46:22 +08:00
|
|
|
try:
|
2025-07-03 19:05:31 +08:00
|
|
|
return np.array(resp["output"]["embeddings"][0]["embedding"]), self.total_token_count(resp)
|
2025-06-12 17:53:59 +08:00
|
|
|
except Exception as _e:
|
|
|
|
|
log_exception(_e, resp)
|
2024-02-08 17:01:01 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class ZhipuEmbed(Base):
|
2025-07-03 19:05:31 +08:00
|
|
|
_FACTORY_NAME = "ZHIPU-AI"
|
|
|
|
|
|
2024-03-28 19:15:16 +08:00
|
|
|
def __init__(self, key, model_name="embedding-2", **kwargs):
|
2024-02-08 17:01:01 +08:00
|
|
|
self.client = ZhipuAI(api_key=key)
|
|
|
|
|
self.model_name = model_name
|
|
|
|
|
|
2024-12-03 16:22:39 +08:00
|
|
|
def encode(self, texts: list):
|
2024-03-27 17:55:45 +08:00
|
|
|
arr = []
|
|
|
|
|
tks_num = 0
|
2025-01-15 14:36:27 +08:00
|
|
|
MAX_LEN = -1
|
|
|
|
|
if self.model_name.lower() == "embedding-2":
|
|
|
|
|
MAX_LEN = 512
|
|
|
|
|
if self.model_name.lower() == "embedding-3":
|
|
|
|
|
MAX_LEN = 3072
|
|
|
|
|
if MAX_LEN > 0:
|
|
|
|
|
texts = [truncate(t, MAX_LEN) for t in texts]
|
|
|
|
|
|
2024-03-27 17:55:45 +08:00
|
|
|
for txt in texts:
|
2025-07-03 19:05:31 +08:00
|
|
|
res = self.client.embeddings.create(input=txt, model=self.model_name)
|
2025-06-12 17:53:59 +08:00
|
|
|
try:
|
|
|
|
|
arr.append(res.data[0].embedding)
|
|
|
|
|
tks_num += self.total_token_count(res)
|
|
|
|
|
except Exception as _e:
|
|
|
|
|
log_exception(_e, res)
|
2024-03-27 17:55:45 +08:00
|
|
|
return np.array(arr), tks_num
|
2024-02-08 17:01:01 +08:00
|
|
|
|
|
|
|
|
def encode_queries(self, text):
|
2025-07-03 19:05:31 +08:00
|
|
|
res = self.client.embeddings.create(input=text, model=self.model_name)
|
2025-06-12 17:53:59 +08:00
|
|
|
try:
|
|
|
|
|
return np.array(res.data[0].embedding), self.total_token_count(res)
|
|
|
|
|
except Exception as _e:
|
|
|
|
|
log_exception(_e, res)
|
2024-04-08 19:20:57 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class OllamaEmbed(Base):
|
2025-07-03 19:05:31 +08:00
|
|
|
_FACTORY_NAME = "Ollama"
|
|
|
|
|
|
2025-07-07 14:13:37 +08:00
|
|
|
_special_tokens = ["<|endoftext|>"]
|
|
|
|
|
|
2024-04-08 19:20:57 +08:00
|
|
|
def __init__(self, key, model_name, **kwargs):
|
2025-07-23 18:10:51 +08:00
|
|
|
self.client = Client(host=kwargs["base_url"]) if not key or key == "x" else Client(host=kwargs["base_url"], headers={"Authorization": f"Bearer {key}"})
|
2024-04-08 19:20:57 +08:00
|
|
|
self.model_name = model_name
|
2025-07-25 12:16:33 +08:00
|
|
|
self.keep_alive = kwargs.get("ollama_keep_alive", int(os.environ.get("OLLAMA_KEEP_ALIVE", -1)))
|
2024-04-08 19:20:57 +08:00
|
|
|
|
2024-12-03 16:22:39 +08:00
|
|
|
def encode(self, texts: list):
|
2024-04-08 19:20:57 +08:00
|
|
|
arr = []
|
|
|
|
|
tks_num = 0
|
|
|
|
|
for txt in texts:
|
2025-07-28 10:16:38 +08:00
|
|
|
# remove special tokens if they exist base on regex in one request
|
2025-07-07 14:13:37 +08:00
|
|
|
for token in OllamaEmbed._special_tokens:
|
|
|
|
|
txt = txt.replace(token, "")
|
2025-07-25 12:16:33 +08:00
|
|
|
res = self.client.embeddings(prompt=txt, model=self.model_name, options={"use_mmap": True}, keep_alive=self.keep_alive)
|
2025-06-12 17:53:59 +08:00
|
|
|
try:
|
|
|
|
|
arr.append(res["embedding"])
|
|
|
|
|
except Exception as _e:
|
|
|
|
|
log_exception(_e, res)
|
2024-04-08 19:20:57 +08:00
|
|
|
tks_num += 128
|
|
|
|
|
return np.array(arr), tks_num
|
|
|
|
|
|
|
|
|
|
def encode_queries(self, text):
|
2025-07-07 14:13:37 +08:00
|
|
|
# remove special tokens if they exist
|
|
|
|
|
for token in OllamaEmbed._special_tokens:
|
|
|
|
|
text = text.replace(token, "")
|
2025-07-25 12:16:33 +08:00
|
|
|
res = self.client.embeddings(prompt=text, model=self.model_name, options={"use_mmap": True}, keep_alive=self.keep_alive)
|
2025-06-12 17:53:59 +08:00
|
|
|
try:
|
|
|
|
|
return np.array(res["embedding"]), 128
|
|
|
|
|
except Exception as _e:
|
|
|
|
|
log_exception(_e, res)
|
2024-04-11 18:22:25 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class XinferenceEmbed(Base):
|
2025-07-03 19:05:31 +08:00
|
|
|
_FACTORY_NAME = "Xinference"
|
|
|
|
|
|
2024-04-11 18:22:25 +08:00
|
|
|
def __init__(self, key, model_name="", base_url=""):
|
2025-06-03 14:18:40 +08:00
|
|
|
base_url = urljoin(base_url, "v1")
|
2024-10-16 10:21:08 +08:00
|
|
|
self.client = OpenAI(api_key=key, base_url=base_url)
|
2024-04-11 18:22:25 +08:00
|
|
|
self.model_name = model_name
|
|
|
|
|
|
2024-12-03 16:22:39 +08:00
|
|
|
def encode(self, texts: list):
|
|
|
|
|
batch_size = 16
|
|
|
|
|
ress = []
|
|
|
|
|
total_tokens = 0
|
|
|
|
|
for i in range(0, len(texts), batch_size):
|
2025-07-31 12:14:50 +08:00
|
|
|
res = None
|
2025-06-12 17:53:59 +08:00
|
|
|
try:
|
2025-07-31 12:14:50 +08:00
|
|
|
res = self.client.embeddings.create(input=texts[i : i + batch_size], model=self.model_name)
|
2025-06-12 17:53:59 +08:00
|
|
|
ress.extend([d.embedding for d in res.data])
|
|
|
|
|
total_tokens += self.total_token_count(res)
|
|
|
|
|
except Exception as _e:
|
|
|
|
|
log_exception(_e, res)
|
2024-12-03 16:22:39 +08:00
|
|
|
return np.array(ress), total_tokens
|
2024-04-11 18:22:25 +08:00
|
|
|
|
|
|
|
|
def encode_queries(self, text):
|
2025-07-31 12:14:50 +08:00
|
|
|
res = None
|
2025-06-12 17:53:59 +08:00
|
|
|
try:
|
2025-07-31 12:14:50 +08:00
|
|
|
res = self.client.embeddings.create(input=[text], model=self.model_name)
|
2025-06-12 17:53:59 +08:00
|
|
|
return np.array(res.data[0].embedding), self.total_token_count(res)
|
|
|
|
|
except Exception as _e:
|
|
|
|
|
log_exception(_e, res)
|
2024-04-15 13:28:06 +05:30
|
|
|
|
2024-04-16 16:42:19 +08:00
|
|
|
|
2024-04-25 14:14:28 +08:00
|
|
|
class YoudaoEmbed(Base):
|
2025-07-03 19:05:31 +08:00
|
|
|
_FACTORY_NAME = "Youdao"
|
2024-04-16 16:42:19 +08:00
|
|
|
_client = None
|
|
|
|
|
|
|
|
|
|
def __init__(self, key=None, model_name="maidalun1020/bce-embedding-base_v1", **kwargs):
|
2025-10-23 23:02:27 +08:00
|
|
|
pass
|
2024-04-16 16:42:19 +08:00
|
|
|
|
2024-12-03 16:22:39 +08:00
|
|
|
def encode(self, texts: list):
|
|
|
|
|
batch_size = 10
|
2024-04-16 16:42:19 +08:00
|
|
|
res = []
|
|
|
|
|
token_count = 0
|
|
|
|
|
for t in texts:
|
|
|
|
|
token_count += num_tokens_from_string(t)
|
|
|
|
|
for i in range(0, len(texts), batch_size):
|
2025-07-03 19:05:31 +08:00
|
|
|
embds = YoudaoEmbed._client.encode(texts[i : i + batch_size])
|
2024-04-16 16:42:19 +08:00
|
|
|
res.extend(embds)
|
|
|
|
|
return np.array(res), token_count
|
|
|
|
|
|
|
|
|
|
def encode_queries(self, text):
|
2024-04-25 14:14:28 +08:00
|
|
|
embds = YoudaoEmbed._client.encode([text])
|
2024-04-16 16:42:19 +08:00
|
|
|
return np.array(embds[0]), num_tokens_from_string(text)
|
2024-05-29 16:50:02 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class JinaEmbed(Base):
|
2025-07-03 19:05:31 +08:00
|
|
|
_FACTORY_NAME = "Jina"
|
2024-05-29 16:50:02 +08:00
|
|
|
|
2025-07-03 19:05:31 +08:00
|
|
|
def __init__(self, key, model_name="jina-embeddings-v3", base_url="https://api.jina.ai/v1/embeddings"):
|
2024-05-29 16:50:02 +08:00
|
|
|
self.base_url = "https://api.jina.ai/v1/embeddings"
|
2025-07-03 19:05:31 +08:00
|
|
|
self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
|
2024-05-29 16:50:02 +08:00
|
|
|
self.model_name = model_name
|
|
|
|
|
|
2024-12-03 16:22:39 +08:00
|
|
|
def encode(self, texts: list):
|
2024-05-29 16:50:02 +08:00
|
|
|
texts = [truncate(t, 8196) for t in texts]
|
2024-12-03 16:22:39 +08:00
|
|
|
batch_size = 16
|
|
|
|
|
ress = []
|
|
|
|
|
token_count = 0
|
|
|
|
|
for i in range(0, len(texts), batch_size):
|
2025-07-03 19:05:31 +08:00
|
|
|
data = {"model": self.model_name, "input": texts[i : i + batch_size], "encoding_type": "float"}
|
2025-06-10 19:04:17 +08:00
|
|
|
response = requests.post(self.base_url, headers=self.headers, json=data)
|
|
|
|
|
try:
|
|
|
|
|
res = response.json()
|
2025-06-12 17:53:59 +08:00
|
|
|
ress.extend([d["embedding"] for d in res["data"]])
|
|
|
|
|
token_count += self.total_token_count(res)
|
|
|
|
|
except Exception as _e:
|
|
|
|
|
log_exception(_e, response)
|
2024-12-03 16:22:39 +08:00
|
|
|
return np.array(ress), token_count
|
2024-05-29 16:50:02 +08:00
|
|
|
|
|
|
|
|
def encode_queries(self, text):
|
|
|
|
|
embds, cnt = self.encode([text])
|
2024-06-10 17:23:58 -07:00
|
|
|
return np.array(embds[0]), cnt
|
|
|
|
|
|
|
|
|
|
|
2024-06-14 11:32:58 +08:00
|
|
|
class MistralEmbed(Base):
|
2025-07-03 19:05:31 +08:00
|
|
|
_FACTORY_NAME = "Mistral"
|
|
|
|
|
|
|
|
|
|
def __init__(self, key, model_name="mistral-embed", base_url=None):
|
2024-06-14 11:32:58 +08:00
|
|
|
from mistralai.client import MistralClient
|
2025-07-03 19:05:31 +08:00
|
|
|
|
2024-06-14 11:32:58 +08:00
|
|
|
self.client = MistralClient(api_key=key)
|
|
|
|
|
self.model_name = model_name
|
|
|
|
|
|
2024-12-03 16:22:39 +08:00
|
|
|
def encode(self, texts: list):
|
2025-07-30 11:37:49 +08:00
|
|
|
import time
|
|
|
|
|
import random
|
2025-09-18 14:49:47 +08:00
|
|
|
|
2024-06-14 11:32:58 +08:00
|
|
|
texts = [truncate(t, 8196) for t in texts]
|
2024-12-03 16:22:39 +08:00
|
|
|
batch_size = 16
|
|
|
|
|
ress = []
|
|
|
|
|
token_count = 0
|
|
|
|
|
for i in range(0, len(texts), batch_size):
|
2025-07-30 11:37:49 +08:00
|
|
|
retry_max = 5
|
|
|
|
|
while retry_max > 0:
|
|
|
|
|
try:
|
|
|
|
|
res = self.client.embeddings(input=texts[i : i + batch_size], model=self.model_name)
|
|
|
|
|
ress.extend([d.embedding for d in res.data])
|
|
|
|
|
token_count += self.total_token_count(res)
|
|
|
|
|
break
|
|
|
|
|
except Exception as _e:
|
|
|
|
|
if retry_max == 1:
|
|
|
|
|
log_exception(_e)
|
|
|
|
|
delay = random.uniform(20, 60)
|
|
|
|
|
time.sleep(delay)
|
|
|
|
|
retry_max -= 1
|
2024-12-03 16:22:39 +08:00
|
|
|
return np.array(ress), token_count
|
2024-06-14 11:32:58 +08:00
|
|
|
|
|
|
|
|
def encode_queries(self, text):
|
2025-07-30 11:37:49 +08:00
|
|
|
import time
|
|
|
|
|
import random
|
|
|
|
|
retry_max = 5
|
|
|
|
|
while retry_max > 0:
|
|
|
|
|
try:
|
|
|
|
|
res = self.client.embeddings(input=[truncate(text, 8196)], model=self.model_name)
|
|
|
|
|
return np.array(res.data[0].embedding), self.total_token_count(res)
|
|
|
|
|
except Exception as _e:
|
|
|
|
|
if retry_max == 1:
|
|
|
|
|
log_exception(_e)
|
|
|
|
|
delay = random.randint(20, 60)
|
|
|
|
|
time.sleep(delay)
|
|
|
|
|
retry_max -= 1
|
2024-07-08 09:37:34 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class BedrockEmbed(Base):
|
2025-07-03 19:05:31 +08:00
|
|
|
_FACTORY_NAME = "Bedrock"
|
|
|
|
|
|
|
|
|
|
def __init__(self, key, model_name, **kwargs):
|
2024-07-08 09:37:34 +08:00
|
|
|
import boto3
|
2025-07-03 19:05:31 +08:00
|
|
|
|
|
|
|
|
self.bedrock_ak = json.loads(key).get("bedrock_ak", "")
|
|
|
|
|
self.bedrock_sk = json.loads(key).get("bedrock_sk", "")
|
|
|
|
|
self.bedrock_region = json.loads(key).get("bedrock_region", "")
|
2024-07-08 09:37:34 +08:00
|
|
|
self.model_name = model_name
|
2025-07-28 10:16:38 +08:00
|
|
|
self.is_amazon = self.model_name.split(".")[0] == "amazon"
|
|
|
|
|
self.is_cohere = self.model_name.split(".")[0] == "cohere"
|
2025-07-03 19:05:31 +08:00
|
|
|
|
|
|
|
|
if self.bedrock_ak == "" or self.bedrock_sk == "" or self.bedrock_region == "":
|
2025-02-23 22:01:14 -05:00
|
|
|
# Try to create a client using the default credentials (AWS_PROFILE, AWS_DEFAULT_REGION, etc.)
|
2025-07-03 19:05:31 +08:00
|
|
|
self.client = boto3.client("bedrock-runtime")
|
2025-02-23 22:01:14 -05:00
|
|
|
else:
|
2025-07-03 19:05:31 +08:00
|
|
|
self.client = boto3.client(service_name="bedrock-runtime", region_name=self.bedrock_region, aws_access_key_id=self.bedrock_ak, aws_secret_access_key=self.bedrock_sk)
|
2024-07-08 09:37:34 +08:00
|
|
|
|
2024-12-03 16:22:39 +08:00
|
|
|
def encode(self, texts: list):
|
2024-07-08 09:37:34 +08:00
|
|
|
texts = [truncate(t, 8196) for t in texts]
|
|
|
|
|
embeddings = []
|
|
|
|
|
token_count = 0
|
|
|
|
|
for text in texts:
|
2025-07-28 10:16:38 +08:00
|
|
|
if self.is_amazon:
|
2024-07-08 09:37:34 +08:00
|
|
|
body = {"inputText": text}
|
2025-07-28 10:16:38 +08:00
|
|
|
elif self.is_cohere:
|
2025-07-03 19:05:31 +08:00
|
|
|
body = {"texts": [text], "input_type": "search_document"}
|
2024-07-08 09:37:34 +08:00
|
|
|
|
|
|
|
|
response = self.client.invoke_model(modelId=self.model_name, body=json.dumps(body))
|
2025-06-12 17:53:59 +08:00
|
|
|
try:
|
|
|
|
|
model_response = json.loads(response["body"].read())
|
|
|
|
|
embeddings.extend([model_response["embedding"]])
|
|
|
|
|
token_count += num_tokens_from_string(text)
|
|
|
|
|
except Exception as _e:
|
|
|
|
|
log_exception(_e, response)
|
2024-07-08 09:37:34 +08:00
|
|
|
|
|
|
|
|
return np.array(embeddings), token_count
|
|
|
|
|
|
|
|
|
|
def encode_queries(self, text):
|
|
|
|
|
embeddings = []
|
|
|
|
|
token_count = num_tokens_from_string(text)
|
2025-07-28 10:16:38 +08:00
|
|
|
if self.is_amazon:
|
2024-07-08 09:37:34 +08:00
|
|
|
body = {"inputText": truncate(text, 8196)}
|
2025-07-28 10:16:38 +08:00
|
|
|
elif self.is_cohere:
|
2025-07-03 19:05:31 +08:00
|
|
|
body = {"texts": [truncate(text, 8196)], "input_type": "search_query"}
|
2024-07-08 09:37:34 +08:00
|
|
|
|
|
|
|
|
response = self.client.invoke_model(modelId=self.model_name, body=json.dumps(body))
|
2025-06-12 17:53:59 +08:00
|
|
|
try:
|
|
|
|
|
model_response = json.loads(response["body"].read())
|
|
|
|
|
embeddings.extend(model_response["embedding"])
|
|
|
|
|
except Exception as _e:
|
|
|
|
|
log_exception(_e, response)
|
2024-07-08 09:37:34 +08:00
|
|
|
|
|
|
|
|
return np.array(embeddings), token_count
|
|
|
|
|
|
2025-01-06 14:41:29 +08:00
|
|
|
|
2024-07-11 15:41:00 +08:00
|
|
|
class GeminiEmbed(Base):
|
2025-07-03 19:05:31 +08:00
|
|
|
_FACTORY_NAME = "Gemini"
|
|
|
|
|
|
|
|
|
|
def __init__(self, key, model_name="models/text-embedding-004", **kwargs):
|
2024-12-03 16:22:39 +08:00
|
|
|
self.key = key
|
2025-07-03 19:05:31 +08:00
|
|
|
self.model_name = "models/" + model_name
|
|
|
|
|
|
2024-12-03 16:22:39 +08:00
|
|
|
def encode(self, texts: list):
|
2024-07-11 15:41:00 +08:00
|
|
|
texts = [truncate(t, 2048) for t in texts]
|
|
|
|
|
token_count = sum(num_tokens_from_string(text) for text in texts)
|
2024-12-03 16:22:39 +08:00
|
|
|
genai.configure(api_key=self.key)
|
|
|
|
|
batch_size = 16
|
|
|
|
|
ress = []
|
|
|
|
|
for i in range(0, len(texts), batch_size):
|
2025-07-03 19:05:31 +08:00
|
|
|
result = genai.embed_content(model=self.model_name, content=texts[i : i + batch_size], task_type="retrieval_document", title="Embedding of single string")
|
2025-06-12 17:53:59 +08:00
|
|
|
try:
|
2025-07-03 19:05:31 +08:00
|
|
|
ress.extend(result["embedding"])
|
2025-06-12 17:53:59 +08:00
|
|
|
except Exception as _e:
|
|
|
|
|
log_exception(_e, result)
|
2025-07-03 19:05:31 +08:00
|
|
|
return np.array(ress), token_count
|
|
|
|
|
|
2024-07-11 15:41:00 +08:00
|
|
|
def encode_queries(self, text):
|
2024-12-03 16:22:39 +08:00
|
|
|
genai.configure(api_key=self.key)
|
2025-07-03 19:05:31 +08:00
|
|
|
result = genai.embed_content(model=self.model_name, content=truncate(text, 2048), task_type="retrieval_document", title="Embedding of single string")
|
2024-07-11 15:41:00 +08:00
|
|
|
token_count = num_tokens_from_string(text)
|
2025-06-12 17:53:59 +08:00
|
|
|
try:
|
2025-07-03 19:05:31 +08:00
|
|
|
return np.array(result["embedding"]), token_count
|
2025-06-12 17:53:59 +08:00
|
|
|
except Exception as _e:
|
|
|
|
|
log_exception(_e, result)
|
2025-01-06 14:41:29 +08:00
|
|
|
|
2024-07-23 10:43:09 +08:00
|
|
|
|
|
|
|
|
class NvidiaEmbed(Base):
|
2025-07-03 19:05:31 +08:00
|
|
|
_FACTORY_NAME = "NVIDIA"
|
|
|
|
|
|
|
|
|
|
def __init__(self, key, model_name, base_url="https://integrate.api.nvidia.com/v1/embeddings"):
|
2024-07-23 10:43:09 +08:00
|
|
|
if not base_url:
|
|
|
|
|
base_url = "https://integrate.api.nvidia.com/v1/embeddings"
|
|
|
|
|
self.api_key = key
|
|
|
|
|
self.base_url = base_url
|
|
|
|
|
self.headers = {
|
|
|
|
|
"accept": "application/json",
|
|
|
|
|
"Content-Type": "application/json",
|
|
|
|
|
"authorization": f"Bearer {self.api_key}",
|
|
|
|
|
}
|
|
|
|
|
self.model_name = model_name
|
|
|
|
|
if model_name == "nvidia/embed-qa-4":
|
|
|
|
|
self.base_url = "https://ai.api.nvidia.com/v1/retrieval/nvidia/embeddings"
|
|
|
|
|
self.model_name = "NV-Embed-QA"
|
|
|
|
|
if model_name == "snowflake/arctic-embed-l":
|
|
|
|
|
self.base_url = "https://ai.api.nvidia.com/v1/retrieval/snowflake/arctic-embed-l/embeddings"
|
|
|
|
|
|
2024-12-03 16:22:39 +08:00
|
|
|
def encode(self, texts: list):
|
|
|
|
|
batch_size = 16
|
|
|
|
|
ress = []
|
|
|
|
|
token_count = 0
|
|
|
|
|
for i in range(0, len(texts), batch_size):
|
|
|
|
|
payload = {
|
|
|
|
|
"input": texts[i : i + batch_size],
|
|
|
|
|
"input_type": "query",
|
|
|
|
|
"model": self.model_name,
|
|
|
|
|
"encoding_format": "float",
|
|
|
|
|
"truncate": "END",
|
|
|
|
|
}
|
2025-06-10 19:04:17 +08:00
|
|
|
response = requests.post(self.base_url, headers=self.headers, json=payload)
|
|
|
|
|
try:
|
|
|
|
|
res = response.json()
|
2025-06-12 17:53:59 +08:00
|
|
|
except Exception as _e:
|
|
|
|
|
log_exception(_e, response)
|
2024-12-03 16:22:39 +08:00
|
|
|
ress.extend([d["embedding"] for d in res["data"]])
|
2025-01-26 13:54:26 +08:00
|
|
|
token_count += self.total_token_count(res)
|
2024-12-03 16:22:39 +08:00
|
|
|
return np.array(ress), token_count
|
2024-07-23 10:43:09 +08:00
|
|
|
|
|
|
|
|
def encode_queries(self, text):
|
|
|
|
|
embds, cnt = self.encode([text])
|
|
|
|
|
return np.array(embds[0]), cnt
|
2024-07-24 12:46:43 +08:00
|
|
|
|
|
|
|
|
|
2024-07-25 10:23:35 +08:00
|
|
|
class LmStudioEmbed(LocalAIEmbed):
|
2025-07-03 19:05:31 +08:00
|
|
|
_FACTORY_NAME = "LM-Studio"
|
|
|
|
|
|
2024-07-24 12:46:43 +08:00
|
|
|
def __init__(self, key, model_name, base_url):
|
|
|
|
|
if not base_url:
|
|
|
|
|
raise ValueError("Local llm url cannot be None")
|
2025-06-03 14:18:40 +08:00
|
|
|
base_url = urljoin(base_url, "v1")
|
2024-08-07 18:10:42 +08:00
|
|
|
self.client = OpenAI(api_key="lm-studio", base_url=base_url)
|
2024-07-24 12:46:43 +08:00
|
|
|
self.model_name = model_name
|
2024-08-06 16:20:21 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class OpenAI_APIEmbed(OpenAIEmbed):
|
2025-07-03 19:05:31 +08:00
|
|
|
_FACTORY_NAME = ["VLLM", "OpenAI-API-Compatible"]
|
|
|
|
|
|
2024-08-06 16:20:21 +08:00
|
|
|
def __init__(self, key, model_name, base_url):
|
|
|
|
|
if not base_url:
|
|
|
|
|
raise ValueError("url cannot be None")
|
2025-06-03 14:18:40 +08:00
|
|
|
base_url = urljoin(base_url, "v1")
|
2024-08-06 16:20:21 +08:00
|
|
|
self.client = OpenAI(api_key=key, base_url=base_url)
|
2024-08-07 18:40:51 +08:00
|
|
|
self.model_name = model_name.split("___")[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CoHereEmbed(Base):
|
2025-07-03 19:05:31 +08:00
|
|
|
_FACTORY_NAME = "Cohere"
|
|
|
|
|
|
2024-08-07 18:40:51 +08:00
|
|
|
def __init__(self, key, model_name, base_url=None):
|
|
|
|
|
from cohere import Client
|
|
|
|
|
|
|
|
|
|
self.client = Client(api_key=key)
|
|
|
|
|
self.model_name = model_name
|
|
|
|
|
|
2024-12-03 16:22:39 +08:00
|
|
|
def encode(self, texts: list):
|
|
|
|
|
batch_size = 16
|
|
|
|
|
ress = []
|
|
|
|
|
token_count = 0
|
|
|
|
|
for i in range(0, len(texts), batch_size):
|
|
|
|
|
res = self.client.embed(
|
|
|
|
|
texts=texts[i : i + batch_size],
|
|
|
|
|
model=self.model_name,
|
|
|
|
|
input_type="search_document",
|
|
|
|
|
embedding_types=["float"],
|
|
|
|
|
)
|
2025-06-12 17:53:59 +08:00
|
|
|
try:
|
|
|
|
|
ress.extend([d for d in res.embeddings.float])
|
|
|
|
|
token_count += res.meta.billed_units.input_tokens
|
|
|
|
|
except Exception as _e:
|
|
|
|
|
log_exception(_e, res)
|
2024-12-03 16:22:39 +08:00
|
|
|
return np.array(ress), token_count
|
2024-08-07 18:40:51 +08:00
|
|
|
|
|
|
|
|
def encode_queries(self, text):
|
|
|
|
|
res = self.client.embed(
|
|
|
|
|
texts=[text],
|
|
|
|
|
model=self.model_name,
|
|
|
|
|
input_type="search_query",
|
|
|
|
|
embedding_types=["float"],
|
|
|
|
|
)
|
2025-06-12 17:53:59 +08:00
|
|
|
try:
|
|
|
|
|
return np.array(res.embeddings.float[0]), int(res.meta.billed_units.input_tokens)
|
|
|
|
|
except Exception as _e:
|
|
|
|
|
log_exception(_e, res)
|
2024-08-12 10:11:50 +08:00
|
|
|
|
|
|
|
|
|
2025-01-24 10:29:30 +08:00
|
|
|
class TogetherAIEmbed(OpenAIEmbed):
|
2025-07-03 19:05:31 +08:00
|
|
|
_FACTORY_NAME = "TogetherAI"
|
|
|
|
|
|
2024-08-12 10:15:21 +08:00
|
|
|
def __init__(self, key, model_name, base_url="https://api.together.xyz/v1"):
|
|
|
|
|
if not base_url:
|
|
|
|
|
base_url = "https://api.together.xyz/v1"
|
2024-11-19 14:51:33 +08:00
|
|
|
super().__init__(key, model_name, base_url=base_url)
|
2024-08-12 10:15:21 +08:00
|
|
|
|
2024-08-19 10:36:57 +08:00
|
|
|
|
2024-08-12 10:11:50 +08:00
|
|
|
class PerfXCloudEmbed(OpenAIEmbed):
|
2025-07-03 19:05:31 +08:00
|
|
|
_FACTORY_NAME = "PerfXCloud"
|
|
|
|
|
|
2024-08-12 10:11:50 +08:00
|
|
|
def __init__(self, key, model_name, base_url="https://cloud.perfxlab.cn/v1"):
|
|
|
|
|
if not base_url:
|
|
|
|
|
base_url = "https://cloud.perfxlab.cn/v1"
|
|
|
|
|
super().__init__(key, model_name, base_url)
|
2024-08-12 11:06:25 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class UpstageEmbed(OpenAIEmbed):
|
2025-07-03 19:05:31 +08:00
|
|
|
_FACTORY_NAME = "Upstage"
|
|
|
|
|
|
2024-08-12 11:06:25 +08:00
|
|
|
def __init__(self, key, model_name, base_url="https://api.upstage.ai/v1/solar"):
|
|
|
|
|
if not base_url:
|
|
|
|
|
base_url = "https://api.upstage.ai/v1/solar"
|
|
|
|
|
super().__init__(key, model_name, base_url)
|
2024-08-13 16:09:10 +08:00
|
|
|
|
|
|
|
|
|
2024-09-11 12:17:44 +08:00
|
|
|
class SILICONFLOWEmbed(Base):
|
2025-07-03 19:05:31 +08:00
|
|
|
_FACTORY_NAME = "SILICONFLOW"
|
|
|
|
|
|
2025-06-30 11:22:11 +08:00
|
|
|
def __init__(self, key, model_name, base_url="https://api.siliconflow.cn/v1/embeddings"):
|
2024-08-13 16:09:10 +08:00
|
|
|
if not base_url:
|
2024-09-11 12:17:44 +08:00
|
|
|
base_url = "https://api.siliconflow.cn/v1/embeddings"
|
|
|
|
|
self.headers = {
|
|
|
|
|
"accept": "application/json",
|
|
|
|
|
"content-type": "application/json",
|
|
|
|
|
"authorization": f"Bearer {key}",
|
|
|
|
|
}
|
|
|
|
|
self.base_url = base_url
|
|
|
|
|
self.model_name = model_name
|
|
|
|
|
|
2024-12-03 16:22:39 +08:00
|
|
|
def encode(self, texts: list):
|
|
|
|
|
batch_size = 16
|
|
|
|
|
ress = []
|
|
|
|
|
token_count = 0
|
|
|
|
|
for i in range(0, len(texts), batch_size):
|
|
|
|
|
texts_batch = texts[i : i + batch_size]
|
2025-09-11 12:02:12 +08:00
|
|
|
if self.model_name in ["BAAI/bge-large-zh-v1.5", "BAAI/bge-large-en-v1.5"]:
|
|
|
|
|
# limit 512, 340 is almost safe
|
2025-10-09 12:36:19 +08:00
|
|
|
texts_batch = [" " if not text.strip() else truncate(text, 256) for text in texts_batch]
|
2025-09-11 12:02:12 +08:00
|
|
|
else:
|
|
|
|
|
texts_batch = [" " if not text.strip() else text for text in texts_batch]
|
2025-09-10 13:02:53 +08:00
|
|
|
|
2024-12-03 16:22:39 +08:00
|
|
|
payload = {
|
|
|
|
|
"model": self.model_name,
|
|
|
|
|
"input": texts_batch,
|
|
|
|
|
"encoding_format": "float",
|
|
|
|
|
}
|
2025-06-10 19:04:17 +08:00
|
|
|
response = requests.post(self.base_url, json=payload, headers=self.headers)
|
|
|
|
|
try:
|
|
|
|
|
res = response.json()
|
2025-06-12 17:53:59 +08:00
|
|
|
ress.extend([d["embedding"] for d in res["data"]])
|
|
|
|
|
token_count += self.total_token_count(res)
|
|
|
|
|
except Exception as _e:
|
|
|
|
|
log_exception(_e, response)
|
|
|
|
|
|
2024-12-03 16:22:39 +08:00
|
|
|
return np.array(ress), token_count
|
2024-09-11 12:17:44 +08:00
|
|
|
|
|
|
|
|
def encode_queries(self, text):
|
|
|
|
|
payload = {
|
|
|
|
|
"model": self.model_name,
|
|
|
|
|
"input": text,
|
|
|
|
|
"encoding_format": "float",
|
|
|
|
|
}
|
2025-06-20 11:13:00 +08:00
|
|
|
response = requests.post(self.base_url, json=payload, headers=self.headers)
|
2025-06-10 19:04:17 +08:00
|
|
|
try:
|
|
|
|
|
res = response.json()
|
2025-06-12 17:53:59 +08:00
|
|
|
return np.array(res["data"][0]["embedding"]), self.total_token_count(res)
|
|
|
|
|
except Exception as _e:
|
|
|
|
|
log_exception(_e, response)
|
2024-08-19 10:36:57 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class ReplicateEmbed(Base):
|
2025-07-03 19:05:31 +08:00
|
|
|
_FACTORY_NAME = "Replicate"
|
|
|
|
|
|
2024-08-19 10:36:57 +08:00
|
|
|
def __init__(self, key, model_name, base_url=None):
|
|
|
|
|
from replicate.client import Client
|
|
|
|
|
|
|
|
|
|
self.model_name = model_name
|
|
|
|
|
self.client = Client(api_token=key)
|
|
|
|
|
|
2024-12-03 16:22:39 +08:00
|
|
|
def encode(self, texts: list):
|
|
|
|
|
batch_size = 16
|
|
|
|
|
token_count = sum([num_tokens_from_string(text) for text in texts])
|
|
|
|
|
ress = []
|
|
|
|
|
for i in range(0, len(texts), batch_size):
|
|
|
|
|
res = self.client.run(self.model_name, input={"texts": texts[i : i + batch_size]})
|
|
|
|
|
ress.extend(res)
|
|
|
|
|
return np.array(ress), token_count
|
2024-08-19 10:36:57 +08:00
|
|
|
|
|
|
|
|
def encode_queries(self, text):
|
|
|
|
|
res = self.client.embed(self.model_name, input={"texts": [text]})
|
|
|
|
|
return np.array(res), num_tokens_from_string(text)
|
2024-08-22 16:45:15 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class BaiduYiyanEmbed(Base):
|
2025-07-03 19:05:31 +08:00
|
|
|
_FACTORY_NAME = "BaiduYiyan"
|
|
|
|
|
|
2024-08-22 16:45:15 +08:00
|
|
|
def __init__(self, key, model_name, base_url=None):
|
|
|
|
|
import qianfan
|
|
|
|
|
|
|
|
|
|
key = json.loads(key)
|
|
|
|
|
ak = key.get("yiyan_ak", "")
|
|
|
|
|
sk = key.get("yiyan_sk", "")
|
|
|
|
|
self.client = qianfan.Embedding(ak=ak, sk=sk)
|
|
|
|
|
self.model_name = model_name
|
|
|
|
|
|
2024-11-27 18:06:43 +08:00
|
|
|
def encode(self, texts: list, batch_size=16):
|
2024-08-22 16:45:15 +08:00
|
|
|
res = self.client.do(model=self.model_name, texts=texts).body
|
2025-06-12 17:53:59 +08:00
|
|
|
try:
|
|
|
|
|
return (
|
|
|
|
|
np.array([r["embedding"] for r in res["data"]]),
|
|
|
|
|
self.total_token_count(res),
|
|
|
|
|
)
|
|
|
|
|
except Exception as _e:
|
|
|
|
|
log_exception(_e, res)
|
2024-08-22 16:45:15 +08:00
|
|
|
|
|
|
|
|
def encode_queries(self, text):
|
|
|
|
|
res = self.client.do(model=self.model_name, texts=[text]).body
|
2025-06-12 17:53:59 +08:00
|
|
|
try:
|
|
|
|
|
return (
|
|
|
|
|
np.array([r["embedding"] for r in res["data"]]),
|
|
|
|
|
self.total_token_count(res),
|
|
|
|
|
)
|
|
|
|
|
except Exception as _e:
|
|
|
|
|
log_exception(_e, res)
|
2024-08-29 16:14:49 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class VoyageEmbed(Base):
|
2025-07-03 19:05:31 +08:00
|
|
|
_FACTORY_NAME = "Voyage AI"
|
|
|
|
|
|
2024-08-29 16:14:49 +08:00
|
|
|
def __init__(self, key, model_name, base_url=None):
|
|
|
|
|
import voyageai
|
|
|
|
|
|
|
|
|
|
self.client = voyageai.Client(api_key=key)
|
|
|
|
|
self.model_name = model_name
|
|
|
|
|
|
2024-12-03 16:22:39 +08:00
|
|
|
def encode(self, texts: list):
|
|
|
|
|
batch_size = 16
|
|
|
|
|
ress = []
|
|
|
|
|
token_count = 0
|
|
|
|
|
for i in range(0, len(texts), batch_size):
|
2025-07-03 19:05:31 +08:00
|
|
|
res = self.client.embed(texts=texts[i : i + batch_size], model=self.model_name, input_type="document")
|
2025-06-12 17:53:59 +08:00
|
|
|
try:
|
|
|
|
|
ress.extend(res.embeddings)
|
|
|
|
|
token_count += res.total_tokens
|
|
|
|
|
except Exception as _e:
|
|
|
|
|
log_exception(_e, res)
|
2024-12-03 16:22:39 +08:00
|
|
|
return np.array(ress), token_count
|
2024-08-29 16:14:49 +08:00
|
|
|
|
|
|
|
|
def encode_queries(self, text):
|
2025-07-03 19:05:31 +08:00
|
|
|
res = self.client.embed(texts=text, model=self.model_name, input_type="query")
|
2025-06-12 17:53:59 +08:00
|
|
|
try:
|
|
|
|
|
return np.array(res.embeddings)[0], res.total_tokens
|
|
|
|
|
except Exception as _e:
|
|
|
|
|
log_exception(_e, res)
|
2024-09-27 19:15:38 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class HuggingFaceEmbed(Base):
|
2025-07-03 19:05:31 +08:00
|
|
|
_FACTORY_NAME = "HuggingFace"
|
|
|
|
|
|
2025-08-07 08:45:37 +07:00
|
|
|
def __init__(self, key, model_name, base_url=None, **kwargs):
|
2024-09-27 19:15:38 +08:00
|
|
|
if not model_name:
|
|
|
|
|
raise ValueError("Model name cannot be None")
|
|
|
|
|
self.key = key
|
2024-12-05 13:28:42 +08:00
|
|
|
self.model_name = model_name.split("___")[0]
|
2024-09-27 19:15:38 +08:00
|
|
|
self.base_url = base_url or "http://127.0.0.1:8080"
|
|
|
|
|
|
2024-12-03 16:22:39 +08:00
|
|
|
def encode(self, texts: list):
|
2025-10-23 23:02:27 +08:00
|
|
|
response = requests.post(f"{self.base_url}/embed", json={"inputs": texts}, headers={"Content-Type": "application/json"})
|
|
|
|
|
if response.status_code == 200:
|
|
|
|
|
embeddings = response.json()
|
|
|
|
|
else:
|
|
|
|
|
raise Exception(f"Error: {response.status_code} - {response.text}")
|
fix a bug when using huggingface embedding api (#8432)
### What problem does this PR solve?
image_version: v0.19.1
This PR fixes a bug in the HuggingFaceEmBedding API method that was
causing AssertionError: assert len(vects) == len(docs) during the
document embedding process.
#### Problem
The HuggingFaceEmbed.encode() method had an early return statement
inside the for loop, causing it to return after processing only the
first text input instead of processing all texts in the input list.
**Error Messenge**
```python
AssertionError: assert len(vects) == len(docs) # input chunks != embedded vectors from embedding api
File "/ragflow/rag/svr/task_executor.py", line 442, in embedding
```
**Buggy code(/ragflow/rag/llm/embedding_model.py)**
```python
class HuggingFaceEmbed(Base):
def __init__(self, key, model_name, base_url=None):
if not model_name:
raise ValueError("Model name cannot be None")
self.key = key
self.model_name = model_name.split("___")[0]
self.base_url = base_url or "http://127.0.0.1:8080"
def encode(self, texts: list):
embeddings = []
for text in texts:
response = requests.post(...)
if response.status_code == 200:
try:
embedding = response.json()
embeddings.append(embedding[0])
# ❌ Early return
return np.array(embeddings), sum([num_tokens_from_string(text) for text in texts])
except Exception as _e:
log_exception(_e, response)
else:
raise Exception(...)
```
**Fixed Code(I just Rollback this function to the v0.19.0 version)**
```python
Class HuggingFaceEmbed(Base):
def __init__(self, key, model_name, base_url=None):
if not model_name:
raise ValueError("Model name cannot be None")
self.key = key
self.model_name = model_name.split("___")[0]
self.base_url = base_url or "http://127.0.0.1:8080"
def encode(self, texts: list):
embeddings = []
for text in texts:
response = requests.post(...)
if response.status_code == 200:
embedding = response.json()
embeddings.append(embedding[0]) # ✅ Only append, no return
else:
raise Exception(...)
return np.array(embeddings), sum([num_tokens_from_string(text) for text in texts]) # ✅ Return after processing all
```
### Type of change
- [x] Bug Fix (non-breaking change which fixes an issue)
2025-06-24 09:35:02 +08:00
|
|
|
return np.array(embeddings), sum([num_tokens_from_string(text) for text in texts])
|
2024-09-27 19:15:38 +08:00
|
|
|
|
2025-10-23 23:02:27 +08:00
|
|
|
def encode_queries(self, text: str):
|
2025-07-03 19:05:31 +08:00
|
|
|
response = requests.post(f"{self.base_url}/embed", json={"inputs": text}, headers={"Content-Type": "application/json"})
|
2024-09-27 19:15:38 +08:00
|
|
|
if response.status_code == 200:
|
2025-10-23 23:02:27 +08:00
|
|
|
embedding = response.json()[0]
|
|
|
|
|
return np.array(embedding), num_tokens_from_string(text)
|
2024-09-27 19:15:38 +08:00
|
|
|
else:
|
|
|
|
|
raise Exception(f"Error: {response.status_code} - {response.text}")
|
|
|
|
|
|
2024-12-05 13:28:42 +08:00
|
|
|
|
2024-11-27 09:30:49 +08:00
|
|
|
class VolcEngineEmbed(OpenAIEmbed):
|
2025-07-03 19:05:31 +08:00
|
|
|
_FACTORY_NAME = "VolcEngine"
|
|
|
|
|
|
2024-11-27 09:30:49 +08:00
|
|
|
def __init__(self, key, model_name, base_url="https://ark.cn-beijing.volces.com/api/v3"):
|
|
|
|
|
if not base_url:
|
|
|
|
|
base_url = "https://ark.cn-beijing.volces.com/api/v3"
|
2025-07-03 19:05:31 +08:00
|
|
|
ark_api_key = json.loads(key).get("ark_api_key", "")
|
|
|
|
|
model_name = json.loads(key).get("ep_id", "") + json.loads(key).get("endpoint_id", "")
|
|
|
|
|
super().__init__(ark_api_key, model_name, base_url)
|
2025-01-15 14:15:58 +08:00
|
|
|
|
2025-06-12 17:53:59 +08:00
|
|
|
|
2025-01-15 14:15:58 +08:00
|
|
|
class GPUStackEmbed(OpenAIEmbed):
|
2025-07-03 19:05:31 +08:00
|
|
|
_FACTORY_NAME = "GPUStack"
|
|
|
|
|
|
2025-01-15 14:15:58 +08:00
|
|
|
def __init__(self, key, model_name, base_url):
|
|
|
|
|
if not base_url:
|
|
|
|
|
raise ValueError("url cannot be None")
|
2025-06-03 14:18:40 +08:00
|
|
|
base_url = urljoin(base_url, "v1")
|
2025-01-15 14:15:58 +08:00
|
|
|
|
|
|
|
|
self.client = OpenAI(api_key=key, base_url=base_url)
|
2025-03-21 15:25:48 +08:00
|
|
|
self.model_name = model_name
|
2025-06-13 15:42:17 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class NovitaEmbed(SILICONFLOWEmbed):
|
2025-07-03 19:05:31 +08:00
|
|
|
_FACTORY_NAME = "NovitaAI"
|
|
|
|
|
|
2025-06-13 15:42:17 +08:00
|
|
|
def __init__(self, key, model_name, base_url="https://api.novita.ai/v3/openai/embeddings"):
|
2025-06-30 11:22:11 +08:00
|
|
|
if not base_url:
|
|
|
|
|
base_url = "https://api.novita.ai/v3/openai/embeddings"
|
2025-06-30 09:22:31 +08:00
|
|
|
super().__init__(key, model_name, base_url)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GiteeEmbed(SILICONFLOWEmbed):
|
2025-07-03 19:05:31 +08:00
|
|
|
_FACTORY_NAME = "GiteeAI"
|
|
|
|
|
|
2025-06-30 09:22:31 +08:00
|
|
|
def __init__(self, key, model_name, base_url="https://ai.gitee.com/v1/embeddings"):
|
2025-06-30 11:22:11 +08:00
|
|
|
if not base_url:
|
|
|
|
|
base_url = "https://ai.gitee.com/v1/embeddings"
|
2025-07-03 19:05:31 +08:00
|
|
|
super().__init__(key, model_name, base_url)
|
2025-09-10 13:02:53 +08:00
|
|
|
|
2025-07-23 18:10:35 +08:00
|
|
|
class DeepInfraEmbed(OpenAIEmbed):
|
|
|
|
|
_FACTORY_NAME = "DeepInfra"
|
|
|
|
|
|
|
|
|
|
def __init__(self, key, model_name, base_url="https://api.deepinfra.com/v1/openai"):
|
|
|
|
|
if not base_url:
|
|
|
|
|
base_url = "https://api.deepinfra.com/v1/openai"
|
|
|
|
|
super().__init__(key, model_name, base_url)
|
2025-07-31 14:48:30 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class Ai302Embed(Base):
|
|
|
|
|
_FACTORY_NAME = "302.AI"
|
|
|
|
|
|
|
|
|
|
def __init__(self, key, model_name, base_url="https://api.302.ai/v1/embeddings"):
|
|
|
|
|
if not base_url:
|
|
|
|
|
base_url = "https://api.302.ai/v1/embeddings"
|
2025-09-10 13:02:53 +08:00
|
|
|
super().__init__(key, model_name, base_url)
|
2025-09-18 09:51:29 +08:00
|
|
|
|
|
|
|
|
|
2025-09-26 10:50:56 +08:00
|
|
|
class CometAPIEmbed(OpenAIEmbed):
|
2025-09-18 09:51:29 +08:00
|
|
|
_FACTORY_NAME = "CometAPI"
|
|
|
|
|
|
2025-09-18 14:49:47 +08:00
|
|
|
def __init__(self, key, model_name, base_url="https://api.cometapi.com/v1"):
|
2025-09-18 09:51:29 +08:00
|
|
|
if not base_url:
|
2025-09-18 14:49:47 +08:00
|
|
|
base_url = "https://api.cometapi.com/v1"
|
2025-09-18 09:51:29 +08:00
|
|
|
super().__init__(key, model_name, base_url)
|
2025-10-09 11:14:49 +08:00
|
|
|
|
|
|
|
|
class DeerAPIEmbed(OpenAIEmbed):
|
|
|
|
|
_FACTORY_NAME = "DeerAPI"
|
|
|
|
|
|
|
|
|
|
def __init__(self, key, model_name, base_url="https://api.deerapi.com/v1"):
|
|
|
|
|
if not base_url:
|
|
|
|
|
base_url = "https://api.deerapi.com/v1"
|
|
|
|
|
super().__init__(key, model_name, base_url)
|