| 
									
										
										
										
											2024-01-18 21:39:12 +08:00
										 |  |  | import base64 | 
					
						
							| 
									
										
										
										
											2023-06-25 16:49:14 +08:00
										 |  |  | import logging | 
					
						
							| 
									
										
										
										
											2024-12-24 18:38:51 +08:00
										 |  |  | from typing import Any, Optional, cast | 
					
						
							| 
									
										
										
										
											2023-06-25 16:49:14 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-23 19:10:11 +08:00
										 |  |  | import numpy as np | 
					
						
							| 
									
										
										
										
											2024-02-06 13:21:13 +08:00
										 |  |  | from sqlalchemy.exc import IntegrityError | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-10-12 23:58:41 +08:00
										 |  |  | from configs import dify_config | 
					
						
							| 
									
										
										
										
											2024-10-17 19:12:42 +08:00
										 |  |  | from core.entities.embedding_type import EmbeddingInputType | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | from core.model_manager import ModelInstance | 
					
						
							| 
									
										
										
										
											2024-01-19 21:37:54 +08:00
										 |  |  | from core.model_runtime.entities.model_entities import ModelPropertyKey | 
					
						
							|  |  |  | from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel | 
					
						
							| 
									
										
										
										
											2024-10-17 19:12:42 +08:00
										 |  |  | from core.rag.embedding.embedding_base import Embeddings | 
					
						
							| 
									
										
										
										
											2023-06-25 16:49:14 +08:00
										 |  |  | from extensions.ext_database import db | 
					
						
							| 
									
										
										
										
											2024-01-18 21:39:12 +08:00
										 |  |  | from extensions.ext_redis import redis_client | 
					
						
							| 
									
										
										
										
											2023-06-25 16:49:14 +08:00
										 |  |  | from libs import helper | 
					
						
							| 
									
										
										
										
											2024-04-02 20:46:24 +08:00
										 |  |  | from models.dataset import Embedding | 
					
						
							| 
									
										
										
										
											2023-06-25 16:49:14 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | logger = logging.getLogger(__name__) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-06-25 16:49:14 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | class CacheEmbedding(Embeddings): | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |     def __init__(self, model_instance: ModelInstance, user: Optional[str] = None) -> None: | 
					
						
							|  |  |  |         self._model_instance = model_instance | 
					
						
							|  |  |  |         self._user = user | 
					
						
							| 
									
										
										
										
											2023-06-25 16:49:14 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-02-09 15:21:33 +08:00
										 |  |  |     def embed_documents(self, texts: list[str]) -> list[list[float]]: | 
					
						
							| 
									
										
										
										
											2024-01-19 21:37:54 +08:00
										 |  |  |         """Embed search docs in batches of 10.""" | 
					
						
							| 
									
										
										
										
											2024-04-02 20:46:24 +08:00
										 |  |  |         # use doc embedding cache or store if not exists | 
					
						
							| 
									
										
										
										
											2024-12-24 18:38:51 +08:00
										 |  |  |         text_embeddings: list[Any] = [None for _ in range(len(texts))] | 
					
						
							| 
									
										
										
										
											2024-04-02 20:46:24 +08:00
										 |  |  |         embedding_queue_indices = [] | 
					
						
							|  |  |  |         for i, text in enumerate(texts): | 
					
						
							|  |  |  |             hash = helper.generate_text_hash(text) | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |             embedding = ( | 
					
						
							|  |  |  |                 db.session.query(Embedding) | 
					
						
							|  |  |  |                 .filter_by( | 
					
						
							|  |  |  |                     model_name=self._model_instance.model, hash=hash, provider_name=self._model_instance.provider | 
					
						
							|  |  |  |                 ) | 
					
						
							|  |  |  |                 .first() | 
					
						
							|  |  |  |             ) | 
					
						
							| 
									
										
										
										
											2024-04-02 20:46:24 +08:00
										 |  |  |             if embedding: | 
					
						
							|  |  |  |                 text_embeddings[i] = embedding.get_embedding() | 
					
						
							|  |  |  |             else: | 
					
						
							|  |  |  |                 embedding_queue_indices.append(i) | 
					
						
							|  |  |  |         if embedding_queue_indices: | 
					
						
							|  |  |  |             embedding_queue_texts = [texts[i] for i in embedding_queue_indices] | 
					
						
							|  |  |  |             embedding_queue_embeddings = [] | 
					
						
							|  |  |  |             try: | 
					
						
							|  |  |  |                 model_type_instance = cast(TextEmbeddingModel, self._model_instance.model_type_instance) | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                 model_schema = model_type_instance.get_model_schema( | 
					
						
							|  |  |  |                     self._model_instance.model, self._model_instance.credentials | 
					
						
							|  |  |  |                 ) | 
					
						
							|  |  |  |                 max_chunks = ( | 
					
						
							|  |  |  |                     model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS] | 
					
						
							|  |  |  |                     if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties | 
					
						
							|  |  |  |                     else 1 | 
					
						
							|  |  |  |                 ) | 
					
						
							| 
									
										
										
										
											2024-04-02 20:46:24 +08:00
										 |  |  |                 for i in range(0, len(embedding_queue_texts), max_chunks): | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                     batch_texts = embedding_queue_texts[i : i + max_chunks] | 
					
						
							| 
									
										
										
										
											2024-04-02 20:46:24 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-24 21:53:50 +08:00
										 |  |  |                     embedding_result = self._model_instance.invoke_text_embedding( | 
					
						
							|  |  |  |                         texts=batch_texts, user=self._user, input_type=EmbeddingInputType.DOCUMENT | 
					
						
							|  |  |  |                     ) | 
					
						
							| 
									
										
										
										
											2024-04-02 20:46:24 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |                     for vector in embedding_result.embeddings: | 
					
						
							|  |  |  |                         try: | 
					
						
							| 
									
										
										
										
											2024-12-24 18:38:51 +08:00
										 |  |  |                             # FIXME: type ignore for numpy here | 
					
						
							|  |  |  |                             normalized_embedding = (vector / np.linalg.norm(vector)).tolist()  # type: ignore | 
					
						
							| 
									
										
										
										
											2024-12-19 20:50:20 +08:00
										 |  |  |                             # stackoverflow best way: https://stackoverflow.com/questions/20319813/how-to-check-list-containing-nan | 
					
						
							|  |  |  |                             if np.isnan(normalized_embedding).any(): | 
					
						
							|  |  |  |                                 # for issue #11827  float values are not json compliant | 
					
						
							|  |  |  |                                 logger.warning(f"Normalized embedding is nan: {normalized_embedding}") | 
					
						
							|  |  |  |                                 continue | 
					
						
							| 
									
										
										
										
											2024-04-02 20:46:24 +08:00
										 |  |  |                             embedding_queue_embeddings.append(normalized_embedding) | 
					
						
							|  |  |  |                         except IntegrityError: | 
					
						
							|  |  |  |                             db.session.rollback() | 
					
						
							| 
									
										
										
										
											2025-02-17 17:05:13 +08:00
										 |  |  |                         except Exception: | 
					
						
							| 
									
										
										
										
											2024-11-15 15:41:40 +08:00
										 |  |  |                             logging.exception("Failed transform embedding") | 
					
						
							| 
									
										
										
										
											2024-04-07 15:20:58 +08:00
										 |  |  |                 cache_embeddings = [] | 
					
						
							| 
									
										
										
										
											2024-04-09 02:16:19 +08:00
										 |  |  |                 try: | 
					
						
							| 
									
										
										
										
											2024-12-24 18:38:51 +08:00
										 |  |  |                     for i, n_embedding in zip(embedding_queue_indices, embedding_queue_embeddings): | 
					
						
							|  |  |  |                         text_embeddings[i] = n_embedding | 
					
						
							| 
									
										
										
										
											2024-04-09 02:16:19 +08:00
										 |  |  |                         hash = helper.generate_text_hash(texts[i]) | 
					
						
							|  |  |  |                         if hash not in cache_embeddings: | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                             embedding_cache = Embedding( | 
					
						
							|  |  |  |                                 model_name=self._model_instance.model, | 
					
						
							|  |  |  |                                 hash=hash, | 
					
						
							|  |  |  |                                 provider_name=self._model_instance.provider, | 
					
						
							|  |  |  |                             ) | 
					
						
							| 
									
										
										
										
											2024-12-24 18:38:51 +08:00
										 |  |  |                             embedding_cache.set_embedding(n_embedding) | 
					
						
							| 
									
										
										
										
											2024-04-09 02:16:19 +08:00
										 |  |  |                             db.session.add(embedding_cache) | 
					
						
							|  |  |  |                             cache_embeddings.append(hash) | 
					
						
							|  |  |  |                     db.session.commit() | 
					
						
							|  |  |  |                 except IntegrityError: | 
					
						
							|  |  |  |                     db.session.rollback() | 
					
						
							| 
									
										
										
										
											2024-04-02 20:46:24 +08:00
										 |  |  |             except Exception as ex: | 
					
						
							|  |  |  |                 db.session.rollback() | 
					
						
							| 
									
										
										
										
											2024-11-15 15:41:40 +08:00
										 |  |  |                 logger.exception("Failed to embed documents: %s") | 
					
						
							| 
									
										
										
										
											2024-04-02 20:46:24 +08:00
										 |  |  |                 raise ex | 
					
						
							| 
									
										
										
										
											2023-06-25 16:49:14 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         return text_embeddings | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-02-09 15:21:33 +08:00
										 |  |  |     def embed_query(self, text: str) -> list[float]: | 
					
						
							| 
									
										
										
										
											2023-06-25 16:49:14 +08:00
										 |  |  |         """Embed query text.""" | 
					
						
							|  |  |  |         # use doc embedding cache or store if not exists | 
					
						
							|  |  |  |         hash = helper.generate_text_hash(text) | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |         embedding_cache_key = f"{self._model_instance.provider}_{self._model_instance.model}_{hash}" | 
					
						
							| 
									
										
										
										
											2024-01-18 21:39:12 +08:00
										 |  |  |         embedding = redis_client.get(embedding_cache_key) | 
					
						
							| 
									
										
										
										
											2023-06-25 16:49:14 +08:00
										 |  |  |         if embedding: | 
					
						
							| 
									
										
										
										
											2024-01-19 21:37:54 +08:00
										 |  |  |             redis_client.expire(embedding_cache_key, 600) | 
					
						
							| 
									
										
										
										
											2024-11-29 10:18:41 +09:00
										 |  |  |             decoded_embedding = np.frombuffer(base64.b64decode(embedding), dtype="float") | 
					
						
							|  |  |  |             return [float(x) for x in decoded_embedding] | 
					
						
							| 
									
										
										
										
											2023-08-12 00:57:00 +08:00
										 |  |  |         try: | 
					
						
							| 
									
										
										
										
											2024-09-24 21:53:50 +08:00
										 |  |  |             embedding_result = self._model_instance.invoke_text_embedding( | 
					
						
							|  |  |  |                 texts=[text], user=self._user, input_type=EmbeddingInputType.QUERY | 
					
						
							|  |  |  |             ) | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |             embedding_results = embedding_result.embeddings[0] | 
					
						
							| 
									
										
										
										
											2024-12-24 18:38:51 +08:00
										 |  |  |             # FIXME: type ignore for numpy here | 
					
						
							|  |  |  |             embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist()  # type: ignore | 
					
						
							| 
									
										
										
										
											2024-12-20 09:28:32 +08:00
										 |  |  |             if np.isnan(embedding_results).any(): | 
					
						
							|  |  |  |                 raise ValueError("Normalized embedding is nan please try again") | 
					
						
							| 
									
										
										
										
											2023-08-12 00:57:00 +08:00
										 |  |  |         except Exception as ex: | 
					
						
							| 
									
										
										
										
											2024-10-12 23:58:41 +08:00
										 |  |  |             if dify_config.DEBUG: | 
					
						
							| 
									
										
										
										
											2024-11-15 15:41:40 +08:00
										 |  |  |                 logging.exception(f"Failed to embed query text '{text[:10]}...({len(text)} chars)'") | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |             raise ex | 
					
						
							| 
									
										
										
										
											2023-06-25 16:49:14 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         try: | 
					
						
							| 
									
										
										
										
											2024-01-18 21:39:12 +08:00
										 |  |  |             # encode embedding to base64 | 
					
						
							|  |  |  |             embedding_vector = np.array(embedding_results) | 
					
						
							|  |  |  |             vector_bytes = embedding_vector.tobytes() | 
					
						
							|  |  |  |             # Transform to Base64 | 
					
						
							|  |  |  |             encoded_vector = base64.b64encode(vector_bytes) | 
					
						
							|  |  |  |             # Transform to string | 
					
						
							|  |  |  |             encoded_str = encoded_vector.decode("utf-8") | 
					
						
							| 
									
										
										
										
											2024-01-19 21:37:54 +08:00
										 |  |  |             redis_client.setex(embedding_cache_key, 600, encoded_str) | 
					
						
							| 
									
										
										
										
											2024-09-14 12:56:45 +08:00
										 |  |  |         except Exception as ex: | 
					
						
							| 
									
										
										
										
											2024-10-12 23:58:41 +08:00
										 |  |  |             if dify_config.DEBUG: | 
					
						
							| 
									
										
										
										
											2024-11-15 15:41:40 +08:00
										 |  |  |                 logging.exception(f"Failed to add embedding to redis for the text '{text[:10]}...({len(text)} chars)'") | 
					
						
							| 
									
										
										
										
											2024-10-12 23:58:41 +08:00
										 |  |  |             raise ex | 
					
						
							| 
									
										
										
										
											2023-06-25 16:49:14 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         return embedding_results |