From 00c7ddbc9be0ffb1f9cdbea2379244b13b4db6bf Mon Sep 17 00:00:00 2001 From: zhxlp <1573635222@qq.com> Date: Tue, 18 Feb 2025 13:42:22 +0800 Subject: [PATCH] Fix: The max tokens defined by the tenant are not used (#4297) (#2817) (#5066) ### What problem does this PR solve? Fix: The max tokens defined by the tenant are not used (#4297) (#2817) ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --------- Co-authored-by: Kevin Hu --- api/db/services/dialog_service.py | 21 ++++++--------------- api/db/services/llm_service.py | 15 +++++++++------ 2 files changed, 15 insertions(+), 21 deletions(-) diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index 0e8771e33..835ed80ca 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -29,7 +29,7 @@ from api.db.db_models import Dialog, DB from api.db.services.common_service import CommonService from api.db.services.document_service import DocumentService from api.db.services.knowledgebase_service import KnowledgebaseService -from api.db.services.llm_service import LLMService, TenantLLMService, LLMBundle +from api.db.services.llm_service import TenantLLMService, LLMBundle from api import settings from graphrag.utils import get_tags_from_cache, set_tags_to_cache from rag.app.resume import forbidden_select_fields4resume @@ -172,21 +172,12 @@ def chat(dialog, messages, stream=True, **kwargs): chat_start_ts = timer() - # Get llm model name and model provider name - llm_id, model_provider = TenantLLMService.split_model_name_and_factory(dialog.llm_id) - - # Get llm model instance by model and provide name - llm = LLMService.query(llm_name=llm_id) if not model_provider else LLMService.query(llm_name=llm_id, fid=model_provider) - - if not llm: - # Model name is provided by tenant, but not system built-in - llm = TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=llm_id) if not model_provider else \ - TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=llm_id, llm_factory=model_provider) - if not llm: - raise LookupError("LLM(%s) not found" % dialog.llm_id) - max_tokens = 8192 + if llm_id2llm_type(dialog.llm_id) == "image2text": + llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id) else: - max_tokens = llm[0].max_tokens + llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.CHAT, dialog.llm_id) + + max_tokens = llm_model_config.get("max_tokens", 8192) check_llm_ts = timer() diff --git a/api/db/services/llm_service.py b/api/db/services/llm_service.py index efca4339d..5a61e4ed2 100644 --- a/api/db/services/llm_service.py +++ b/api/db/services/llm_service.py @@ -86,8 +86,7 @@ class TenantLLMService(CommonService): @classmethod @DB.connection_context() - def model_instance(cls, tenant_id, llm_type, - llm_name=None, lang="Chinese"): + def get_model_config(cls, tenant_id, llm_type, llm_name=None): e, tenant = TenantService.get_by_id(tenant_id) if not e: raise LookupError("Tenant not found") @@ -124,7 +123,13 @@ class TenantLLMService(CommonService): if not mdlnm: raise LookupError(f"Type of {llm_type} model is not set.") raise LookupError("Model({}) not authorized".format(mdlnm)) + return model_config + @classmethod + @DB.connection_context() + def model_instance(cls, tenant_id, llm_type, + llm_name=None, lang="Chinese"): + model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name) if llm_type == LLMType.EMBEDDING.value: if model_config["llm_factory"] not in EmbeddingModel: return @@ -228,10 +233,8 @@ class LLMBundle(object): tenant_id, llm_type, llm_name, lang=lang) assert self.mdl, "Can't find model for {}/{}/{}".format( tenant_id, llm_type, llm_name) - self.max_length = 8192 - for lm in LLMService.query(llm_name=llm_name): - self.max_length = lm.max_tokens - break + model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name) + self.max_length = model_config.get("max_tokens", 8192) def encode(self, texts: list): embeddings, used_tokens = self.mdl.encode(texts)