2024-01-15 08:46:22 +08:00
#
2025-03-26 19:33:14 +08:00
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
2024-01-15 08:46:22 +08:00
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
2025-03-26 19:33:14 +08:00
import asyncio
import json
import logging
import os
2025-03-22 23:07:03 +08:00
import random
2025-03-26 19:33:14 +08:00
import re
import time
2023-12-25 19:05:59 +08:00
from abc import ABC
2025-03-26 19:33:14 +08:00
2024-02-27 14:57:34 +08:00
import openai
2025-03-26 19:33:14 +08:00
import requests
from dashscope import Generation
2024-04-08 19:20:57 +08:00
from ollama import Client
2025-03-26 19:33:14 +08:00
from openai import OpenAI
from openai . lib . azure import AzureOpenAI
from zhipuai import ZhipuAI
2024-12-05 13:28:42 +08:00
from rag . nlp import is_chinese , is_english
2024-05-20 12:23:51 +08:00
from rag . utils import num_tokens_from_string
2025-03-22 23:07:03 +08:00
# Error message constants
ERROR_PREFIX = " **ERROR** "
ERROR_RATE_LIMIT = " RATE_LIMIT_EXCEEDED "
ERROR_AUTHENTICATION = " AUTH_ERROR "
ERROR_INVALID_REQUEST = " INVALID_REQUEST "
ERROR_SERVER = " SERVER_ERROR "
ERROR_TIMEOUT = " TIMEOUT "
ERROR_CONNECTION = " CONNECTION_ERROR "
ERROR_MODEL = " MODEL_ERROR "
ERROR_CONTENT_FILTER = " CONTENT_FILTERED "
ERROR_QUOTA = " QUOTA_EXCEEDED "
ERROR_MAX_RETRIES = " MAX_RETRIES_EXCEEDED "
ERROR_GENERIC = " GENERIC_ERROR "
2023-12-28 13:50:13 +08:00
2025-03-12 19:40:54 +08:00
LENGTH_NOTIFICATION_CN = " ······ \n 由于大模型的上下文窗口大小限制,回答已经被大模型截断。 "
2025-03-13 14:43:24 +08:00
LENGTH_NOTIFICATION_EN = " ... \n The answer is truncated by your chosen LLM due to its limitation on context length. "
2024-09-27 13:17:25 +08:00
2025-01-22 19:43:14 +08:00
2023-12-25 19:05:59 +08:00
class Base ( ABC ) :
2024-05-08 10:30:02 +08:00
def __init__ ( self , key , model_name , base_url ) :
2025-03-26 19:33:14 +08:00
timeout = int ( os . environ . get ( " LM_TIMEOUT_SECONDS " , 600 ) )
2024-10-21 12:11:08 +08:00
self . client = OpenAI ( api_key = key , base_url = base_url , timeout = timeout )
2024-01-22 19:51:38 +08:00
self . model_name = model_name
2025-03-22 23:07:03 +08:00
# Configure retry parameters
2025-03-26 19:33:14 +08:00
self . max_retries = int ( os . environ . get ( " LLM_MAX_RETRIES " , 5 ) )
self . base_delay = float ( os . environ . get ( " LLM_BASE_DELAY " , 2.0 ) )
2025-03-22 23:07:03 +08:00
def _get_delay ( self , attempt ) :
""" Calculate retry delay time """
2025-03-26 19:33:14 +08:00
return self . base_delay * ( 2 * * attempt ) + random . uniform ( 0 , 0.5 )
2025-03-22 23:07:03 +08:00
def _classify_error ( self , error ) :
""" Classify error based on error message content """
error_str = str ( error ) . lower ( )
2025-03-26 19:33:14 +08:00
2025-03-22 23:07:03 +08:00
if " rate limit " in error_str or " 429 " in error_str or " tpm limit " in error_str or " too many requests " in error_str or " requests per minute " in error_str :
return ERROR_RATE_LIMIT
elif " auth " in error_str or " key " in error_str or " apikey " in error_str or " 401 " in error_str or " forbidden " in error_str or " permission " in error_str :
return ERROR_AUTHENTICATION
elif " invalid " in error_str or " bad request " in error_str or " 400 " in error_str or " format " in error_str or " malformed " in error_str or " parameter " in error_str :
return ERROR_INVALID_REQUEST
elif " server " in error_str or " 502 " in error_str or " 503 " in error_str or " 504 " in error_str or " 500 " in error_str or " unavailable " in error_str :
return ERROR_SERVER
elif " timeout " in error_str or " timed out " in error_str :
return ERROR_TIMEOUT
elif " connect " in error_str or " network " in error_str or " unreachable " in error_str or " dns " in error_str :
return ERROR_CONNECTION
elif " quota " in error_str or " capacity " in error_str or " credit " in error_str or " billing " in error_str or " limit " in error_str and " rate " not in error_str :
return ERROR_QUOTA
2025-04-02 17:10:57 +08:00
elif " filter " in error_str or " content " in error_str or " policy " in error_str or " blocked " in error_str or " safety " in error_str or " inappropriate " in error_str :
2025-03-22 23:07:03 +08:00
return ERROR_CONTENT_FILTER
elif " model " in error_str or " not found " in error_str or " does not exist " in error_str or " not available " in error_str :
return ERROR_MODEL
else :
return ERROR_GENERIC
2023-12-25 19:05:59 +08:00
def chat ( self , system , history , gen_conf ) :
2024-03-27 11:33:46 +08:00
if system :
history . insert ( 0 , { " role " : " system " , " content " : system } )
2025-03-06 11:29:40 +08:00
if " max_tokens " in gen_conf :
del gen_conf [ " max_tokens " ]
2025-03-22 23:07:03 +08:00
# Implement exponential backoff retry strategy
for attempt in range ( self . max_retries ) :
try :
2025-03-26 19:33:14 +08:00
response = self . client . chat . completions . create ( model = self . model_name , messages = history , * * gen_conf )
2025-03-22 23:07:03 +08:00
if any ( [ not response . choices , not response . choices [ 0 ] . message , not response . choices [ 0 ] . message . content ] ) :
return " " , 0
ans = response . choices [ 0 ] . message . content . strip ( )
if response . choices [ 0 ] . finish_reason == " length " :
if is_chinese ( ans ) :
ans + = LENGTH_NOTIFICATION_CN
else :
ans + = LENGTH_NOTIFICATION_EN
2025-03-26 19:33:14 +08:00
return ans , self . total_token_count ( response )
2025-03-22 23:07:03 +08:00
except Exception as e :
2025-04-02 17:10:57 +08:00
logging . exception ( " chat_model.Base.chat got exception " )
2025-03-22 23:07:03 +08:00
# Classify the error
error_code = self . _classify_error ( e )
2025-03-26 19:33:14 +08:00
2025-03-22 23:07:03 +08:00
# Check if it's a rate limit error or server error and not the last attempt
should_retry = ( error_code == ERROR_RATE_LIMIT or error_code == ERROR_SERVER ) and attempt < self . max_retries - 1
2025-03-26 19:33:14 +08:00
2025-03-22 23:07:03 +08:00
if should_retry :
delay = self . _get_delay ( attempt )
2025-03-26 19:33:14 +08:00
logging . warning ( f " Error: { error_code } . Retrying in { delay : .2f } seconds... (Attempt { attempt + 1 } / { self . max_retries } ) " )
2025-03-22 23:07:03 +08:00
time . sleep ( delay )
2024-12-04 09:34:49 +08:00
else :
2025-03-22 23:07:03 +08:00
# For non-rate limit errors or the last attempt, return an error message
if attempt == self . max_retries - 1 :
error_code = ERROR_MAX_RETRIES
2025-04-02 17:10:57 +08:00
return f " { ERROR_PREFIX } : { error_code } - { str ( e ) } . response: { response } " , 0
2023-12-25 19:05:59 +08:00
2024-05-16 20:14:53 +08:00
def chat_streamly ( self , system , history , gen_conf ) :
if system :
history . insert ( 0 , { " role " : " system " , " content " : system } )
2025-03-06 11:29:40 +08:00
if " max_tokens " in gen_conf :
del gen_conf [ " max_tokens " ]
2024-05-16 20:14:53 +08:00
ans = " "
total_tokens = 0
2025-03-26 19:33:14 +08:00
reasoning_start = False
2024-05-16 20:14:53 +08:00
try :
2025-03-26 19:33:14 +08:00
response = self . client . chat . completions . create ( model = self . model_name , messages = history , stream = True , * * gen_conf )
2024-05-16 20:14:53 +08:00
for resp in response :
2024-12-08 14:21:12 +08:00
if not resp . choices :
continue
2024-07-29 09:21:31 +08:00
if not resp . choices [ 0 ] . delta . content :
2024-10-08 18:27:04 +08:00
resp . choices [ 0 ] . delta . content = " "
2025-02-12 15:43:13 +08:00
if hasattr ( resp . choices [ 0 ] . delta , " reasoning_content " ) and resp . choices [ 0 ] . delta . reasoning_content :
2025-03-26 19:33:14 +08:00
ans = " "
if not reasoning_start :
reasoning_start = True
ans = " <think> "
2025-02-12 15:43:13 +08:00
ans + = resp . choices [ 0 ] . delta . reasoning_content + " </think> "
else :
2025-03-26 19:33:14 +08:00
reasoning_start = False
ans = resp . choices [ 0 ] . delta . content
2024-10-22 11:40:05 +08:00
2025-01-26 13:54:26 +08:00
tol = self . total_token_count ( resp )
if not tol :
total_tokens + = num_tokens_from_string ( resp . choices [ 0 ] . delta . content )
2024-12-08 14:21:12 +08:00
else :
2025-01-26 13:54:26 +08:00
total_tokens = tol
2024-10-22 11:38:37 +08:00
2024-05-16 20:14:53 +08:00
if resp . choices [ 0 ] . finish_reason == " length " :
2024-12-04 09:34:49 +08:00
if is_chinese ( ans ) :
ans + = LENGTH_NOTIFICATION_CN
else :
ans + = LENGTH_NOTIFICATION_EN
2024-05-16 20:14:53 +08:00
yield ans
except openai . APIError as e :
yield ans + " \n **ERROR**: " + str ( e )
yield total_tokens
2025-01-26 13:54:26 +08:00
def total_token_count ( self , resp ) :
try :
return resp . usage . total_tokens
except Exception :
pass
try :
return resp [ " usage " ] [ " total_tokens " ]
except Exception :
pass
return 0
Dynamic Context Window Size for Ollama Chat (#6582)
# Dynamic Context Window Size for Ollama Chat
## Problem Statement
Previously, the Ollama chat implementation used a fixed context window
size of 32768 tokens. This caused two main issues:
1. Performance degradation due to unnecessarily large context windows
for small conversations
2. Potential business logic failures when using smaller fixed sizes
(e.g., 2048 tokens)
## Solution
Implemented a dynamic context window size calculation that:
1. Uses a base context size of 8192 tokens
2. Applies a 1.2x buffer ratio to the total token count
3. Adds multiples of 8192 tokens based on the buffered token count
4. Implements a smart context size update strategy
## Implementation Details
### Token Counting Logic
```python
def count_tokens(text):
"""Calculate token count for text"""
# Simple calculation: 1 token per ASCII character
# 2 tokens for non-ASCII characters (Chinese, Japanese, Korean, etc.)
total = 0
for char in text:
if ord(char) < 128: # ASCII characters
total += 1
else: # Non-ASCII characters
total += 2
return total
```
### Dynamic Context Calculation
```python
def _calculate_dynamic_ctx(self, history):
"""Calculate dynamic context window size"""
# Calculate total tokens for all messages
total_tokens = 0
for message in history:
content = message.get("content", "")
content_tokens = count_tokens(content)
role_tokens = 4 # Role marker token overhead
total_tokens += content_tokens + role_tokens
# Apply 1.2x buffer ratio
total_tokens_with_buffer = int(total_tokens * 1.2)
# Calculate context size in multiples of 8192
if total_tokens_with_buffer <= 8192:
ctx_size = 8192
else:
ctx_multiplier = (total_tokens_with_buffer // 8192) + 1
ctx_size = ctx_multiplier * 8192
return ctx_size
```
### Integration in Chat Method
```python
def chat(self, system, history, gen_conf):
if system:
history.insert(0, {"role": "system", "content": system})
if "max_tokens" in gen_conf:
del gen_conf["max_tokens"]
try:
# Calculate new context size
new_ctx_size = self._calculate_dynamic_ctx(history)
# Prepare options with context size
options = {
"num_ctx": new_ctx_size
}
# Add other generation options
if "temperature" in gen_conf:
options["temperature"] = gen_conf["temperature"]
if "max_tokens" in gen_conf:
options["num_predict"] = gen_conf["max_tokens"]
if "top_p" in gen_conf:
options["top_p"] = gen_conf["top_p"]
if "presence_penalty" in gen_conf:
options["presence_penalty"] = gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf:
options["frequency_penalty"] = gen_conf["frequency_penalty"]
# Make API call with dynamic context size
response = self.client.chat(
model=self.model_name,
messages=history,
options=options,
keep_alive=60
)
return response["message"]["content"].strip(), response.get("eval_count", 0) + response.get("prompt_eval_count", 0)
except Exception as e:
return "**ERROR**: " + str(e), 0
```
## Benefits
1. **Improved Performance**: Uses appropriate context windows based on
conversation length
2. **Better Resource Utilization**: Context window size scales with
content
3. **Maintained Compatibility**: Works with existing business logic
4. **Predictable Scaling**: Context growth in 8192-token increments
5. **Smart Updates**: Context size updates are optimized to reduce
unnecessary model reloads
## Future Considerations
1. Fine-tune buffer ratio based on usage patterns
2. Add monitoring for context window utilization
3. Consider language-specific token counting optimizations
4. Implement adaptive threshold based on conversation patterns
5. Add metrics for context size update frequency
---------
Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
2025-03-28 12:38:27 +08:00
def _calculate_dynamic_ctx ( self , history ) :
""" Calculate dynamic context window size """
def count_tokens ( text ) :
""" Calculate token count for text """
# Simple calculation: 1 token per ASCII character
# 2 tokens for non-ASCII characters (Chinese, Japanese, Korean, etc.)
total = 0
for char in text :
if ord ( char ) < 128 : # ASCII characters
total + = 1
else : # Non-ASCII characters (Chinese, Japanese, Korean, etc.)
total + = 2
return total
# Calculate total tokens for all messages
total_tokens = 0
for message in history :
content = message . get ( " content " , " " )
# Calculate content tokens
content_tokens = count_tokens ( content )
# Add role marker token overhead
role_tokens = 4
total_tokens + = content_tokens + role_tokens
# Apply 1.2x buffer ratio
total_tokens_with_buffer = int ( total_tokens * 1.2 )
if total_tokens_with_buffer < = 8192 :
ctx_size = 8192
else :
ctx_multiplier = ( total_tokens_with_buffer / / 8192 ) + 1
ctx_size = ctx_multiplier * 8192
return ctx_size
2023-12-25 19:05:59 +08:00
2024-05-08 10:30:02 +08:00
class GptTurbo ( Base ) :
def __init__ ( self , key , model_name = " gpt-3.5-turbo " , base_url = " https://api.openai.com/v1 " ) :
2024-12-08 14:21:12 +08:00
if not base_url :
base_url = " https://api.openai.com/v1 "
2024-05-08 10:30:02 +08:00
super ( ) . __init__ ( key , model_name , base_url )
class MoonshotChat ( Base ) :
2024-03-28 19:15:16 +08:00
def __init__ ( self , key , model_name = " moonshot-v1-8k " , base_url = " https://api.moonshot.cn/v1 " ) :
2024-12-08 14:21:12 +08:00
if not base_url :
base_url = " https://api.moonshot.cn/v1 "
2024-05-08 10:30:02 +08:00
super ( ) . __init__ ( key , model_name , base_url )
2024-03-14 19:45:29 +08:00
2024-07-25 10:23:35 +08:00
2024-05-08 10:30:02 +08:00
class XinferenceChat ( Base ) :
def __init__ ( self , key = None , model_name = " " , base_url = " " ) :
2024-07-25 10:23:35 +08:00
if not base_url :
raise ValueError ( " Local llm url cannot be None " )
if base_url . split ( " / " ) [ - 1 ] != " v1 " :
2024-08-07 18:10:42 +08:00
base_url = os . path . join ( base_url , " v1 " )
2024-05-08 10:30:02 +08:00
super ( ) . __init__ ( key , model_name , base_url )
2024-10-29 10:42:45 +08:00
2024-10-11 14:45:48 +08:00
class HuggingFaceChat ( Base ) :
def __init__ ( self , key = None , model_name = " " , base_url = " " ) :
if not base_url :
raise ValueError ( " Local llm url cannot be None " )
if base_url . split ( " / " ) [ - 1 ] != " v1 " :
base_url = os . path . join ( base_url , " v1 " )
2024-12-05 13:28:42 +08:00
super ( ) . __init__ ( key , model_name . split ( " ___ " ) [ 0 ] , base_url )
2024-05-08 10:30:02 +08:00
2024-10-29 10:42:45 +08:00
2025-02-24 10:12:20 +08:00
class ModelScopeChat ( Base ) :
def __init__ ( self , key = None , model_name = " " , base_url = " " ) :
if not base_url :
raise ValueError ( " Local llm url cannot be None " )
2025-03-26 19:33:14 +08:00
base_url = base_url . rstrip ( " / " )
2025-02-24 10:12:20 +08:00
if base_url . split ( " / " ) [ - 1 ] != " v1 " :
base_url = os . path . join ( base_url , " v1 " )
super ( ) . __init__ ( key , model_name . split ( " ___ " ) [ 0 ] , base_url )
2024-05-08 10:30:02 +08:00
class DeepSeekChat ( Base ) :
def __init__ ( self , key , model_name = " deepseek-chat " , base_url = " https://api.deepseek.com/v1 " ) :
2024-12-08 14:21:12 +08:00
if not base_url :
base_url = " https://api.deepseek.com/v1 "
2024-05-08 10:30:02 +08:00
super ( ) . __init__ ( key , model_name , base_url )
2024-03-15 18:59:00 +08:00
2024-07-04 15:57:25 +08:00
2024-07-04 09:57:16 +08:00
class AzureChat ( Base ) :
def __init__ ( self , key , model_name , * * kwargs ) :
2025-03-26 19:33:14 +08:00
api_key = json . loads ( key ) . get ( " api_key " , " " )
api_version = json . loads ( key ) . get ( " api_version " , " 2024-02-01 " )
2025-03-27 14:59:15 +08:00
super ( ) . __init__ ( key , model_name , kwargs [ " base_url " ] )
2024-10-11 11:26:42 +08:00
self . client = AzureOpenAI ( api_key = api_key , azure_endpoint = kwargs [ " base_url " ] , api_version = api_version )
2024-07-04 09:57:16 +08:00
self . model_name = model_name
2024-03-14 19:45:29 +08:00
2024-05-28 09:09:37 +08:00
class BaiChuanChat ( Base ) :
def __init__ ( self , key , model_name = " Baichuan3-Turbo " , base_url = " https://api.baichuan-ai.com/v1 " ) :
if not base_url :
base_url = " https://api.baichuan-ai.com/v1 "
super ( ) . __init__ ( key , model_name , base_url )
@staticmethod
def _format_params ( params ) :
return {
" temperature " : params . get ( " temperature " , 0.3 ) ,
" top_p " : params . get ( " top_p " , 0.85 ) ,
}
def chat ( self , system , history , gen_conf ) :
if system :
history . insert ( 0 , { " role " : " system " , " content " : system } )
2025-03-06 11:29:40 +08:00
if " max_tokens " in gen_conf :
del gen_conf [ " max_tokens " ]
2024-05-28 09:09:37 +08:00
try :
response = self . client . chat . completions . create (
model = self . model_name ,
messages = history ,
2025-03-26 19:33:14 +08:00
extra_body = { " tools " : [ { " type " : " web_search " , " web_search " : { " enable " : True , " search_mode " : " performance_first " } } ] } ,
* * self . _format_params ( gen_conf ) ,
)
2024-05-28 09:09:37 +08:00
ans = response . choices [ 0 ] . message . content . strip ( )
if response . choices [ 0 ] . finish_reason == " length " :
2024-12-04 09:34:49 +08:00
if is_chinese ( [ ans ] ) :
ans + = LENGTH_NOTIFICATION_CN
else :
ans + = LENGTH_NOTIFICATION_EN
2025-01-26 13:54:26 +08:00
return ans , self . total_token_count ( response )
2024-05-28 09:09:37 +08:00
except openai . APIError as e :
return " **ERROR**: " + str ( e ) , 0
def chat_streamly ( self , system , history , gen_conf ) :
if system :
history . insert ( 0 , { " role " : " system " , " content " : system } )
2025-03-06 11:29:40 +08:00
if " max_tokens " in gen_conf :
del gen_conf [ " max_tokens " ]
2024-05-28 09:09:37 +08:00
ans = " "
total_tokens = 0
try :
response = self . client . chat . completions . create (
model = self . model_name ,
messages = history ,
2025-03-26 19:33:14 +08:00
extra_body = { " tools " : [ { " type " : " web_search " , " web_search " : { " enable " : True , " search_mode " : " performance_first " } } ] } ,
2024-05-28 09:09:37 +08:00
stream = True ,
2025-03-26 19:33:14 +08:00
* * self . _format_params ( gen_conf ) ,
)
2024-05-28 09:09:37 +08:00
for resp in response :
2024-12-08 14:21:12 +08:00
if not resp . choices :
continue
2024-05-28 09:09:37 +08:00
if not resp . choices [ 0 ] . delta . content :
2024-10-08 18:27:04 +08:00
resp . choices [ 0 ] . delta . content = " "
2025-03-26 19:33:14 +08:00
ans = resp . choices [ 0 ] . delta . content
2025-01-26 13:54:26 +08:00
tol = self . total_token_count ( resp )
if not tol :
total_tokens + = num_tokens_from_string ( resp . choices [ 0 ] . delta . content )
else :
total_tokens = tol
2024-05-28 09:09:37 +08:00
if resp . choices [ 0 ] . finish_reason == " length " :
2024-12-04 09:34:49 +08:00
if is_chinese ( [ ans ] ) :
ans + = LENGTH_NOTIFICATION_CN
else :
ans + = LENGTH_NOTIFICATION_EN
2024-05-28 09:09:37 +08:00
yield ans
except Exception as e :
yield ans + " \n **ERROR**: " + str ( e )
yield total_tokens
2023-12-28 13:50:13 +08:00
class QWenChat ( Base ) :
2024-03-28 19:15:16 +08:00
def __init__ ( self , key , model_name = Generation . Models . qwen_turbo , * * kwargs ) :
2024-01-22 19:51:38 +08:00
import dashscope
2025-03-26 19:33:14 +08:00
2024-01-22 19:51:38 +08:00
dashscope . api_key = key
self . model_name = model_name
2025-03-12 18:54:15 +08:00
if self . is_reasoning_model ( self . model_name ) :
2025-02-24 15:43:32 +08:00
super ( ) . __init__ ( key , model_name , " https://dashscope.aliyuncs.com/compatible-mode/v1 " )
2024-01-22 19:51:38 +08:00
2023-12-25 19:05:59 +08:00
def chat ( self , system , history , gen_conf ) :
2025-03-06 11:29:40 +08:00
if " max_tokens " in gen_conf :
del gen_conf [ " max_tokens " ]
2025-03-12 18:54:15 +08:00
if self . is_reasoning_model ( self . model_name ) :
2025-02-25 17:42:29 +08:00
return super ( ) . chat ( system , history , gen_conf )
2025-02-24 14:04:25 +08:00
2025-03-26 19:33:14 +08:00
stream_flag = str ( os . environ . get ( " QWEN_CHAT_BY_STREAM " , " true " ) ) . lower ( ) == " true "
2024-10-21 12:11:08 +08:00
if not stream_flag :
from http import HTTPStatus
2025-03-26 19:33:14 +08:00
2024-10-21 12:11:08 +08:00
if system :
history . insert ( 0 , { " role " : " system " , " content " : system } )
2024-03-07 16:12:01 +08:00
2025-03-26 19:33:14 +08:00
response = Generation . call ( self . model_name , messages = history , result_format = " message " , * * gen_conf )
2024-10-21 12:11:08 +08:00
ans = " "
tk_count = 0
if response . status_code == HTTPStatus . OK :
2025-03-26 19:33:14 +08:00
ans + = response . output . choices [ 0 ] [ " message " ] [ " content " ]
2025-01-26 13:54:26 +08:00
tk_count + = self . total_token_count ( response )
2024-10-21 12:11:08 +08:00
if response . output . choices [ 0 ] . get ( " finish_reason " , " " ) == " length " :
2024-12-04 09:34:49 +08:00
if is_chinese ( [ ans ] ) :
ans + = LENGTH_NOTIFICATION_CN
else :
ans + = LENGTH_NOTIFICATION_EN
2024-10-21 12:11:08 +08:00
return ans , tk_count
2024-02-08 17:01:01 +08:00
2024-10-21 12:11:08 +08:00
return " **ERROR**: " + response . message , tk_count
else :
g = self . _chat_streamly ( system , history , gen_conf , incremental_output = True )
result_list = list ( g )
error_msg_list = [ item for item in result_list if str ( item ) . find ( " **ERROR** " ) > = 0 ]
if len ( error_msg_list ) > 0 :
2025-03-18 13:37:34 +08:00
return " **ERROR**: " + " " . join ( error_msg_list ) , 0
2024-10-21 12:11:08 +08:00
else :
return " " . join ( result_list [ : - 1 ] ) , result_list [ - 1 ]
2025-03-26 19:33:14 +08:00
def _chat_streamly ( self , system , history , gen_conf , incremental_output = True ) :
2024-05-16 20:14:53 +08:00
from http import HTTPStatus
2025-03-26 19:33:14 +08:00
2024-05-16 20:14:53 +08:00
if system :
history . insert ( 0 , { " role " : " system " , " content " : system } )
2025-03-06 11:29:40 +08:00
if " max_tokens " in gen_conf :
del gen_conf [ " max_tokens " ]
2024-05-16 20:14:53 +08:00
ans = " "
2024-05-30 16:18:15 +08:00
tk_count = 0
2024-05-16 20:14:53 +08:00
try :
2025-03-26 19:33:14 +08:00
response = Generation . call ( self . model_name , messages = history , result_format = " message " , stream = True , incremental_output = incremental_output , * * gen_conf )
2024-05-16 20:14:53 +08:00
for resp in response :
if resp . status_code == HTTPStatus . OK :
2025-03-26 19:33:14 +08:00
ans = resp . output . choices [ 0 ] [ " message " ] [ " content " ]
2025-01-26 13:54:26 +08:00
tk_count = self . total_token_count ( resp )
2024-05-16 20:14:53 +08:00
if resp . output . choices [ 0 ] . get ( " finish_reason " , " " ) == " length " :
2024-12-04 09:34:49 +08:00
if is_chinese ( ans ) :
ans + = LENGTH_NOTIFICATION_CN
else :
ans + = LENGTH_NOTIFICATION_EN
2024-05-16 20:14:53 +08:00
yield ans
else :
2025-03-26 19:33:14 +08:00
yield (
ans + " \n **ERROR**: " + resp . message
if not re . search ( r " (key|quota) " , str ( resp . message ) . lower ( ) )
else " Out of credit. Please set the API key in **settings > Model providers.** "
)
2024-05-16 20:14:53 +08:00
except Exception as e :
yield ans + " \n **ERROR**: " + str ( e )
yield tk_count
2024-10-21 12:11:08 +08:00
def chat_streamly ( self , system , history , gen_conf ) :
2025-03-06 11:29:40 +08:00
if " max_tokens " in gen_conf :
del gen_conf [ " max_tokens " ]
2025-03-12 18:54:15 +08:00
if self . is_reasoning_model ( self . model_name ) :
2025-02-25 17:42:29 +08:00
return super ( ) . chat_streamly ( system , history , gen_conf )
2025-02-24 15:43:32 +08:00
2024-10-21 12:11:08 +08:00
return self . _chat_streamly ( system , history , gen_conf )
2025-03-12 18:54:15 +08:00
@staticmethod
def is_reasoning_model ( model_name : str ) - > bool :
2025-03-26 19:33:14 +08:00
return any (
[
model_name . lower ( ) . find ( " deepseek " ) > = 0 ,
model_name . lower ( ) . find ( " qwq " ) > = 0 and model_name . lower ( ) != " qwq-32b-preview " ,
]
)
2025-03-12 18:54:15 +08:00
2024-02-08 17:01:01 +08:00
class ZhipuChat ( Base ) :
2024-03-28 19:15:16 +08:00
def __init__ ( self , key , model_name = " glm-3-turbo " , * * kwargs ) :
2024-02-08 17:01:01 +08:00
self . client = ZhipuAI ( api_key = key )
self . model_name = model_name
def chat ( self , system , history , gen_conf ) :
2024-03-27 11:33:46 +08:00
if system :
history . insert ( 0 , { " role " : " system " , " content " : system } )
2025-03-06 11:29:40 +08:00
if " max_tokens " in gen_conf :
del gen_conf [ " max_tokens " ]
2024-03-07 16:12:01 +08:00
try :
2024-12-08 14:21:12 +08:00
if " presence_penalty " in gen_conf :
del gen_conf [ " presence_penalty " ]
if " frequency_penalty " in gen_conf :
del gen_conf [ " frequency_penalty " ]
2025-03-26 19:33:14 +08:00
response = self . client . chat . completions . create ( model = self . model_name , messages = history , * * gen_conf )
2024-03-27 17:55:45 +08:00
ans = response . choices [ 0 ] . message . content . strip ( )
if response . choices [ 0 ] . finish_reason == " length " :
2024-12-04 09:34:49 +08:00
if is_chinese ( ans ) :
ans + = LENGTH_NOTIFICATION_CN
else :
ans + = LENGTH_NOTIFICATION_EN
2025-01-26 13:54:26 +08:00
return ans , self . total_token_count ( response )
2024-03-07 16:12:01 +08:00
except Exception as e :
2024-03-12 11:57:08 +08:00
return " **ERROR**: " + str ( e ) , 0
2024-05-16 20:14:53 +08:00
def chat_streamly ( self , system , history , gen_conf ) :
if system :
history . insert ( 0 , { " role " : " system " , " content " : system } )
2025-03-06 11:29:40 +08:00
if " max_tokens " in gen_conf :
del gen_conf [ " max_tokens " ]
2024-12-08 14:21:12 +08:00
if " presence_penalty " in gen_conf :
del gen_conf [ " presence_penalty " ]
if " frequency_penalty " in gen_conf :
del gen_conf [ " frequency_penalty " ]
2024-05-16 20:14:53 +08:00
ans = " "
2024-05-30 16:18:15 +08:00
tk_count = 0
2024-05-16 20:14:53 +08:00
try :
2025-03-26 19:33:14 +08:00
response = self . client . chat . completions . create ( model = self . model_name , messages = history , stream = True , * * gen_conf )
2024-05-16 20:14:53 +08:00
for resp in response :
2024-12-08 14:21:12 +08:00
if not resp . choices [ 0 ] . delta . content :
continue
2024-05-16 20:14:53 +08:00
delta = resp . choices [ 0 ] . delta . content
2025-03-26 19:33:14 +08:00
ans = delta
2024-05-17 17:07:33 +08:00
if resp . choices [ 0 ] . finish_reason == " length " :
2024-12-04 09:34:49 +08:00
if is_chinese ( ans ) :
ans + = LENGTH_NOTIFICATION_CN
else :
ans + = LENGTH_NOTIFICATION_EN
2025-01-26 13:54:26 +08:00
tk_count = self . total_token_count ( resp )
2024-12-08 14:21:12 +08:00
if resp . choices [ 0 ] . finish_reason == " stop " :
2025-01-26 13:54:26 +08:00
tk_count = self . total_token_count ( resp )
2024-05-16 20:14:53 +08:00
yield ans
except Exception as e :
yield ans + " \n **ERROR**: " + str ( e )
yield tk_count
2024-03-27 11:33:46 +08:00
2024-04-08 19:20:57 +08:00
class OllamaChat ( Base ) :
def __init__ ( self , key , model_name , * * kwargs ) :
Dynamic Context Window Size for Ollama Chat (#6582)
# Dynamic Context Window Size for Ollama Chat
## Problem Statement
Previously, the Ollama chat implementation used a fixed context window
size of 32768 tokens. This caused two main issues:
1. Performance degradation due to unnecessarily large context windows
for small conversations
2. Potential business logic failures when using smaller fixed sizes
(e.g., 2048 tokens)
## Solution
Implemented a dynamic context window size calculation that:
1. Uses a base context size of 8192 tokens
2. Applies a 1.2x buffer ratio to the total token count
3. Adds multiples of 8192 tokens based on the buffered token count
4. Implements a smart context size update strategy
## Implementation Details
### Token Counting Logic
```python
def count_tokens(text):
"""Calculate token count for text"""
# Simple calculation: 1 token per ASCII character
# 2 tokens for non-ASCII characters (Chinese, Japanese, Korean, etc.)
total = 0
for char in text:
if ord(char) < 128: # ASCII characters
total += 1
else: # Non-ASCII characters
total += 2
return total
```
### Dynamic Context Calculation
```python
def _calculate_dynamic_ctx(self, history):
"""Calculate dynamic context window size"""
# Calculate total tokens for all messages
total_tokens = 0
for message in history:
content = message.get("content", "")
content_tokens = count_tokens(content)
role_tokens = 4 # Role marker token overhead
total_tokens += content_tokens + role_tokens
# Apply 1.2x buffer ratio
total_tokens_with_buffer = int(total_tokens * 1.2)
# Calculate context size in multiples of 8192
if total_tokens_with_buffer <= 8192:
ctx_size = 8192
else:
ctx_multiplier = (total_tokens_with_buffer // 8192) + 1
ctx_size = ctx_multiplier * 8192
return ctx_size
```
### Integration in Chat Method
```python
def chat(self, system, history, gen_conf):
if system:
history.insert(0, {"role": "system", "content": system})
if "max_tokens" in gen_conf:
del gen_conf["max_tokens"]
try:
# Calculate new context size
new_ctx_size = self._calculate_dynamic_ctx(history)
# Prepare options with context size
options = {
"num_ctx": new_ctx_size
}
# Add other generation options
if "temperature" in gen_conf:
options["temperature"] = gen_conf["temperature"]
if "max_tokens" in gen_conf:
options["num_predict"] = gen_conf["max_tokens"]
if "top_p" in gen_conf:
options["top_p"] = gen_conf["top_p"]
if "presence_penalty" in gen_conf:
options["presence_penalty"] = gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf:
options["frequency_penalty"] = gen_conf["frequency_penalty"]
# Make API call with dynamic context size
response = self.client.chat(
model=self.model_name,
messages=history,
options=options,
keep_alive=60
)
return response["message"]["content"].strip(), response.get("eval_count", 0) + response.get("prompt_eval_count", 0)
except Exception as e:
return "**ERROR**: " + str(e), 0
```
## Benefits
1. **Improved Performance**: Uses appropriate context windows based on
conversation length
2. **Better Resource Utilization**: Context window size scales with
content
3. **Maintained Compatibility**: Works with existing business logic
4. **Predictable Scaling**: Context growth in 8192-token increments
5. **Smart Updates**: Context size updates are optimized to reduce
unnecessary model reloads
## Future Considerations
1. Fine-tune buffer ratio based on usage patterns
2. Add monitoring for context window utilization
3. Consider language-specific token counting optimizations
4. Implement adaptive threshold based on conversation patterns
5. Add metrics for context size update frequency
---------
Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
2025-03-28 12:38:27 +08:00
self . client = Client ( host = kwargs [ " base_url " ] ) if not key or key == " x " else Client ( host = kwargs [ " base_url " ] , headers = { " Authorization " : f " Bearer { key } " } )
2024-04-08 19:20:57 +08:00
self . model_name = model_name
def chat ( self , system , history , gen_conf ) :
if system :
history . insert ( 0 , { " role " : " system " , " content " : system } )
2025-03-06 11:29:40 +08:00
if " max_tokens " in gen_conf :
del gen_conf [ " max_tokens " ]
2024-04-08 19:20:57 +08:00
try :
Dynamic Context Window Size for Ollama Chat (#6582)
# Dynamic Context Window Size for Ollama Chat
## Problem Statement
Previously, the Ollama chat implementation used a fixed context window
size of 32768 tokens. This caused two main issues:
1. Performance degradation due to unnecessarily large context windows
for small conversations
2. Potential business logic failures when using smaller fixed sizes
(e.g., 2048 tokens)
## Solution
Implemented a dynamic context window size calculation that:
1. Uses a base context size of 8192 tokens
2. Applies a 1.2x buffer ratio to the total token count
3. Adds multiples of 8192 tokens based on the buffered token count
4. Implements a smart context size update strategy
## Implementation Details
### Token Counting Logic
```python
def count_tokens(text):
"""Calculate token count for text"""
# Simple calculation: 1 token per ASCII character
# 2 tokens for non-ASCII characters (Chinese, Japanese, Korean, etc.)
total = 0
for char in text:
if ord(char) < 128: # ASCII characters
total += 1
else: # Non-ASCII characters
total += 2
return total
```
### Dynamic Context Calculation
```python
def _calculate_dynamic_ctx(self, history):
"""Calculate dynamic context window size"""
# Calculate total tokens for all messages
total_tokens = 0
for message in history:
content = message.get("content", "")
content_tokens = count_tokens(content)
role_tokens = 4 # Role marker token overhead
total_tokens += content_tokens + role_tokens
# Apply 1.2x buffer ratio
total_tokens_with_buffer = int(total_tokens * 1.2)
# Calculate context size in multiples of 8192
if total_tokens_with_buffer <= 8192:
ctx_size = 8192
else:
ctx_multiplier = (total_tokens_with_buffer // 8192) + 1
ctx_size = ctx_multiplier * 8192
return ctx_size
```
### Integration in Chat Method
```python
def chat(self, system, history, gen_conf):
if system:
history.insert(0, {"role": "system", "content": system})
if "max_tokens" in gen_conf:
del gen_conf["max_tokens"]
try:
# Calculate new context size
new_ctx_size = self._calculate_dynamic_ctx(history)
# Prepare options with context size
options = {
"num_ctx": new_ctx_size
}
# Add other generation options
if "temperature" in gen_conf:
options["temperature"] = gen_conf["temperature"]
if "max_tokens" in gen_conf:
options["num_predict"] = gen_conf["max_tokens"]
if "top_p" in gen_conf:
options["top_p"] = gen_conf["top_p"]
if "presence_penalty" in gen_conf:
options["presence_penalty"] = gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf:
options["frequency_penalty"] = gen_conf["frequency_penalty"]
# Make API call with dynamic context size
response = self.client.chat(
model=self.model_name,
messages=history,
options=options,
keep_alive=60
)
return response["message"]["content"].strip(), response.get("eval_count", 0) + response.get("prompt_eval_count", 0)
except Exception as e:
return "**ERROR**: " + str(e), 0
```
## Benefits
1. **Improved Performance**: Uses appropriate context windows based on
conversation length
2. **Better Resource Utilization**: Context window size scales with
content
3. **Maintained Compatibility**: Works with existing business logic
4. **Predictable Scaling**: Context growth in 8192-token increments
5. **Smart Updates**: Context size updates are optimized to reduce
unnecessary model reloads
## Future Considerations
1. Fine-tune buffer ratio based on usage patterns
2. Add monitoring for context window utilization
3. Consider language-specific token counting optimizations
4. Implement adaptive threshold based on conversation patterns
5. Add metrics for context size update frequency
---------
Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
2025-03-28 12:38:27 +08:00
# Calculate context size
ctx_size = self . _calculate_dynamic_ctx ( history )
options = {
" num_ctx " : ctx_size
}
2024-12-08 14:21:12 +08:00
if " temperature " in gen_conf :
options [ " temperature " ] = gen_conf [ " temperature " ]
if " max_tokens " in gen_conf :
options [ " num_predict " ] = gen_conf [ " max_tokens " ]
if " top_p " in gen_conf :
options [ " top_p " ] = gen_conf [ " top_p " ]
if " presence_penalty " in gen_conf :
options [ " presence_penalty " ] = gen_conf [ " presence_penalty " ]
if " frequency_penalty " in gen_conf :
options [ " frequency_penalty " ] = gen_conf [ " frequency_penalty " ]
Dynamic Context Window Size for Ollama Chat (#6582)
# Dynamic Context Window Size for Ollama Chat
## Problem Statement
Previously, the Ollama chat implementation used a fixed context window
size of 32768 tokens. This caused two main issues:
1. Performance degradation due to unnecessarily large context windows
for small conversations
2. Potential business logic failures when using smaller fixed sizes
(e.g., 2048 tokens)
## Solution
Implemented a dynamic context window size calculation that:
1. Uses a base context size of 8192 tokens
2. Applies a 1.2x buffer ratio to the total token count
3. Adds multiples of 8192 tokens based on the buffered token count
4. Implements a smart context size update strategy
## Implementation Details
### Token Counting Logic
```python
def count_tokens(text):
"""Calculate token count for text"""
# Simple calculation: 1 token per ASCII character
# 2 tokens for non-ASCII characters (Chinese, Japanese, Korean, etc.)
total = 0
for char in text:
if ord(char) < 128: # ASCII characters
total += 1
else: # Non-ASCII characters
total += 2
return total
```
### Dynamic Context Calculation
```python
def _calculate_dynamic_ctx(self, history):
"""Calculate dynamic context window size"""
# Calculate total tokens for all messages
total_tokens = 0
for message in history:
content = message.get("content", "")
content_tokens = count_tokens(content)
role_tokens = 4 # Role marker token overhead
total_tokens += content_tokens + role_tokens
# Apply 1.2x buffer ratio
total_tokens_with_buffer = int(total_tokens * 1.2)
# Calculate context size in multiples of 8192
if total_tokens_with_buffer <= 8192:
ctx_size = 8192
else:
ctx_multiplier = (total_tokens_with_buffer // 8192) + 1
ctx_size = ctx_multiplier * 8192
return ctx_size
```
### Integration in Chat Method
```python
def chat(self, system, history, gen_conf):
if system:
history.insert(0, {"role": "system", "content": system})
if "max_tokens" in gen_conf:
del gen_conf["max_tokens"]
try:
# Calculate new context size
new_ctx_size = self._calculate_dynamic_ctx(history)
# Prepare options with context size
options = {
"num_ctx": new_ctx_size
}
# Add other generation options
if "temperature" in gen_conf:
options["temperature"] = gen_conf["temperature"]
if "max_tokens" in gen_conf:
options["num_predict"] = gen_conf["max_tokens"]
if "top_p" in gen_conf:
options["top_p"] = gen_conf["top_p"]
if "presence_penalty" in gen_conf:
options["presence_penalty"] = gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf:
options["frequency_penalty"] = gen_conf["frequency_penalty"]
# Make API call with dynamic context size
response = self.client.chat(
model=self.model_name,
messages=history,
options=options,
keep_alive=60
)
return response["message"]["content"].strip(), response.get("eval_count", 0) + response.get("prompt_eval_count", 0)
except Exception as e:
return "**ERROR**: " + str(e), 0
```
## Benefits
1. **Improved Performance**: Uses appropriate context windows based on
conversation length
2. **Better Resource Utilization**: Context window size scales with
content
3. **Maintained Compatibility**: Works with existing business logic
4. **Predictable Scaling**: Context growth in 8192-token increments
5. **Smart Updates**: Context size updates are optimized to reduce
unnecessary model reloads
## Future Considerations
1. Fine-tune buffer ratio based on usage patterns
2. Add monitoring for context window utilization
3. Consider language-specific token counting optimizations
4. Implement adaptive threshold based on conversation patterns
5. Add metrics for context size update frequency
---------
Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
2025-03-28 12:38:27 +08:00
response = self . client . chat ( model = self . model_name , messages = history , options = options , keep_alive = 10 )
2024-04-08 19:20:57 +08:00
ans = response [ " message " ] [ " content " ] . strip ( )
Dynamic Context Window Size for Ollama Chat (#6582)
# Dynamic Context Window Size for Ollama Chat
## Problem Statement
Previously, the Ollama chat implementation used a fixed context window
size of 32768 tokens. This caused two main issues:
1. Performance degradation due to unnecessarily large context windows
for small conversations
2. Potential business logic failures when using smaller fixed sizes
(e.g., 2048 tokens)
## Solution
Implemented a dynamic context window size calculation that:
1. Uses a base context size of 8192 tokens
2. Applies a 1.2x buffer ratio to the total token count
3. Adds multiples of 8192 tokens based on the buffered token count
4. Implements a smart context size update strategy
## Implementation Details
### Token Counting Logic
```python
def count_tokens(text):
"""Calculate token count for text"""
# Simple calculation: 1 token per ASCII character
# 2 tokens for non-ASCII characters (Chinese, Japanese, Korean, etc.)
total = 0
for char in text:
if ord(char) < 128: # ASCII characters
total += 1
else: # Non-ASCII characters
total += 2
return total
```
### Dynamic Context Calculation
```python
def _calculate_dynamic_ctx(self, history):
"""Calculate dynamic context window size"""
# Calculate total tokens for all messages
total_tokens = 0
for message in history:
content = message.get("content", "")
content_tokens = count_tokens(content)
role_tokens = 4 # Role marker token overhead
total_tokens += content_tokens + role_tokens
# Apply 1.2x buffer ratio
total_tokens_with_buffer = int(total_tokens * 1.2)
# Calculate context size in multiples of 8192
if total_tokens_with_buffer <= 8192:
ctx_size = 8192
else:
ctx_multiplier = (total_tokens_with_buffer // 8192) + 1
ctx_size = ctx_multiplier * 8192
return ctx_size
```
### Integration in Chat Method
```python
def chat(self, system, history, gen_conf):
if system:
history.insert(0, {"role": "system", "content": system})
if "max_tokens" in gen_conf:
del gen_conf["max_tokens"]
try:
# Calculate new context size
new_ctx_size = self._calculate_dynamic_ctx(history)
# Prepare options with context size
options = {
"num_ctx": new_ctx_size
}
# Add other generation options
if "temperature" in gen_conf:
options["temperature"] = gen_conf["temperature"]
if "max_tokens" in gen_conf:
options["num_predict"] = gen_conf["max_tokens"]
if "top_p" in gen_conf:
options["top_p"] = gen_conf["top_p"]
if "presence_penalty" in gen_conf:
options["presence_penalty"] = gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf:
options["frequency_penalty"] = gen_conf["frequency_penalty"]
# Make API call with dynamic context size
response = self.client.chat(
model=self.model_name,
messages=history,
options=options,
keep_alive=60
)
return response["message"]["content"].strip(), response.get("eval_count", 0) + response.get("prompt_eval_count", 0)
except Exception as e:
return "**ERROR**: " + str(e), 0
```
## Benefits
1. **Improved Performance**: Uses appropriate context windows based on
conversation length
2. **Better Resource Utilization**: Context window size scales with
content
3. **Maintained Compatibility**: Works with existing business logic
4. **Predictable Scaling**: Context growth in 8192-token increments
5. **Smart Updates**: Context size updates are optimized to reduce
unnecessary model reloads
## Future Considerations
1. Fine-tune buffer ratio based on usage patterns
2. Add monitoring for context window utilization
3. Consider language-specific token counting optimizations
4. Implement adaptive threshold based on conversation patterns
5. Add metrics for context size update frequency
---------
Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
2025-03-28 12:38:27 +08:00
token_count = response . get ( " eval_count " , 0 ) + response . get ( " prompt_eval_count " , 0 )
return ans , token_count
2024-04-08 19:20:57 +08:00
except Exception as e :
return " **ERROR**: " + str ( e ) , 0
2024-05-16 20:14:53 +08:00
def chat_streamly ( self , system , history , gen_conf ) :
if system :
history . insert ( 0 , { " role " : " system " , " content " : system } )
2025-03-06 11:29:40 +08:00
if " max_tokens " in gen_conf :
del gen_conf [ " max_tokens " ]
2024-05-16 20:14:53 +08:00
try :
Dynamic Context Window Size for Ollama Chat (#6582)
# Dynamic Context Window Size for Ollama Chat
## Problem Statement
Previously, the Ollama chat implementation used a fixed context window
size of 32768 tokens. This caused two main issues:
1. Performance degradation due to unnecessarily large context windows
for small conversations
2. Potential business logic failures when using smaller fixed sizes
(e.g., 2048 tokens)
## Solution
Implemented a dynamic context window size calculation that:
1. Uses a base context size of 8192 tokens
2. Applies a 1.2x buffer ratio to the total token count
3. Adds multiples of 8192 tokens based on the buffered token count
4. Implements a smart context size update strategy
## Implementation Details
### Token Counting Logic
```python
def count_tokens(text):
"""Calculate token count for text"""
# Simple calculation: 1 token per ASCII character
# 2 tokens for non-ASCII characters (Chinese, Japanese, Korean, etc.)
total = 0
for char in text:
if ord(char) < 128: # ASCII characters
total += 1
else: # Non-ASCII characters
total += 2
return total
```
### Dynamic Context Calculation
```python
def _calculate_dynamic_ctx(self, history):
"""Calculate dynamic context window size"""
# Calculate total tokens for all messages
total_tokens = 0
for message in history:
content = message.get("content", "")
content_tokens = count_tokens(content)
role_tokens = 4 # Role marker token overhead
total_tokens += content_tokens + role_tokens
# Apply 1.2x buffer ratio
total_tokens_with_buffer = int(total_tokens * 1.2)
# Calculate context size in multiples of 8192
if total_tokens_with_buffer <= 8192:
ctx_size = 8192
else:
ctx_multiplier = (total_tokens_with_buffer // 8192) + 1
ctx_size = ctx_multiplier * 8192
return ctx_size
```
### Integration in Chat Method
```python
def chat(self, system, history, gen_conf):
if system:
history.insert(0, {"role": "system", "content": system})
if "max_tokens" in gen_conf:
del gen_conf["max_tokens"]
try:
# Calculate new context size
new_ctx_size = self._calculate_dynamic_ctx(history)
# Prepare options with context size
options = {
"num_ctx": new_ctx_size
}
# Add other generation options
if "temperature" in gen_conf:
options["temperature"] = gen_conf["temperature"]
if "max_tokens" in gen_conf:
options["num_predict"] = gen_conf["max_tokens"]
if "top_p" in gen_conf:
options["top_p"] = gen_conf["top_p"]
if "presence_penalty" in gen_conf:
options["presence_penalty"] = gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf:
options["frequency_penalty"] = gen_conf["frequency_penalty"]
# Make API call with dynamic context size
response = self.client.chat(
model=self.model_name,
messages=history,
options=options,
keep_alive=60
)
return response["message"]["content"].strip(), response.get("eval_count", 0) + response.get("prompt_eval_count", 0)
except Exception as e:
return "**ERROR**: " + str(e), 0
```
## Benefits
1. **Improved Performance**: Uses appropriate context windows based on
conversation length
2. **Better Resource Utilization**: Context window size scales with
content
3. **Maintained Compatibility**: Works with existing business logic
4. **Predictable Scaling**: Context growth in 8192-token increments
5. **Smart Updates**: Context size updates are optimized to reduce
unnecessary model reloads
## Future Considerations
1. Fine-tune buffer ratio based on usage patterns
2. Add monitoring for context window utilization
3. Consider language-specific token counting optimizations
4. Implement adaptive threshold based on conversation patterns
5. Add metrics for context size update frequency
---------
Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
2025-03-28 12:38:27 +08:00
# Calculate context size
ctx_size = self . _calculate_dynamic_ctx ( history )
options = {
" num_ctx " : ctx_size
}
if " temperature " in gen_conf :
options [ " temperature " ] = gen_conf [ " temperature " ]
if " max_tokens " in gen_conf :
options [ " num_predict " ] = gen_conf [ " max_tokens " ]
if " top_p " in gen_conf :
options [ " top_p " ] = gen_conf [ " top_p " ]
if " presence_penalty " in gen_conf :
options [ " presence_penalty " ] = gen_conf [ " presence_penalty " ]
if " frequency_penalty " in gen_conf :
options [ " frequency_penalty " ] = gen_conf [ " frequency_penalty " ]
ans = " "
try :
response = self . client . chat ( model = self . model_name , messages = history , stream = True , options = options , keep_alive = 10 )
for resp in response :
if resp [ " done " ] :
token_count = resp . get ( " prompt_eval_count " , 0 ) + resp . get ( " eval_count " , 0 )
yield token_count
ans = resp [ " message " ] [ " content " ]
yield ans
except Exception as e :
yield ans + " \n **ERROR**: " + str ( e )
yield 0
2024-05-16 20:14:53 +08:00
except Exception as e :
Dynamic Context Window Size for Ollama Chat (#6582)
# Dynamic Context Window Size for Ollama Chat
## Problem Statement
Previously, the Ollama chat implementation used a fixed context window
size of 32768 tokens. This caused two main issues:
1. Performance degradation due to unnecessarily large context windows
for small conversations
2. Potential business logic failures when using smaller fixed sizes
(e.g., 2048 tokens)
## Solution
Implemented a dynamic context window size calculation that:
1. Uses a base context size of 8192 tokens
2. Applies a 1.2x buffer ratio to the total token count
3. Adds multiples of 8192 tokens based on the buffered token count
4. Implements a smart context size update strategy
## Implementation Details
### Token Counting Logic
```python
def count_tokens(text):
"""Calculate token count for text"""
# Simple calculation: 1 token per ASCII character
# 2 tokens for non-ASCII characters (Chinese, Japanese, Korean, etc.)
total = 0
for char in text:
if ord(char) < 128: # ASCII characters
total += 1
else: # Non-ASCII characters
total += 2
return total
```
### Dynamic Context Calculation
```python
def _calculate_dynamic_ctx(self, history):
"""Calculate dynamic context window size"""
# Calculate total tokens for all messages
total_tokens = 0
for message in history:
content = message.get("content", "")
content_tokens = count_tokens(content)
role_tokens = 4 # Role marker token overhead
total_tokens += content_tokens + role_tokens
# Apply 1.2x buffer ratio
total_tokens_with_buffer = int(total_tokens * 1.2)
# Calculate context size in multiples of 8192
if total_tokens_with_buffer <= 8192:
ctx_size = 8192
else:
ctx_multiplier = (total_tokens_with_buffer // 8192) + 1
ctx_size = ctx_multiplier * 8192
return ctx_size
```
### Integration in Chat Method
```python
def chat(self, system, history, gen_conf):
if system:
history.insert(0, {"role": "system", "content": system})
if "max_tokens" in gen_conf:
del gen_conf["max_tokens"]
try:
# Calculate new context size
new_ctx_size = self._calculate_dynamic_ctx(history)
# Prepare options with context size
options = {
"num_ctx": new_ctx_size
}
# Add other generation options
if "temperature" in gen_conf:
options["temperature"] = gen_conf["temperature"]
if "max_tokens" in gen_conf:
options["num_predict"] = gen_conf["max_tokens"]
if "top_p" in gen_conf:
options["top_p"] = gen_conf["top_p"]
if "presence_penalty" in gen_conf:
options["presence_penalty"] = gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf:
options["frequency_penalty"] = gen_conf["frequency_penalty"]
# Make API call with dynamic context size
response = self.client.chat(
model=self.model_name,
messages=history,
options=options,
keep_alive=60
)
return response["message"]["content"].strip(), response.get("eval_count", 0) + response.get("prompt_eval_count", 0)
except Exception as e:
return "**ERROR**: " + str(e), 0
```
## Benefits
1. **Improved Performance**: Uses appropriate context windows based on
conversation length
2. **Better Resource Utilization**: Context window size scales with
content
3. **Maintained Compatibility**: Works with existing business logic
4. **Predictable Scaling**: Context growth in 8192-token increments
5. **Smart Updates**: Context size updates are optimized to reduce
unnecessary model reloads
## Future Considerations
1. Fine-tune buffer ratio based on usage patterns
2. Add monitoring for context window utilization
3. Consider language-specific token counting optimizations
4. Implement adaptive threshold based on conversation patterns
5. Add metrics for context size update frequency
---------
Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
2025-03-28 12:38:27 +08:00
yield " **ERROR**: " + str ( e )
yield 0
2024-05-20 12:23:51 +08:00
2024-07-19 15:50:28 +08:00
class LocalAIChat ( Base ) :
def __init__ ( self , key , model_name , base_url ) :
2024-07-25 10:23:35 +08:00
if not base_url :
raise ValueError ( " Local llm url cannot be None " )
if base_url . split ( " / " ) [ - 1 ] != " v1 " :
2024-08-07 18:10:42 +08:00
base_url = os . path . join ( base_url , " v1 " )
self . client = OpenAI ( api_key = " empty " , base_url = base_url )
2024-07-19 15:50:28 +08:00
self . model_name = model_name . split ( " ___ " ) [ 0 ]
2024-05-20 12:23:51 +08:00
class LocalLLM ( Base ) :
class RPCProxy :
def __init__ ( self , host , port ) :
self . host = host
self . port = int ( port )
self . __conn ( )
def __conn ( self ) :
from multiprocessing . connection import Client
2024-07-30 14:07:00 +08:00
2025-03-26 19:33:14 +08:00
self . _connection = Client ( ( self . host , self . port ) , authkey = b " infiniflow-token4kevinhu " )
2024-05-20 12:23:51 +08:00
def __getattr__ ( self , name ) :
import pickle
def do_rpc ( * args , * * kwargs ) :
for _ in range ( 3 ) :
try :
2024-07-30 14:07:00 +08:00
self . _connection . send ( pickle . dumps ( ( name , args , kwargs ) ) )
2024-05-20 12:23:51 +08:00
return pickle . loads ( self . _connection . recv ( ) )
2024-11-29 14:52:27 +08:00
except Exception :
2024-05-20 12:23:51 +08:00
self . __conn ( )
raise Exception ( " RPC connection lost! " )
return do_rpc
2024-07-30 14:07:00 +08:00
def __init__ ( self , key , model_name ) :
from jina import Client
2024-05-20 12:23:51 +08:00
2024-07-30 14:07:00 +08:00
self . client = Client ( port = 12345 , protocol = " grpc " , asyncio = True )
2024-05-20 12:40:59 +08:00
2024-07-30 14:07:00 +08:00
def _prepare_prompt ( self , system , history , gen_conf ) :
2024-11-29 14:52:27 +08:00
from rag . svr . jina_server import Prompt
2025-03-26 19:33:14 +08:00
2024-05-20 12:40:59 +08:00
if system :
history . insert ( 0 , { " role " : " system " , " content " : system } )
2024-07-30 14:07:00 +08:00
return Prompt ( message = history , gen_conf = gen_conf )
def _stream_response ( self , endpoint , prompt ) :
2024-11-29 14:52:27 +08:00
from rag . svr . jina_server import Generation
2025-03-26 19:33:14 +08:00
2024-05-20 12:40:59 +08:00
answer = " "
try :
2025-03-26 19:33:14 +08:00
res = self . client . stream_doc ( on = endpoint , inputs = prompt , return_type = Generation )
2024-07-30 14:07:00 +08:00
loop = asyncio . get_event_loop ( )
try :
while True :
answer = loop . run_until_complete ( res . __anext__ ( ) ) . text
yield answer
except StopAsyncIteration :
pass
2024-05-20 12:40:59 +08:00
except Exception as e :
yield answer + " \n **ERROR**: " + str ( e )
2024-07-30 14:07:00 +08:00
yield num_tokens_from_string ( answer )
2024-05-20 12:40:59 +08:00
2024-07-30 14:07:00 +08:00
def chat ( self , system , history , gen_conf ) :
2025-03-06 11:29:40 +08:00
if " max_tokens " in gen_conf :
del gen_conf [ " max_tokens " ]
2024-07-30 14:07:00 +08:00
prompt = self . _prepare_prompt ( system , history , gen_conf )
chat_gen = self . _stream_response ( " /chat " , prompt )
ans = next ( chat_gen )
total_tokens = next ( chat_gen )
return ans , total_tokens
def chat_streamly ( self , system , history , gen_conf ) :
2025-03-06 11:29:40 +08:00
if " max_tokens " in gen_conf :
del gen_conf [ " max_tokens " ]
2024-07-30 14:07:00 +08:00
prompt = self . _prepare_prompt ( system , history , gen_conf )
return self . _stream_response ( " /stream " , prompt )
2024-05-23 11:15:29 +08:00
class VolcEngineChat ( Base ) :
2025-03-26 19:33:14 +08:00
def __init__ ( self , key , model_name , base_url = " https://ark.cn-beijing.volces.com/api/v3 " ) :
2024-05-23 11:15:29 +08:00
"""
Since do not want to modify the original database fields , and the VolcEngine authentication method is quite special ,
2024-08-26 13:34:29 +08:00
Assemble ark_api_key , ep_id into api_key , store it as a dictionary type , and parse it for use
2024-05-23 11:15:29 +08:00
model_name is for display only
"""
2025-03-26 19:33:14 +08:00
base_url = base_url if base_url else " https://ark.cn-beijing.volces.com/api/v3 "
ark_api_key = json . loads ( key ) . get ( " ark_api_key " , " " )
model_name = json . loads ( key ) . get ( " ep_id " , " " ) + json . loads ( key ) . get ( " endpoint_id " , " " )
2024-08-26 13:34:29 +08:00
super ( ) . __init__ ( ark_api_key , model_name , base_url )
2024-05-31 16:38:53 +08:00
class MiniMaxChat ( Base ) :
2024-07-17 15:32:51 +08:00
def __init__ (
2025-03-26 19:33:14 +08:00
self ,
key ,
model_name ,
base_url = " https://api.minimax.chat/v1/text/chatcompletion_v2 " ,
2024-07-17 15:32:51 +08:00
) :
2024-05-31 16:38:53 +08:00
if not base_url :
2024-07-17 15:32:51 +08:00
base_url = " https://api.minimax.chat/v1/text/chatcompletion_v2 "
self . base_url = base_url
self . model_name = model_name
self . api_key = key
def chat ( self , system , history , gen_conf ) :
if system :
history . insert ( 0 , { " role " : " system " , " content " : system } )
for k in list ( gen_conf . keys ( ) ) :
if k not in [ " temperature " , " top_p " , " max_tokens " ] :
del gen_conf [ k ]
headers = {
" Authorization " : f " Bearer { self . api_key } " ,
" Content-Type " : " application/json " ,
}
2025-03-26 19:33:14 +08:00
payload = json . dumps ( { " model " : self . model_name , " messages " : history , * * gen_conf } )
2024-07-17 15:32:51 +08:00
try :
2025-03-26 19:33:14 +08:00
response = requests . request ( " POST " , url = self . base_url , headers = headers , data = payload )
2024-07-17 15:32:51 +08:00
response = response . json ( )
ans = response [ " choices " ] [ 0 ] [ " message " ] [ " content " ] . strip ( )
if response [ " choices " ] [ 0 ] [ " finish_reason " ] == " length " :
2024-12-04 09:34:49 +08:00
if is_chinese ( ans ) :
ans + = LENGTH_NOTIFICATION_CN
else :
ans + = LENGTH_NOTIFICATION_EN
2025-01-26 13:54:26 +08:00
return ans , self . total_token_count ( response )
2024-07-17 15:32:51 +08:00
except Exception as e :
return " **ERROR**: " + str ( e ) , 0
def chat_streamly ( self , system , history , gen_conf ) :
if system :
history . insert ( 0 , { " role " : " system " , " content " : system } )
2025-03-06 11:29:40 +08:00
for k in list ( gen_conf . keys ( ) ) :
if k not in [ " temperature " , " top_p " , " max_tokens " ] :
del gen_conf [ k ]
2024-07-17 15:32:51 +08:00
ans = " "
total_tokens = 0
try :
headers = {
" Authorization " : f " Bearer { self . api_key } " ,
" Content-Type " : " application/json " ,
}
payload = json . dumps (
{
" model " : self . model_name ,
" messages " : history ,
" stream " : True ,
* * gen_conf ,
}
)
response = requests . request (
" POST " ,
url = self . base_url ,
headers = headers ,
data = payload ,
)
for resp in response . text . split ( " \n \n " ) [ : - 1 ] :
resp = json . loads ( resp [ 6 : ] )
2024-07-29 19:35:16 +08:00
text = " "
if " choices " in resp and " delta " in resp [ " choices " ] [ 0 ] :
2024-07-17 15:32:51 +08:00
text = resp [ " choices " ] [ 0 ] [ " delta " ] [ " content " ]
2025-03-26 19:33:14 +08:00
ans = text
2025-01-26 13:54:26 +08:00
tol = self . total_token_count ( resp )
if not tol :
total_tokens + = num_tokens_from_string ( text )
else :
total_tokens = tol
2024-07-17 15:32:51 +08:00
yield ans
except Exception as e :
yield ans + " \n **ERROR**: " + str ( e )
yield total_tokens
2024-06-14 11:32:58 +08:00
class MistralChat ( Base ) :
def __init__ ( self , key , model_name , base_url = None ) :
from mistralai . client import MistralClient
2025-03-26 19:33:14 +08:00
2024-06-14 11:32:58 +08:00
self . client = MistralClient ( api_key = key )
self . model_name = model_name
def chat ( self , system , history , gen_conf ) :
if system :
history . insert ( 0 , { " role " : " system " , " content " : system } )
for k in list ( gen_conf . keys ( ) ) :
if k not in [ " temperature " , " top_p " , " max_tokens " ] :
del gen_conf [ k ]
try :
2025-03-26 19:33:14 +08:00
response = self . client . chat ( model = self . model_name , messages = history , * * gen_conf )
2024-06-14 11:32:58 +08:00
ans = response . choices [ 0 ] . message . content
if response . choices [ 0 ] . finish_reason == " length " :
2024-12-04 09:34:49 +08:00
if is_chinese ( ans ) :
ans + = LENGTH_NOTIFICATION_CN
else :
ans + = LENGTH_NOTIFICATION_EN
2025-01-26 13:54:26 +08:00
return ans , self . total_token_count ( response )
2024-06-14 11:32:58 +08:00
except openai . APIError as e :
return " **ERROR**: " + str ( e ) , 0
def chat_streamly ( self , system , history , gen_conf ) :
if system :
history . insert ( 0 , { " role " : " system " , " content " : system } )
for k in list ( gen_conf . keys ( ) ) :
if k not in [ " temperature " , " top_p " , " max_tokens " ] :
del gen_conf [ k ]
ans = " "
total_tokens = 0
try :
2025-03-26 19:33:14 +08:00
response = self . client . chat_stream ( model = self . model_name , messages = history , * * gen_conf )
2024-06-14 11:32:58 +08:00
for resp in response :
2024-12-08 14:21:12 +08:00
if not resp . choices or not resp . choices [ 0 ] . delta . content :
continue
2025-03-26 19:33:14 +08:00
ans = resp . choices [ 0 ] . delta . content
2024-06-14 11:32:58 +08:00
total_tokens + = 1
if resp . choices [ 0 ] . finish_reason == " length " :
2024-12-04 09:34:49 +08:00
if is_chinese ( ans ) :
ans + = LENGTH_NOTIFICATION_CN
else :
ans + = LENGTH_NOTIFICATION_EN
2024-06-14 11:32:58 +08:00
yield ans
except openai . APIError as e :
yield ans + " \n **ERROR**: " + str ( e )
yield total_tokens
2024-07-08 09:37:34 +08:00
class BedrockChat ( Base ) :
def __init__ ( self , key , model_name , * * kwargs ) :
import boto3
2025-03-26 19:33:14 +08:00
self . bedrock_ak = json . loads ( key ) . get ( " bedrock_ak " , " " )
self . bedrock_sk = json . loads ( key ) . get ( " bedrock_sk " , " " )
self . bedrock_region = json . loads ( key ) . get ( " bedrock_region " , " " )
2024-07-08 09:37:34 +08:00
self . model_name = model_name
2025-03-12 18:54:15 +08:00
2025-03-26 19:33:14 +08:00
if self . bedrock_ak == " " or self . bedrock_sk == " " or self . bedrock_region == " " :
2025-02-23 22:01:14 -05:00
# Try to create a client using the default credentials (AWS_PROFILE, AWS_DEFAULT_REGION, etc.)
2025-03-26 19:33:14 +08:00
self . client = boto3 . client ( " bedrock-runtime " )
2025-02-23 22:01:14 -05:00
else :
2025-03-26 19:33:14 +08:00
self . client = boto3 . client ( service_name = " bedrock-runtime " , region_name = self . bedrock_region , aws_access_key_id = self . bedrock_ak , aws_secret_access_key = self . bedrock_sk )
2024-07-08 09:37:34 +08:00
def chat ( self , system , history , gen_conf ) :
2024-07-08 16:20:19 +08:00
from botocore . exceptions import ClientError
2025-03-26 19:33:14 +08:00
2024-07-08 09:37:34 +08:00
for k in list ( gen_conf . keys ( ) ) :
2025-03-26 11:27:12 +08:00
if k not in [ " temperature " ] :
2024-07-08 09:37:34 +08:00
del gen_conf [ k ]
2024-08-15 14:54:49 +08:00
for item in history :
2024-10-08 18:27:04 +08:00
if not isinstance ( item [ " content " ] , list ) and not isinstance ( item [ " content " ] , tuple ) :
item [ " content " ] = [ { " text " : item [ " content " ] } ]
2024-07-08 09:37:34 +08:00
try :
# Send the message to the model, using a basic inference configuration.
response = self . client . converse (
modelId = self . model_name ,
messages = history ,
2024-08-23 06:44:37 +03:00
inferenceConfig = gen_conf ,
2024-10-08 18:27:04 +08:00
system = [ { " text " : ( system if system else " Answer the user ' s message. " ) } ] ,
2024-07-08 09:37:34 +08:00
)
2024-10-08 18:27:04 +08:00
2024-07-08 09:37:34 +08:00
# Extract and print the response text.
ans = response [ " output " ] [ " message " ] [ " content " ] [ 0 ] [ " text " ]
return ans , num_tokens_from_string ( ans )
except ( ClientError , Exception ) as e :
return f " ERROR: Can ' t invoke ' { self . model_name } ' . Reason: { e } " , 0
def chat_streamly ( self , system , history , gen_conf ) :
2024-07-08 16:20:19 +08:00
from botocore . exceptions import ClientError
2025-03-26 19:33:14 +08:00
2024-07-08 09:37:34 +08:00
for k in list ( gen_conf . keys ( ) ) :
2025-03-26 11:27:12 +08:00
if k not in [ " temperature " ] :
2024-07-08 09:37:34 +08:00
del gen_conf [ k ]
2024-08-15 14:54:49 +08:00
for item in history :
2024-10-08 18:27:04 +08:00
if not isinstance ( item [ " content " ] , list ) and not isinstance ( item [ " content " ] , tuple ) :
item [ " content " ] = [ { " text " : item [ " content " ] } ]
2025-03-26 19:33:14 +08:00
if self . model_name . split ( " . " ) [ 0 ] == " ai21 " :
2024-07-08 09:37:34 +08:00
try :
2025-03-26 19:33:14 +08:00
response = self . client . converse ( modelId = self . model_name , messages = history , inferenceConfig = gen_conf , system = [ { " text " : ( system if system else " Answer the user ' s message. " ) } ] )
2024-07-08 09:37:34 +08:00
ans = response [ " output " ] [ " message " ] [ " content " ] [ 0 ] [ " text " ]
return ans , num_tokens_from_string ( ans )
except ( ClientError , Exception ) as e :
return f " ERROR: Can ' t invoke ' { self . model_name } ' . Reason: { e } " , 0
ans = " "
try :
# Send the message to the model, using a basic inference configuration.
streaming_response = self . client . converse_stream (
2025-03-26 19:33:14 +08:00
modelId = self . model_name , messages = history , inferenceConfig = gen_conf , system = [ { " text " : ( system if system else " Answer the user ' s message. " ) } ]
2024-07-08 09:37:34 +08:00
)
# Extract and print the streamed response text in real-time.
for resp in streaming_response [ " stream " ] :
if " contentBlockDelta " in resp :
2025-03-26 19:33:14 +08:00
ans = resp [ " contentBlockDelta " ] [ " delta " ] [ " text " ]
2024-07-08 09:37:34 +08:00
yield ans
2024-10-08 18:27:04 +08:00
2024-07-08 09:37:34 +08:00
except ( ClientError , Exception ) as e :
yield ans + f " ERROR: Can ' t invoke ' { self . model_name } ' . Reason: { e } "
yield num_tokens_from_string ( ans )
2024-07-11 15:41:00 +08:00
2024-09-20 15:33:38 +08:00
2024-07-11 15:41:00 +08:00
class GeminiChat ( Base ) :
2024-10-08 18:27:04 +08:00
def __init__ ( self , key , model_name , base_url = None ) :
2025-03-26 19:33:14 +08:00
from google . generativeai import GenerativeModel , client
2024-10-08 18:27:04 +08:00
2024-07-11 15:41:00 +08:00
client . configure ( api_key = key )
_client = client . get_default_generative_client ( )
2025-03-26 19:33:14 +08:00
self . model_name = " models/ " + model_name
2024-07-11 15:41:00 +08:00
self . model = GenerativeModel ( model_name = self . model_name )
self . model . _client = _client
2024-10-08 18:27:04 +08:00
def chat ( self , system , history , gen_conf ) :
2024-09-02 12:06:41 +08:00
from google . generativeai . types import content_types
2024-10-08 18:27:04 +08:00
2024-07-11 15:41:00 +08:00
if system :
2024-09-02 12:06:41 +08:00
self . model . _system_instruction = content_types . to_content ( system )
2024-07-11 15:41:00 +08:00
for k in list ( gen_conf . keys ( ) ) :
2025-03-06 11:29:40 +08:00
if k not in [ " temperature " , " top_p " , " max_tokens " ] :
2024-07-11 15:41:00 +08:00
del gen_conf [ k ]
for item in history :
2025-03-26 19:33:14 +08:00
if " role " in item and item [ " role " ] == " assistant " :
item [ " role " ] = " model "
if " role " in item and item [ " role " ] == " system " :
item [ " role " ] = " user "
if " content " in item :
item [ " parts " ] = item . pop ( " content " )
2024-10-08 18:27:04 +08:00
2024-07-11 15:41:00 +08:00
try :
2025-03-26 19:33:14 +08:00
response = self . model . generate_content ( history , generation_config = gen_conf )
2024-07-11 15:41:00 +08:00
ans = response . text
return ans , response . usage_metadata . total_token_count
except Exception as e :
return " **ERROR**: " + str ( e ) , 0
def chat_streamly ( self , system , history , gen_conf ) :
2024-09-02 12:06:41 +08:00
from google . generativeai . types import content_types
2024-10-08 18:27:04 +08:00
2024-07-11 15:41:00 +08:00
if system :
2024-09-02 12:06:41 +08:00
self . model . _system_instruction = content_types . to_content ( system )
2024-07-11 15:41:00 +08:00
for k in list ( gen_conf . keys ( ) ) :
2025-03-06 11:29:40 +08:00
if k not in [ " temperature " , " top_p " , " max_tokens " ] :
2024-07-11 15:41:00 +08:00
del gen_conf [ k ]
for item in history :
2025-03-26 19:33:14 +08:00
if " role " in item and item [ " role " ] == " assistant " :
item [ " role " ] = " model "
if " content " in item :
item [ " parts " ] = item . pop ( " content " )
2024-07-11 15:41:00 +08:00
ans = " "
try :
2025-03-26 19:33:14 +08:00
response = self . model . generate_content ( history , generation_config = gen_conf , stream = True )
2024-07-11 15:41:00 +08:00
for resp in response :
2025-03-26 19:33:14 +08:00
ans = resp . text
2024-07-11 15:41:00 +08:00
yield ans
2024-10-25 10:50:44 +08:00
yield response . _chunks [ - 1 ] . usage_metadata . total_token_count
2024-07-11 15:41:00 +08:00
except Exception as e :
yield ans + " \n **ERROR**: " + str ( e )
2024-10-25 10:50:44 +08:00
yield 0
added SVG for Groq model model providers (#1470)
#1432 #1447
This PR adds support for the GROQ LLM (Large Language Model).
Groq is an AI solutions company delivering ultra-low latency inference
with the first-ever LPU™ Inference Engine. The Groq API enables
developers to integrate state-of-the-art LLMs, such as Llama-2 and
llama3-70b-8192, into low latency applications with the request limits
specified below. Learn more at [groq.com](https://groq.com/).
Supported Models
| ID | Requests per Minute | Requests per Day | Tokens per Minute |
|----------------------|---------------------|------------------|-------------------|
| gemma-7b-it | 30 | 14,400 | 15,000 |
| gemma2-9b-it | 30 | 14,400 | 15,000 |
| llama3-70b-8192 | 30 | 14,400 | 6,000 |
| llama3-8b-8192 | 30 | 14,400 | 30,000 |
| mixtral-8x7b-32768 | 30 | 14,400 | 5,000 |
---------
Co-authored-by: paresh0628 <paresh.tuvoc@gmail.com>
Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
2024-07-12 06:55:44 +05:30
2025-01-26 13:54:26 +08:00
class GroqChat ( Base ) :
2025-03-26 19:33:14 +08:00
def __init__ ( self , key , model_name , base_url = " " ) :
2025-01-02 13:44:44 +08:00
from groq import Groq
2025-03-26 19:33:14 +08:00
added SVG for Groq model model providers (#1470)
#1432 #1447
This PR adds support for the GROQ LLM (Large Language Model).
Groq is an AI solutions company delivering ultra-low latency inference
with the first-ever LPU™ Inference Engine. The Groq API enables
developers to integrate state-of-the-art LLMs, such as Llama-2 and
llama3-70b-8192, into low latency applications with the request limits
specified below. Learn more at [groq.com](https://groq.com/).
Supported Models
| ID | Requests per Minute | Requests per Day | Tokens per Minute |
|----------------------|---------------------|------------------|-------------------|
| gemma-7b-it | 30 | 14,400 | 15,000 |
| gemma2-9b-it | 30 | 14,400 | 15,000 |
| llama3-70b-8192 | 30 | 14,400 | 6,000 |
| llama3-8b-8192 | 30 | 14,400 | 30,000 |
| mixtral-8x7b-32768 | 30 | 14,400 | 5,000 |
---------
Co-authored-by: paresh0628 <paresh.tuvoc@gmail.com>
Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
2024-07-12 06:55:44 +05:30
self . client = Groq ( api_key = key )
self . model_name = model_name
def chat ( self , system , history , gen_conf ) :
if system :
history . insert ( 0 , { " role " : " system " , " content " : system } )
for k in list ( gen_conf . keys ( ) ) :
if k not in [ " temperature " , " top_p " , " max_tokens " ] :
del gen_conf [ k ]
ans = " "
try :
2025-03-26 19:33:14 +08:00
response = self . client . chat . completions . create ( model = self . model_name , messages = history , * * gen_conf )
added SVG for Groq model model providers (#1470)
#1432 #1447
This PR adds support for the GROQ LLM (Large Language Model).
Groq is an AI solutions company delivering ultra-low latency inference
with the first-ever LPU™ Inference Engine. The Groq API enables
developers to integrate state-of-the-art LLMs, such as Llama-2 and
llama3-70b-8192, into low latency applications with the request limits
specified below. Learn more at [groq.com](https://groq.com/).
Supported Models
| ID | Requests per Minute | Requests per Day | Tokens per Minute |
|----------------------|---------------------|------------------|-------------------|
| gemma-7b-it | 30 | 14,400 | 15,000 |
| gemma2-9b-it | 30 | 14,400 | 15,000 |
| llama3-70b-8192 | 30 | 14,400 | 6,000 |
| llama3-8b-8192 | 30 | 14,400 | 30,000 |
| mixtral-8x7b-32768 | 30 | 14,400 | 5,000 |
---------
Co-authored-by: paresh0628 <paresh.tuvoc@gmail.com>
Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
2024-07-12 06:55:44 +05:30
ans = response . choices [ 0 ] . message . content
if response . choices [ 0 ] . finish_reason == " length " :
2024-12-04 09:34:49 +08:00
if is_chinese ( ans ) :
ans + = LENGTH_NOTIFICATION_CN
else :
ans + = LENGTH_NOTIFICATION_EN
2025-01-26 13:54:26 +08:00
return ans , self . total_token_count ( response )
added SVG for Groq model model providers (#1470)
#1432 #1447
This PR adds support for the GROQ LLM (Large Language Model).
Groq is an AI solutions company delivering ultra-low latency inference
with the first-ever LPU™ Inference Engine. The Groq API enables
developers to integrate state-of-the-art LLMs, such as Llama-2 and
llama3-70b-8192, into low latency applications with the request limits
specified below. Learn more at [groq.com](https://groq.com/).
Supported Models
| ID | Requests per Minute | Requests per Day | Tokens per Minute |
|----------------------|---------------------|------------------|-------------------|
| gemma-7b-it | 30 | 14,400 | 15,000 |
| gemma2-9b-it | 30 | 14,400 | 15,000 |
| llama3-70b-8192 | 30 | 14,400 | 6,000 |
| llama3-8b-8192 | 30 | 14,400 | 30,000 |
| mixtral-8x7b-32768 | 30 | 14,400 | 5,000 |
---------
Co-authored-by: paresh0628 <paresh.tuvoc@gmail.com>
Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
2024-07-12 06:55:44 +05:30
except Exception as e :
return ans + " \n **ERROR**: " + str ( e ) , 0
def chat_streamly ( self , system , history , gen_conf ) :
if system :
history . insert ( 0 , { " role " : " system " , " content " : system } )
for k in list ( gen_conf . keys ( ) ) :
if k not in [ " temperature " , " top_p " , " max_tokens " ] :
del gen_conf [ k ]
ans = " "
total_tokens = 0
try :
2025-03-26 19:33:14 +08:00
response = self . client . chat . completions . create ( model = self . model_name , messages = history , stream = True , * * gen_conf )
added SVG for Groq model model providers (#1470)
#1432 #1447
This PR adds support for the GROQ LLM (Large Language Model).
Groq is an AI solutions company delivering ultra-low latency inference
with the first-ever LPU™ Inference Engine. The Groq API enables
developers to integrate state-of-the-art LLMs, such as Llama-2 and
llama3-70b-8192, into low latency applications with the request limits
specified below. Learn more at [groq.com](https://groq.com/).
Supported Models
| ID | Requests per Minute | Requests per Day | Tokens per Minute |
|----------------------|---------------------|------------------|-------------------|
| gemma-7b-it | 30 | 14,400 | 15,000 |
| gemma2-9b-it | 30 | 14,400 | 15,000 |
| llama3-70b-8192 | 30 | 14,400 | 6,000 |
| llama3-8b-8192 | 30 | 14,400 | 30,000 |
| mixtral-8x7b-32768 | 30 | 14,400 | 5,000 |
---------
Co-authored-by: paresh0628 <paresh.tuvoc@gmail.com>
Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
2024-07-12 06:55:44 +05:30
for resp in response :
if not resp . choices or not resp . choices [ 0 ] . delta . content :
continue
2025-03-26 19:33:14 +08:00
ans = resp . choices [ 0 ] . delta . content
added SVG for Groq model model providers (#1470)
#1432 #1447
This PR adds support for the GROQ LLM (Large Language Model).
Groq is an AI solutions company delivering ultra-low latency inference
with the first-ever LPU™ Inference Engine. The Groq API enables
developers to integrate state-of-the-art LLMs, such as Llama-2 and
llama3-70b-8192, into low latency applications with the request limits
specified below. Learn more at [groq.com](https://groq.com/).
Supported Models
| ID | Requests per Minute | Requests per Day | Tokens per Minute |
|----------------------|---------------------|------------------|-------------------|
| gemma-7b-it | 30 | 14,400 | 15,000 |
| gemma2-9b-it | 30 | 14,400 | 15,000 |
| llama3-70b-8192 | 30 | 14,400 | 6,000 |
| llama3-8b-8192 | 30 | 14,400 | 30,000 |
| mixtral-8x7b-32768 | 30 | 14,400 | 5,000 |
---------
Co-authored-by: paresh0628 <paresh.tuvoc@gmail.com>
Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
2024-07-12 06:55:44 +05:30
total_tokens + = 1
if resp . choices [ 0 ] . finish_reason == " length " :
2024-12-04 09:34:49 +08:00
if is_chinese ( ans ) :
ans + = LENGTH_NOTIFICATION_CN
else :
ans + = LENGTH_NOTIFICATION_EN
added SVG for Groq model model providers (#1470)
#1432 #1447
This PR adds support for the GROQ LLM (Large Language Model).
Groq is an AI solutions company delivering ultra-low latency inference
with the first-ever LPU™ Inference Engine. The Groq API enables
developers to integrate state-of-the-art LLMs, such as Llama-2 and
llama3-70b-8192, into low latency applications with the request limits
specified below. Learn more at [groq.com](https://groq.com/).
Supported Models
| ID | Requests per Minute | Requests per Day | Tokens per Minute |
|----------------------|---------------------|------------------|-------------------|
| gemma-7b-it | 30 | 14,400 | 15,000 |
| gemma2-9b-it | 30 | 14,400 | 15,000 |
| llama3-70b-8192 | 30 | 14,400 | 6,000 |
| llama3-8b-8192 | 30 | 14,400 | 30,000 |
| mixtral-8x7b-32768 | 30 | 14,400 | 5,000 |
---------
Co-authored-by: paresh0628 <paresh.tuvoc@gmail.com>
Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
2024-07-12 06:55:44 +05:30
yield ans
except Exception as e :
yield ans + " \n **ERROR**: " + str ( e )
2024-07-16 15:19:43 +08:00
yield total_tokens
## openrouter
class OpenRouterChat ( Base ) :
def __init__ ( self , key , model_name , base_url = " https://openrouter.ai/api/v1 " ) :
2024-07-25 10:23:35 +08:00
if not base_url :
base_url = " https://openrouter.ai/api/v1 "
super ( ) . __init__ ( key , model_name , base_url )
2024-07-19 16:26:12 +08:00
class StepFunChat ( Base ) :
2024-07-24 10:49:37 +08:00
def __init__ ( self , key , model_name , base_url = " https://api.stepfun.com/v1 " ) :
2024-07-19 16:26:12 +08:00
if not base_url :
2024-07-24 10:49:37 +08:00
base_url = " https://api.stepfun.com/v1 "
2024-07-23 10:43:09 +08:00
super ( ) . __init__ ( key , model_name , base_url )
class NvidiaChat ( Base ) :
2024-07-25 10:23:35 +08:00
def __init__ ( self , key , model_name , base_url = " https://integrate.api.nvidia.com/v1 " ) :
2024-07-23 10:43:09 +08:00
if not base_url :
2024-07-25 10:23:35 +08:00
base_url = " https://integrate.api.nvidia.com/v1 "
super ( ) . __init__ ( key , model_name , base_url )
2024-07-24 12:46:43 +08:00
class LmStudioChat ( Base ) :
def __init__ ( self , key , model_name , base_url ) :
if not base_url :
raise ValueError ( " Local llm url cannot be None " )
if base_url . split ( " / " ) [ - 1 ] != " v1 " :
2024-08-06 16:20:21 +08:00
base_url = os . path . join ( base_url , " v1 " )
2025-03-27 14:59:15 +08:00
super ( ) . __init__ ( key , model_name , base_url )
2024-08-06 16:20:21 +08:00
self . client = OpenAI ( api_key = " lm-studio " , base_url = base_url )
2024-07-24 12:46:43 +08:00
self . model_name = model_name
2024-08-06 16:20:21 +08:00
class OpenAI_APIChat ( Base ) :
def __init__ ( self , key , model_name , base_url ) :
if not base_url :
raise ValueError ( " url cannot be None " )
model_name = model_name . split ( " ___ " ) [ 0 ]
super ( ) . __init__ ( key , model_name , base_url )
2024-08-07 18:40:51 +08:00
2025-02-17 12:03:26 +08:00
class PPIOChat ( Base ) :
def __init__ ( self , key , model_name , base_url = " https://api.ppinfra.com/v3/openai " ) :
if not base_url :
base_url = " https://api.ppinfra.com/v3/openai "
super ( ) . __init__ ( key , model_name , base_url )
2024-08-07 18:40:51 +08:00
class CoHereChat ( Base ) :
def __init__ ( self , key , model_name , base_url = " " ) :
from cohere import Client
self . client = Client ( api_key = key )
self . model_name = model_name
def chat ( self , system , history , gen_conf ) :
if system :
history . insert ( 0 , { " role " : " system " , " content " : system } )
2025-03-06 11:29:40 +08:00
if " max_tokens " in gen_conf :
del gen_conf [ " max_tokens " ]
2024-08-07 18:40:51 +08:00
if " top_p " in gen_conf :
gen_conf [ " p " ] = gen_conf . pop ( " top_p " )
if " frequency_penalty " in gen_conf and " presence_penalty " in gen_conf :
gen_conf . pop ( " presence_penalty " )
for item in history :
if " role " in item and item [ " role " ] == " user " :
item [ " role " ] = " USER "
if " role " in item and item [ " role " ] == " assistant " :
item [ " role " ] = " CHATBOT "
if " content " in item :
item [ " message " ] = item . pop ( " content " )
mes = history . pop ( ) [ " message " ]
ans = " "
try :
2025-03-26 19:33:14 +08:00
response = self . client . chat ( model = self . model_name , chat_history = history , message = mes , * * gen_conf )
2024-08-07 18:40:51 +08:00
ans = response . text
if response . finish_reason == " MAX_TOKENS " :
2025-03-26 19:33:14 +08:00
ans + = " ... \n For the content length reason, it stopped, continue? " if is_english ( [ ans ] ) else " ······ \n 由于长度的原因,回答被截断了,要继续吗? "
2024-08-07 18:40:51 +08:00
return (
ans ,
response . meta . tokens . input_tokens + response . meta . tokens . output_tokens ,
)
except Exception as e :
return ans + " \n **ERROR**: " + str ( e ) , 0
def chat_streamly ( self , system , history , gen_conf ) :
if system :
history . insert ( 0 , { " role " : " system " , " content " : system } )
2025-03-06 11:29:40 +08:00
if " max_tokens " in gen_conf :
del gen_conf [ " max_tokens " ]
2024-08-07 18:40:51 +08:00
if " top_p " in gen_conf :
gen_conf [ " p " ] = gen_conf . pop ( " top_p " )
if " frequency_penalty " in gen_conf and " presence_penalty " in gen_conf :
gen_conf . pop ( " presence_penalty " )
for item in history :
if " role " in item and item [ " role " ] == " user " :
item [ " role " ] = " USER "
if " role " in item and item [ " role " ] == " assistant " :
item [ " role " ] = " CHATBOT "
if " content " in item :
item [ " message " ] = item . pop ( " content " )
mes = history . pop ( ) [ " message " ]
ans = " "
total_tokens = 0
try :
2025-03-26 19:33:14 +08:00
response = self . client . chat_stream ( model = self . model_name , chat_history = history , message = mes , * * gen_conf )
2024-08-07 18:40:51 +08:00
for resp in response :
if resp . event_type == " text-generation " :
2025-03-26 19:33:14 +08:00
ans = resp . text
2024-08-07 18:40:51 +08:00
total_tokens + = num_tokens_from_string ( resp . text )
elif resp . event_type == " stream-end " :
if resp . finish_reason == " MAX_TOKENS " :
2025-03-26 19:33:14 +08:00
ans + = " ... \n For the content length reason, it stopped, continue? " if is_english ( [ ans ] ) else " ······ \n 由于长度的原因,回答被截断了,要继续吗? "
2024-08-07 18:40:51 +08:00
yield ans
except Exception as e :
yield ans + " \n **ERROR**: " + str ( e )
yield total_tokens
2024-08-08 12:09:50 +08:00
class LeptonAIChat ( Base ) :
def __init__ ( self , key , model_name , base_url = None ) :
if not base_url :
2024-10-08 18:27:04 +08:00
base_url = os . path . join ( " https:// " + model_name + " .lepton.run " , " api " , " v1 " )
2024-08-12 10:11:50 +08:00
super ( ) . __init__ ( key , model_name , base_url )
2024-08-12 10:15:21 +08:00
class TogetherAIChat ( Base ) :
def __init__ ( self , key , model_name , base_url = " https://api.together.xyz/v1 " ) :
if not base_url :
base_url = " https://api.together.xyz/v1 "
super ( ) . __init__ ( key , model_name , base_url )
2024-08-19 10:36:57 +08:00
2024-08-12 10:11:50 +08:00
class PerfXCloudChat ( Base ) :
def __init__ ( self , key , model_name , base_url = " https://cloud.perfxlab.cn/v1 " ) :
if not base_url :
base_url = " https://cloud.perfxlab.cn/v1 "
super ( ) . __init__ ( key , model_name , base_url )
2024-08-12 11:06:25 +08:00
class UpstageChat ( Base ) :
def __init__ ( self , key , model_name , base_url = " https://api.upstage.ai/v1/solar " ) :
if not base_url :
base_url = " https://api.upstage.ai/v1/solar "
super ( ) . __init__ ( key , model_name , base_url )
2024-08-12 17:26:26 +08:00
class NovitaAIChat ( Base ) :
def __init__ ( self , key , model_name , base_url = " https://api.novita.ai/v3/openai " ) :
if not base_url :
base_url = " https://api.novita.ai/v3/openai "
super ( ) . __init__ ( key , model_name , base_url )
2024-08-13 16:09:10 +08:00
class SILICONFLOWChat ( Base ) :
def __init__ ( self , key , model_name , base_url = " https://api.siliconflow.cn/v1 " ) :
if not base_url :
base_url = " https://api.siliconflow.cn/v1 "
2024-08-15 10:02:36 +08:00
super ( ) . __init__ ( key , model_name , base_url )
class YiChat ( Base ) :
2024-09-27 12:55:58 +08:00
def __init__ ( self , key , model_name , base_url = " https://api.lingyiwanwu.com/v1 " ) :
2024-08-15 10:02:36 +08:00
if not base_url :
2024-09-27 12:55:58 +08:00
base_url = " https://api.lingyiwanwu.com/v1 "
2024-08-19 10:36:57 +08:00
super ( ) . __init__ ( key , model_name , base_url )
class ReplicateChat ( Base ) :
def __init__ ( self , key , model_name , base_url = None ) :
from replicate . client import Client
self . model_name = model_name
self . client = Client ( api_token = key )
self . system = " "
def chat ( self , system , history , gen_conf ) :
if " max_tokens " in gen_conf :
2025-03-06 11:29:40 +08:00
del gen_conf [ " max_tokens " ]
2024-08-19 10:36:57 +08:00
if system :
self . system = system
2025-03-26 19:33:14 +08:00
prompt = " \n " . join ( [ item [ " role " ] + " : " + item [ " content " ] for item in history [ - 5 : ] ] )
2024-08-19 10:36:57 +08:00
ans = " "
try :
response = self . client . run (
self . model_name ,
input = { " system_prompt " : self . system , " prompt " : prompt , * * gen_conf } ,
)
ans = " " . join ( response )
return ans , num_tokens_from_string ( ans )
except Exception as e :
return ans + " \n **ERROR**: " + str ( e ) , 0
def chat_streamly ( self , system , history , gen_conf ) :
if " max_tokens " in gen_conf :
2025-03-06 11:29:40 +08:00
del gen_conf [ " max_tokens " ]
2024-08-19 10:36:57 +08:00
if system :
self . system = system
2025-03-26 19:33:14 +08:00
prompt = " \n " . join ( [ item [ " role " ] + " : " + item [ " content " ] for item in history [ - 5 : ] ] )
2024-08-19 10:36:57 +08:00
ans = " "
try :
response = self . client . run (
self . model_name ,
input = { " system_prompt " : self . system , " prompt " : prompt , * * gen_conf } ,
)
for resp in response :
2025-03-26 19:33:14 +08:00
ans = resp
2024-08-19 10:36:57 +08:00
yield ans
except Exception as e :
yield ans + " \n **ERROR**: " + str ( e )
yield num_tokens_from_string ( ans )
2024-08-20 15:27:13 +08:00
class HunyuanChat ( Base ) :
def __init__ ( self , key , model_name , base_url = None ) :
from tencentcloud . common import credential
from tencentcloud . hunyuan . v20230901 import hunyuan_client
key = json . loads ( key )
sid = key . get ( " hunyuan_sid " , " " )
sk = key . get ( " hunyuan_sk " , " " )
cred = credential . Credential ( sid , sk )
self . model_name = model_name
self . client = hunyuan_client . HunyuanClient ( cred , " " )
def chat ( self , system , history , gen_conf ) :
from tencentcloud . common . exception . tencent_cloud_sdk_exception import (
TencentCloudSDKException ,
)
2025-03-26 19:33:14 +08:00
from tencentcloud . hunyuan . v20230901 import models
2024-08-20 15:27:13 +08:00
_gen_conf = { }
2024-10-08 18:27:04 +08:00
_history = [ { k . capitalize ( ) : v for k , v in item . items ( ) } for item in history ]
2024-08-20 15:27:13 +08:00
if system :
_history . insert ( 0 , { " Role " : " system " , " Content " : system } )
2025-03-06 11:29:40 +08:00
if " max_tokens " in gen_conf :
del gen_conf [ " max_tokens " ]
2024-08-20 15:27:13 +08:00
if " temperature " in gen_conf :
_gen_conf [ " Temperature " ] = gen_conf [ " temperature " ]
if " top_p " in gen_conf :
_gen_conf [ " TopP " ] = gen_conf [ " top_p " ]
req = models . ChatCompletionsRequest ( )
params = { " Model " : self . model_name , " Messages " : _history , * * _gen_conf }
req . from_json_string ( json . dumps ( params ) )
ans = " "
try :
response = self . client . ChatCompletions ( req )
ans = response . Choices [ 0 ] . Message . Content
return ans , response . Usage . TotalTokens
except TencentCloudSDKException as e :
return ans + " \n **ERROR**: " + str ( e ) , 0
def chat_streamly ( self , system , history , gen_conf ) :
from tencentcloud . common . exception . tencent_cloud_sdk_exception import (
TencentCloudSDKException ,
)
2025-03-26 19:33:14 +08:00
from tencentcloud . hunyuan . v20230901 import models
2024-08-20 16:56:42 +08:00
2024-08-20 15:27:13 +08:00
_gen_conf = { }
2024-10-08 18:27:04 +08:00
_history = [ { k . capitalize ( ) : v for k , v in item . items ( ) } for item in history ]
2024-08-20 15:27:13 +08:00
if system :
_history . insert ( 0 , { " Role " : " system " , " Content " : system } )
2025-03-06 11:29:40 +08:00
if " max_tokens " in gen_conf :
del gen_conf [ " max_tokens " ]
2024-08-20 15:27:13 +08:00
if " temperature " in gen_conf :
_gen_conf [ " Temperature " ] = gen_conf [ " temperature " ]
if " top_p " in gen_conf :
_gen_conf [ " TopP " ] = gen_conf [ " top_p " ]
req = models . ChatCompletionsRequest ( )
params = {
" Model " : self . model_name ,
" Messages " : _history ,
" Stream " : True ,
* * _gen_conf ,
}
req . from_json_string ( json . dumps ( params ) )
ans = " "
total_tokens = 0
try :
response = self . client . ChatCompletions ( req )
for resp in response :
resp = json . loads ( resp [ " data " ] )
if not resp [ " Choices " ] or not resp [ " Choices " ] [ 0 ] [ " Delta " ] [ " Content " ] :
continue
2025-03-26 19:33:14 +08:00
ans = resp [ " Choices " ] [ 0 ] [ " Delta " ] [ " Content " ]
2024-08-20 15:27:13 +08:00
total_tokens + = 1
yield ans
except TencentCloudSDKException as e :
yield ans + " \n **ERROR**: " + str ( e )
yield total_tokens
2024-08-20 16:56:42 +08:00
class SparkChat ( Base ) :
2025-03-26 19:33:14 +08:00
def __init__ ( self , key , model_name , base_url = " https://spark-api-open.xf-yun.com/v1 " ) :
2024-08-20 16:56:42 +08:00
if not base_url :
base_url = " https://spark-api-open.xf-yun.com/v1 "
model2version = {
" Spark-Max " : " generalv3.5 " ,
" Spark-Lite " : " general " ,
" Spark-Pro " : " generalv3 " ,
" Spark-Pro-128K " : " pro-128k " ,
" Spark-4.0-Ultra " : " 4.0Ultra " ,
}
2024-11-20 12:16:36 +08:00
version2model = { v : k for k , v in model2version . items ( ) }
assert model_name in model2version or model_name in version2model , f " The given model name is not supported yet. Support: { list ( model2version . keys ( ) ) } "
if model_name in model2version :
model_version = model2version [ model_name ]
2024-12-08 14:21:12 +08:00
else :
model_version = model_name
2024-08-20 16:56:42 +08:00
super ( ) . __init__ ( key , model_version , base_url )
2024-08-22 16:45:15 +08:00
class BaiduYiyanChat ( Base ) :
def __init__ ( self , key , model_name , base_url = None ) :
import qianfan
2024-08-29 13:30:06 +08:00
2024-08-22 16:45:15 +08:00
key = json . loads ( key )
2024-10-08 18:27:04 +08:00
ak = key . get ( " yiyan_ak " , " " )
sk = key . get ( " yiyan_sk " , " " )
self . client = qianfan . ChatCompletion ( ak = ak , sk = sk )
2024-08-22 16:45:15 +08:00
self . model_name = model_name . lower ( )
self . system = " "
def chat ( self , system , history , gen_conf ) :
if system :
self . system = system
2025-03-26 19:33:14 +08:00
gen_conf [ " penalty_score " ] = ( ( gen_conf . get ( " presence_penalty " , 0 ) + gen_conf . get ( " frequency_penalty " , 0 ) ) / 2 ) + 1
2024-08-22 16:45:15 +08:00
if " max_tokens " in gen_conf :
2025-03-06 11:29:40 +08:00
del gen_conf [ " max_tokens " ]
2024-08-22 16:45:15 +08:00
ans = " "
2024-08-29 13:30:06 +08:00
2024-08-22 16:45:15 +08:00
try :
2025-03-26 19:33:14 +08:00
response = self . client . do ( model = self . model_name , messages = history , system = self . system , * * gen_conf ) . body
ans = response [ " result " ]
2025-01-26 13:54:26 +08:00
return ans , self . total_token_count ( response )
2024-08-29 13:30:06 +08:00
2024-08-22 16:45:15 +08:00
except Exception as e :
return ans + " \n **ERROR**: " + str ( e ) , 0
def chat_streamly ( self , system , history , gen_conf ) :
if system :
self . system = system
2025-03-26 19:33:14 +08:00
gen_conf [ " penalty_score " ] = ( ( gen_conf . get ( " presence_penalty " , 0 ) + gen_conf . get ( " frequency_penalty " , 0 ) ) / 2 ) + 1
2024-08-22 16:45:15 +08:00
if " max_tokens " in gen_conf :
2025-03-06 11:29:40 +08:00
del gen_conf [ " max_tokens " ]
2024-08-22 16:45:15 +08:00
ans = " "
total_tokens = 0
2024-08-29 13:30:06 +08:00
2024-08-22 16:45:15 +08:00
try :
2025-03-26 19:33:14 +08:00
response = self . client . do ( model = self . model_name , messages = history , system = self . system , stream = True , * * gen_conf )
2024-08-22 16:45:15 +08:00
for resp in response :
resp = resp . body
2025-03-26 19:33:14 +08:00
ans = resp [ " result " ]
2025-01-26 13:54:26 +08:00
total_tokens = self . total_token_count ( resp )
2024-08-22 16:45:15 +08:00
yield ans
except Exception as e :
return ans + " \n **ERROR**: " + str ( e ) , 0
yield total_tokens
2024-08-29 13:30:06 +08:00
class AnthropicChat ( Base ) :
def __init__ ( self , key , model_name , base_url = None ) :
import anthropic
self . client = anthropic . Anthropic ( api_key = key )
self . model_name = model_name
self . system = " "
def chat ( self , system , history , gen_conf ) :
if system :
self . system = system
2024-12-08 14:21:12 +08:00
if " presence_penalty " in gen_conf :
del gen_conf [ " presence_penalty " ]
if " frequency_penalty " in gen_conf :
del gen_conf [ " frequency_penalty " ]
2025-03-25 10:41:55 +08:00
gen_conf [ " max_tokens " ] = 8192
2025-03-24 12:34:57 +08:00
if " haiku " in self . model_name or " opus " in self . model_name :
gen_conf [ " max_tokens " ] = 4096
2024-08-29 13:30:06 +08:00
2024-10-29 10:42:45 +08:00
ans = " "
2024-08-29 13:30:06 +08:00
try :
response = self . client . messages . create (
model = self . model_name ,
messages = history ,
system = self . system ,
stream = False ,
* * gen_conf ,
2024-11-13 16:13:52 +08:00
) . to_dict ( )
2024-08-29 13:30:06 +08:00
ans = response [ " content " ] [ 0 ] [ " text " ]
if response [ " stop_reason " ] == " max_tokens " :
2025-03-26 19:33:14 +08:00
ans + = " ... \n For the content length reason, it stopped, continue? " if is_english ( [ ans ] ) else " ······ \n 由于长度的原因,回答被截断了,要继续吗? "
2024-08-29 13:30:06 +08:00
return (
ans ,
response [ " usage " ] [ " input_tokens " ] + response [ " usage " ] [ " output_tokens " ] ,
)
except Exception as e :
return ans + " \n **ERROR**: " + str ( e ) , 0
def chat_streamly ( self , system , history , gen_conf ) :
if system :
self . system = system
2024-12-08 14:21:12 +08:00
if " presence_penalty " in gen_conf :
del gen_conf [ " presence_penalty " ]
if " frequency_penalty " in gen_conf :
del gen_conf [ " frequency_penalty " ]
2025-03-25 10:41:55 +08:00
gen_conf [ " max_tokens " ] = 8192
2025-03-24 12:34:57 +08:00
if " haiku " in self . model_name or " opus " in self . model_name :
gen_conf [ " max_tokens " ] = 4096
2024-08-29 13:30:06 +08:00
ans = " "
total_tokens = 0
2025-03-26 19:33:14 +08:00
reasoning_start = False
2024-08-29 13:30:06 +08:00
try :
response = self . client . messages . create (
model = self . model_name ,
messages = history ,
2025-03-24 12:34:57 +08:00
system = system ,
2024-08-29 13:30:06 +08:00
stream = True ,
* * gen_conf ,
)
2025-02-01 09:39:30 -05:00
for res in response :
2025-03-26 19:33:14 +08:00
if res . type == " content_block_delta " :
2025-03-24 12:34:57 +08:00
if res . delta . type == " thinking_delta " and res . delta . thinking :
2025-03-26 19:33:14 +08:00
ans = " "
if not reasoning_start :
reasoning_start = True
ans = " <think> "
2025-03-24 12:34:57 +08:00
ans + = res . delta . thinking + " </think> "
else :
2025-03-26 19:33:14 +08:00
reasoning_start = False
2025-03-24 12:34:57 +08:00
text = res . delta . text
2025-03-26 19:33:14 +08:00
ans = text
2025-03-24 12:34:57 +08:00
total_tokens + = num_tokens_from_string ( text )
2024-11-11 11:54:14 +08:00
yield ans
2024-08-29 13:30:06 +08:00
except Exception as e :
yield ans + " \n **ERROR**: " + str ( e )
yield total_tokens
2024-09-02 12:06:41 +08:00
class GoogleChat ( Base ) :
def __init__ ( self , key , model_name , base_url = None ) :
import base64
2025-03-26 19:33:14 +08:00
from google . oauth2 import service_account
2025-02-07 12:00:19 +08:00
key = json . loads ( key )
2025-03-26 19:33:14 +08:00
access_token = json . loads ( base64 . b64decode ( key . get ( " google_service_account_key " , " " ) ) )
2024-09-02 12:06:41 +08:00
project_id = key . get ( " google_project_id " , " " )
region = key . get ( " google_region " , " " )
scopes = [ " https://www.googleapis.com/auth/cloud-platform " ]
self . model_name = model_name
self . system = " "
if " claude " in self . model_name :
from anthropic import AnthropicVertex
from google . auth . transport . requests import Request
if access_token :
2025-03-26 19:33:14 +08:00
credits = service_account . Credentials . from_service_account_info ( access_token , scopes = scopes )
2024-09-02 12:06:41 +08:00
request = Request ( )
credits . refresh ( request )
token = credits . token
2025-03-26 19:33:14 +08:00
self . client = AnthropicVertex ( region = region , project_id = project_id , access_token = token )
2024-09-02 12:06:41 +08:00
else :
self . client = AnthropicVertex ( region = region , project_id = project_id )
else :
import vertexai . generative_models as glm
2025-03-26 19:33:14 +08:00
from google . cloud import aiplatform
2024-09-02 12:06:41 +08:00
if access_token :
2025-03-26 19:33:14 +08:00
credits = service_account . Credentials . from_service_account_info ( access_token )
aiplatform . init ( credentials = credits , project = project_id , location = region )
2024-09-02 12:06:41 +08:00
else :
aiplatform . init ( project = project_id , location = region )
self . client = glm . GenerativeModel ( model_name = self . model_name )
def chat ( self , system , history , gen_conf ) :
if system :
self . system = system
if " claude " in self . model_name :
2025-03-06 11:29:40 +08:00
if " max_tokens " in gen_conf :
del gen_conf [ " max_tokens " ]
2024-09-02 12:06:41 +08:00
try :
response = self . client . messages . create (
model = self . model_name ,
messages = history ,
system = self . system ,
stream = False ,
* * gen_conf ,
) . json ( )
ans = response [ " content " ] [ 0 ] [ " text " ]
if response [ " stop_reason " ] == " max_tokens " :
2025-03-26 19:33:14 +08:00
ans + = " ... \n For the content length reason, it stopped, continue? " if is_english ( [ ans ] ) else " ······ \n 由于长度的原因,回答被截断了,要继续吗? "
2024-09-02 12:06:41 +08:00
return (
ans ,
2025-03-26 19:33:14 +08:00
response [ " usage " ] [ " input_tokens " ] + response [ " usage " ] [ " output_tokens " ] ,
2024-09-02 12:06:41 +08:00
)
except Exception as e :
2024-09-30 17:54:27 +08:00
return " \n **ERROR**: " + str ( e ) , 0
2024-09-02 12:06:41 +08:00
else :
self . client . _system_instruction = self . system
if " max_tokens " in gen_conf :
gen_conf [ " max_output_tokens " ] = gen_conf [ " max_tokens " ]
for k in list ( gen_conf . keys ( ) ) :
if k not in [ " temperature " , " top_p " , " max_output_tokens " ] :
del gen_conf [ k ]
for item in history :
if " role " in item and item [ " role " ] == " assistant " :
item [ " role " ] = " model "
if " content " in item :
item [ " parts " ] = item . pop ( " content " )
try :
2025-03-26 19:33:14 +08:00
response = self . client . generate_content ( history , generation_config = gen_conf )
2024-09-02 12:06:41 +08:00
ans = response . text
return ans , response . usage_metadata . total_token_count
except Exception as e :
return " **ERROR**: " + str ( e ) , 0
def chat_streamly ( self , system , history , gen_conf ) :
if system :
self . system = system
if " claude " in self . model_name :
2025-03-06 11:29:40 +08:00
if " max_tokens " in gen_conf :
del gen_conf [ " max_tokens " ]
2024-09-02 12:06:41 +08:00
ans = " "
total_tokens = 0
try :
response = self . client . messages . create (
model = self . model_name ,
messages = history ,
system = self . system ,
stream = True ,
* * gen_conf ,
)
for res in response . iter_lines ( ) :
res = res . decode ( " utf-8 " )
if " content_block_delta " in res and " data " in res :
text = json . loads ( res [ 6 : ] ) [ " delta " ] [ " text " ]
2025-03-26 19:33:14 +08:00
ans = text
2024-09-02 12:06:41 +08:00
total_tokens + = num_tokens_from_string ( text )
except Exception as e :
yield ans + " \n **ERROR**: " + str ( e )
yield total_tokens
else :
self . client . _system_instruction = self . system
if " max_tokens " in gen_conf :
gen_conf [ " max_output_tokens " ] = gen_conf [ " max_tokens " ]
for k in list ( gen_conf . keys ( ) ) :
if k not in [ " temperature " , " top_p " , " max_output_tokens " ] :
del gen_conf [ k ]
for item in history :
if " role " in item and item [ " role " ] == " assistant " :
item [ " role " ] = " model "
if " content " in item :
item [ " parts " ] = item . pop ( " content " )
ans = " "
try :
2025-03-26 19:33:14 +08:00
response = self . model . generate_content ( history , generation_config = gen_conf , stream = True )
2024-09-02 12:06:41 +08:00
for resp in response :
2025-03-26 19:33:14 +08:00
ans = resp . text
2024-09-02 12:06:41 +08:00
yield ans
except Exception as e :
yield ans + " \n **ERROR**: " + str ( e )
yield response . _chunks [ - 1 ] . usage_metadata . total_token_count
2025-01-15 14:15:58 +08:00
2025-03-06 11:29:40 +08:00
2025-01-15 14:15:58 +08:00
class GPUStackChat ( Base ) :
def __init__ ( self , key = None , model_name = " " , base_url = " " ) :
if not base_url :
raise ValueError ( " Local llm url cannot be None " )
2025-03-31 15:33:52 +08:00
if base_url . split ( " / " ) [ - 1 ] != " v1 " :
base_url = os . path . join ( base_url , " v1 " )
2025-02-24 10:12:20 +08:00
super ( ) . __init__ ( key , model_name , base_url )