2024-01-15 08:46:22 +08:00
#
2024-01-19 19:51:57 +08:00
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
2024-01-15 08:46:22 +08:00
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
2024-07-04 09:57:16 +08:00
from openai . lib . azure import AzureOpenAI
2024-03-27 11:33:46 +08:00
from zhipuai import ZhipuAI
from dashscope import Generation
2023-12-25 19:05:59 +08:00
from abc import ABC
2023-12-28 13:50:13 +08:00
from openai import OpenAI
2024-02-27 14:57:34 +08:00
import openai
2024-04-08 19:20:57 +08:00
from ollama import Client
2024-05-23 11:15:29 +08:00
from volcengine . maas . v2 import MaasService
2024-03-07 16:12:01 +08:00
from rag . nlp import is_english
2024-05-20 12:23:51 +08:00
from rag . utils import num_tokens_from_string
added SVG for Groq model model providers (#1470)
#1432 #1447
This PR adds support for the GROQ LLM (Large Language Model).
Groq is an AI solutions company delivering ultra-low latency inference
with the first-ever LPU™ Inference Engine. The Groq API enables
developers to integrate state-of-the-art LLMs, such as Llama-2 and
llama3-70b-8192, into low latency applications with the request limits
specified below. Learn more at [groq.com](https://groq.com/).
Supported Models
| ID | Requests per Minute | Requests per Day | Tokens per Minute |
|----------------------|---------------------|------------------|-------------------|
| gemma-7b-it | 30 | 14,400 | 15,000 |
| gemma2-9b-it | 30 | 14,400 | 15,000 |
| llama3-70b-8192 | 30 | 14,400 | 6,000 |
| llama3-8b-8192 | 30 | 14,400 | 30,000 |
| mixtral-8x7b-32768 | 30 | 14,400 | 5,000 |
---------
Co-authored-by: paresh0628 <paresh.tuvoc@gmail.com>
Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
2024-07-12 06:55:44 +05:30
from groq import Groq
2024-07-25 10:23:35 +08:00
import os
2024-07-17 15:32:51 +08:00
import json
import requests
2024-07-30 14:07:00 +08:00
import asyncio
2023-12-28 13:50:13 +08:00
2023-12-25 19:05:59 +08:00
class Base ( ABC ) :
2024-05-08 10:30:02 +08:00
def __init__ ( self , key , model_name , base_url ) :
2024-03-28 19:15:16 +08:00
self . client = OpenAI ( api_key = key , base_url = base_url )
2024-01-22 19:51:38 +08:00
self . model_name = model_name
2023-12-25 19:05:59 +08:00
def chat ( self , system , history , gen_conf ) :
2024-03-27 11:33:46 +08:00
if system :
history . insert ( 0 , { " role " : " system " , " content " : system } )
2024-02-27 14:57:34 +08:00
try :
2024-03-07 16:12:01 +08:00
response = self . client . chat . completions . create (
2024-02-27 14:57:34 +08:00
model = self . model_name ,
messages = history ,
* * gen_conf )
2024-03-27 17:55:45 +08:00
ans = response . choices [ 0 ] . message . content . strip ( )
if response . choices [ 0 ] . finish_reason == " length " :
2024-03-07 16:12:01 +08:00
ans + = " ... \n For the content length reason, it stopped, continue? " if is_english (
[ ans ] ) else " ······ \n 由于长度的原因,回答被截断了,要继续吗? "
2024-04-15 14:43:44 +08:00
return ans , response . usage . total_tokens
2024-02-27 14:57:34 +08:00
except openai . APIError as e :
2024-03-27 11:33:46 +08:00
return " **ERROR**: " + str ( e ) , 0
2023-12-25 19:05:59 +08:00
2024-05-16 20:14:53 +08:00
def chat_streamly ( self , system , history , gen_conf ) :
if system :
history . insert ( 0 , { " role " : " system " , " content " : system } )
ans = " "
total_tokens = 0
try :
response = self . client . chat . completions . create (
model = self . model_name ,
messages = history ,
stream = True ,
* * gen_conf )
for resp in response :
2024-07-25 10:23:35 +08:00
if not resp . choices : continue
2024-07-29 09:21:31 +08:00
if not resp . choices [ 0 ] . delta . content :
resp . choices [ 0 ] . delta . content = " "
2024-05-16 20:14:53 +08:00
ans + = resp . choices [ 0 ] . delta . content
2024-07-25 10:23:35 +08:00
total_tokens = (
(
total_tokens
+ num_tokens_from_string ( resp . choices [ 0 ] . delta . content )
)
2024-08-08 12:09:50 +08:00
if not hasattr ( resp , " usage " ) or not resp . usage
2024-08-09 10:18:08 +08:00
else resp . usage . get ( " total_tokens " , total_tokens )
2024-07-25 10:23:35 +08:00
)
2024-05-16 20:14:53 +08:00
if resp . choices [ 0 ] . finish_reason == " length " :
ans + = " ... \n For the content length reason, it stopped, continue? " if is_english (
[ ans ] ) else " ······ \n 由于长度的原因,回答被截断了,要继续吗? "
yield ans
except openai . APIError as e :
yield ans + " \n **ERROR**: " + str ( e )
yield total_tokens
2023-12-25 19:05:59 +08:00
2024-05-08 10:30:02 +08:00
class GptTurbo ( Base ) :
def __init__ ( self , key , model_name = " gpt-3.5-turbo " , base_url = " https://api.openai.com/v1 " ) :
if not base_url : base_url = " https://api.openai.com/v1 "
super ( ) . __init__ ( key , model_name , base_url )
class MoonshotChat ( Base ) :
2024-03-28 19:15:16 +08:00
def __init__ ( self , key , model_name = " moonshot-v1-8k " , base_url = " https://api.moonshot.cn/v1 " ) :
if not base_url : base_url = " https://api.moonshot.cn/v1 "
2024-05-08 10:30:02 +08:00
super ( ) . __init__ ( key , model_name , base_url )
2024-03-14 19:45:29 +08:00
2024-07-25 10:23:35 +08:00
2024-05-08 10:30:02 +08:00
class XinferenceChat ( Base ) :
def __init__ ( self , key = None , model_name = " " , base_url = " " ) :
2024-07-25 10:23:35 +08:00
if not base_url :
raise ValueError ( " Local llm url cannot be None " )
if base_url . split ( " / " ) [ - 1 ] != " v1 " :
2024-08-07 18:10:42 +08:00
base_url = os . path . join ( base_url , " v1 " )
2024-05-08 10:30:02 +08:00
key = " xxx "
super ( ) . __init__ ( key , model_name , base_url )
class DeepSeekChat ( Base ) :
def __init__ ( self , key , model_name = " deepseek-chat " , base_url = " https://api.deepseek.com/v1 " ) :
if not base_url : base_url = " https://api.deepseek.com/v1 "
super ( ) . __init__ ( key , model_name , base_url )
2024-03-15 18:59:00 +08:00
2024-07-04 15:57:25 +08:00
2024-07-04 09:57:16 +08:00
class AzureChat ( Base ) :
def __init__ ( self , key , model_name , * * kwargs ) :
self . client = AzureOpenAI ( api_key = key , azure_endpoint = kwargs [ " base_url " ] , api_version = " 2024-02-01 " )
self . model_name = model_name
2024-03-14 19:45:29 +08:00
2024-05-28 09:09:37 +08:00
class BaiChuanChat ( Base ) :
def __init__ ( self , key , model_name = " Baichuan3-Turbo " , base_url = " https://api.baichuan-ai.com/v1 " ) :
if not base_url :
base_url = " https://api.baichuan-ai.com/v1 "
super ( ) . __init__ ( key , model_name , base_url )
@staticmethod
def _format_params ( params ) :
return {
" temperature " : params . get ( " temperature " , 0.3 ) ,
" max_tokens " : params . get ( " max_tokens " , 2048 ) ,
" top_p " : params . get ( " top_p " , 0.85 ) ,
}
def chat ( self , system , history , gen_conf ) :
if system :
history . insert ( 0 , { " role " : " system " , " content " : system } )
try :
response = self . client . chat . completions . create (
model = self . model_name ,
messages = history ,
extra_body = {
" tools " : [ {
" type " : " web_search " ,
" web_search " : {
" enable " : True ,
" search_mode " : " performance_first "
}
} ]
} ,
* * self . _format_params ( gen_conf ) )
ans = response . choices [ 0 ] . message . content . strip ( )
if response . choices [ 0 ] . finish_reason == " length " :
ans + = " ... \n For the content length reason, it stopped, continue? " if is_english (
[ ans ] ) else " ······ \n 由于长度的原因,回答被截断了,要继续吗? "
return ans , response . usage . total_tokens
except openai . APIError as e :
return " **ERROR**: " + str ( e ) , 0
def chat_streamly ( self , system , history , gen_conf ) :
if system :
history . insert ( 0 , { " role " : " system " , " content " : system } )
ans = " "
total_tokens = 0
try :
response = self . client . chat . completions . create (
model = self . model_name ,
messages = history ,
extra_body = {
" tools " : [ {
" type " : " web_search " ,
" web_search " : {
" enable " : True ,
" search_mode " : " performance_first "
}
} ]
} ,
stream = True ,
* * self . _format_params ( gen_conf ) )
for resp in response :
2024-07-30 16:55:59 +08:00
if not resp . choices : continue
2024-05-28 09:09:37 +08:00
if not resp . choices [ 0 ] . delta . content :
2024-07-30 16:55:59 +08:00
resp . choices [ 0 ] . delta . content = " "
2024-05-28 09:09:37 +08:00
ans + = resp . choices [ 0 ] . delta . content
2024-07-30 16:55:59 +08:00
total_tokens = (
(
total_tokens
+ num_tokens_from_string ( resp . choices [ 0 ] . delta . content )
)
if not hasattr ( resp , " usage " )
else resp . usage [ " total_tokens " ]
)
2024-05-28 09:09:37 +08:00
if resp . choices [ 0 ] . finish_reason == " length " :
ans + = " ... \n For the content length reason, it stopped, continue? " if is_english (
[ ans ] ) else " ······ \n 由于长度的原因,回答被截断了,要继续吗? "
yield ans
except Exception as e :
yield ans + " \n **ERROR**: " + str ( e )
yield total_tokens
2023-12-28 13:50:13 +08:00
class QWenChat ( Base ) :
2024-03-28 19:15:16 +08:00
def __init__ ( self , key , model_name = Generation . Models . qwen_turbo , * * kwargs ) :
2024-01-22 19:51:38 +08:00
import dashscope
dashscope . api_key = key
self . model_name = model_name
2023-12-25 19:05:59 +08:00
def chat ( self , system , history , gen_conf ) :
from http import HTTPStatus
2024-03-27 11:33:46 +08:00
if system :
history . insert ( 0 , { " role " : " system " , " content " : system } )
2023-12-25 19:05:59 +08:00
response = Generation . call (
2024-01-22 19:51:38 +08:00
self . model_name ,
2023-12-28 13:50:13 +08:00
messages = history ,
2024-02-19 19:22:17 +08:00
result_format = ' message ' ,
* * gen_conf
2023-12-25 19:05:59 +08:00
)
2024-03-07 16:12:01 +08:00
ans = " "
tk_count = 0
2023-12-25 19:05:59 +08:00
if response . status_code == HTTPStatus . OK :
2024-03-07 16:12:01 +08:00
ans + = response . output . choices [ 0 ] [ ' message ' ] [ ' content ' ]
2024-04-11 10:13:43 +08:00
tk_count + = response . usage . total_tokens
2024-03-07 16:12:01 +08:00
if response . output . choices [ 0 ] . get ( " finish_reason " , " " ) == " length " :
2024-03-27 11:33:46 +08:00
ans + = " ... \n For the content length reason, it stopped, continue? " if is_english (
[ ans ] ) else " ······ \n 由于长度的原因,回答被截断了,要继续吗? "
2024-03-07 16:12:01 +08:00
return ans , tk_count
return " **ERROR**: " + response . message , tk_count
2024-02-08 17:01:01 +08:00
2024-05-16 20:14:53 +08:00
def chat_streamly ( self , system , history , gen_conf ) :
from http import HTTPStatus
if system :
history . insert ( 0 , { " role " : " system " , " content " : system } )
ans = " "
2024-05-30 16:18:15 +08:00
tk_count = 0
2024-05-16 20:14:53 +08:00
try :
response = Generation . call (
self . model_name ,
messages = history ,
result_format = ' message ' ,
stream = True ,
* * gen_conf
)
for resp in response :
if resp . status_code == HTTPStatus . OK :
ans = resp . output . choices [ 0 ] [ ' message ' ] [ ' content ' ]
tk_count = resp . usage . total_tokens
if resp . output . choices [ 0 ] . get ( " finish_reason " , " " ) == " length " :
ans + = " ... \n For the content length reason, it stopped, continue? " if is_english (
[ ans ] ) else " ······ \n 由于长度的原因,回答被截断了,要继续吗? "
yield ans
else :
yield ans + " \n **ERROR**: " + resp . message if str ( resp . message ) . find ( " Access " ) < 0 else " Out of credit. Please set the API key in **settings > Model providers.** "
except Exception as e :
yield ans + " \n **ERROR**: " + str ( e )
yield tk_count
2024-02-08 17:01:01 +08:00
class ZhipuChat ( Base ) :
2024-03-28 19:15:16 +08:00
def __init__ ( self , key , model_name = " glm-3-turbo " , * * kwargs ) :
2024-02-08 17:01:01 +08:00
self . client = ZhipuAI ( api_key = key )
self . model_name = model_name
def chat ( self , system , history , gen_conf ) :
2024-03-27 11:33:46 +08:00
if system :
history . insert ( 0 , { " role " : " system " , " content " : system } )
2024-03-07 16:12:01 +08:00
try :
2024-04-09 09:24:08 +08:00
if " presence_penalty " in gen_conf : del gen_conf [ " presence_penalty " ]
2024-04-09 16:16:10 +08:00
if " frequency_penalty " in gen_conf : del gen_conf [ " frequency_penalty " ]
2024-03-07 16:12:01 +08:00
response = self . client . chat . completions . create (
2024-03-27 17:55:45 +08:00
model = self . model_name ,
2024-03-07 16:12:01 +08:00
messages = history ,
* * gen_conf
)
2024-03-27 17:55:45 +08:00
ans = response . choices [ 0 ] . message . content . strip ( )
if response . choices [ 0 ] . finish_reason == " length " :
2024-03-07 16:12:01 +08:00
ans + = " ... \n For the content length reason, it stopped, continue? " if is_english (
[ ans ] ) else " ······ \n 由于长度的原因,回答被截断了,要继续吗? "
2024-04-15 14:43:44 +08:00
return ans , response . usage . total_tokens
2024-03-07 16:12:01 +08:00
except Exception as e :
2024-03-12 11:57:08 +08:00
return " **ERROR**: " + str ( e ) , 0
2024-05-16 20:14:53 +08:00
def chat_streamly ( self , system , history , gen_conf ) :
if system :
history . insert ( 0 , { " role " : " system " , " content " : system } )
if " presence_penalty " in gen_conf : del gen_conf [ " presence_penalty " ]
if " frequency_penalty " in gen_conf : del gen_conf [ " frequency_penalty " ]
ans = " "
2024-05-30 16:18:15 +08:00
tk_count = 0
2024-05-16 20:14:53 +08:00
try :
response = self . client . chat . completions . create (
model = self . model_name ,
messages = history ,
stream = True ,
* * gen_conf
)
for resp in response :
if not resp . choices [ 0 ] . delta . content : continue
delta = resp . choices [ 0 ] . delta . content
ans + = delta
2024-05-17 17:07:33 +08:00
if resp . choices [ 0 ] . finish_reason == " length " :
2024-05-16 20:14:53 +08:00
ans + = " ... \n For the content length reason, it stopped, continue? " if is_english (
[ ans ] ) else " ······ \n 由于长度的原因,回答被截断了,要继续吗? "
2024-05-17 17:07:33 +08:00
tk_count = resp . usage . total_tokens
if resp . choices [ 0 ] . finish_reason == " stop " : tk_count = resp . usage . total_tokens
2024-05-16 20:14:53 +08:00
yield ans
except Exception as e :
yield ans + " \n **ERROR**: " + str ( e )
yield tk_count
2024-03-27 11:33:46 +08:00
2024-04-08 19:20:57 +08:00
class OllamaChat ( Base ) :
def __init__ ( self , key , model_name , * * kwargs ) :
self . client = Client ( host = kwargs [ " base_url " ] )
self . model_name = model_name
def chat ( self , system , history , gen_conf ) :
if system :
history . insert ( 0 , { " role " : " system " , " content " : system } )
try :
2024-04-30 11:04:14 +08:00
options = { }
if " temperature " in gen_conf : options [ " temperature " ] = gen_conf [ " temperature " ]
if " max_tokens " in gen_conf : options [ " num_predict " ] = gen_conf [ " max_tokens " ]
if " top_p " in gen_conf : options [ " top_k " ] = gen_conf [ " top_p " ]
if " presence_penalty " in gen_conf : options [ " presence_penalty " ] = gen_conf [ " presence_penalty " ]
if " frequency_penalty " in gen_conf : options [ " frequency_penalty " ] = gen_conf [ " frequency_penalty " ]
2024-04-08 19:20:57 +08:00
response = self . client . chat (
model = self . model_name ,
messages = history ,
2024-05-30 11:27:58 +08:00
options = options ,
keep_alive = - 1
2024-04-08 19:20:57 +08:00
)
ans = response [ " message " ] [ " content " ] . strip ( )
2024-04-22 15:13:01 +08:00
return ans , response [ " eval_count " ] + response . get ( " prompt_eval_count " , 0 )
2024-04-08 19:20:57 +08:00
except Exception as e :
return " **ERROR**: " + str ( e ) , 0
2024-05-16 20:14:53 +08:00
def chat_streamly ( self , system , history , gen_conf ) :
if system :
history . insert ( 0 , { " role " : " system " , " content " : system } )
options = { }
if " temperature " in gen_conf : options [ " temperature " ] = gen_conf [ " temperature " ]
if " max_tokens " in gen_conf : options [ " num_predict " ] = gen_conf [ " max_tokens " ]
if " top_p " in gen_conf : options [ " top_k " ] = gen_conf [ " top_p " ]
if " presence_penalty " in gen_conf : options [ " presence_penalty " ] = gen_conf [ " presence_penalty " ]
if " frequency_penalty " in gen_conf : options [ " frequency_penalty " ] = gen_conf [ " frequency_penalty " ]
ans = " "
try :
response = self . client . chat (
model = self . model_name ,
messages = history ,
stream = True ,
2024-05-30 11:27:58 +08:00
options = options ,
keep_alive = - 1
2024-05-16 20:14:53 +08:00
)
for resp in response :
if resp [ " done " ] :
2024-05-17 12:07:00 +08:00
yield resp . get ( " prompt_eval_count " , 0 ) + resp . get ( " eval_count " , 0 )
ans + = resp [ " message " ] [ " content " ]
2024-05-16 20:14:53 +08:00
yield ans
except Exception as e :
yield ans + " \n **ERROR**: " + str ( e )
yield 0
2024-05-20 12:23:51 +08:00
2024-07-19 15:50:28 +08:00
class LocalAIChat ( Base ) :
def __init__ ( self , key , model_name , base_url ) :
2024-07-25 10:23:35 +08:00
if not base_url :
raise ValueError ( " Local llm url cannot be None " )
if base_url . split ( " / " ) [ - 1 ] != " v1 " :
2024-08-07 18:10:42 +08:00
base_url = os . path . join ( base_url , " v1 " )
self . client = OpenAI ( api_key = " empty " , base_url = base_url )
2024-07-19 15:50:28 +08:00
self . model_name = model_name . split ( " ___ " ) [ 0 ]
2024-05-20 12:23:51 +08:00
class LocalLLM ( Base ) :
class RPCProxy :
def __init__ ( self , host , port ) :
self . host = host
self . port = int ( port )
self . __conn ( )
def __conn ( self ) :
from multiprocessing . connection import Client
2024-07-30 14:07:00 +08:00
2024-05-20 12:23:51 +08:00
self . _connection = Client (
2024-07-30 14:07:00 +08:00
( self . host , self . port ) , authkey = b " infiniflow-token4kevinhu "
)
2024-05-20 12:23:51 +08:00
def __getattr__ ( self , name ) :
import pickle
def do_rpc ( * args , * * kwargs ) :
for _ in range ( 3 ) :
try :
2024-07-30 14:07:00 +08:00
self . _connection . send ( pickle . dumps ( ( name , args , kwargs ) ) )
2024-05-20 12:23:51 +08:00
return pickle . loads ( self . _connection . recv ( ) )
except Exception as e :
self . __conn ( )
raise Exception ( " RPC connection lost! " )
return do_rpc
2024-07-30 14:07:00 +08:00
def __init__ ( self , key , model_name ) :
from jina import Client
2024-05-20 12:23:51 +08:00
2024-07-30 14:07:00 +08:00
self . client = Client ( port = 12345 , protocol = " grpc " , asyncio = True )
2024-05-20 12:40:59 +08:00
2024-07-30 14:07:00 +08:00
def _prepare_prompt ( self , system , history , gen_conf ) :
2024-08-01 19:52:56 +08:00
from rag . svr . jina_server import Prompt , Generation
2024-05-20 12:40:59 +08:00
if system :
history . insert ( 0 , { " role " : " system " , " content " : system } )
2024-07-30 14:07:00 +08:00
if " max_tokens " in gen_conf :
gen_conf [ " max_new_tokens " ] = gen_conf . pop ( " max_tokens " )
return Prompt ( message = history , gen_conf = gen_conf )
def _stream_response ( self , endpoint , prompt ) :
2024-08-01 19:52:56 +08:00
from rag . svr . jina_server import Prompt , Generation
2024-05-20 12:40:59 +08:00
answer = " "
try :
2024-07-30 14:07:00 +08:00
res = self . client . stream_doc (
on = endpoint , inputs = prompt , return_type = Generation
)
loop = asyncio . get_event_loop ( )
try :
while True :
answer = loop . run_until_complete ( res . __anext__ ( ) ) . text
yield answer
except StopAsyncIteration :
pass
2024-05-20 12:40:59 +08:00
except Exception as e :
yield answer + " \n **ERROR**: " + str ( e )
2024-07-30 14:07:00 +08:00
yield num_tokens_from_string ( answer )
2024-05-20 12:40:59 +08:00
2024-07-30 14:07:00 +08:00
def chat ( self , system , history , gen_conf ) :
prompt = self . _prepare_prompt ( system , history , gen_conf )
chat_gen = self . _stream_response ( " /chat " , prompt )
ans = next ( chat_gen )
total_tokens = next ( chat_gen )
return ans , total_tokens
def chat_streamly ( self , system , history , gen_conf ) :
prompt = self . _prepare_prompt ( system , history , gen_conf )
return self . _stream_response ( " /stream " , prompt )
2024-05-23 11:15:29 +08:00
class VolcEngineChat ( Base ) :
def __init__ ( self , key , model_name , base_url ) :
"""
Since do not want to modify the original database fields , and the VolcEngine authentication method is quite special ,
Assemble ak , sk , ep_id into api_key , store it as a dictionary type , and parse it for use
model_name is for display only
"""
self . client = MaasService ( ' maas-api.ml-platform-cn-beijing.volces.com ' , ' cn-beijing ' )
self . volc_ak = eval ( key ) . get ( ' volc_ak ' , ' ' )
self . volc_sk = eval ( key ) . get ( ' volc_sk ' , ' ' )
self . client . set_ak ( self . volc_ak )
self . client . set_sk ( self . volc_sk )
self . model_name = eval ( key ) . get ( ' ep_id ' , ' ' )
def chat ( self , system , history , gen_conf ) :
if system :
history . insert ( 0 , { " role " : " system " , " content " : system } )
try :
req = {
" parameters " : {
" min_new_tokens " : gen_conf . get ( " min_new_tokens " , 1 ) ,
" top_k " : gen_conf . get ( " top_k " , 0 ) ,
" max_prompt_tokens " : gen_conf . get ( " max_prompt_tokens " , 30000 ) ,
" temperature " : gen_conf . get ( " temperature " , 0.1 ) ,
" max_new_tokens " : gen_conf . get ( " max_tokens " , 1000 ) ,
" top_p " : gen_conf . get ( " top_p " , 0.3 ) ,
} ,
" messages " : history
}
response = self . client . chat ( self . model_name , req )
ans = response . choices [ 0 ] . message . content . strip ( )
if response . choices [ 0 ] . finish_reason == " length " :
ans + = " ... \n For the content length reason, it stopped, continue? " if is_english (
[ ans ] ) else " ······ \n 由于长度的原因,回答被截断了,要继续吗? "
return ans , response . usage . total_tokens
except Exception as e :
return " **ERROR**: " + str ( e ) , 0
def chat_streamly ( self , system , history , gen_conf ) :
if system :
history . insert ( 0 , { " role " : " system " , " content " : system } )
ans = " "
2024-05-30 16:18:15 +08:00
tk_count = 0
2024-05-23 11:15:29 +08:00
try :
req = {
" parameters " : {
" min_new_tokens " : gen_conf . get ( " min_new_tokens " , 1 ) ,
" top_k " : gen_conf . get ( " top_k " , 0 ) ,
" max_prompt_tokens " : gen_conf . get ( " max_prompt_tokens " , 30000 ) ,
" temperature " : gen_conf . get ( " temperature " , 0.1 ) ,
" max_new_tokens " : gen_conf . get ( " max_tokens " , 1000 ) ,
" top_p " : gen_conf . get ( " top_p " , 0.3 ) ,
} ,
" messages " : history
}
stream = self . client . stream_chat ( self . model_name , req )
for resp in stream :
if not resp . choices [ 0 ] . message . content :
continue
ans + = resp . choices [ 0 ] . message . content
if resp . choices [ 0 ] . finish_reason == " stop " :
2024-05-24 11:34:39 +08:00
tk_count = resp . usage . total_tokens
yield ans
2024-05-27 11:01:20 +08:00
2024-05-23 11:15:29 +08:00
except Exception as e :
yield ans + " \n **ERROR**: " + str ( e )
2024-05-24 11:34:39 +08:00
yield tk_count
2024-05-31 16:38:53 +08:00
class MiniMaxChat ( Base ) :
2024-07-17 15:32:51 +08:00
def __init__ (
self ,
key ,
model_name ,
base_url = " https://api.minimax.chat/v1/text/chatcompletion_v2 " ,
) :
2024-05-31 16:38:53 +08:00
if not base_url :
2024-07-17 15:32:51 +08:00
base_url = " https://api.minimax.chat/v1/text/chatcompletion_v2 "
self . base_url = base_url
self . model_name = model_name
self . api_key = key
def chat ( self , system , history , gen_conf ) :
if system :
history . insert ( 0 , { " role " : " system " , " content " : system } )
for k in list ( gen_conf . keys ( ) ) :
if k not in [ " temperature " , " top_p " , " max_tokens " ] :
del gen_conf [ k ]
headers = {
" Authorization " : f " Bearer { self . api_key } " ,
" Content-Type " : " application/json " ,
}
payload = json . dumps (
{ " model " : self . model_name , " messages " : history , * * gen_conf }
)
try :
response = requests . request (
" POST " , url = self . base_url , headers = headers , data = payload
)
response = response . json ( )
ans = response [ " choices " ] [ 0 ] [ " message " ] [ " content " ] . strip ( )
if response [ " choices " ] [ 0 ] [ " finish_reason " ] == " length " :
ans + = " ... \n For the content length reason, it stopped, continue? " if is_english (
[ ans ] ) else " ······ \n 由于长度的原因,回答被截断了,要继续吗? "
return ans , response [ " usage " ] [ " total_tokens " ]
except Exception as e :
return " **ERROR**: " + str ( e ) , 0
def chat_streamly ( self , system , history , gen_conf ) :
if system :
history . insert ( 0 , { " role " : " system " , " content " : system } )
ans = " "
total_tokens = 0
try :
headers = {
" Authorization " : f " Bearer { self . api_key } " ,
" Content-Type " : " application/json " ,
}
payload = json . dumps (
{
" model " : self . model_name ,
" messages " : history ,
" stream " : True ,
* * gen_conf ,
}
)
response = requests . request (
" POST " ,
url = self . base_url ,
headers = headers ,
data = payload ,
)
for resp in response . text . split ( " \n \n " ) [ : - 1 ] :
resp = json . loads ( resp [ 6 : ] )
2024-07-29 19:35:16 +08:00
text = " "
if " choices " in resp and " delta " in resp [ " choices " ] [ 0 ] :
2024-07-17 15:32:51 +08:00
text = resp [ " choices " ] [ 0 ] [ " delta " ] [ " content " ]
ans + = text
2024-07-29 19:35:16 +08:00
total_tokens = (
total_tokens + num_tokens_from_string ( text )
if " usage " not in resp
else resp [ " usage " ] [ " total_tokens " ]
)
2024-07-17 15:32:51 +08:00
yield ans
except Exception as e :
yield ans + " \n **ERROR**: " + str ( e )
yield total_tokens
2024-06-14 11:32:58 +08:00
class MistralChat ( Base ) :
def __init__ ( self , key , model_name , base_url = None ) :
from mistralai . client import MistralClient
self . client = MistralClient ( api_key = key )
self . model_name = model_name
def chat ( self , system , history , gen_conf ) :
if system :
history . insert ( 0 , { " role " : " system " , " content " : system } )
for k in list ( gen_conf . keys ( ) ) :
if k not in [ " temperature " , " top_p " , " max_tokens " ] :
del gen_conf [ k ]
try :
response = self . client . chat (
model = self . model_name ,
messages = history ,
* * gen_conf )
ans = response . choices [ 0 ] . message . content
if response . choices [ 0 ] . finish_reason == " length " :
ans + = " ... \n For the content length reason, it stopped, continue? " if is_english (
[ ans ] ) else " ······ \n 由于长度的原因,回答被截断了,要继续吗? "
return ans , response . usage . total_tokens
except openai . APIError as e :
return " **ERROR**: " + str ( e ) , 0
def chat_streamly ( self , system , history , gen_conf ) :
if system :
history . insert ( 0 , { " role " : " system " , " content " : system } )
for k in list ( gen_conf . keys ( ) ) :
if k not in [ " temperature " , " top_p " , " max_tokens " ] :
del gen_conf [ k ]
ans = " "
total_tokens = 0
try :
response = self . client . chat_stream (
model = self . model_name ,
messages = history ,
* * gen_conf )
for resp in response :
if not resp . choices or not resp . choices [ 0 ] . delta . content : continue
ans + = resp . choices [ 0 ] . delta . content
total_tokens + = 1
if resp . choices [ 0 ] . finish_reason == " length " :
ans + = " ... \n For the content length reason, it stopped, continue? " if is_english (
[ ans ] ) else " ······ \n 由于长度的原因,回答被截断了,要继续吗? "
yield ans
except openai . APIError as e :
yield ans + " \n **ERROR**: " + str ( e )
yield total_tokens
2024-07-08 09:37:34 +08:00
class BedrockChat ( Base ) :
def __init__ ( self , key , model_name , * * kwargs ) :
import boto3
self . bedrock_ak = eval ( key ) . get ( ' bedrock_ak ' , ' ' )
self . bedrock_sk = eval ( key ) . get ( ' bedrock_sk ' , ' ' )
self . bedrock_region = eval ( key ) . get ( ' bedrock_region ' , ' ' )
self . model_name = model_name
self . client = boto3 . client ( service_name = ' bedrock-runtime ' , region_name = self . bedrock_region ,
aws_access_key_id = self . bedrock_ak , aws_secret_access_key = self . bedrock_sk )
def chat ( self , system , history , gen_conf ) :
2024-07-08 16:20:19 +08:00
from botocore . exceptions import ClientError
2024-07-08 09:37:34 +08:00
for k in list ( gen_conf . keys ( ) ) :
if k not in [ " temperature " , " top_p " , " max_tokens " ] :
del gen_conf [ k ]
if " max_tokens " in gen_conf :
gen_conf [ " maxTokens " ] = gen_conf [ " max_tokens " ]
_ = gen_conf . pop ( " max_tokens " )
if " top_p " in gen_conf :
gen_conf [ " topP " ] = gen_conf [ " top_p " ]
_ = gen_conf . pop ( " top_p " )
2024-08-15 14:54:49 +08:00
for item in history :
if not isinstance ( item [ " content " ] , list ) and not isinstance ( item [ " content " ] , tuple ) :
item [ " content " ] = [ { " text " : item [ " content " ] } ]
2024-07-08 09:37:34 +08:00
try :
# Send the message to the model, using a basic inference configuration.
response = self . client . converse (
modelId = self . model_name ,
messages = history ,
2024-08-23 06:44:37 +03:00
inferenceConfig = gen_conf ,
system = [ { " text " : system } ] if system else None ,
2024-07-08 09:37:34 +08:00
)
# Extract and print the response text.
ans = response [ " output " ] [ " message " ] [ " content " ] [ 0 ] [ " text " ]
return ans , num_tokens_from_string ( ans )
except ( ClientError , Exception ) as e :
return f " ERROR: Can ' t invoke ' { self . model_name } ' . Reason: { e } " , 0
def chat_streamly ( self , system , history , gen_conf ) :
2024-07-08 16:20:19 +08:00
from botocore . exceptions import ClientError
2024-07-08 09:37:34 +08:00
for k in list ( gen_conf . keys ( ) ) :
if k not in [ " temperature " , " top_p " , " max_tokens " ] :
del gen_conf [ k ]
if " max_tokens " in gen_conf :
gen_conf [ " maxTokens " ] = gen_conf [ " max_tokens " ]
_ = gen_conf . pop ( " max_tokens " )
if " top_p " in gen_conf :
gen_conf [ " topP " ] = gen_conf [ " top_p " ]
_ = gen_conf . pop ( " top_p " )
2024-08-15 14:54:49 +08:00
for item in history :
if not isinstance ( item [ " content " ] , list ) and not isinstance ( item [ " content " ] , tuple ) :
item [ " content " ] = [ { " text " : item [ " content " ] } ]
2024-07-08 09:37:34 +08:00
if self . model_name . split ( ' . ' ) [ 0 ] == ' ai21 ' :
try :
response = self . client . converse (
modelId = self . model_name ,
messages = history ,
2024-08-23 06:44:37 +03:00
inferenceConfig = gen_conf ,
system = [ { " text " : system } ] if system else None ,
2024-07-08 09:37:34 +08:00
)
ans = response [ " output " ] [ " message " ] [ " content " ] [ 0 ] [ " text " ]
return ans , num_tokens_from_string ( ans )
except ( ClientError , Exception ) as e :
return f " ERROR: Can ' t invoke ' { self . model_name } ' . Reason: { e } " , 0
ans = " "
try :
# Send the message to the model, using a basic inference configuration.
streaming_response = self . client . converse_stream (
modelId = self . model_name ,
messages = history ,
inferenceConfig = gen_conf
)
# Extract and print the streamed response text in real-time.
for resp in streaming_response [ " stream " ] :
if " contentBlockDelta " in resp :
ans + = resp [ " contentBlockDelta " ] [ " delta " ] [ " text " ]
yield ans
except ( ClientError , Exception ) as e :
yield ans + f " ERROR: Can ' t invoke ' { self . model_name } ' . Reason: { e } "
yield num_tokens_from_string ( ans )
2024-07-11 15:41:00 +08:00
class GeminiChat ( Base ) :
def __init__ ( self , key , model_name , base_url = None ) :
from google . generativeai import client , GenerativeModel
client . configure ( api_key = key )
_client = client . get_default_generative_client ( )
self . model_name = ' models/ ' + model_name
self . model = GenerativeModel ( model_name = self . model_name )
self . model . _client = _client
def chat ( self , system , history , gen_conf ) :
if system :
history . insert ( 0 , { " role " : " user " , " parts " : system } )
if ' max_tokens ' in gen_conf :
gen_conf [ ' max_output_tokens ' ] = gen_conf [ ' max_tokens ' ]
for k in list ( gen_conf . keys ( ) ) :
if k not in [ " temperature " , " top_p " , " max_output_tokens " ] :
del gen_conf [ k ]
for item in history :
if ' role ' in item and item [ ' role ' ] == ' assistant ' :
item [ ' role ' ] = ' model '
if ' content ' in item :
item [ ' parts ' ] = item . pop ( ' content ' )
try :
response = self . model . generate_content (
history ,
generation_config = gen_conf )
ans = response . text
return ans , response . usage_metadata . total_token_count
except Exception as e :
return " **ERROR**: " + str ( e ) , 0
def chat_streamly ( self , system , history , gen_conf ) :
if system :
history . insert ( 0 , { " role " : " user " , " parts " : system } )
if ' max_tokens ' in gen_conf :
gen_conf [ ' max_output_tokens ' ] = gen_conf [ ' max_tokens ' ]
for k in list ( gen_conf . keys ( ) ) :
if k not in [ " temperature " , " top_p " , " max_output_tokens " ] :
del gen_conf [ k ]
for item in history :
if ' role ' in item and item [ ' role ' ] == ' assistant ' :
item [ ' role ' ] = ' model '
if ' content ' in item :
item [ ' parts ' ] = item . pop ( ' content ' )
ans = " "
try :
response = self . model . generate_content (
history ,
generation_config = gen_conf , stream = True )
for resp in response :
ans + = resp . text
yield ans
except Exception as e :
yield ans + " \n **ERROR**: " + str ( e )
added SVG for Groq model model providers (#1470)
#1432 #1447
This PR adds support for the GROQ LLM (Large Language Model).
Groq is an AI solutions company delivering ultra-low latency inference
with the first-ever LPU™ Inference Engine. The Groq API enables
developers to integrate state-of-the-art LLMs, such as Llama-2 and
llama3-70b-8192, into low latency applications with the request limits
specified below. Learn more at [groq.com](https://groq.com/).
Supported Models
| ID | Requests per Minute | Requests per Day | Tokens per Minute |
|----------------------|---------------------|------------------|-------------------|
| gemma-7b-it | 30 | 14,400 | 15,000 |
| gemma2-9b-it | 30 | 14,400 | 15,000 |
| llama3-70b-8192 | 30 | 14,400 | 6,000 |
| llama3-8b-8192 | 30 | 14,400 | 30,000 |
| mixtral-8x7b-32768 | 30 | 14,400 | 5,000 |
---------
Co-authored-by: paresh0628 <paresh.tuvoc@gmail.com>
Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
2024-07-12 06:55:44 +05:30
yield response . _chunks [ - 1 ] . usage_metadata . total_token_count
class GroqChat :
def __init__ ( self , key , model_name , base_url = ' ' ) :
self . client = Groq ( api_key = key )
self . model_name = model_name
def chat ( self , system , history , gen_conf ) :
if system :
history . insert ( 0 , { " role " : " system " , " content " : system } )
for k in list ( gen_conf . keys ( ) ) :
if k not in [ " temperature " , " top_p " , " max_tokens " ] :
del gen_conf [ k ]
ans = " "
try :
response = self . client . chat . completions . create (
model = self . model_name ,
messages = history ,
* * gen_conf
)
ans = response . choices [ 0 ] . message . content
if response . choices [ 0 ] . finish_reason == " length " :
2024-07-16 15:19:43 +08:00
ans + = " ... \n For the content length reason, it stopped, continue? " if is_english (
added SVG for Groq model model providers (#1470)
#1432 #1447
This PR adds support for the GROQ LLM (Large Language Model).
Groq is an AI solutions company delivering ultra-low latency inference
with the first-ever LPU™ Inference Engine. The Groq API enables
developers to integrate state-of-the-art LLMs, such as Llama-2 and
llama3-70b-8192, into low latency applications with the request limits
specified below. Learn more at [groq.com](https://groq.com/).
Supported Models
| ID | Requests per Minute | Requests per Day | Tokens per Minute |
|----------------------|---------------------|------------------|-------------------|
| gemma-7b-it | 30 | 14,400 | 15,000 |
| gemma2-9b-it | 30 | 14,400 | 15,000 |
| llama3-70b-8192 | 30 | 14,400 | 6,000 |
| llama3-8b-8192 | 30 | 14,400 | 30,000 |
| mixtral-8x7b-32768 | 30 | 14,400 | 5,000 |
---------
Co-authored-by: paresh0628 <paresh.tuvoc@gmail.com>
Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
2024-07-12 06:55:44 +05:30
[ ans ] ) else " ······ \n 由于长度的原因,回答被截断了,要继续吗? "
return ans , response . usage . total_tokens
except Exception as e :
return ans + " \n **ERROR**: " + str ( e ) , 0
def chat_streamly ( self , system , history , gen_conf ) :
if system :
history . insert ( 0 , { " role " : " system " , " content " : system } )
for k in list ( gen_conf . keys ( ) ) :
if k not in [ " temperature " , " top_p " , " max_tokens " ] :
del gen_conf [ k ]
ans = " "
total_tokens = 0
try :
response = self . client . chat . completions . create (
model = self . model_name ,
messages = history ,
stream = True ,
* * gen_conf
)
for resp in response :
if not resp . choices or not resp . choices [ 0 ] . delta . content :
continue
ans + = resp . choices [ 0 ] . delta . content
total_tokens + = 1
if resp . choices [ 0 ] . finish_reason == " length " :
2024-07-16 15:19:43 +08:00
ans + = " ... \n For the content length reason, it stopped, continue? " if is_english (
added SVG for Groq model model providers (#1470)
#1432 #1447
This PR adds support for the GROQ LLM (Large Language Model).
Groq is an AI solutions company delivering ultra-low latency inference
with the first-ever LPU™ Inference Engine. The Groq API enables
developers to integrate state-of-the-art LLMs, such as Llama-2 and
llama3-70b-8192, into low latency applications with the request limits
specified below. Learn more at [groq.com](https://groq.com/).
Supported Models
| ID | Requests per Minute | Requests per Day | Tokens per Minute |
|----------------------|---------------------|------------------|-------------------|
| gemma-7b-it | 30 | 14,400 | 15,000 |
| gemma2-9b-it | 30 | 14,400 | 15,000 |
| llama3-70b-8192 | 30 | 14,400 | 6,000 |
| llama3-8b-8192 | 30 | 14,400 | 30,000 |
| mixtral-8x7b-32768 | 30 | 14,400 | 5,000 |
---------
Co-authored-by: paresh0628 <paresh.tuvoc@gmail.com>
Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
2024-07-12 06:55:44 +05:30
[ ans ] ) else " ······ \n 由于长度的原因,回答被截断了,要继续吗? "
yield ans
except Exception as e :
yield ans + " \n **ERROR**: " + str ( e )
2024-07-16 15:19:43 +08:00
yield total_tokens
## openrouter
class OpenRouterChat ( Base ) :
def __init__ ( self , key , model_name , base_url = " https://openrouter.ai/api/v1 " ) :
2024-07-25 10:23:35 +08:00
if not base_url :
base_url = " https://openrouter.ai/api/v1 "
super ( ) . __init__ ( key , model_name , base_url )
2024-07-19 16:26:12 +08:00
class StepFunChat ( Base ) :
2024-07-24 10:49:37 +08:00
def __init__ ( self , key , model_name , base_url = " https://api.stepfun.com/v1 " ) :
2024-07-19 16:26:12 +08:00
if not base_url :
2024-07-24 10:49:37 +08:00
base_url = " https://api.stepfun.com/v1 "
2024-07-23 10:43:09 +08:00
super ( ) . __init__ ( key , model_name , base_url )
class NvidiaChat ( Base ) :
2024-07-25 10:23:35 +08:00
def __init__ ( self , key , model_name , base_url = " https://integrate.api.nvidia.com/v1 " ) :
2024-07-23 10:43:09 +08:00
if not base_url :
2024-07-25 10:23:35 +08:00
base_url = " https://integrate.api.nvidia.com/v1 "
super ( ) . __init__ ( key , model_name , base_url )
2024-07-24 12:46:43 +08:00
class LmStudioChat ( Base ) :
def __init__ ( self , key , model_name , base_url ) :
if not base_url :
raise ValueError ( " Local llm url cannot be None " )
if base_url . split ( " / " ) [ - 1 ] != " v1 " :
2024-08-06 16:20:21 +08:00
base_url = os . path . join ( base_url , " v1 " )
self . client = OpenAI ( api_key = " lm-studio " , base_url = base_url )
2024-07-24 12:46:43 +08:00
self . model_name = model_name
2024-08-06 16:20:21 +08:00
class OpenAI_APIChat ( Base ) :
def __init__ ( self , key , model_name , base_url ) :
if not base_url :
raise ValueError ( " url cannot be None " )
if base_url . split ( " / " ) [ - 1 ] != " v1 " :
base_url = os . path . join ( base_url , " v1 " )
model_name = model_name . split ( " ___ " ) [ 0 ]
super ( ) . __init__ ( key , model_name , base_url )
2024-08-07 18:40:51 +08:00
class CoHereChat ( Base ) :
def __init__ ( self , key , model_name , base_url = " " ) :
from cohere import Client
self . client = Client ( api_key = key )
self . model_name = model_name
def chat ( self , system , history , gen_conf ) :
if system :
history . insert ( 0 , { " role " : " system " , " content " : system } )
if " top_p " in gen_conf :
gen_conf [ " p " ] = gen_conf . pop ( " top_p " )
if " frequency_penalty " in gen_conf and " presence_penalty " in gen_conf :
gen_conf . pop ( " presence_penalty " )
for item in history :
if " role " in item and item [ " role " ] == " user " :
item [ " role " ] = " USER "
if " role " in item and item [ " role " ] == " assistant " :
item [ " role " ] = " CHATBOT "
if " content " in item :
item [ " message " ] = item . pop ( " content " )
mes = history . pop ( ) [ " message " ]
ans = " "
try :
response = self . client . chat (
model = self . model_name , chat_history = history , message = mes , * * gen_conf
)
ans = response . text
if response . finish_reason == " MAX_TOKENS " :
ans + = (
" ... \n For the content length reason, it stopped, continue? "
if is_english ( [ ans ] )
else " ······ \n 由于长度的原因,回答被截断了,要继续吗? "
)
return (
ans ,
response . meta . tokens . input_tokens + response . meta . tokens . output_tokens ,
)
except Exception as e :
return ans + " \n **ERROR**: " + str ( e ) , 0
def chat_streamly ( self , system , history , gen_conf ) :
if system :
history . insert ( 0 , { " role " : " system " , " content " : system } )
if " top_p " in gen_conf :
gen_conf [ " p " ] = gen_conf . pop ( " top_p " )
if " frequency_penalty " in gen_conf and " presence_penalty " in gen_conf :
gen_conf . pop ( " presence_penalty " )
for item in history :
if " role " in item and item [ " role " ] == " user " :
item [ " role " ] = " USER "
if " role " in item and item [ " role " ] == " assistant " :
item [ " role " ] = " CHATBOT "
if " content " in item :
item [ " message " ] = item . pop ( " content " )
mes = history . pop ( ) [ " message " ]
ans = " "
total_tokens = 0
try :
response = self . client . chat_stream (
model = self . model_name , chat_history = history , message = mes , * * gen_conf
)
for resp in response :
if resp . event_type == " text-generation " :
ans + = resp . text
total_tokens + = num_tokens_from_string ( resp . text )
elif resp . event_type == " stream-end " :
if resp . finish_reason == " MAX_TOKENS " :
ans + = (
" ... \n For the content length reason, it stopped, continue? "
if is_english ( [ ans ] )
else " ······ \n 由于长度的原因,回答被截断了,要继续吗? "
)
yield ans
except Exception as e :
yield ans + " \n **ERROR**: " + str ( e )
yield total_tokens
2024-08-08 12:09:50 +08:00
class LeptonAIChat ( Base ) :
def __init__ ( self , key , model_name , base_url = None ) :
if not base_url :
base_url = os . path . join ( " https:// " + model_name + " .lepton.run " , " api " , " v1 " )
2024-08-12 10:11:50 +08:00
super ( ) . __init__ ( key , model_name , base_url )
2024-08-12 10:15:21 +08:00
class TogetherAIChat ( Base ) :
def __init__ ( self , key , model_name , base_url = " https://api.together.xyz/v1 " ) :
if not base_url :
base_url = " https://api.together.xyz/v1 "
super ( ) . __init__ ( key , model_name , base_url )
2024-08-19 10:36:57 +08:00
2024-08-12 10:11:50 +08:00
class PerfXCloudChat ( Base ) :
def __init__ ( self , key , model_name , base_url = " https://cloud.perfxlab.cn/v1 " ) :
if not base_url :
base_url = " https://cloud.perfxlab.cn/v1 "
super ( ) . __init__ ( key , model_name , base_url )
2024-08-12 11:06:25 +08:00
class UpstageChat ( Base ) :
def __init__ ( self , key , model_name , base_url = " https://api.upstage.ai/v1/solar " ) :
if not base_url :
base_url = " https://api.upstage.ai/v1/solar "
super ( ) . __init__ ( key , model_name , base_url )
2024-08-12 17:26:26 +08:00
class NovitaAIChat ( Base ) :
def __init__ ( self , key , model_name , base_url = " https://api.novita.ai/v3/openai " ) :
if not base_url :
base_url = " https://api.novita.ai/v3/openai "
super ( ) . __init__ ( key , model_name , base_url )
2024-08-13 16:09:10 +08:00
class SILICONFLOWChat ( Base ) :
def __init__ ( self , key , model_name , base_url = " https://api.siliconflow.cn/v1 " ) :
if not base_url :
base_url = " https://api.siliconflow.cn/v1 "
2024-08-15 10:02:36 +08:00
super ( ) . __init__ ( key , model_name , base_url )
class YiChat ( Base ) :
2024-08-16 16:35:40 +08:00
def __init__ ( self , key , model_name , base_url = " https://api.01.ai/v1 " ) :
2024-08-15 10:02:36 +08:00
if not base_url :
2024-08-16 16:35:40 +08:00
base_url = " https://api.01.ai/v1 "
2024-08-19 10:36:57 +08:00
super ( ) . __init__ ( key , model_name , base_url )
class ReplicateChat ( Base ) :
def __init__ ( self , key , model_name , base_url = None ) :
from replicate . client import Client
self . model_name = model_name
self . client = Client ( api_token = key )
self . system = " "
def chat ( self , system , history , gen_conf ) :
if " max_tokens " in gen_conf :
gen_conf [ " max_new_tokens " ] = gen_conf . pop ( " max_tokens " )
if system :
self . system = system
prompt = " \n " . join (
[ item [ " role " ] + " : " + item [ " content " ] for item in history [ - 5 : ] ]
)
ans = " "
try :
response = self . client . run (
self . model_name ,
input = { " system_prompt " : self . system , " prompt " : prompt , * * gen_conf } ,
)
ans = " " . join ( response )
return ans , num_tokens_from_string ( ans )
except Exception as e :
return ans + " \n **ERROR**: " + str ( e ) , 0
def chat_streamly ( self , system , history , gen_conf ) :
if " max_tokens " in gen_conf :
gen_conf [ " max_new_tokens " ] = gen_conf . pop ( " max_tokens " )
if system :
self . system = system
prompt = " \n " . join (
[ item [ " role " ] + " : " + item [ " content " ] for item in history [ - 5 : ] ]
)
ans = " "
try :
response = self . client . run (
self . model_name ,
input = { " system_prompt " : self . system , " prompt " : prompt , * * gen_conf } ,
)
for resp in response :
ans + = resp
yield ans
except Exception as e :
yield ans + " \n **ERROR**: " + str ( e )
yield num_tokens_from_string ( ans )
2024-08-20 15:27:13 +08:00
class HunyuanChat ( Base ) :
def __init__ ( self , key , model_name , base_url = None ) :
from tencentcloud . common import credential
from tencentcloud . hunyuan . v20230901 import hunyuan_client
key = json . loads ( key )
sid = key . get ( " hunyuan_sid " , " " )
sk = key . get ( " hunyuan_sk " , " " )
cred = credential . Credential ( sid , sk )
self . model_name = model_name
self . client = hunyuan_client . HunyuanClient ( cred , " " )
def chat ( self , system , history , gen_conf ) :
from tencentcloud . hunyuan . v20230901 import models
from tencentcloud . common . exception . tencent_cloud_sdk_exception import (
TencentCloudSDKException ,
)
_gen_conf = { }
_history = [ { k . capitalize ( ) : v for k , v in item . items ( ) } for item in history ]
if system :
_history . insert ( 0 , { " Role " : " system " , " Content " : system } )
if " temperature " in gen_conf :
_gen_conf [ " Temperature " ] = gen_conf [ " temperature " ]
if " top_p " in gen_conf :
_gen_conf [ " TopP " ] = gen_conf [ " top_p " ]
req = models . ChatCompletionsRequest ( )
params = { " Model " : self . model_name , " Messages " : _history , * * _gen_conf }
req . from_json_string ( json . dumps ( params ) )
ans = " "
try :
response = self . client . ChatCompletions ( req )
ans = response . Choices [ 0 ] . Message . Content
return ans , response . Usage . TotalTokens
except TencentCloudSDKException as e :
return ans + " \n **ERROR**: " + str ( e ) , 0
def chat_streamly ( self , system , history , gen_conf ) :
from tencentcloud . hunyuan . v20230901 import models
from tencentcloud . common . exception . tencent_cloud_sdk_exception import (
TencentCloudSDKException ,
)
2024-08-20 16:56:42 +08:00
2024-08-20 15:27:13 +08:00
_gen_conf = { }
_history = [ { k . capitalize ( ) : v for k , v in item . items ( ) } for item in history ]
if system :
_history . insert ( 0 , { " Role " : " system " , " Content " : system } )
2024-08-20 16:56:42 +08:00
2024-08-20 15:27:13 +08:00
if " temperature " in gen_conf :
_gen_conf [ " Temperature " ] = gen_conf [ " temperature " ]
if " top_p " in gen_conf :
_gen_conf [ " TopP " ] = gen_conf [ " top_p " ]
req = models . ChatCompletionsRequest ( )
params = {
" Model " : self . model_name ,
" Messages " : _history ,
" Stream " : True ,
* * _gen_conf ,
}
req . from_json_string ( json . dumps ( params ) )
ans = " "
total_tokens = 0
try :
response = self . client . ChatCompletions ( req )
for resp in response :
resp = json . loads ( resp [ " data " ] )
if not resp [ " Choices " ] or not resp [ " Choices " ] [ 0 ] [ " Delta " ] [ " Content " ] :
continue
ans + = resp [ " Choices " ] [ 0 ] [ " Delta " ] [ " Content " ]
total_tokens + = 1
yield ans
except TencentCloudSDKException as e :
yield ans + " \n **ERROR**: " + str ( e )
yield total_tokens
2024-08-20 16:56:42 +08:00
class SparkChat ( Base ) :
def __init__ (
self , key , model_name , base_url = " https://spark-api-open.xf-yun.com/v1 "
) :
if not base_url :
base_url = " https://spark-api-open.xf-yun.com/v1 "
model2version = {
" Spark-Max " : " generalv3.5 " ,
" Spark-Lite " : " general " ,
" Spark-Pro " : " generalv3 " ,
" Spark-Pro-128K " : " pro-128k " ,
" Spark-4.0-Ultra " : " 4.0Ultra " ,
}
model_version = model2version [ model_name ]
super ( ) . __init__ ( key , model_version , base_url )
2024-08-22 16:45:15 +08:00
class BaiduYiyanChat ( Base ) :
def __init__ ( self , key , model_name , base_url = None ) :
import qianfan
key = json . loads ( key )
ak = key . get ( " yiyan_ak " , " " )
sk = key . get ( " yiyan_sk " , " " )
self . client = qianfan . ChatCompletion ( ak = ak , sk = sk )
self . model_name = model_name . lower ( )
self . system = " "
def chat ( self , system , history , gen_conf ) :
if system :
self . system = system
gen_conf [ " penalty_score " ] = (
( gen_conf . get ( " presence_penalty " , 0 ) + gen_conf . get ( " frequency_penalty " , 0 ) ) / 2
) + 1
if " max_tokens " in gen_conf :
gen_conf [ " max_output_tokens " ] = gen_conf [ " max_tokens " ]
ans = " "
try :
response = self . client . do (
model = self . model_name ,
messages = history ,
system = self . system ,
* * gen_conf
) . body
ans = response [ ' result ' ]
return ans , response [ " usage " ] [ " total_tokens " ]
except Exception as e :
return ans + " \n **ERROR**: " + str ( e ) , 0
def chat_streamly ( self , system , history , gen_conf ) :
if system :
self . system = system
gen_conf [ " penalty_score " ] = (
( gen_conf . get ( " presence_penalty " , 0 ) + gen_conf . get ( " frequency_penalty " , 0 ) ) / 2
) + 1
if " max_tokens " in gen_conf :
gen_conf [ " max_output_tokens " ] = gen_conf [ " max_tokens " ]
ans = " "
total_tokens = 0
try :
response = self . client . do (
model = self . model_name ,
messages = history ,
system = self . system ,
stream = True ,
* * gen_conf
)
for resp in response :
resp = resp . body
ans + = resp [ ' result ' ]
total_tokens = resp [ " usage " ] [ " total_tokens " ]
yield ans
except Exception as e :
return ans + " \n **ERROR**: " + str ( e ) , 0
yield total_tokens