mirror of
				https://github.com/langgenius/dify.git
				synced 2025-10-31 19:03:09 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			336 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			336 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import tempfile
 | |
| from binascii import hexlify, unhexlify
 | |
| from collections.abc import Generator
 | |
| 
 | |
| from core.model_manager import ModelManager
 | |
| from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
 | |
| from core.model_runtime.entities.message_entities import (
 | |
|     PromptMessage,
 | |
|     SystemPromptMessage,
 | |
|     UserPromptMessage,
 | |
| )
 | |
| from core.plugin.backwards_invocation.base import BaseBackwardsInvocation
 | |
| from core.plugin.entities.request import (
 | |
|     RequestInvokeLLM,
 | |
|     RequestInvokeModeration,
 | |
|     RequestInvokeRerank,
 | |
|     RequestInvokeSpeech2Text,
 | |
|     RequestInvokeSummary,
 | |
|     RequestInvokeTextEmbedding,
 | |
|     RequestInvokeTTS,
 | |
| )
 | |
| from core.tools.entities.tool_entities import ToolProviderType
 | |
| from core.tools.utils.model_invocation_utils import ModelInvocationUtils
 | |
| from core.workflow.nodes.llm.node import LLMNode
 | |
| from models.account import Tenant
 | |
| 
 | |
| 
 | |
| class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
 | |
|     @classmethod
 | |
|     def invoke_llm(
 | |
|         cls, user_id: str, tenant: Tenant, payload: RequestInvokeLLM
 | |
|     ) -> Generator[LLMResultChunk, None, None] | LLMResult:
 | |
|         """
 | |
|         invoke llm
 | |
|         """
 | |
|         model_instance = ModelManager().get_model_instance(
 | |
|             tenant_id=tenant.id,
 | |
|             provider=payload.provider,
 | |
|             model_type=payload.model_type,
 | |
|             model=payload.model,
 | |
|         )
 | |
| 
 | |
|         # invoke model
 | |
|         response = model_instance.invoke_llm(
 | |
|             prompt_messages=payload.prompt_messages,
 | |
|             model_parameters=payload.completion_params,
 | |
|             tools=payload.tools,
 | |
|             stop=payload.stop,
 | |
|             stream=True if payload.stream is None else payload.stream,
 | |
|             user=user_id,
 | |
|         )
 | |
| 
 | |
|         if isinstance(response, Generator):
 | |
| 
 | |
|             def handle() -> Generator[LLMResultChunk, None, None]:
 | |
|                 for chunk in response:
 | |
|                     if chunk.delta.usage:
 | |
|                         LLMNode.deduct_llm_quota(
 | |
|                             tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage
 | |
|                         )
 | |
|                     yield chunk
 | |
| 
 | |
|             return handle()
 | |
|         else:
 | |
|             if response.usage:
 | |
|                 LLMNode.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
 | |
| 
 | |
|             def handle_non_streaming(response: LLMResult) -> Generator[LLMResultChunk, None, None]:
 | |
|                 yield LLMResultChunk(
 | |
|                     model=response.model,
 | |
|                     prompt_messages=response.prompt_messages,
 | |
|                     system_fingerprint=response.system_fingerprint,
 | |
|                     delta=LLMResultChunkDelta(
 | |
|                         index=0,
 | |
|                         message=response.message,
 | |
|                         usage=response.usage,
 | |
|                         finish_reason="",
 | |
|                     ),
 | |
|                 )
 | |
| 
 | |
|             return handle_non_streaming(response)
 | |
| 
 | |
|     @classmethod
 | |
|     def invoke_text_embedding(cls, user_id: str, tenant: Tenant, payload: RequestInvokeTextEmbedding):
 | |
|         """
 | |
|         invoke text embedding
 | |
|         """
 | |
|         model_instance = ModelManager().get_model_instance(
 | |
|             tenant_id=tenant.id,
 | |
|             provider=payload.provider,
 | |
|             model_type=payload.model_type,
 | |
|             model=payload.model,
 | |
|         )
 | |
| 
 | |
|         # invoke model
 | |
|         response = model_instance.invoke_text_embedding(
 | |
|             texts=payload.texts,
 | |
|             user=user_id,
 | |
|         )
 | |
| 
 | |
|         return response
 | |
| 
 | |
|     @classmethod
 | |
|     def invoke_rerank(cls, user_id: str, tenant: Tenant, payload: RequestInvokeRerank):
 | |
|         """
 | |
|         invoke rerank
 | |
|         """
 | |
|         model_instance = ModelManager().get_model_instance(
 | |
|             tenant_id=tenant.id,
 | |
|             provider=payload.provider,
 | |
|             model_type=payload.model_type,
 | |
|             model=payload.model,
 | |
|         )
 | |
