mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-11-11 15:23:59 +00:00
Dynamic Context Window Size for Ollama Chat (#6582)
# Dynamic Context Window Size for Ollama Chat
## Problem Statement
Previously, the Ollama chat implementation used a fixed context window
size of 32768 tokens. This caused two main issues:
1. Performance degradation due to unnecessarily large context windows
for small conversations
2. Potential business logic failures when using smaller fixed sizes
(e.g., 2048 tokens)
## Solution
Implemented a dynamic context window size calculation that:
1. Uses a base context size of 8192 tokens
2. Applies a 1.2x buffer ratio to the total token count
3. Adds multiples of 8192 tokens based on the buffered token count
4. Implements a smart context size update strategy
## Implementation Details
### Token Counting Logic
```python
def count_tokens(text):
"""Calculate token count for text"""
# Simple calculation: 1 token per ASCII character
# 2 tokens for non-ASCII characters (Chinese, Japanese, Korean, etc.)
total = 0
for char in text:
if ord(char) < 128: # ASCII characters
total += 1
else: # Non-ASCII characters
total += 2
return total
```
### Dynamic Context Calculation
```python
def _calculate_dynamic_ctx(self, history):
"""Calculate dynamic context window size"""
# Calculate total tokens for all messages
total_tokens = 0
for message in history:
content = message.get("content", "")
content_tokens = count_tokens(content)
role_tokens = 4 # Role marker token overhead
total_tokens += content_tokens + role_tokens
# Apply 1.2x buffer ratio
total_tokens_with_buffer = int(total_tokens * 1.2)
# Calculate context size in multiples of 8192
if total_tokens_with_buffer <= 8192:
ctx_size = 8192
else:
ctx_multiplier = (total_tokens_with_buffer // 8192) + 1
ctx_size = ctx_multiplier * 8192
return ctx_size
```
### Integration in Chat Method
```python
def chat(self, system, history, gen_conf):
if system:
history.insert(0, {"role": "system", "content": system})
if "max_tokens" in gen_conf:
del gen_conf["max_tokens"]
try:
# Calculate new context size
new_ctx_size = self._calculate_dynamic_ctx(history)
# Prepare options with context size
options = {
"num_ctx": new_ctx_size
}
# Add other generation options
if "temperature" in gen_conf:
options["temperature"] = gen_conf["temperature"]
if "max_tokens" in gen_conf:
options["num_predict"] = gen_conf["max_tokens"]
if "top_p" in gen_conf:
options["top_p"] = gen_conf["top_p"]
if "presence_penalty" in gen_conf:
options["presence_penalty"] = gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf:
options["frequency_penalty"] = gen_conf["frequency_penalty"]
# Make API call with dynamic context size
response = self.client.chat(
model=self.model_name,
messages=history,
options=options,
keep_alive=60
)
return response["message"]["content"].strip(), response.get("eval_count", 0) + response.get("prompt_eval_count", 0)
except Exception as e:
return "**ERROR**: " + str(e), 0
```
## Benefits
1. **Improved Performance**: Uses appropriate context windows based on
conversation length
2. **Better Resource Utilization**: Context window size scales with
content
3. **Maintained Compatibility**: Works with existing business logic
4. **Predictable Scaling**: Context growth in 8192-token increments
5. **Smart Updates**: Context size updates are optimized to reduce
unnecessary model reloads
## Future Considerations
1. Fine-tune buffer ratio based on usage patterns
2. Add monitoring for context window utilization
3. Consider language-specific token counting optimizations
4. Implement adaptive threshold based on conversation patterns
5. Add metrics for context size update frequency
---------
Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
This commit is contained in:
parent
1fbc4870f0
commit
c61df5dd25
@ -179,7 +179,41 @@ class Base(ABC):
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
def _calculate_dynamic_ctx(self, history):
|
||||||
|
"""Calculate dynamic context window size"""
|
||||||
|
def count_tokens(text):
|
||||||
|
"""Calculate token count for text"""
|
||||||
|
# Simple calculation: 1 token per ASCII character
|
||||||
|
# 2 tokens for non-ASCII characters (Chinese, Japanese, Korean, etc.)
|
||||||
|
total = 0
|
||||||
|
for char in text:
|
||||||
|
if ord(char) < 128: # ASCII characters
|
||||||
|
total += 1
|
||||||
|
else: # Non-ASCII characters (Chinese, Japanese, Korean, etc.)
|
||||||
|
total += 2
|
||||||
|
return total
|
||||||
|
|
||||||
|
# Calculate total tokens for all messages
|
||||||
|
total_tokens = 0
|
||||||
|
for message in history:
|
||||||
|
content = message.get("content", "")
|
||||||
|
# Calculate content tokens
|
||||||
|
content_tokens = count_tokens(content)
|
||||||
|
# Add role marker token overhead
|
||||||
|
role_tokens = 4
|
||||||
|
total_tokens += content_tokens + role_tokens
|
||||||
|
|
||||||
|
# Apply 1.2x buffer ratio
|
||||||
|
total_tokens_with_buffer = int(total_tokens * 1.2)
|
||||||
|
|
||||||
|
if total_tokens_with_buffer <= 8192:
|
||||||
|
ctx_size = 8192
|
||||||
|
else:
|
||||||
|
ctx_multiplier = (total_tokens_with_buffer // 8192) + 1
|
||||||
|
ctx_size = ctx_multiplier * 8192
|
||||||
|
|
||||||
|
return ctx_size
|
||||||
|
|
||||||
class GptTurbo(Base):
|
class GptTurbo(Base):
|
||||||
def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1"):
|
def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1"):
|
||||||
@ -469,7 +503,7 @@ class ZhipuChat(Base):
|
|||||||
|
|
||||||
class OllamaChat(Base):
|
class OllamaChat(Base):
|
||||||
def __init__(self, key, model_name, **kwargs):
|
def __init__(self, key, model_name, **kwargs):
|
||||||
self.client = Client(host=kwargs["base_url"]) if not key or key == "x" else Client(host=kwargs["base_url"], headers={"Authorization": f"Bear {key}"})
|
self.client = Client(host=kwargs["base_url"]) if not key or key == "x" else Client(host=kwargs["base_url"], headers={"Authorization": f"Bearer {key}"})
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
|
|
||||||
def chat(self, system, history, gen_conf):
|
def chat(self, system, history, gen_conf):
|
||||||
@ -478,7 +512,12 @@ class OllamaChat(Base):
|
|||||||
if "max_tokens" in gen_conf:
|
if "max_tokens" in gen_conf:
|
||||||
del gen_conf["max_tokens"]
|
del gen_conf["max_tokens"]
|
||||||
try:
|
try:
|
||||||
options = {"num_ctx": 32768}
|
# Calculate context size
|
||||||
|
ctx_size = self._calculate_dynamic_ctx(history)
|
||||||
|
|
||||||
|
options = {
|
||||||
|
"num_ctx": ctx_size
|
||||||
|
}
|
||||||
if "temperature" in gen_conf:
|
if "temperature" in gen_conf:
|
||||||
options["temperature"] = gen_conf["temperature"]
|
options["temperature"] = gen_conf["temperature"]
|
||||||
if "max_tokens" in gen_conf:
|
if "max_tokens" in gen_conf:
|
||||||
@ -489,9 +528,11 @@ class OllamaChat(Base):
|
|||||||
options["presence_penalty"] = gen_conf["presence_penalty"]
|
options["presence_penalty"] = gen_conf["presence_penalty"]
|
||||||
if "frequency_penalty" in gen_conf:
|
if "frequency_penalty" in gen_conf:
|
||||||
options["frequency_penalty"] = gen_conf["frequency_penalty"]
|
options["frequency_penalty"] = gen_conf["frequency_penalty"]
|
||||||
response = self.client.chat(model=self.model_name, messages=history, options=options, keep_alive=-1)
|
|
||||||
|
response = self.client.chat(model=self.model_name, messages=history, options=options, keep_alive=10)
|
||||||
ans = response["message"]["content"].strip()
|
ans = response["message"]["content"].strip()
|
||||||
return ans, response.get("eval_count", 0) + response.get("prompt_eval_count", 0)
|
token_count = response.get("eval_count", 0) + response.get("prompt_eval_count", 0)
|
||||||
|
return ans, token_count
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return "**ERROR**: " + str(e), 0
|
return "**ERROR**: " + str(e), 0
|
||||||
|
|
||||||
@ -500,28 +541,38 @@ class OllamaChat(Base):
|
|||||||
history.insert(0, {"role": "system", "content": system})
|
history.insert(0, {"role": "system", "content": system})
|
||||||
if "max_tokens" in gen_conf:
|
if "max_tokens" in gen_conf:
|
||||||
del gen_conf["max_tokens"]
|
del gen_conf["max_tokens"]
|
||||||
options = {}
|
|
||||||
if "temperature" in gen_conf:
|
|
||||||
options["temperature"] = gen_conf["temperature"]
|
|
||||||
if "max_tokens" in gen_conf:
|
|
||||||
options["num_predict"] = gen_conf["max_tokens"]
|
|
||||||
if "top_p" in gen_conf:
|
|
||||||
options["top_p"] = gen_conf["top_p"]
|
|
||||||
if "presence_penalty" in gen_conf:
|
|
||||||
options["presence_penalty"] = gen_conf["presence_penalty"]
|
|
||||||
if "frequency_penalty" in gen_conf:
|
|
||||||
options["frequency_penalty"] = gen_conf["frequency_penalty"]
|
|
||||||
ans = ""
|
|
||||||
try:
|
try:
|
||||||
response = self.client.chat(model=self.model_name, messages=history, stream=True, options=options, keep_alive=-1)
|
# Calculate context size
|
||||||
for resp in response:
|
ctx_size = self._calculate_dynamic_ctx(history)
|
||||||
if resp["done"]:
|
options = {
|
||||||
yield resp.get("prompt_eval_count", 0) + resp.get("eval_count", 0)
|
"num_ctx": ctx_size
|
||||||
ans = resp["message"]["content"]
|
}
|
||||||
yield ans
|
if "temperature" in gen_conf:
|
||||||
|
options["temperature"] = gen_conf["temperature"]
|
||||||
|
if "max_tokens" in gen_conf:
|
||||||
|
options["num_predict"] = gen_conf["max_tokens"]
|
||||||
|
if "top_p" in gen_conf:
|
||||||
|
options["top_p"] = gen_conf["top_p"]
|
||||||
|
if "presence_penalty" in gen_conf:
|
||||||
|
options["presence_penalty"] = gen_conf["presence_penalty"]
|
||||||
|
if "frequency_penalty" in gen_conf:
|
||||||
|
options["frequency_penalty"] = gen_conf["frequency_penalty"]
|
||||||
|
|
||||||
|
ans = ""
|
||||||
|
try:
|
||||||
|
response = self.client.chat(model=self.model_name, messages=history, stream=True, options=options, keep_alive=10 )
|
||||||
|
for resp in response:
|
||||||
|
if resp["done"]:
|
||||||
|
token_count = resp.get("prompt_eval_count", 0) + resp.get("eval_count", 0)
|
||||||
|
yield token_count
|
||||||
|
ans = resp["message"]["content"]
|
||||||
|
yield ans
|
||||||
|
except Exception as e:
|
||||||
|
yield ans + "\n**ERROR**: " + str(e)
|
||||||
|
yield 0
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
yield ans + "\n**ERROR**: " + str(e)
|
yield "**ERROR**: " + str(e)
|
||||||
yield 0
|
yield 0
|
||||||
|
|
||||||
|
|
||||||
class LocalAIChat(Base):
|
class LocalAIChat(Base):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user