dify/api/core/rag/embedding/cached_embedding.py

149 lines
6.9 KiB
Python
Raw Normal View History

import base64
import logging
import pickle
from typing import Any, cast
import numpy as np
from sqlalchemy.exc import IntegrityError
from configs import dify_config
from core.entities.embedding_type import EmbeddingInputType
from core.model_manager import ModelInstance
from core.model_runtime.entities.model_entities import ModelPropertyKey
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.rag.embedding.embedding_base import Embeddings
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs import helper
from models.dataset import Embedding
logger = logging.getLogger(__name__)
class CacheEmbedding(Embeddings):
def __init__(self, model_instance: ModelInstance, user: str | None = None):
self._model_instance = model_instance
self._user = user
def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""Embed search docs in batches of 10."""
# use doc embedding cache or store if not exists
2024-12-24 18:38:51 +08:00
text_embeddings: list[Any] = [None for _ in range(len(texts))]
embedding_queue_indices = []
for i, text in enumerate(texts):
hash = helper.generate_text_hash(text)
embedding = (
db.session.query(Embedding)
.filter_by(
model_name=self._model_instance.model, hash=hash, provider_name=self._model_instance.provider
)
.first()
)
if embedding:
text_embeddings[i] = embedding.get_embedding()
else:
embedding_queue_indices.append(i)
# NOTE: avoid closing the shared scoped session here; downstream code may still have pending work
if embedding_queue_indices:
embedding_queue_texts = [texts[i] for i in embedding_queue_indices]
embedding_queue_embeddings = []
try:
model_type_instance = cast(TextEmbeddingModel, self._model_instance.model_type_instance)
model_schema = model_type_instance.get_model_schema(
self._model_instance.model, self._model_instance.credentials
)
max_chunks = (
model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS]
if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties
else 1
)
for i in range(0, len(embedding_queue_texts), max_chunks):
batch_texts = embedding_queue_texts[i : i + max_chunks]
embedding_result = self._model_instance.invoke_text_embedding(
texts=batch_texts, user=self._user, input_type=EmbeddingInputType.DOCUMENT
)
for vector in embedding_result.embeddings:
try:
2024-12-24 18:38:51 +08:00
# FIXME: type ignore for numpy here
normalized_embedding = (vector / np.linalg.norm(vector)).tolist() # type: ignore
# stackoverflow best way: https://stackoverflow.com/questions/20319813/how-to-check-list-containing-nan
if np.isnan(normalized_embedding).any():
# for issue #11827 float values are not json compliant
logger.warning("Normalized embedding is nan: %s", normalized_embedding)
continue
embedding_queue_embeddings.append(normalized_embedding)
except IntegrityError:
db.session.rollback()
Introduce Plugins (#13836) Signed-off-by: yihong0618 <zouzou0208@gmail.com> Signed-off-by: -LAN- <laipz8200@outlook.com> Signed-off-by: xhe <xw897002528@gmail.com> Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: takatost <takatost@gmail.com> Co-authored-by: kurokobo <kuro664@gmail.com> Co-authored-by: Novice Lee <novicelee@NoviPro.local> Co-authored-by: zxhlyh <jasonapring2015@outlook.com> Co-authored-by: AkaraChen <akarachen@outlook.com> Co-authored-by: Yi <yxiaoisme@gmail.com> Co-authored-by: Joel <iamjoel007@gmail.com> Co-authored-by: JzoNg <jzongcode@gmail.com> Co-authored-by: twwu <twwu@dify.ai> Co-authored-by: Hiroshi Fujita <fujita-h@users.noreply.github.com> Co-authored-by: AkaraChen <85140972+AkaraChen@users.noreply.github.com> Co-authored-by: NFish <douxc512@gmail.com> Co-authored-by: Wu Tianwei <30284043+WTW0313@users.noreply.github.com> Co-authored-by: 非法操作 <hjlarry@163.com> Co-authored-by: Novice <857526207@qq.com> Co-authored-by: Hiroki Nagai <82458324+nagaihiroki-git@users.noreply.github.com> Co-authored-by: Gen Sato <52241300+halogen22@users.noreply.github.com> Co-authored-by: eux <euxuuu@gmail.com> Co-authored-by: huangzhuo1949 <167434202+huangzhuo1949@users.noreply.github.com> Co-authored-by: huangzhuo <huangzhuo1@xiaomi.com> Co-authored-by: lotsik <lotsik@mail.ru> Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com> Co-authored-by: nite-knite <nkCoding@gmail.com> Co-authored-by: Jyong <76649700+JohnJyong@users.noreply.github.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: gakkiyomi <gakkiyomi@aliyun.com> Co-authored-by: CN-P5 <heibai2006@gmail.com> Co-authored-by: CN-P5 <heibai2006@qq.com> Co-authored-by: Chuehnone <1897025+chuehnone@users.noreply.github.com> Co-authored-by: yihong <zouzou0208@gmail.com> Co-authored-by: Kevin9703 <51311316+Kevin9703@users.noreply.github.com> Co-authored-by: -LAN- <laipz8200@outlook.com> Co-authored-by: Boris Feld <lothiraldan@gmail.com> Co-authored-by: mbo <himabo@gmail.com> Co-authored-by: mabo <mabo@aeyes.ai> Co-authored-by: Warren Chen <warren.chen830@gmail.com> Co-authored-by: JzoNgKVO <27049666+JzoNgKVO@users.noreply.github.com> Co-authored-by: jiandanfeng <chenjh3@wangsu.com> Co-authored-by: zhu-an <70234959+xhdd123321@users.noreply.github.com> Co-authored-by: zhaoqingyu.1075 <zhaoqingyu.1075@bytedance.com> Co-authored-by: 海狸大師 <86974027+yenslife@users.noreply.github.com> Co-authored-by: Xu Song <xusong.vip@gmail.com> Co-authored-by: rayshaw001 <396301947@163.com> Co-authored-by: Ding Jiatong <dingjiatong@gmail.com> Co-authored-by: Bowen Liang <liangbowen@gf.com.cn> Co-authored-by: JasonVV <jasonwangiii@outlook.com> Co-authored-by: le0zh <newlight@qq.com> Co-authored-by: zhuxinliang <zhuxinliang@didiglobal.com> Co-authored-by: k-zaku <zaku99@outlook.jp> Co-authored-by: luckylhb90 <luckylhb90@gmail.com> Co-authored-by: hobo.l <hobo.l@binance.com> Co-authored-by: jiangbo721 <365065261@qq.com> Co-authored-by: 刘江波 <jiangbo721@163.com> Co-authored-by: Shun Miyazawa <34241526+miya@users.noreply.github.com> Co-authored-by: EricPan <30651140+Egfly@users.noreply.github.com> Co-authored-by: crazywoola <427733928@qq.com> Co-authored-by: sino <sino2322@gmail.com> Co-authored-by: Jhvcc <37662342+Jhvcc@users.noreply.github.com> Co-authored-by: lowell <lowell.hu@zkteco.in> Co-authored-by: Boris Polonsky <BorisPolonsky@users.noreply.github.com> Co-authored-by: Ademílson Tonato <ademilsonft@outlook.com> Co-authored-by: Ademílson Tonato <ademilson.tonato@refurbed.com> Co-authored-by: IWAI, Masaharu <iwaim.sub@gmail.com> Co-authored-by: Yueh-Po Peng (Yabi) <94939112+y10ab1@users.noreply.github.com> Co-authored-by: Jason <ggbbddjm@gmail.com> Co-authored-by: Xin Zhang <sjhpzx@gmail.com> Co-authored-by: yjc980121 <3898524+yjc980121@users.noreply.github.com> Co-authored-by: heyszt <36215648+hieheihei@users.noreply.github.com> Co-authored-by: Abdullah AlOsaimi <osaimiacc@gmail.com> Co-authored-by: Abdullah AlOsaimi <189027247+osaimi@users.noreply.github.com> Co-authored-by: Yingchun Lai <laiyingchun@apache.org> Co-authored-by: Hash Brown <hi@xzd.me> Co-authored-by: zuodongxu <192560071+zuodongxu@users.noreply.github.com> Co-authored-by: Masashi Tomooka <tmokmss@users.noreply.github.com> Co-authored-by: aplio <ryo.091219@gmail.com> Co-authored-by: Obada Khalili <54270856+obadakhalili@users.noreply.github.com> Co-authored-by: Nam Vu <zuzoovn@gmail.com> Co-authored-by: Kei YAMAZAKI <1715090+kei-yamazaki@users.noreply.github.com> Co-authored-by: TechnoHouse <13776377+deephbz@users.noreply.github.com> Co-authored-by: Riddhimaan-Senapati <114703025+Riddhimaan-Senapati@users.noreply.github.com> Co-authored-by: MaFee921 <31881301+2284730142@users.noreply.github.com> Co-authored-by: te-chan <t-nakanome@sakura-is.co.jp> Co-authored-by: HQidea <HQidea@users.noreply.github.com> Co-authored-by: Joshbly <36315710+Joshbly@users.noreply.github.com> Co-authored-by: xhe <xw897002528@gmail.com> Co-authored-by: weiwenyan-dev <154779315+weiwenyan-dev@users.noreply.github.com> Co-authored-by: ex_wenyan.wei <ex_wenyan.wei@tcl.com> Co-authored-by: engchina <12236799+engchina@users.noreply.github.com> Co-authored-by: engchina <atjapan2015@gmail.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: 呆萌闷油瓶 <253605712@qq.com> Co-authored-by: Kemal <kemalmeler@outlook.com> Co-authored-by: Lazy_Frog <4590648+lazyFrogLOL@users.noreply.github.com> Co-authored-by: Yi Xiao <54782454+YIXIAO0@users.noreply.github.com> Co-authored-by: Steven sun <98230804+Tuyohai@users.noreply.github.com> Co-authored-by: steven <sunzwj@digitalchina.com> Co-authored-by: Kalo Chin <91766386+fdb02983rhy@users.noreply.github.com> Co-authored-by: Katy Tao <34019945+KatyTao@users.noreply.github.com> Co-authored-by: depy <42985524+h4ckdepy@users.noreply.github.com> Co-authored-by: 胡春东 <gycm520@gmail.com> Co-authored-by: Junjie.M <118170653@qq.com> Co-authored-by: MuYu <mr.muzea@gmail.com> Co-authored-by: Naoki Takashima <39912547+takatea@users.noreply.github.com> Co-authored-by: Summer-Gu <37869445+gubinjie@users.noreply.github.com> Co-authored-by: Fei He <droxer.he@gmail.com> Co-authored-by: ybalbert001 <120714773+ybalbert001@users.noreply.github.com> Co-authored-by: Yuanbo Li <ybalbert@amazon.com> Co-authored-by: douxc <7553076+douxc@users.noreply.github.com> Co-authored-by: liuzhenghua <1090179900@qq.com> Co-authored-by: Wu Jiayang <62842862+Wu-Jiayang@users.noreply.github.com> Co-authored-by: Your Name <you@example.com> Co-authored-by: kimjion <45935338+kimjion@users.noreply.github.com> Co-authored-by: AugNSo <song.tiankai@icloud.com> Co-authored-by: llinvokerl <38915183+llinvokerl@users.noreply.github.com> Co-authored-by: liusurong.lsr <liusurong.lsr@alibaba-inc.com> Co-authored-by: Vasu Negi <vasu-negi@users.noreply.github.com> Co-authored-by: Hundredwz <1808096180@qq.com> Co-authored-by: Xiyuan Chen <52963600+GareArc@users.noreply.github.com>
2025-02-17 17:05:13 +08:00
except Exception:
logger.exception("Failed transform embedding")
cache_embeddings = []
try:
2024-12-24 18:38:51 +08:00
for i, n_embedding in zip(embedding_queue_indices, embedding_queue_embeddings):
text_embeddings[i] = n_embedding
hash = helper.generate_text_hash(texts[i])
if hash not in cache_embeddings:
embedding_cache = Embedding(
model_name=self._model_instance.model,
hash=hash,
provider_name=self._model_instance.provider,
embedding=pickle.dumps(n_embedding, protocol=pickle.HIGHEST_PROTOCOL),
)
db.session.add(embedding_cache)
cache_embeddings.append(hash)
db.session.commit()
except IntegrityError:
db.session.rollback()
except Exception as ex:
db.session.rollback()
logger.exception("Failed to embed documents")
raise ex
return text_embeddings
def embed_query(self, text: str) -> list[float]:
"""Embed query text."""
# use doc embedding cache or store if not exists
hash = helper.generate_text_hash(text)
embedding_cache_key = f"{self._model_instance.provider}_{self._model_instance.model}_{hash}"
embedding = redis_client.get(embedding_cache_key)
if embedding:
redis_client.expire(embedding_cache_key, 600)
decoded_embedding = np.frombuffer(base64.b64decode(embedding), dtype="float")
return [float(x) for x in decoded_embedding]
try:
embedding_result = self._model_instance.invoke_text_embedding(
texts=[text], user=self._user, input_type=EmbeddingInputType.QUERY
)
embedding_results = embedding_result.embeddings[0]
2024-12-24 18:38:51 +08:00
# FIXME: type ignore for numpy here
embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist() # type: ignore
if np.isnan(embedding_results).any():
raise ValueError("Normalized embedding is nan please try again")
except Exception as ex:
if dify_config.DEBUG:
logger.exception("Failed to embed query text '%s...(%s chars)'", text[:10], len(text))
raise ex
try:
# encode embedding to base64
embedding_vector = np.array(embedding_results)
vector_bytes = embedding_vector.tobytes()
# Transform to Base64
encoded_vector = base64.b64encode(vector_bytes)
# Transform to string
encoded_str = encoded_vector.decode("utf-8")
redis_client.setex(embedding_cache_key, 600, encoded_str)
except Exception as ex:
if dify_config.DEBUG:
logger.exception(
"Failed to add embedding to redis for the text '%s...(%s chars)'", text[:10], len(text)
)
raise ex
2025-06-11 01:01:33 +08:00
return embedding_results # type: ignore