From 373e90ee6d50b5ca71d59b0b476802c739e426f4 Mon Sep 17 00:00:00 2001 From: takatost Date: Mon, 2 Oct 2023 22:27:25 +0800 Subject: [PATCH] fix: detached model in completion thread (#1269) --- api/core/model_providers/models/llm/base.py | 2 - api/services/completion_service.py | 49 +++++++++++++-------- 2 files changed, 30 insertions(+), 21 deletions(-) diff --git a/api/core/model_providers/models/llm/base.py b/api/core/model_providers/models/llm/base.py index 0aacfe2602..c7cec88ff0 100644 --- a/api/core/model_providers/models/llm/base.py +++ b/api/core/model_providers/models/llm/base.py @@ -132,8 +132,6 @@ class BaseLLM(BaseProviderModel): if self.deduct_quota: self.model_provider.check_quota_over_limit() - db.session.commit() - if not callbacks: callbacks = self.callbacks else: diff --git a/api/services/completion_service.py b/api/services/completion_service.py index d8ffd02ed4..c95905c6c8 100644 --- a/api/services/completion_service.py +++ b/api/services/completion_service.py @@ -3,7 +3,7 @@ import logging import threading import time import uuid -from typing import Generator, Union, Any +from typing import Generator, Union, Any, Optional from flask import current_app, Flask from redis.client import PubSub @@ -141,12 +141,12 @@ class CompletionService: generate_worker_thread = threading.Thread(target=cls.generate_worker, kwargs={ 'flask_app': current_app._get_current_object(), 'generate_task_id': generate_task_id, - 'app_model': app_model, + 'detached_app_model': app_model, 'app_model_config': app_model_config, 'query': query, 'inputs': inputs, - 'user': user, - 'conversation': conversation, + 'detached_user': user, + 'detached_conversation': conversation, 'streaming': streaming, 'is_model_config_override': is_model_config_override, 'retriever_from': args['retriever_from'] if 'retriever_from' in args else 'dev' @@ -171,18 +171,22 @@ class CompletionService: return user @classmethod - def generate_worker(cls, flask_app: Flask, generate_task_id: str, app_model: App, app_model_config: AppModelConfig, - query: str, inputs: dict, user: Union[Account, EndUser], - conversation: Conversation, streaming: bool, is_model_config_override: bool, + def generate_worker(cls, flask_app: Flask, generate_task_id: str, detached_app_model: App, app_model_config: AppModelConfig, + query: str, inputs: dict, detached_user: Union[Account, EndUser], + detached_conversation: Optional[Conversation], streaming: bool, is_model_config_override: bool, retriever_from: str = 'dev'): with flask_app.app_context(): + # fixed the state of the model object when it detached from the original session + user = db.session.merge(detached_user) + app_model = db.session.merge(detached_app_model) + + if detached_conversation: + conversation = db.session.merge(detached_conversation) + else: + conversation = None + try: - if conversation: - # fixed the state of the conversation object when it detached from the original session - conversation = db.session.query(Conversation).filter_by(id=conversation.id).first() - # run - Completion.generate( task_id=generate_task_id, app=app_model, @@ -210,12 +214,14 @@ class CompletionService: db.session.commit() @classmethod - def countdown_and_close(cls, flask_app: Flask, worker_thread, pubsub, user, generate_task_id) -> threading.Thread: + def countdown_and_close(cls, flask_app: Flask, worker_thread, pubsub, detached_user, generate_task_id) -> threading.Thread: # wait for 10 minutes to close the thread timeout = 600 def close_pubsub(): with flask_app.app_context(): + user = db.session.merge(detached_user) + sleep_iterations = 0 while sleep_iterations < timeout and worker_thread.is_alive(): if sleep_iterations > 0 and sleep_iterations % 10 == 0: @@ -279,11 +285,11 @@ class CompletionService: generate_worker_thread = threading.Thread(target=cls.generate_more_like_this_worker, kwargs={ 'flask_app': current_app._get_current_object(), 'generate_task_id': generate_task_id, - 'app_model': app_model, + 'detached_app_model': app_model, 'app_model_config': app_model_config, - 'message': message, + 'detached_message': message, 'pre_prompt': pre_prompt, - 'user': user, + 'detached_user': user, 'streaming': streaming }) @@ -294,10 +300,15 @@ class CompletionService: return cls.compact_response(pubsub, streaming) @classmethod - def generate_more_like_this_worker(cls, flask_app: Flask, generate_task_id: str, app_model: App, - app_model_config: AppModelConfig, message: Message, pre_prompt: str, - user: Union[Account, EndUser], streaming: bool): + def generate_more_like_this_worker(cls, flask_app: Flask, generate_task_id: str, detached_app_model: App, + app_model_config: AppModelConfig, detached_message: Message, pre_prompt: str, + detached_user: Union[Account, EndUser], streaming: bool): with flask_app.app_context(): + # fixed the state of the model object when it detached from the original session + user = db.session.merge(detached_user) + app_model = db.session.merge(detached_app_model) + message = db.session.merge(detached_message) + try: # run Completion.generate_more_like_this(