diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index f00915dd2..46ea7b14e 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -179,7 +179,41 @@ class Base(ABC): except Exception: pass return 0 + + def _calculate_dynamic_ctx(self, history): + """Calculate dynamic context window size""" + def count_tokens(text): + """Calculate token count for text""" + # Simple calculation: 1 token per ASCII character + # 2 tokens for non-ASCII characters (Chinese, Japanese, Korean, etc.) + total = 0 + for char in text: + if ord(char) < 128: # ASCII characters + total += 1 + else: # Non-ASCII characters (Chinese, Japanese, Korean, etc.) + total += 2 + return total + # Calculate total tokens for all messages + total_tokens = 0 + for message in history: + content = message.get("content", "") + # Calculate content tokens + content_tokens = count_tokens(content) + # Add role marker token overhead + role_tokens = 4 + total_tokens += content_tokens + role_tokens + + # Apply 1.2x buffer ratio + total_tokens_with_buffer = int(total_tokens * 1.2) + + if total_tokens_with_buffer <= 8192: + ctx_size = 8192 + else: + ctx_multiplier = (total_tokens_with_buffer // 8192) + 1 + ctx_size = ctx_multiplier * 8192 + + return ctx_size class GptTurbo(Base): def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1"): @@ -469,7 +503,7 @@ class ZhipuChat(Base): class OllamaChat(Base): def __init__(self, key, model_name, **kwargs): - self.client = Client(host=kwargs["base_url"]) if not key or key == "x" else Client(host=kwargs["base_url"], headers={"Authorization": f"Bear {key}"}) + self.client = Client(host=kwargs["base_url"]) if not key or key == "x" else Client(host=kwargs["base_url"], headers={"Authorization": f"Bearer {key}"}) self.model_name = model_name def chat(self, system, history, gen_conf): @@ -478,7 +512,12 @@ class OllamaChat(Base): if "max_tokens" in gen_conf: del gen_conf["max_tokens"] try: - options = {"num_ctx": 32768} + # Calculate context size + ctx_size = self._calculate_dynamic_ctx(history) + + options = { + "num_ctx": ctx_size + } if "temperature" in gen_conf: options["temperature"] = gen_conf["temperature"] if "max_tokens" in gen_conf: @@ -489,9 +528,11 @@ class OllamaChat(Base): options["presence_penalty"] = gen_conf["presence_penalty"] if "frequency_penalty" in gen_conf: options["frequency_penalty"] = gen_conf["frequency_penalty"] - response = self.client.chat(model=self.model_name, messages=history, options=options, keep_alive=-1) + + response = self.client.chat(model=self.model_name, messages=history, options=options, keep_alive=10) ans = response["message"]["content"].strip() - return ans, response.get("eval_count", 0) + response.get("prompt_eval_count", 0) + token_count = response.get("eval_count", 0) + response.get("prompt_eval_count", 0) + return ans, token_count except Exception as e: return "**ERROR**: " + str(e), 0 @@ -500,28 +541,38 @@ class OllamaChat(Base): history.insert(0, {"role": "system", "content": system}) if "max_tokens" in gen_conf: del gen_conf["max_tokens"] - options = {} - if "temperature" in gen_conf: - options["temperature"] = gen_conf["temperature"] - if "max_tokens" in gen_conf: - options["num_predict"] = gen_conf["max_tokens"] - if "top_p" in gen_conf: - options["top_p"] = gen_conf["top_p"] - if "presence_penalty" in gen_conf: - options["presence_penalty"] = gen_conf["presence_penalty"] - if "frequency_penalty" in gen_conf: - options["frequency_penalty"] = gen_conf["frequency_penalty"] - ans = "" try: - response = self.client.chat(model=self.model_name, messages=history, stream=True, options=options, keep_alive=-1) - for resp in response: - if resp["done"]: - yield resp.get("prompt_eval_count", 0) + resp.get("eval_count", 0) - ans = resp["message"]["content"] - yield ans + # Calculate context size + ctx_size = self._calculate_dynamic_ctx(history) + options = { + "num_ctx": ctx_size + } + if "temperature" in gen_conf: + options["temperature"] = gen_conf["temperature"] + if "max_tokens" in gen_conf: + options["num_predict"] = gen_conf["max_tokens"] + if "top_p" in gen_conf: + options["top_p"] = gen_conf["top_p"] + if "presence_penalty" in gen_conf: + options["presence_penalty"] = gen_conf["presence_penalty"] + if "frequency_penalty" in gen_conf: + options["frequency_penalty"] = gen_conf["frequency_penalty"] + + ans = "" + try: + response = self.client.chat(model=self.model_name, messages=history, stream=True, options=options, keep_alive=10 ) + for resp in response: + if resp["done"]: + token_count = resp.get("prompt_eval_count", 0) + resp.get("eval_count", 0) + yield token_count + ans = resp["message"]["content"] + yield ans + except Exception as e: + yield ans + "\n**ERROR**: " + str(e) + yield 0 except Exception as e: - yield ans + "\n**ERROR**: " + str(e) - yield 0 + yield "**ERROR**: " + str(e) + yield 0 class LocalAIChat(Base):