diff --git a/rag/llm/cv_model.py b/rag/llm/cv_model.py index 48a55b674..74127b3cb 100644 --- a/rag/llm/cv_model.py +++ b/rag/llm/cv_model.py @@ -25,7 +25,6 @@ import base64 from io import BytesIO import json import requests -from transformers import GenerationConfig from rag.nlp import is_english from api.utils import get_uuid @@ -510,6 +509,7 @@ class GeminiCV(Base): return res.text,res.usage_metadata.total_token_count def chat(self, system, history, gen_conf, image=""): + from transformers import GenerationConfig if system: history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"] try: @@ -533,6 +533,7 @@ class GeminiCV(Base): return "**ERROR**: " + str(e), 0 def chat_streamly(self, system, history, gen_conf, image=""): + from transformers import GenerationConfig if system: history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]