| 
 | |
|         # invoke model
 | |
|         response = model_instance.invoke_rerank(
 | |
|             query=payload.query,
 | |
|             docs=payload.docs,
 | |
|             score_threshold=payload.score_threshold,
 | |
|             top_n=payload.top_n,
 | |
|             user=user_id,
 | |
|         )
 | |
| 
 | |
|         return response
 | |
| 
 | |
|     @classmethod
 | |
|     def invoke_tts(cls, user_id: str, tenant: Tenant, payload: RequestInvokeTTS):
 | |
|         """
 | |
|         invoke tts
 | |
|         """
 | |
|         model_instance = ModelManager().get_model_instance(
 | |
|             tenant_id=tenant.id,
 | |
|             provider=payload.provider,
 | |
|             model_type=payload.model_type,
 | |
|             model=payload.model,
 | |
|         )
 | |
| 
 | |
|         # invoke model
 | |
|         response = model_instance.invoke_tts(
 | |
|             content_text=payload.content_text,
 | |
|             tenant_id=tenant.id,
 | |
|             voice=payload.voice,
 | |
|             user=user_id,
 | |
|         )
 | |
| 
 | |
|         def handle() -> Generator[dict, None, None]:
 | |
|             for chunk in response:
 | |
|                 yield {"result": hexlify(chunk).decode("utf-8")}
 | |
| 
 | |
|         return handle()
 | |
| 
 | |
|     @classmethod
 | |
|     def invoke_speech2text(cls, user_id: str, tenant: Tenant, payload: RequestInvokeSpeech2Text):
 | |
|         """
 | |
|         invoke speech2text
 | |
|         """
 | |
|         model_instance = ModelManager().get_model_instance(
 | |
|             tenant_id=tenant.id,
 | |
|             provider=payload.provider,
 | |
|             model_type=payload.model_type,
 | |
|             model=payload.model,
 | |
|         )
 | |
| 
 | |
|         # invoke model
 | |
|         with tempfile.NamedTemporaryFile(suffix=".mp3", mode="wb", delete=True) as temp:
 | |
|             temp.write(unhexlify(payload.file))
 | |
|             temp.flush()
 | |
|             temp.seek(0)
 | |
| 
 | |
|             response = model_instance.invoke_speech2text(
 | |
|                 file=temp,
 | |
|                 user=user_id,
 | |
|             )
 | |
| 
 | |
|             return {
 | |
|                 "result": response,
 | |
|             }
 | |
| 
 | |
|     @classmethod
 | |
|     def invoke_moderation(cls, user_id: str, tenant: Tenant, payload: RequestInvokeModeration):
 | |
|         """
 | |
|         invoke moderation
 | |
|         """
 | |
|         model_instance = ModelManager().get_model_instance(
 | |
|             tenant_id=tenant.id,
 | |
|             provider=payload.provider,
 | |
|             model_type=payload.model_type,
 | |
|             model=payload.model,
 | |
|         )
 | |
| 
 | |
|         # invoke model
 | |
|         response = model_instance.invoke_moderation(
 | |
|             text=payload.text,
 | |
|             user=user_id,
 | |
|         )
 | |
| 
 | |
|         return {
 | |
|             "result": response,
 | |
|         }
 | |
| 
 | |
|     @classmethod
 | |
|     def get_system_model_max_tokens(cls, tenant_id: str) -> int:
 | |
|         """
 | |
|         get system model max tokens
 | |
|         """
 | |
|         return ModelInvocationUtils.get_max_llm_context_tokens(tenant_id=tenant_id)
 | |
| 
 | |
|     @classmethod
 | |
|     def get_prompt_tokens(cls, tenant_id: str, prompt_messages: list[PromptMessage]) -> int:
 | |
|         """
 | |
|         get prompt tokens
 | |
|         """
 | |
|         return ModelInvocationUtils.calculate_tokens(tenant_id=tenant_id, prompt_messages=prompt_messages)
 | |
| 
 | |
|     @classmethod
 | |
|     def invoke_system_model(
 | |
|         cls,
 | |
|         user_id: str,
 | |
|         tenant: Tenant,
 | |
|         prompt_messages: list[PromptMessage],
 | |
|     ) -> LLMResult:
 | |
|         """
 | |
|         invoke system model
 | |
|         """
 | |
|         return ModelInvocationUtils.invoke(
 | |
|             user_id=user_id,
 | |
|             tenant_id=tenant.id,
 | |
|             tool_type=ToolProviderType.PLUGIN,
 | |
|             tool_name="plugin",
 | |
|             prompt_messages=prompt_messages,
 | |
|         )
 | |
