2024-01-15 08:46:22 +08:00
#
2024-01-19 19:51:57 +08:00
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
2024-01-15 08:46:22 +08:00
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
2024-10-29 10:08:08 +08:00
import re
2024-07-04 09:57:16 +08:00
from openai . lib . azure import AzureOpenAI
2024-03-27 11:33:46 +08:00
from zhipuai import ZhipuAI
from dashscope import Generation
2023-12-25 19:05:59 +08:00
from abc import ABC
2023-12-28 13:50:13 +08:00
from openai import OpenAI
2024-02-27 14:57:34 +08:00
import openai
2024-04-08 19:20:57 +08:00
from ollama import Client
2024-12-04 09:34:49 +08:00
from rag . nlp import is_chinese
2024-05-20 12:23:51 +08:00
from rag . utils import num_tokens_from_string
added SVG for Groq model model providers (#1470)
#1432 #1447
This PR adds support for the GROQ LLM (Large Language Model).
Groq is an AI solutions company delivering ultra-low latency inference
with the first-ever LPU™ Inference Engine. The Groq API enables
developers to integrate state-of-the-art LLMs, such as Llama-2 and
llama3-70b-8192, into low latency applications with the request limits
specified below. Learn more at [groq.com](https://groq.com/).
Supported Models
| ID | Requests per Minute | Requests per Day | Tokens per Minute |
|----------------------|---------------------|------------------|-------------------|
| gemma-7b-it | 30 | 14,400 | 15,000 |
| gemma2-9b-it | 30 | 14,400 | 15,000 |
| llama3-70b-8192 | 30 | 14,400 | 6,000 |
| llama3-8b-8192 | 30 | 14,400 | 30,000 |
| mixtral-8x7b-32768 | 30 | 14,400 | 5,000 |
---------
Co-authored-by: paresh0628 <paresh.tuvoc@gmail.com>
Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
2024-07-12 06:55:44 +05:30
from groq import Groq
2024-10-08 18:27:04 +08:00
import os
2024-07-17 15:32:51 +08:00
import json
import requests
2024-07-30 14:07:00 +08:00
import asyncio
2023-12-28 13:50:13 +08:00
2024-12-04 09:34:49 +08:00
LENGTH_NOTIFICATION_CN = " ······ \n 由于长度的原因,回答被截断了,要继续吗? "
LENGTH_NOTIFICATION_EN = " ... \n For the content length reason, it stopped, continue? "
2024-09-27 13:17:25 +08:00
2023-12-25 19:05:59 +08:00
class Base ( ABC ) :
2024-05-08 10:30:02 +08:00
def __init__ ( self , key , model_name , base_url ) :
2024-10-21 12:11:08 +08:00
timeout = int ( os . environ . get ( ' LM_TIMEOUT_SECONDS ' , 600 ) )
self . client = OpenAI ( api_key = key , base_url = base_url , timeout = timeout )
2024-01-22 19:51:38 +08:00
self . model_name = model_name
2023-12-25 19:05:59 +08:00
def chat ( self , system , history , gen_conf ) :
2024-03-27 11:33:46 +08:00
if system :
history . insert ( 0 , { " role " : " system " , " content " : system } )
2024-02-27 14:57:34 +08:00
try :
2024-03-07 16:12:01 +08:00
response = self . client . chat . completions . create (
2024-02-27 14:57:34 +08:00
model = self . model_name ,
messages = history ,
* * gen_conf )
2024-03-27 17:55:45 +08:00
ans = response . choices [ 0 ] . message . content . strip ( )
if response . choices [ 0 ] . finish_reason == " length " :
2024-12-04 09:34:49 +08:00
if is_chinese ( ans ) :
ans + = LENGTH_NOTIFICATION_CN
else :
ans + = LENGTH_NOTIFICATION_EN
2024-04-15 14:43:44 +08:00
return ans , response . usage . total_tokens
2024-02-27 14:57:34 +08:00
except openai . APIError as e :
2024-03-27 11:33:46 +08:00
return " **ERROR**: " + str ( e ) , 0
2023-12-25 19:05:59 +08:00
2024-05-16 20:14:53 +08:00
def chat_streamly ( self , system , history , gen_conf ) :
if system :
history . insert ( 0 , { " role " : " system " , " content " : system } )
ans = " "
total_tokens = 0
try :
response = self . client . chat . completions . create (
model = self . model_name ,
messages = history ,
stream = True ,
* * gen_conf )
for resp in response :
2024-10-08 18:27:04 +08:00
if not resp . choices : continue
2024-07-29 09:21:31 +08:00
if not resp . choices [ 0 ] . delta . content :
2024-10-08 18:27:04 +08:00
resp . choices [ 0 ] . delta . content = " "
2024-05-16 20:14:53 +08:00
ans + = resp . choices [ 0 ] . delta . content
2024-10-22 11:40:05 +08:00
2024-10-22 11:38:37 +08:00
if not hasattr ( resp , " usage " ) or not resp . usage :
total_tokens = (
total_tokens
+ num_tokens_from_string ( resp . choices [ 0 ] . delta . content )
)
elif isinstance ( resp . usage , dict ) :
total_tokens = resp . usage . get ( " total_tokens " , total_tokens )
else : total_tokens = resp . usage . total_tokens
2024-05-16 20:14:53 +08:00
if resp . choices [ 0 ] . finish_reason == " length " :
2024-12-04 09:34:49 +08:00
if is_chinese ( ans ) :
ans + = LENGTH_NOTIFICATION_CN
else :
ans + = LENGTH_NOTIFICATION_EN
2024-05-16 20:14:53 +08:00
yield ans
except openai . APIError as e :
yield ans + " \n **ERROR**: " + str ( e )
yield total_tokens
2023-12-25 19:05:59 +08:00
2024-05-08 10:30:02 +08:00
class GptTurbo ( Base ) :
def __init__ ( self , key , model_name = " gpt-3.5-turbo " , base_url = " https://api.openai.com/v1 " ) :
2024-10-08 18:27:04 +08:00
if not base_url : base_url = " https://api.openai.com/v1 "
2024-05-08 10:30:02 +08:00
super ( ) . __init__ ( key , model_name , base_url )
class MoonshotChat ( Base ) :
2024-03-28 19:15:16 +08:00
def __init__ ( self , key , model_name = " moonshot-v1-8k " , base_url = " https://api.moonshot.cn/v1 " ) :
2024-10-08 18:27:04 +08:00
if not base_url : base_url = " https://api.moonshot.cn/v1 "
2024-05-08 10:30:02 +08:00
super ( ) . __init__ ( key , model_name , base_url )
2024-03-14 19:45:29 +08:00
2024-07-25 10:23:35 +08:00
2024-05-08 10:30:02 +08:00
class XinferenceChat ( Base ) :
def __init__ ( self , key = None , model_name = " " , base_url = " " ) :
2024-07-25 10:23:35 +08:00
if not base_url :
raise ValueError ( " Local llm url cannot be None " )
if base_url . split ( " / " ) [ - 1 ] != " v1 " :
2024-08-07 18:10:42 +08:00
base_url = os . path . join ( base_url , " v1 " )
2024-05-08 10:30:02 +08:00
super ( ) . __init__ ( key , model_name , base_url )
2024-10-29 10:42:45 +08:00
2024-10-11 14:45:48 +08:00
class HuggingFaceChat ( Base ) :
def __init__ ( self , key = None , model_name = " " , base_url = " " ) :
if not base_url :
raise ValueError ( " Local llm url cannot be None " )
if base_url . split ( " / " ) [ - 1 ] != " v1 " :
base_url = os . path . join ( base_url , " v1 " )
super ( ) . __init__ ( key , model_name , base_url )
2024-05-08 10:30:02 +08:00
2024-10-29 10:42:45 +08:00
2024-05-08 10:30:02 +08:00
class DeepSeekChat ( Base ) :
def __init__ ( self , key , model_name = " deepseek-chat " , base_url = " https://api.deepseek.com/v1 " ) :
2024-10-08 18:27:04 +08:00
if not base_url : base_url = " https://api.deepseek.com/v1 "
2024-05-08 10:30:02 +08:00
super ( ) . __init__ ( key , model_name , base_url )
2024-03-15 18:59:00 +08:00
2024-07-04 15:57:25 +08:00
2024-07-04 09:57:16 +08:00
class AzureChat ( Base ) :
def __init__ ( self , key , model_name , * * kwargs ) :
2024-10-11 11:26:42 +08:00
api_key = json . loads ( key ) . get ( ' api_key ' , ' ' )
api_version = json . loads ( key ) . get ( ' api_version ' , ' 2024-02-01 ' )
self . client = AzureOpenAI ( api_key = api_key , azure_endpoint = kwargs [ " base_url " ] , api_version = api_version )
2024-07-04 09:57:16 +08:00
self . model_name = model_name
2024-03-14 19:45:29 +08:00
2024-05-28 09:09:37 +08:00
class BaiChuanChat ( Base ) :
def __init__ ( self , key , model_name = " Baichuan3-Turbo " , base_url = " https://api.baichuan-ai.com/v1 " ) :
if not base_url :
base_url = " https://api.baichuan-ai.com/v1 "
super ( ) . __init__ ( key , model_name , base_url )
@staticmethod
def _format_params ( params ) :
return {
" temperature " : params . get ( " temperature " , 0.3 ) ,
" max_tokens " : params . get ( " max_tokens " , 2048 ) ,
" top_p " : params . get ( " top_p " , 0.85 ) ,
}
def chat ( self , system , history , gen_conf ) :
if system :
history . insert ( 0 , { " role " : " system " , " content " : system } )
try :
response = self . client . chat . completions . create (
model = self . model_name ,
messages = history ,
extra_body = {
" tools " : [ {
" type " : " web_search " ,
" web_search " : {
" enable " : True ,
" search_mode " : " performance_first "
}
} ]
} ,
* * self . _format_params ( gen_conf ) )
ans = response . choices [ 0 ] . message . content . strip ( )
if response . choices [ 0 ] . finish_reason == " length " :
2024-12-04 09:34:49 +08:00
if is_chinese ( [ ans ] ) :
ans + = LENGTH_NOTIFICATION_CN
else :
ans + = LENGTH_NOTIFICATION_EN
2024-05-28 09:09:37 +08:00
return ans , response . usage . total_tokens
except openai . APIError as e :
return " **ERROR**: " + str ( e ) , 0
def chat_streamly ( self , system , history , gen_conf ) :
if system :
history . insert ( 0 , { " role " : " system " , " content " : system } )
ans = " "
total_tokens = 0
try :
response = self . client . chat . completions . create (
model = self . model_name ,
messages = history ,
extra_body = {
" tools " : [ {
" type " : " web_search " ,
" web_search " : {
" enable " : True ,
" search_mode " : " performance_first "
}
} ]
} ,
stream = True ,
* * self . _format_params ( gen_conf ) )
for resp in response :
2024-10-08 18:27:04 +08:00
if not resp . choices : continue
2024-05-28 09:09:37 +08:00
if not resp . choices [ 0 ] . delta . content :
2024-10-08 18:27:04 +08:00
resp . choices [ 0 ] . delta . content = " "
2024-05-28 09:09:37 +08:00
ans + = resp . choices [ 0 ] . delta . content
2024-07-30 16:55:59 +08:00
total_tokens = (
(
2024-10-08 18:27:04 +08:00
total_tokens
+ num_tokens_from_string ( resp . choices [ 0 ] . delta . content )
2024-07-30 16:55:59 +08:00
)
if not hasattr ( resp , " usage " )
else resp . usage [ " total_tokens " ]
)
2024-05-28 09:09:37 +08:00
if resp . choices [ 0 ] . finish_reason == " length " :
2024-12-04 09:34:49 +08:00
if is_chinese ( [ ans ] ) :
ans + = LENGTH_NOTIFICATION_CN
else :
ans + = LENGTH_NOTIFICATION_EN
2024-05-28 09:09:37 +08:00
yield ans
except Exception as e :
yield ans + " \n **ERROR**: " + str ( e )
yield total_tokens
2023-12-28 13:50:13 +08:00
class QWenChat ( Base ) :
2024-03-28 19:15:16 +08:00
def __init__ ( self , key , model_name = Generation . Models . qwen_turbo , * * kwargs ) :
2024-01-22 19:51:38 +08:00
import dashscope
dashscope . api_key = key
self . model_name = model_name
2023-12-25 19:05:59 +08:00
def chat ( self , system , history , gen_conf ) :
2024-10-21 12:11:08 +08:00
stream_flag = str ( os . environ . get ( ' QWEN_CHAT_BY_STREAM ' , ' true ' ) ) . lower ( ) == ' true '
if not stream_flag :
from http import HTTPStatus
if system :
history . insert ( 0 , { " role " : " system " , " content " : system } )
2024-03-07 16:12:01 +08:00
2024-10-21 12:11:08 +08:00
response = Generation . call (
self . model_name ,
messages = history ,
result_format = ' message ' ,
* * gen_conf
)
ans = " "
tk_count = 0
if response . status_code == HTTPStatus . OK :
ans + = response . output . choices [ 0 ] [ ' message ' ] [ ' content ' ]
tk_count + = response . usage . total_tokens
if response . output . choices [ 0 ] . get ( " finish_reason " , " " ) == " length " :
2024-12-04 09:34:49 +08:00
if is_chinese ( [ ans ] ) :
ans + = LENGTH_NOTIFICATION_CN
else :
ans + = LENGTH_NOTIFICATION_EN
2024-10-21 12:11:08 +08:00
return ans , tk_count
2024-02-08 17:01:01 +08:00
2024-10-21 12:11:08 +08:00
return " **ERROR**: " + response . message , tk_count
else :
g = self . _chat_streamly ( system , history , gen_conf , incremental_output = True )
result_list = list ( g )
error_msg_list = [ item for item in result_list if str ( item ) . find ( " **ERROR** " ) > = 0 ]
if len ( error_msg_list ) > 0 :
return " **ERROR**: " + " " . join ( error_msg_list ) , 0
else :
return " " . join ( result_list [ : - 1 ] ) , result_list [ - 1 ]
def _chat_streamly ( self , system , history , gen_conf , incremental_output = False ) :
2024-05-16 20:14:53 +08:00
from http import HTTPStatus
if system :
history . insert ( 0 , { " role " : " system " , " content " : system } )
ans = " "
2024-05-30 16:18:15 +08:00
tk_count = 0
2024-05-16 20:14:53 +08:00
try :
response = Generation . call (
self . model_name ,
messages = history ,
result_format = ' message ' ,
stream = True ,
2024-10-21 12:11:08 +08:00
incremental_output = incremental_output ,
2024-05-16 20:14:53 +08:00
* * gen_conf
)
for resp in response :
if resp . status_code == HTTPStatus . OK :
ans = resp . output . choices [ 0 ] [ ' message ' ] [ ' content ' ]
tk_count = resp . usage . total_tokens
if resp . output . choices [ 0 ] . get ( " finish_reason " , " " ) == " length " :
2024-12-04 09:34:49 +08:00
if is_chinese ( ans ) :
ans + = LENGTH_NOTIFICATION_CN
else :
ans + = LENGTH_NOTIFICATION_EN
2024-05-16 20:14:53 +08:00
yield ans
else :
2024-10-29 10:08:08 +08:00
yield ans + " \n **ERROR**: " + resp . message if not re . search ( r " (key|quota) " , str ( resp . message ) . lower ( ) ) else " Out of credit. Please set the API key in **settings > Model providers.** "
2024-05-16 20:14:53 +08:00
except Exception as e :
yield ans + " \n **ERROR**: " + str ( e )
yield tk_count
2024-10-21 12:11:08 +08:00
def chat_streamly ( self , system , history , gen_conf ) :
return self . _chat_streamly ( system , history , gen_conf )
2024-02-08 17:01:01 +08:00
class ZhipuChat ( Base ) :
2024-03-28 19:15:16 +08:00
def __init__ ( self , key , model_name = " glm-3-turbo " , * * kwargs ) :
2024-02-08 17:01:01 +08:00
self . client = ZhipuAI ( api_key = key )
self . model_name = model_name
def chat ( self , system , history , gen_conf ) :
2024-03-27 11:33:46 +08:00
if system :
history . insert ( 0 , { " role " : " system " , " content " : system } )
2024-03-07 16:12:01 +08:00
try :
2024-04-09 09:24:08 +08:00
if " presence_penalty " in gen_conf : del gen_conf [ " presence_penalty " ]
2024-04-09 16:16:10 +08:00
if " frequency_penalty " in gen_conf : del gen_conf [ " frequency_penalty " ]
2024-03-07 16:12:01 +08:00
response = self . client . chat . completions . create (
2024-03-27 17:55:45 +08:00
model = self . model_name ,
2024-03-07 16:12:01 +08:00
messages = history ,
* * gen_conf
)
2024-03-27 17:55:45 +08:00
ans = response . choices [ 0 ] . message . content . strip ( )
if response . choices [ 0 ] . finish_reason == " length " :
2024-12-04 09:34:49 +08:00
if is_chinese ( ans ) :
ans + = LENGTH_NOTIFICATION_CN
else :
ans + = LENGTH_NOTIFICATION_EN
2024-04-15 14:43:44 +08:00
return ans , response . usage . total_tokens
2024-03-07 16:12:01 +08:00
except Exception as e :
2024-03-12 11:57:08 +08:00
return " **ERROR**: " + str ( e ) , 0
2024-05-16 20:14:53 +08:00
def chat_streamly ( self , system , history , gen_conf ) :
if system :
history . insert ( 0 , { " role " : " system " , " content " : system } )
if " presence_penalty " in gen_conf : del gen_conf [ " presence_penalty " ]
if " frequency_penalty " in gen_conf : del gen_conf [ " frequency_penalty " ]
ans = " "
2024-05-30 16:18:15 +08:00
tk_count = 0
2024-05-16 20:14:53 +08:00
try :
response = self . client . chat . completions . create (
model = self . model_name ,
messages = history ,
stream = True ,
* * gen_conf
)
for resp in response :
2024-10-08 18:27:04 +08:00
if not resp . choices [ 0 ] . delta . content : continue
2024-05-16 20:14:53 +08:00
delta = resp . choices [ 0 ] . delta . content
ans + = delta
2024-05-17 17:07:33 +08:00
if resp . choices [ 0 ] . finish_reason == " length " :
2024-12-04 09:34:49 +08:00
if is_chinese ( ans ) :
ans + = LENGTH_NOTIFICATION_CN
else :
ans + = LENGTH_NOTIFICATION_EN
2024-05-17 17:07:33 +08:00
tk_count = resp . usage . total_tokens
if resp . choices [ 0 ] . finish_reason == " stop " : tk_count = resp . usage . total_tokens
2024-05-16 20:14:53 +08:00
yield ans
except Exception as e :
yield ans + " \n **ERROR**: " + str ( e )
yield tk_count
2024-03-27 11:33:46 +08:00
2024-04-08 19:20:57 +08:00
class OllamaChat ( Base ) :
def __init__ ( self , key , model_name , * * kwargs ) :
self . client = Client ( host = kwargs [ " base_url " ] )
self . model_name = model_name
def chat ( self , system , history , gen_conf ) :
if system :
history . insert ( 0 , { " role " : " system " , " content " : system } )
try :
2024-04-30 11:04:14 +08:00
options = { }
if " temperature " in gen_conf : options [ " temperature " ] = gen_conf [ " temperature " ]
if " max_tokens " in gen_conf : options [ " num_predict " ] = gen_conf [ " max_tokens " ]
2024-11-29 14:52:27 +08:00
if " top_p " in gen_conf : options [ " top_p " ] = gen_conf [ " top_p " ]
2024-04-30 11:04:14 +08:00
if " presence_penalty " in gen_conf : options [ " presence_penalty " ] = gen_conf [ " presence_penalty " ]
if " frequency_penalty " in gen_conf : options [ " frequency_penalty " ] = gen_conf [ " frequency_penalty " ]
2024-04-08 19:20:57 +08:00
response = self . client . chat (
model = self . model_name ,
messages = history ,
2024-05-30 11:27:58 +08:00
options = options ,
keep_alive = - 1
2024-04-08 19:20:57 +08:00
)
ans = response [ " message " ] [ " content " ] . strip ( )
2024-11-26 16:31:07 +08:00
return ans , response . get ( " eval_count " , 0 ) + response . get ( " prompt_eval_count " , 0 )
2024-04-08 19:20:57 +08:00
except Exception as e :
return " **ERROR**: " + str ( e ) , 0
2024-05-16 20:14:53 +08:00
def chat_streamly ( self , system , history , gen_conf ) :
if system :
history . insert ( 0 , { " role " : " system " , " content " : system } )
options = { }
if " temperature " in gen_conf : options [ " temperature " ] = gen_conf [ " temperature " ]
if " max_tokens " in gen_conf : options [ " num_predict " ] = gen_conf [ " max_tokens " ]
2024-11-29 14:52:27 +08:00
if " top_p " in gen_conf : options [ " top_p " ] = gen_conf [ " top_p " ]
2024-05-16 20:14:53 +08:00
if " presence_penalty " in gen_conf : options [ " presence_penalty " ] = gen_conf [ " presence_penalty " ]
if " frequency_penalty " in gen_conf : options [ " frequency_penalty " ] = gen_conf [ " frequency_penalty " ]
ans = " "
try :
response = self . client . chat (
model = self . model_name ,
messages = history ,
stream = True ,
2024-05-30 11:27:58 +08:00
options = options ,
keep_alive = - 1
2024-05-16 20:14:53 +08:00
)
for resp in response :
if resp [ " done " ] :
2024-05-17 12:07:00 +08:00
yield resp . get ( " prompt_eval_count " , 0 ) + resp . get ( " eval_count " , 0 )
ans + = resp [ " message " ] [ " content " ]
2024-05-16 20:14:53 +08:00
yield ans
except Exception as e :
yield ans + " \n **ERROR**: " + str ( e )
yield 0
2024-05-20 12:23:51 +08:00
2024-07-19 15:50:28 +08:00
class LocalAIChat ( Base ) :
def __init__ ( self , key , model_name , base_url ) :
2024-07-25 10:23:35 +08:00
if not base_url :
raise ValueError ( " Local llm url cannot be None " )
if base_url . split ( " / " ) [ - 1 ] != " v1 " :
2024-08-07 18:10:42 +08:00
base_url = os . path . join ( base_url , " v1 " )
self . client = OpenAI ( api_key = " empty " , base_url = base_url )
2024-07-19 15:50:28 +08:00
self . model_name = model_name . split ( " ___ " ) [ 0 ]
2024-05-20 12:23:51 +08:00
class LocalLLM ( Base ) :
class RPCProxy :
def __init__ ( self , host , port ) :
self . host = host
self . port = int ( port )
self . __conn ( )
def __conn ( self ) :
from multiprocessing . connection import Client
2024-07-30 14:07:00 +08:00
2024-05-20 12:23:51 +08:00
self . _connection = Client (
2024-07-30 14:07:00 +08:00
( self . host , self . port ) , authkey = b " infiniflow-token4kevinhu "
)
2024-05-20 12:23:51 +08:00
def __getattr__ ( self , name ) :
import pickle
def do_rpc ( * args , * * kwargs ) :
for _ in range ( 3 ) :
try :
2024-07-30 14:07:00 +08:00
self . _connection . send ( pickle . dumps ( ( name , args , kwargs ) ) )
2024-05-20 12:23:51 +08:00
return pickle . loads ( self . _connection . recv ( ) )
2024-11-29 14:52:27 +08:00
except Exception :
2024-05-20 12:23:51 +08:00
self . __conn ( )
raise Exception ( " RPC connection lost! " )
return do_rpc
2024-07-30 14:07:00 +08:00
def __init__ ( self , key , model_name ) :
from jina import Client
2024-05-20 12:23:51 +08:00
2024-07-30 14:07:00 +08:00
self . client = Client ( port = 12345 , protocol = " grpc " , asyncio = True )
2024-05-20 12:40:59 +08:00
2024-07-30 14:07:00 +08:00
def _prepare_prompt ( self , system , history , gen_conf ) :
2024-11-29 14:52:27 +08:00
from rag . svr . jina_server import Prompt
2024-05-20 12:40:59 +08:00
if system :
history . insert ( 0 , { " role " : " system " , " content " : system } )
2024-07-30 14:07:00 +08:00
if " max_tokens " in gen_conf :
gen_conf [ " max_new_tokens " ] = gen_conf . pop ( " max_tokens " )
return Prompt ( message = history , gen_conf = gen_conf )
def _stream_response ( self , endpoint , prompt ) :
2024-11-29 14:52:27 +08:00
from rag . svr . jina_server import Generation
2024-05-20 12:40:59 +08:00
answer = " "
try :
2024-07-30 14:07:00 +08:00
res = self . client . stream_doc (
on = endpoint , inputs = prompt , return_type = Generation
)
loop = asyncio . get_event_loop ( )
try :
while True :
answer = loop . run_until_complete ( res . __anext__ ( ) ) . text
yield answer
except StopAsyncIteration :
pass
2024-05-20 12:40:59 +08:00
except Exception as e :
yield answer + " \n **ERROR**: " + str ( e )
2024-07-30 14:07:00 +08:00
yield num_tokens_from_string ( answer )
2024-05-20 12:40:59 +08:00
2024-07-30 14:07:00 +08:00
def chat ( self , system , history , gen_conf ) :
prompt = self . _prepare_prompt ( system , history , gen_conf )
chat_gen = self . _stream_response ( " /chat " , prompt )
ans = next ( chat_gen )
total_tokens = next ( chat_gen )
return ans , total_tokens
def chat_streamly ( self , system , history , gen_conf ) :
prompt = self . _prepare_prompt ( system , history , gen_conf )
return self . _stream_response ( " /stream " , prompt )
2024-05-23 11:15:29 +08:00
class VolcEngineChat ( Base ) :
2024-08-26 13:34:29 +08:00
def __init__ ( self , key , model_name , base_url = ' https://ark.cn-beijing.volces.com/api/v3 ' ) :
2024-05-23 11:15:29 +08:00
"""
Since do not want to modify the original database fields , and the VolcEngine authentication method is quite special ,
2024-08-26 13:34:29 +08:00
Assemble ark_api_key , ep_id into api_key , store it as a dictionary type , and parse it for use
2024-05-23 11:15:29 +08:00
model_name is for display only
"""
2024-08-26 13:34:29 +08:00
base_url = base_url if base_url else ' https://ark.cn-beijing.volces.com/api/v3 '
2024-08-29 16:21:32 +08:00
ark_api_key = json . loads ( key ) . get ( ' ark_api_key ' , ' ' )
2024-09-20 10:20:35 +08:00
model_name = json . loads ( key ) . get ( ' ep_id ' , ' ' ) + json . loads ( key ) . get ( ' endpoint_id ' , ' ' )
2024-08-26 13:34:29 +08:00
super ( ) . __init__ ( ark_api_key , model_name , base_url )
2024-05-31 16:38:53 +08:00
class MiniMaxChat ( Base ) :
2024-07-17 15:32:51 +08:00
def __init__ (
2024-10-08 18:27:04 +08:00
self ,
key ,
model_name ,
base_url = " https://api.minimax.chat/v1/text/chatcompletion_v2 " ,
2024-07-17 15:32:51 +08:00
) :
2024-05-31 16:38:53 +08:00
if not base_url :
2024-07-17 15:32:51 +08:00
base_url = " https://api.minimax.chat/v1/text/chatcompletion_v2 "
self . base_url = base_url
self . model_name = model_name
self . api_key = key
def chat ( self , system , history , gen_conf ) :
if system :
history . insert ( 0 , { " role " : " system " , " content " : system } )
for k in list ( gen_conf . keys ( ) ) :
if k not in [ " temperature " , " top_p " , " max_tokens " ] :
del gen_conf [ k ]
headers = {
" Authorization " : f " Bearer { self . api_key } " ,
" Content-Type " : " application/json " ,
}
payload = json . dumps (
{ " model " : self . model_name , " messages " : history , * * gen_conf }
)
try :
response = requests . request (
" POST " , url = self . base_url , headers = headers , data = payload
)
response = response . json ( )
ans = response [ " choices " ] [ 0 ] [ " message " ] [ " content " ] . strip ( )
if response [ " choices " ] [ 0 ] [ " finish_reason " ] == " length " :
2024-12-04 09:34:49 +08:00
if is_chinese ( ans ) :
ans + = LENGTH_NOTIFICATION_CN
else :
ans + = LENGTH_NOTIFICATION_EN
2024-07-17 15:32:51 +08:00
return ans , response [ " usage " ] [ " total_tokens " ]
except Exception as e :
return " **ERROR**: " + str ( e ) , 0
def chat_streamly ( self , system , history , gen_conf ) :
if system :
history . insert ( 0 , { " role " : " system " , " content " : system } )
ans = " "
total_tokens = 0
try :
headers = {
" Authorization " : f " Bearer { self . api_key } " ,
" Content-Type " : " application/json " ,
}
payload = json . dumps (
{
" model " : self . model_name ,
" messages " : history ,
" stream " : True ,
* * gen_conf ,
}
)
response = requests . request (
" POST " ,
url = self . base_url ,
headers = headers ,
data = payload ,
)
for resp in response . text . split ( " \n \n " ) [ : - 1 ] :
resp = json . loads ( resp [ 6 : ] )
2024-07-29 19:35:16 +08:00
text = " "
if " choices " in resp and " delta " in resp [ " choices " ] [ 0 ] :
2024-07-17 15:32:51 +08:00
text = resp [ " choices " ] [ 0 ] [ " delta " ] [ " content " ]
ans + = text
2024-07-29 19:35:16 +08:00
total_tokens = (
total_tokens + num_tokens_from_string ( text )
if " usage " not in resp
else resp [ " usage " ] [ " total_tokens " ]
)
2024-07-17 15:32:51 +08:00
yield ans
except Exception as e :
yield ans + " \n **ERROR**: " + str ( e )
yield total_tokens
2024-06-14 11:32:58 +08:00
class MistralChat ( Base ) :
def __init__ ( self , key , model_name , base_url = None ) :
from mistralai . client import MistralClient
self . client = MistralClient ( api_key = key )
self . model_name = model_name
def chat ( self , system , history , gen_conf ) :
if system :
history . insert ( 0 , { " role " : " system " , " content " : system } )
for k in list ( gen_conf . keys ( ) ) :
if k not in [ " temperature " , " top_p " , " max_tokens " ] :
del gen_conf [ k ]
try :
response = self . client . chat (
model = self . model_name ,
messages = history ,
* * gen_conf )
ans = response . choices [ 0 ] . message . content
if response . choices [ 0 ] . finish_reason == " length " :
2024-12-04 09:34:49 +08:00
if is_chinese ( ans ) :
ans + = LENGTH_NOTIFICATION_CN
else :
ans + = LENGTH_NOTIFICATION_EN
2024-06-14 11:32:58 +08:00
return ans , response . usage . total_tokens
except openai . APIError as e :
return " **ERROR**: " + str ( e ) , 0
def chat_streamly ( self , system , history , gen_conf ) :
if system :
history . insert ( 0 , { " role " : " system " , " content " : system } )
for k in list ( gen_conf . keys ( ) ) :
if k not in [ " temperature " , " top_p " , " max_tokens " ] :
del gen_conf [ k ]
ans = " "
total_tokens = 0
try :
response = self . client . chat_stream (
model = self . model_name ,
messages = history ,
* * gen_conf )
for resp in response :
2024-10-08 18:27:04 +08:00
if not resp . choices or not resp . choices [ 0 ] . delta . content : continue
2024-06-14 11:32:58 +08:00
ans + = resp . choices [ 0 ] . delta . content
total_tokens + = 1
if resp . choices [ 0 ] . finish_reason == " length " :
2024-12-04 09:34:49 +08:00
if is_chinese ( ans ) :
ans + = LENGTH_NOTIFICATION_CN
else :
ans + = LENGTH_NOTIFICATION_EN
2024-06-14 11:32:58 +08:00
yield ans
except openai . APIError as e :
yield ans + " \n **ERROR**: " + str ( e )
yield total_tokens
2024-07-08 09:37:34 +08:00
class BedrockChat ( Base ) :
def __init__ ( self , key , model_name , * * kwargs ) :
import boto3
2024-08-29 16:21:32 +08:00
self . bedrock_ak = json . loads ( key ) . get ( ' bedrock_ak ' , ' ' )
self . bedrock_sk = json . loads ( key ) . get ( ' bedrock_sk ' , ' ' )
self . bedrock_region = json . loads ( key ) . get ( ' bedrock_region ' , ' ' )
2024-07-08 09:37:34 +08:00
self . model_name = model_name
self . client = boto3 . client ( service_name = ' bedrock-runtime ' , region_name = self . bedrock_region ,
aws_access_key_id = self . bedrock_ak , aws_secret_access_key = self . bedrock_sk )
def chat ( self , system , history , gen_conf ) :
2024-07-08 16:20:19 +08:00
from botocore . exceptions import ClientError
2024-07-08 09:37:34 +08:00
for k in list ( gen_conf . keys ( ) ) :
if k not in [ " temperature " , " top_p " , " max_tokens " ] :
del gen_conf [ k ]
if " max_tokens " in gen_conf :
gen_conf [ " maxTokens " ] = gen_conf [ " max_tokens " ]
_ = gen_conf . pop ( " max_tokens " )
if " top_p " in gen_conf :
gen_conf [ " topP " ] = gen_conf [ " top_p " ]
_ = gen_conf . pop ( " top_p " )
2024-08-15 14:54:49 +08:00
for item in history :
2024-10-08 18:27:04 +08:00
if not isinstance ( item [ " content " ] , list ) and not isinstance ( item [ " content " ] , tuple ) :
item [ " content " ] = [ { " text " : item [ " content " ] } ]
2024-07-08 09:37:34 +08:00
try :
# Send the message to the model, using a basic inference configuration.
response = self . client . converse (
modelId = self . model_name ,
messages = history ,
2024-08-23 06:44:37 +03:00
inferenceConfig = gen_conf ,
2024-10-08 18:27:04 +08:00
system = [ { " text " : ( system if system else " Answer the user ' s message. " ) } ] ,
2024-07-08 09:37:34 +08:00
)
2024-10-08 18:27:04 +08:00
2024-07-08 09:37:34 +08:00
# Extract and print the response text.
ans = response [ " output " ] [ " message " ] [ " content " ] [ 0 ] [ " text " ]
return ans , num_tokens_from_string ( ans )
except ( ClientError , Exception ) as e :
return f " ERROR: Can ' t invoke ' { self . model_name } ' . Reason: { e } " , 0
def chat_streamly ( self , system , history , gen_conf ) :
2024-07-08 16:20:19 +08:00
from botocore . exceptions import ClientError
2024-07-08 09:37:34 +08:00
for k in list ( gen_conf . keys ( ) ) :
if k not in [ " temperature " , " top_p " , " max_tokens " ] :
del gen_conf [ k ]
if " max_tokens " in gen_conf :
gen_conf [ " maxTokens " ] = gen_conf [ " max_tokens " ]
_ = gen_conf . pop ( " max_tokens " )
if " top_p " in gen_conf :
gen_conf [ " topP " ] = gen_conf [ " top_p " ]
_ = gen_conf . pop ( " top_p " )
2024-08-15 14:54:49 +08:00
for item in history :
2024-10-08 18:27:04 +08:00
if not isinstance ( item [ " content " ] , list ) and not isinstance ( item [ " content " ] , tuple ) :
item [ " content " ] = [ { " text " : item [ " content " ] } ]
2024-07-08 09:37:34 +08:00
if self . model_name . split ( ' . ' ) [ 0 ] == ' ai21 ' :
try :
response = self . client . converse (
modelId = self . model_name ,
messages = history ,
2024-08-23 06:44:37 +03:00
inferenceConfig = gen_conf ,
2024-10-10 09:13:35 +08:00
system = [ { " text " : ( system if system else " Answer the user ' s message. " ) } ]
2024-07-08 09:37:34 +08:00
)
ans = response [ " output " ] [ " message " ] [ " content " ] [ 0 ] [ " text " ]
return ans , num_tokens_from_string ( ans )
except ( ClientError , Exception ) as e :
return f " ERROR: Can ' t invoke ' { self . model_name } ' . Reason: { e } " , 0
ans = " "
try :
# Send the message to the model, using a basic inference configuration.
streaming_response = self . client . converse_stream (
modelId = self . model_name ,
messages = history ,
2024-10-05 04:44:50 -04:00
inferenceConfig = gen_conf ,
2024-10-10 09:13:35 +08:00
system = [ { " text " : ( system if system else " Answer the user ' s message. " ) } ]
2024-07-08 09:37:34 +08:00
)
# Extract and print the streamed response text in real-time.
for resp in streaming_response [ " stream " ] :
if " contentBlockDelta " in resp :
ans + = resp [ " contentBlockDelta " ] [ " delta " ] [ " text " ]
yield ans
2024-10-08 18:27:04 +08:00
2024-07-08 09:37:34 +08:00
except ( ClientError , Exception ) as e :
yield ans + f " ERROR: Can ' t invoke ' { self . model_name } ' . Reason: { e } "
yield num_tokens_from_string ( ans )
2024-07-11 15:41:00 +08:00
2024-09-20 15:33:38 +08:00
2024-07-11 15:41:00 +08:00
class GeminiChat ( Base ) :
2024-10-08 18:27:04 +08:00
def __init__ ( self , key , model_name , base_url = None ) :
from google . generativeai import client , GenerativeModel
2024-07-11 15:41:00 +08:00
client . configure ( api_key = key )
_client = client . get_default_generative_client ( )
self . model_name = ' models/ ' + model_name
self . model = GenerativeModel ( model_name = self . model_name )
self . model . _client = _client
2024-10-08 18:27:04 +08:00
def chat ( self , system , history , gen_conf ) :
2024-09-02 12:06:41 +08:00
from google . generativeai . types import content_types
2024-10-08 18:27:04 +08:00
2024-07-11 15:41:00 +08:00
if system :
2024-09-02 12:06:41 +08:00
self . model . _system_instruction = content_types . to_content ( system )
2024-10-08 18:27:04 +08:00
2024-07-11 15:41:00 +08:00
if ' max_tokens ' in gen_conf :
gen_conf [ ' max_output_tokens ' ] = gen_conf [ ' max_tokens ' ]
for k in list ( gen_conf . keys ( ) ) :
if k not in [ " temperature " , " top_p " , " max_output_tokens " ] :
del gen_conf [ k ]
for item in history :
if ' role ' in item and item [ ' role ' ] == ' assistant ' :
item [ ' role ' ] = ' model '
2024-10-08 18:27:04 +08:00
if ' role ' in item and item [ ' role ' ] == ' system ' :
item [ ' role ' ] = ' user '
if ' content ' in item :
2024-07-11 15:41:00 +08:00
item [ ' parts ' ] = item . pop ( ' content ' )
2024-10-08 18:27:04 +08:00
2024-07-11 15:41:00 +08:00
try :
response = self . model . generate_content (
history ,
generation_config = gen_conf )
ans = response . text
return ans , response . usage_metadata . total_token_count
except Exception as e :
return " **ERROR**: " + str ( e ) , 0
def chat_streamly ( self , system , history , gen_conf ) :
2024-09-02 12:06:41 +08:00
from google . generativeai . types import content_types
2024-10-08 18:27:04 +08:00
2024-07-11 15:41:00 +08:00
if system :
2024-09-02 12:06:41 +08:00
self . model . _system_instruction = content_types . to_content ( system )
2024-07-11 15:41:00 +08:00
if ' max_tokens ' in gen_conf :
gen_conf [ ' max_output_tokens ' ] = gen_conf [ ' max_tokens ' ]
for k in list ( gen_conf . keys ( ) ) :
if k not in [ " temperature " , " top_p " , " max_output_tokens " ] :
del gen_conf [ k ]
for item in history :
if ' role ' in item and item [ ' role ' ] == ' assistant ' :
item [ ' role ' ] = ' model '
2024-10-08 18:27:04 +08:00
if ' content ' in item :
2024-07-11 15:41:00 +08:00
item [ ' parts ' ] = item . pop ( ' content ' )
ans = " "
try :
response = self . model . generate_content (
history ,
2024-10-08 18:27:04 +08:00
generation_config = gen_conf , stream = True )
2024-07-11 15:41:00 +08:00
for resp in response :
ans + = resp . text
yield ans
2024-10-25 10:50:44 +08:00
yield response . _chunks [ - 1 ] . usage_metadata . total_token_count
2024-07-11 15:41:00 +08:00
except Exception as e :
yield ans + " \n **ERROR**: " + str ( e )
2024-10-25 10:50:44 +08:00
yield 0
added SVG for Groq model model providers (#1470)
#1432 #1447
This PR adds support for the GROQ LLM (Large Language Model).
Groq is an AI solutions company delivering ultra-low latency inference
with the first-ever LPU™ Inference Engine. The Groq API enables
developers to integrate state-of-the-art LLMs, such as Llama-2 and
llama3-70b-8192, into low latency applications with the request limits
specified below. Learn more at [groq.com](https://groq.com/).
Supported Models
| ID | Requests per Minute | Requests per Day | Tokens per Minute |
|----------------------|---------------------|------------------|-------------------|
| gemma-7b-it | 30 | 14,400 | 15,000 |
| gemma2-9b-it | 30 | 14,400 | 15,000 |
| llama3-70b-8192 | 30 | 14,400 | 6,000 |
| llama3-8b-8192 | 30 | 14,400 | 30,000 |
| mixtral-8x7b-32768 | 30 | 14,400 | 5,000 |
---------
Co-authored-by: paresh0628 <paresh.tuvoc@gmail.com>
Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
2024-07-12 06:55:44 +05:30
class GroqChat :
2024-10-08 18:27:04 +08:00
def __init__ ( self , key , model_name , base_url = ' ' ) :
added SVG for Groq model model providers (#1470)
#1432 #1447
This PR adds support for the GROQ LLM (Large Language Model).
Groq is an AI solutions company delivering ultra-low latency inference
with the first-ever LPU™ Inference Engine. The Groq API enables
developers to integrate state-of-the-art LLMs, such as Llama-2 and
llama3-70b-8192, into low latency applications with the request limits
specified below. Learn more at [groq.com](https://groq.com/).
Supported Models
| ID | Requests per Minute | Requests per Day | Tokens per Minute |
|----------------------|---------------------|------------------|-------------------|
| gemma-7b-it | 30 | 14,400 | 15,000 |
| gemma2-9b-it | 30 | 14,400 | 15,000 |
| llama3-70b-8192 | 30 | 14,400 | 6,000 |
| llama3-8b-8192 | 30 | 14,400 | 30,000 |
| mixtral-8x7b-32768 | 30 | 14,400 | 5,000 |
---------
Co-authored-by: paresh0628 <paresh.tuvoc@gmail.com>
Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
2024-07-12 06:55:44 +05:30
self . client = Groq ( api_key = key )
self . model_name = model_name
def chat ( self , system , history , gen_conf ) :
if system :
history . insert ( 0 , { " role " : " system " , " content " : system } )
for k in list ( gen_conf . keys ( ) ) :
if k not in [ " temperature " , " top_p " , " max_tokens " ] :
del gen_conf [ k ]
ans = " "
try :
response = self . client . chat . completions . create (
model = self . model_name ,
messages = history ,
* * gen_conf
)
ans = response . choices [ 0 ] . message . content
if response . choices [ 0 ] . finish_reason == " length " :
2024-12-04 09:34:49 +08:00
if is_chinese ( ans ) :
ans + = LENGTH_NOTIFICATION_CN
else :
ans + = LENGTH_NOTIFICATION_EN
added SVG for Groq model model providers (#1470)
#1432 #1447
This PR adds support for the GROQ LLM (Large Language Model).
Groq is an AI solutions company delivering ultra-low latency inference
with the first-ever LPU™ Inference Engine. The Groq API enables
developers to integrate state-of-the-art LLMs, such as Llama-2 and
llama3-70b-8192, into low latency applications with the request limits
specified below. Learn more at [groq.com](https://groq.com/).
Supported Models
| ID | Requests per Minute | Requests per Day | Tokens per Minute |
|----------------------|---------------------|------------------|-------------------|
| gemma-7b-it | 30 | 14,400 | 15,000 |
| gemma2-9b-it | 30 | 14,400 | 15,000 |
| llama3-70b-8192 | 30 | 14,400 | 6,000 |
| llama3-8b-8192 | 30 | 14,400 | 30,000 |
| mixtral-8x7b-32768 | 30 | 14,400 | 5,000 |
---------
Co-authored-by: paresh0628 <paresh.tuvoc@gmail.com>
Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
2024-07-12 06:55:44 +05:30
return ans , response . usage . total_tokens
except Exception as e :
return ans + " \n **ERROR**: " + str ( e ) , 0
def chat_streamly ( self , system , history , gen_conf ) :
if system :
history . insert ( 0 , { " role " : " system " , " content " : system } )
for k in list ( gen_conf . keys ( ) ) :
if k not in [ " temperature " , " top_p " , " max_tokens " ] :
del gen_conf [ k ]
ans = " "
total_tokens = 0
try :
response = self . client . chat . completions . create (
model = self . model_name ,
messages = history ,
stream = True ,
* * gen_conf
)
for resp in response :
if not resp . choices or not resp . choices [ 0 ] . delta . content :
continue
ans + = resp . choices [ 0 ] . delta . content
total_tokens + = 1
if resp . choices [ 0 ] . finish_reason == " length " :
2024-12-04 09:34:49 +08:00
if is_chinese ( ans ) :
ans + = LENGTH_NOTIFICATION_CN
else :
ans + = LENGTH_NOTIFICATION_EN
added SVG for Groq model model providers (#1470)
#1432 #1447
This PR adds support for the GROQ LLM (Large Language Model).
Groq is an AI solutions company delivering ultra-low latency inference
with the first-ever LPU™ Inference Engine. The Groq API enables
developers to integrate state-of-the-art LLMs, such as Llama-2 and
llama3-70b-8192, into low latency applications with the request limits
specified below. Learn more at [groq.com](https://groq.com/).
Supported Models
| ID | Requests per Minute | Requests per Day | Tokens per Minute |
|----------------------|---------------------|------------------|-------------------|
| gemma-7b-it | 30 | 14,400 | 15,000 |
| gemma2-9b-it | 30 | 14,400 | 15,000 |
| llama3-70b-8192 | 30 | 14,400 | 6,000 |
| llama3-8b-8192 | 30 | 14,400 | 30,000 |
| mixtral-8x7b-32768 | 30 | 14,400 | 5,000 |
---------
Co-authored-by: paresh0628 <paresh.tuvoc@gmail.com>
Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
2024-07-12 06:55:44 +05:30
yield ans
except Exception as e :
yield ans + " \n **ERROR**: " + str ( e )
2024-07-16 15:19:43 +08:00
yield total_tokens
## openrouter
class OpenRouterChat ( Base ) :
def __init__ ( self , key , model_name , base_url = " https://openrouter.ai/api/v1 " ) :
2024-07-25 10:23:35 +08:00
if not base_url :
base_url = " https://openrouter.ai/api/v1 "
super ( ) . __init__ ( key , model_name , base_url )
2024-07-19 16:26:12 +08:00
class StepFunChat ( Base ) :
2024-07-24 10:49:37 +08:00
def __init__ ( self , key , model_name , base_url = " https://api.stepfun.com/v1 " ) :
2024-07-19 16:26:12 +08:00
if not base_url :
2024-07-24 10:49:37 +08:00
base_url = " https://api.stepfun.com/v1 "
2024-07-23 10:43:09 +08:00
super ( ) . __init__ ( key , model_name , base_url )
class NvidiaChat ( Base ) :
2024-07-25 10:23:35 +08:00
def __init__ ( self , key , model_name , base_url = " https://integrate.api.nvidia.com/v1 " ) :
2024-07-23 10:43:09 +08:00
if not base_url :
2024-07-25 10:23:35 +08:00
base_url = " https://integrate.api.nvidia.com/v1 "
super ( ) . __init__ ( key , model_name , base_url )
2024-07-24 12:46:43 +08:00
class LmStudioChat ( Base ) :
def __init__ ( self , key , model_name , base_url ) :
if not base_url :
raise ValueError ( " Local llm url cannot be None " )
if base_url . split ( " / " ) [ - 1 ] != " v1 " :
2024-08-06 16:20:21 +08:00
base_url = os . path . join ( base_url , " v1 " )
self . client = OpenAI ( api_key = " lm-studio " , base_url = base_url )
2024-07-24 12:46:43 +08:00
self . model_name = model_name
2024-08-06 16:20:21 +08:00
class OpenAI_APIChat ( Base ) :
def __init__ ( self , key , model_name , base_url ) :
if not base_url :
raise ValueError ( " url cannot be None " )
if base_url . split ( " / " ) [ - 1 ] != " v1 " :
base_url = os . path . join ( base_url , " v1 " )
model_name = model_name . split ( " ___ " ) [ 0 ]
super ( ) . __init__ ( key , model_name , base_url )
2024-08-07 18:40:51 +08:00
class CoHereChat ( Base ) :
def __init__ ( self , key , model_name , base_url = " " ) :
from cohere import Client
self . client = Client ( api_key = key )
self . model_name = model_name
def chat ( self , system , history , gen_conf ) :
if system :
history . insert ( 0 , { " role " : " system " , " content " : system } )
if " top_p " in gen_conf :
gen_conf [ " p " ] = gen_conf . pop ( " top_p " )
if " frequency_penalty " in gen_conf and " presence_penalty " in gen_conf :
gen_conf . pop ( " presence_penalty " )
for item in history :
if " role " in item and item [ " role " ] == " user " :
item [ " role " ] = " USER "
if " role " in item and item [ " role " ] == " assistant " :
item [ " role " ] = " CHATBOT "
if " content " in item :
item [ " message " ] = item . pop ( " content " )
mes = history . pop ( ) [ " message " ]
ans = " "
try :
response = self . client . chat (
model = self . model_name , chat_history = history , message = mes , * * gen_conf
)
ans = response . text
if response . finish_reason == " MAX_TOKENS " :
ans + = (
" ... \n For the content length reason, it stopped, continue? "
if is_english ( [ ans ] )
else " ······ \n 由于长度的原因,回答被截断了,要继续吗? "
)
return (
ans ,
response . meta . tokens . input_tokens + response . meta . tokens . output_tokens ,
)
except Exception as e :
return ans + " \n **ERROR**: " + str ( e ) , 0
def chat_streamly ( self , system , history , gen_conf ) :
if system :
history . insert ( 0 , { " role " : " system " , " content " : system } )
if " top_p " in gen_conf :
gen_conf [ " p " ] = gen_conf . pop ( " top_p " )
if " frequency_penalty " in gen_conf and " presence_penalty " in gen_conf :
gen_conf . pop ( " presence_penalty " )
for item in history :
if " role " in item and item [ " role " ] == " user " :
item [ " role " ] = " USER "
if " role " in item and item [ " role " ] == " assistant " :
item [ " role " ] = " CHATBOT "
if " content " in item :
item [ " message " ] = item . pop ( " content " )
mes = history . pop ( ) [ " message " ]
ans = " "
total_tokens = 0
try :
response = self . client . chat_stream (
model = self . model_name , chat_history = history , message = mes , * * gen_conf
)
for resp in response :
if resp . event_type == " text-generation " :
ans + = resp . text
total_tokens + = num_tokens_from_string ( resp . text )
elif resp . event_type == " stream-end " :
if resp . finish_reason == " MAX_TOKENS " :
ans + = (
" ... \n For the content length reason, it stopped, continue? "
if is_english ( [ ans ] )
else " ······ \n 由于长度的原因,回答被截断了,要继续吗? "
)
yield ans
except Exception as e :
yield ans + " \n **ERROR**: " + str ( e )
yield total_tokens
2024-08-08 12:09:50 +08:00
class LeptonAIChat ( Base ) :
def __init__ ( self , key , model_name , base_url = None ) :
if not base_url :
2024-10-08 18:27:04 +08:00
base_url = os . path . join ( " https:// " + model_name + " .lepton.run " , " api " , " v1 " )
2024-08-12 10:11:50 +08:00
super ( ) . __init__ ( key , model_name , base_url )
2024-08-12 10:15:21 +08:00
class TogetherAIChat ( Base ) :
def __init__ ( self , key , model_name , base_url = " https://api.together.xyz/v1 " ) :
if not base_url :
base_url = " https://api.together.xyz/v1 "
super ( ) . __init__ ( key , model_name , base_url )
2024-08-19 10:36:57 +08:00
2024-08-12 10:11:50 +08:00
class PerfXCloudChat ( Base ) :
def __init__ ( self , key , model_name , base_url = " https://cloud.perfxlab.cn/v1 " ) :
if not base_url :
base_url = " https://cloud.perfxlab.cn/v1 "
super ( ) . __init__ ( key , model_name , base_url )
2024-08-12 11:06:25 +08:00
class UpstageChat ( Base ) :
def __init__ ( self , key , model_name , base_url = " https://api.upstage.ai/v1/solar " ) :
if not base_url :
base_url = " https://api.upstage.ai/v1/solar "
super ( ) . __init__ ( key , model_name , base_url )
2024-08-12 17:26:26 +08:00
class NovitaAIChat ( Base ) :
def __init__ ( self , key , model_name , base_url = " https://api.novita.ai/v3/openai " ) :
if not base_url :
base_url = " https://api.novita.ai/v3/openai "
super ( ) . __init__ ( key , model_name , base_url )
2024-08-13 16:09:10 +08:00
class SILICONFLOWChat ( Base ) :
def __init__ ( self , key , model_name , base_url = " https://api.siliconflow.cn/v1 " ) :
if not base_url :
base_url = " https://api.siliconflow.cn/v1 "
2024-08-15 10:02:36 +08:00
super ( ) . __init__ ( key , model_name , base_url )
class YiChat ( Base ) :
2024-09-27 12:55:58 +08:00
def __init__ ( self , key , model_name , base_url = " https://api.lingyiwanwu.com/v1 " ) :
2024-08-15 10:02:36 +08:00
if not base_url :
2024-09-27 12:55:58 +08:00
base_url = " https://api.lingyiwanwu.com/v1 "
2024-08-19 10:36:57 +08:00
super ( ) . __init__ ( key , model_name , base_url )
class ReplicateChat ( Base ) :
def __init__ ( self , key , model_name , base_url = None ) :
from replicate . client import Client
self . model_name = model_name
self . client = Client ( api_token = key )
self . system = " "
def chat ( self , system , history , gen_conf ) :
if " max_tokens " in gen_conf :
gen_conf [ " max_new_tokens " ] = gen_conf . pop ( " max_tokens " )
if system :
self . system = system
prompt = " \n " . join (
[ item [ " role " ] + " : " + item [ " content " ] for item in history [ - 5 : ] ]
)
ans = " "
try :
response = self . client . run (
self . model_name ,
input = { " system_prompt " : self . system , " prompt " : prompt , * * gen_conf } ,
)
ans = " " . join ( response )
return ans , num_tokens_from_string ( ans )
except Exception as e :
return ans + " \n **ERROR**: " + str ( e ) , 0
def chat_streamly ( self , system , history , gen_conf ) :
if " max_tokens " in gen_conf :
gen_conf [ " max_new_tokens " ] = gen_conf . pop ( " max_tokens " )
if system :
self . system = system
prompt = " \n " . join (
[ item [ " role " ] + " : " + item [ " content " ] for item in history [ - 5 : ] ]
)
ans = " "
try :
response = self . client . run (
self . model_name ,
input = { " system_prompt " : self . system , " prompt " : prompt , * * gen_conf } ,
)
for resp in response :
ans + = resp
yield ans
except Exception as e :
yield ans + " \n **ERROR**: " + str ( e )
yield num_tokens_from_string ( ans )
2024-08-20 15:27:13 +08:00
class HunyuanChat ( Base ) :
def __init__ ( self , key , model_name , base_url = None ) :
from tencentcloud . common import credential
from tencentcloud . hunyuan . v20230901 import hunyuan_client
key = json . loads ( key )
sid = key . get ( " hunyuan_sid " , " " )
sk = key . get ( " hunyuan_sk " , " " )
cred = credential . Credential ( sid , sk )
self . model_name = model_name
self . client = hunyuan_client . HunyuanClient ( cred , " " )
def chat ( self , system , history , gen_conf ) :
from tencentcloud . hunyuan . v20230901 import models
from tencentcloud . common . exception . tencent_cloud_sdk_exception import (
TencentCloudSDKException ,
)
_gen_conf = { }
2024-10-08 18:27:04 +08:00
_history = [ { k . capitalize ( ) : v for k , v in item . items ( ) } for item in history ]
2024-08-20 15:27:13 +08:00
if system :
_history . insert ( 0 , { " Role " : " system " , " Content " : system } )
if " temperature " in gen_conf :
_gen_conf [ " Temperature " ] = gen_conf [ " temperature " ]
if " top_p " in gen_conf :
_gen_conf [ " TopP " ] = gen_conf [ " top_p " ]
req = models . ChatCompletionsRequest ( )
params = { " Model " : self . model_name , " Messages " : _history , * * _gen_conf }
req . from_json_string ( json . dumps ( params ) )
ans = " "
try :
response = self . client . ChatCompletions ( req )
ans = response . Choices [ 0 ] . Message . Content
return ans , response . Usage . TotalTokens
except TencentCloudSDKException as e :
return ans + " \n **ERROR**: " + str ( e ) , 0
def chat_streamly ( self , system , history , gen_conf ) :
from tencentcloud . hunyuan . v20230901 import models
from tencentcloud . common . exception . tencent_cloud_sdk_exception import (
TencentCloudSDKException ,
)
2024-08-20 16:56:42 +08:00
2024-08-20 15:27:13 +08:00
_gen_conf = { }
2024-10-08 18:27:04 +08:00
_history = [ { k . capitalize ( ) : v for k , v in item . items ( ) } for item in history ]
2024-08-20 15:27:13 +08:00
if system :
_history . insert ( 0 , { " Role " : " system " , " Content " : system } )
2024-08-20 16:56:42 +08:00
2024-08-20 15:27:13 +08:00
if " temperature " in gen_conf :
_gen_conf [ " Temperature " ] = gen_conf [ " temperature " ]
if " top_p " in gen_conf :
_gen_conf [ " TopP " ] = gen_conf [ " top_p " ]
req = models . ChatCompletionsRequest ( )
params = {
" Model " : self . model_name ,
" Messages " : _history ,
" Stream " : True ,
* * _gen_conf ,
}
req . from_json_string ( json . dumps ( params ) )
ans = " "
total_tokens = 0
try :
response = self . client . ChatCompletions ( req )
for resp in response :
resp = json . loads ( resp [ " data " ] )
if not resp [ " Choices " ] or not resp [ " Choices " ] [ 0 ] [ " Delta " ] [ " Content " ] :
continue
ans + = resp [ " Choices " ] [ 0 ] [ " Delta " ] [ " Content " ]
total_tokens + = 1
yield ans
except TencentCloudSDKException as e :
yield ans + " \n **ERROR**: " + str ( e )
yield total_tokens
2024-08-20 16:56:42 +08:00
class SparkChat ( Base ) :
def __init__ (
2024-10-08 18:27:04 +08:00
self , key , model_name , base_url = " https://spark-api-open.xf-yun.com/v1 "
2024-08-20 16:56:42 +08:00
) :
if not base_url :
base_url = " https://spark-api-open.xf-yun.com/v1 "
model2version = {
" Spark-Max " : " generalv3.5 " ,
" Spark-Lite " : " general " ,
" Spark-Pro " : " generalv3 " ,
" Spark-Pro-128K " : " pro-128k " ,
" Spark-4.0-Ultra " : " 4.0Ultra " ,
}
2024-11-20 12:16:36 +08:00
version2model = { v : k for k , v in model2version . items ( ) }
assert model_name in model2version or model_name in version2model , f " The given model name is not supported yet. Support: { list ( model2version . keys ( ) ) } "
if model_name in model2version :
model_version = model2version [ model_name ]
else : model_version = model_name
2024-08-20 16:56:42 +08:00
super ( ) . __init__ ( key , model_version , base_url )
2024-08-22 16:45:15 +08:00
class BaiduYiyanChat ( Base ) :
def __init__ ( self , key , model_name , base_url = None ) :
import qianfan
2024-08-29 13:30:06 +08:00
2024-08-22 16:45:15 +08:00
key = json . loads ( key )
2024-10-08 18:27:04 +08:00
ak = key . get ( " yiyan_ak " , " " )
sk = key . get ( " yiyan_sk " , " " )
self . client = qianfan . ChatCompletion ( ak = ak , sk = sk )
2024-08-22 16:45:15 +08:00
self . model_name = model_name . lower ( )
self . system = " "
def chat ( self , system , history , gen_conf ) :
if system :
self . system = system
gen_conf [ " penalty_score " ] = (
2024-10-08 18:27:04 +08:00
( gen_conf . get ( " presence_penalty " , 0 ) + gen_conf . get ( " frequency_penalty " ,
0 ) ) / 2
) + 1
2024-08-22 16:45:15 +08:00
if " max_tokens " in gen_conf :
gen_conf [ " max_output_tokens " ] = gen_conf [ " max_tokens " ]
ans = " "
2024-08-29 13:30:06 +08:00
2024-08-22 16:45:15 +08:00
try :
response = self . client . do (
2024-10-08 18:27:04 +08:00
model = self . model_name ,
messages = history ,
2024-08-22 16:45:15 +08:00
system = self . system ,
* * gen_conf
) . body
ans = response [ ' result ' ]
return ans , response [ " usage " ] [ " total_tokens " ]
2024-08-29 13:30:06 +08:00
2024-08-22 16:45:15 +08:00
except Exception as e :
return ans + " \n **ERROR**: " + str ( e ) , 0
def chat_streamly ( self , system , history , gen_conf ) :
if system :
self . system = system
gen_conf [ " penalty_score " ] = (
2024-10-08 18:27:04 +08:00
( gen_conf . get ( " presence_penalty " , 0 ) + gen_conf . get ( " frequency_penalty " ,
0 ) ) / 2
) + 1
2024-08-22 16:45:15 +08:00
if " max_tokens " in gen_conf :
gen_conf [ " max_output_tokens " ] = gen_conf [ " max_tokens " ]
ans = " "
total_tokens = 0
2024-08-29 13:30:06 +08:00
2024-08-22 16:45:15 +08:00
try :
response = self . client . do (
2024-10-08 18:27:04 +08:00
model = self . model_name ,
messages = history ,
2024-08-22 16:45:15 +08:00
system = self . system ,
stream = True ,
* * gen_conf
)
for resp in response :
resp = resp . body
ans + = resp [ ' result ' ]
total_tokens = resp [ " usage " ] [ " total_tokens " ]
yield ans
except Exception as e :
return ans + " \n **ERROR**: " + str ( e ) , 0
yield total_tokens
2024-08-29 13:30:06 +08:00
class AnthropicChat ( Base ) :
def __init__ ( self , key , model_name , base_url = None ) :
import anthropic
self . client = anthropic . Anthropic ( api_key = key )
self . model_name = model_name
self . system = " "
def chat ( self , system , history , gen_conf ) :
if system :
self . system = system
if " max_tokens " not in gen_conf :
gen_conf [ " max_tokens " ] = 4096
2024-11-11 11:54:14 +08:00
if " presence_penalty " in gen_conf : del gen_conf [ " presence_penalty " ]
if " frequency_penalty " in gen_conf : del gen_conf [ " frequency_penalty " ]
2024-08-29 13:30:06 +08:00
2024-10-29 10:42:45 +08:00
ans = " "
2024-08-29 13:30:06 +08:00
try :
response = self . client . messages . create (
model = self . model_name ,
messages = history ,
system = self . system ,
stream = False ,
* * gen_conf ,
2024-11-13 16:13:52 +08:00
) . to_dict ( )
2024-08-29 13:30:06 +08:00
ans = response [ " content " ] [ 0 ] [ " text " ]
if response [ " stop_reason " ] == " max_tokens " :
ans + = (
" ... \n For the content length reason, it stopped, continue? "
if is_english ( [ ans ] )
else " ······ \n 由于长度的原因,回答被截断了,要继续吗? "
)
return (
ans ,
response [ " usage " ] [ " input_tokens " ] + response [ " usage " ] [ " output_tokens " ] ,
)
except Exception as e :
return ans + " \n **ERROR**: " + str ( e ) , 0
def chat_streamly ( self , system , history , gen_conf ) :
if system :
self . system = system
if " max_tokens " not in gen_conf :
gen_conf [ " max_tokens " ] = 4096
2024-11-11 11:54:14 +08:00
if " presence_penalty " in gen_conf : del gen_conf [ " presence_penalty " ]
if " frequency_penalty " in gen_conf : del gen_conf [ " frequency_penalty " ]
2024-08-29 13:30:06 +08:00
ans = " "
total_tokens = 0
try :
response = self . client . messages . create (
model = self . model_name ,
messages = history ,
system = self . system ,
stream = True ,
* * gen_conf ,
)
for res in response . iter_lines ( ) :
2024-11-11 11:54:14 +08:00
if res . type == ' content_block_delta ' :
text = res . delta . text
2024-08-29 13:30:06 +08:00
ans + = text
total_tokens + = num_tokens_from_string ( text )
2024-11-11 11:54:14 +08:00
yield ans
2024-08-29 13:30:06 +08:00
except Exception as e :
yield ans + " \n **ERROR**: " + str ( e )
yield total_tokens
2024-09-02 12:06:41 +08:00
class GoogleChat ( Base ) :
def __init__ ( self , key , model_name , base_url = None ) :
from google . oauth2 import service_account
import base64
key = json . load ( key )
access_token = json . loads (
base64 . b64decode ( key . get ( " google_service_account_key " , " " ) )
)
project_id = key . get ( " google_project_id " , " " )
region = key . get ( " google_region " , " " )
scopes = [ " https://www.googleapis.com/auth/cloud-platform " ]
self . model_name = model_name
self . system = " "
if " claude " in self . model_name :
from anthropic import AnthropicVertex
from google . auth . transport . requests import Request
if access_token :
credits = service_account . Credentials . from_service_account_info (
access_token , scopes = scopes
)
request = Request ( )
credits . refresh ( request )
token = credits . token
self . client = AnthropicVertex (
region = region , project_id = project_id , access_token = token
)
else :
self . client = AnthropicVertex ( region = region , project_id = project_id )
else :
from google . cloud import aiplatform
import vertexai . generative_models as glm
if access_token :
credits = service_account . Credentials . from_service_account_info (
access_token
)
aiplatform . init (
credentials = credits , project = project_id , location = region
)
else :
aiplatform . init ( project = project_id , location = region )
self . client = glm . GenerativeModel ( model_name = self . model_name )
def chat ( self , system , history , gen_conf ) :
if system :
self . system = system
if " claude " in self . model_name :
if " max_tokens " not in gen_conf :
gen_conf [ " max_tokens " ] = 4096
try :
response = self . client . messages . create (
model = self . model_name ,
messages = history ,
system = self . system ,
stream = False ,
* * gen_conf ,
) . json ( )
ans = response [ " content " ] [ 0 ] [ " text " ]
if response [ " stop_reason " ] == " max_tokens " :
ans + = (
" ... \n For the content length reason, it stopped, continue? "
if is_english ( [ ans ] )
else " ······ \n 由于长度的原因,回答被截断了,要继续吗? "
)
return (
ans ,
response [ " usage " ] [ " input_tokens " ]
+ response [ " usage " ] [ " output_tokens " ] ,
)
except Exception as e :
2024-09-30 17:54:27 +08:00
return " \n **ERROR**: " + str ( e ) , 0
2024-09-02 12:06:41 +08:00
else :
self . client . _system_instruction = self . system
if " max_tokens " in gen_conf :
gen_conf [ " max_output_tokens " ] = gen_conf [ " max_tokens " ]
for k in list ( gen_conf . keys ( ) ) :
if k not in [ " temperature " , " top_p " , " max_output_tokens " ] :
del gen_conf [ k ]
for item in history :
if " role " in item and item [ " role " ] == " assistant " :
item [ " role " ] = " model "
if " content " in item :
item [ " parts " ] = item . pop ( " content " )
try :
response = self . client . generate_content (
history , generation_config = gen_conf
)
ans = response . text
return ans , response . usage_metadata . total_token_count
except Exception as e :
return " **ERROR**: " + str ( e ) , 0
def chat_streamly ( self , system , history , gen_conf ) :
if system :
self . system = system
if " claude " in self . model_name :
if " max_tokens " not in gen_conf :
gen_conf [ " max_tokens " ] = 4096
ans = " "
total_tokens = 0
try :
response = self . client . messages . create (
model = self . model_name ,
messages = history ,
system = self . system ,
stream = True ,
* * gen_conf ,
)
for res in response . iter_lines ( ) :
res = res . decode ( " utf-8 " )
if " content_block_delta " in res and " data " in res :
text = json . loads ( res [ 6 : ] ) [ " delta " ] [ " text " ]
ans + = text
total_tokens + = num_tokens_from_string ( text )
except Exception as e :
yield ans + " \n **ERROR**: " + str ( e )
yield total_tokens
else :
self . client . _system_instruction = self . system
if " max_tokens " in gen_conf :
gen_conf [ " max_output_tokens " ] = gen_conf [ " max_tokens " ]
for k in list ( gen_conf . keys ( ) ) :
if k not in [ " temperature " , " top_p " , " max_output_tokens " ] :
del gen_conf [ k ]
for item in history :
if " role " in item and item [ " role " ] == " assistant " :
item [ " role " ] = " model "
if " content " in item :
item [ " parts " ] = item . pop ( " content " )
ans = " "
try :
response = self . model . generate_content (
history , generation_config = gen_conf , stream = True
)
for resp in response :
ans + = resp . text
yield ans
except Exception as e :
yield ans + " \n **ERROR**: " + str ( e )
yield response . _chunks [ - 1 ] . usage_metadata . total_token_count