| 
 | |
|     @classmethod
 | |
|     def invoke_summary(cls, user_id: str, tenant: Tenant, payload: RequestInvokeSummary):
 | |
|         """
 | |
|         invoke summary
 | |
|         """
 | |
|         max_tokens = cls.get_system_model_max_tokens(tenant_id=tenant.id)
 | |
|         content = payload.text
 | |
| 
 | |
|         SUMMARY_PROMPT = """You are a professional language researcher, you are interested in the language
 | |
| and you can quickly aimed at the main point of an webpage and reproduce it in your own words but 
 | |
| retain the original meaning and keep the key points. 
 | |
| however, the text you got is too long, what you got is possible a part of the text.
 | |
| Please summarize the text you got.
 | |
| 
 | |
| Here is the extra instruction you need to follow:
 | |
| <extra_instruction>
 | |
| {payload.instruction}
 | |
| </extra_instruction>
 | |
| """
 | |
| 
 | |
|         if (
 | |
|             cls.get_prompt_tokens(
 | |
|                 tenant_id=tenant.id,
 | |
|                 prompt_messages=[UserPromptMessage(content=content)],
 | |
|             )
 | |
|             < max_tokens * 0.6
 | |
|         ):
 | |
|             return content
 | |
| 
 | |
|         def get_prompt_tokens(content: str) -> int:
 | |
|             return cls.get_prompt_tokens(
 | |
|                 tenant_id=tenant.id,
 | |
|                 prompt_messages=[
 | |
|                     SystemPromptMessage(content=SUMMARY_PROMPT.replace("{payload.instruction}", payload.instruction)),
 | |
|                     UserPromptMessage(content=content),
 | |
|                 ],
 | |
|             )
 | |
| 
 | |
|         def summarize(content: str) -> str:
 | |
|             summary = cls.invoke_system_model(
 | |
|                 user_id=user_id,
 | |
|                 tenant=tenant,
 | |
|                 prompt_messages=[
 | |
|                     SystemPromptMessage(content=SUMMARY_PROMPT.replace("{payload.instruction}", payload.instruction)),
 | |
|                     UserPromptMessage(content=content),
 | |
|                 ],
 | |
|             )
 | |
| 
 | |
|             assert isinstance(summary.message.content, str)
 | |
|             return summary.message.content
 | |
| 
 | |
|         lines = content.split("\n")
 | |
|         new_lines: list[str] = []
 | |
|         # split long line into multiple lines
 | |
|         for i in range(len(lines)):
 | |
|             line = lines[i]
 | |
|             if not line.strip():
 | |
|                 continue
 | |
|             if len(line) < max_tokens * 0.5:
 | |
|                 new_lines.append(line)
 | |
|             elif get_prompt_tokens(line) > max_tokens * 0.7:
 | |
|                 while get_prompt_tokens(line) > max_tokens * 0.7:
 | |
|                     new_lines.append(line[: int(max_tokens * 0.5)])
 | |
|                     line = line[int(max_tokens * 0.5) :]
 | |
|                 new_lines.append(line)
 | |
|             else:
 | |
|                 new_lines.append(line)
 | |
| 
 | |
|         # merge lines into messages with max tokens
 | |
|         messages: list[str] = []
 | |
|         for i in new_lines:  # type: ignore
 | |
|             if len(messages) == 0:
 | |
|                 messages.append(i)  # type: ignore
 | |
|             else:
 | |
|                 if len(messages[-1]) + len(i) < max_tokens * 0.5:  # type: ignore
 | |
|                     messages[-1] += i  # type: ignore
 | |
|                 if get_prompt_tokens(messages[-1] + i) > max_tokens * 0.7:  # type: ignore
 | |
|                     messages.append(i)  # type: ignore
 | |
|                 else:
 | |
|                     messages[-1] += i  # type: ignore
 | |
| 
 | |
|         summaries = []
 | |
|         for i in range(len(messages)):
 | |
|             message = messages[i]
 | |
|             summary = summarize(message)
 | |
|             summaries.append(summary)
 | |
| 
 | |
|         result = "\n".join(summaries)
 | |
| 
 | |
|         if (
 | |
|             cls.get_prompt_tokens(
 | |
|                 tenant_id=tenant.id,
 | |
|                 prompt_messages=[UserPromptMessage(content=result)],
 | |
|             )
 | |
|             > max_tokens * 0.7
 | |
|         ):
 | |
|             return cls.invoke_summary(
 | |
|                 user_id=user_id,
 | |
|                 tenant=tenant,
 | |
|                 payload=RequestInvokeSummary(text=result, instruction=payload.instruction),
 | |
|             )
 | |
| 
 | |
|         return result
 | 
