2024-01-15 08:46:22 +08:00
#
2024-01-19 19:51:57 +08:00
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
2024-01-15 08:46:22 +08:00
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
2024-03-27 11:33:46 +08:00
from zhipuai import ZhipuAI
from dashscope import Generation
2023-12-25 19:05:59 +08:00
from abc import ABC
2023-12-28 13:50:13 +08:00
from openai import OpenAI
2024-02-27 14:57:34 +08:00
import openai
2024-04-08 19:20:57 +08:00
from ollama import Client
2024-03-07 16:12:01 +08:00
from rag . nlp import is_english
2024-05-20 12:23:51 +08:00
from rag . utils import num_tokens_from_string
2024-03-07 16:12:01 +08:00
2023-12-28 13:50:13 +08:00
2023-12-25 19:05:59 +08:00
class Base ( ABC ) :
2024-05-08 10:30:02 +08:00
def __init__ ( self , key , model_name , base_url ) :
2024-03-28 19:15:16 +08:00
self . client = OpenAI ( api_key = key , base_url = base_url )
2024-01-22 19:51:38 +08:00
self . model_name = model_name
2023-12-25 19:05:59 +08:00
def chat ( self , system , history , gen_conf ) :
2024-03-27 11:33:46 +08:00
if system :
history . insert ( 0 , { " role " : " system " , " content " : system } )
2024-02-27 14:57:34 +08:00
try :
2024-03-07 16:12:01 +08:00
response = self . client . chat . completions . create (
2024-02-27 14:57:34 +08:00
model = self . model_name ,
messages = history ,
* * gen_conf )
2024-03-27 17:55:45 +08:00
ans = response . choices [ 0 ] . message . content . strip ( )
if response . choices [ 0 ] . finish_reason == " length " :
2024-03-07 16:12:01 +08:00
ans + = " ... \n For the content length reason, it stopped, continue? " if is_english (
[ ans ] ) else " ······ \n 由于长度的原因,回答被截断了,要继续吗? "
2024-04-15 14:43:44 +08:00
return ans , response . usage . total_tokens
2024-02-27 14:57:34 +08:00
except openai . APIError as e :
2024-03-27 11:33:46 +08:00
return " **ERROR**: " + str ( e ) , 0
2023-12-25 19:05:59 +08:00
2024-05-16 20:14:53 +08:00
def chat_streamly ( self , system , history , gen_conf ) :
if system :
history . insert ( 0 , { " role " : " system " , " content " : system } )
ans = " "
total_tokens = 0
try :
response = self . client . chat . completions . create (
model = self . model_name ,
messages = history ,
stream = True ,
* * gen_conf )
for resp in response :
2024-05-22 15:29:07 +08:00
if len ( resp . choices ) == 0 : continue
2024-05-16 20:14:53 +08:00
if not resp . choices [ 0 ] . delta . content : continue
ans + = resp . choices [ 0 ] . delta . content
total_tokens + = 1
if resp . choices [ 0 ] . finish_reason == " length " :
ans + = " ... \n For the content length reason, it stopped, continue? " if is_english (
[ ans ] ) else " ······ \n 由于长度的原因,回答被截断了,要继续吗? "
yield ans
except openai . APIError as e :
yield ans + " \n **ERROR**: " + str ( e )
yield total_tokens
2023-12-25 19:05:59 +08:00
2024-05-08 10:30:02 +08:00
class GptTurbo ( Base ) :
def __init__ ( self , key , model_name = " gpt-3.5-turbo " , base_url = " https://api.openai.com/v1 " ) :
if not base_url : base_url = " https://api.openai.com/v1 "
super ( ) . __init__ ( key , model_name , base_url )
class MoonshotChat ( Base ) :
2024-03-28 19:15:16 +08:00
def __init__ ( self , key , model_name = " moonshot-v1-8k " , base_url = " https://api.moonshot.cn/v1 " ) :
if not base_url : base_url = " https://api.moonshot.cn/v1 "
2024-05-08 10:30:02 +08:00
super ( ) . __init__ ( key , model_name , base_url )
2024-03-14 19:45:29 +08:00
2024-05-08 10:30:02 +08:00
class XinferenceChat ( Base ) :
def __init__ ( self , key = None , model_name = " " , base_url = " " ) :
key = " xxx "
super ( ) . __init__ ( key , model_name , base_url )
class DeepSeekChat ( Base ) :
def __init__ ( self , key , model_name = " deepseek-chat " , base_url = " https://api.deepseek.com/v1 " ) :
if not base_url : base_url = " https://api.deepseek.com/v1 "
super ( ) . __init__ ( key , model_name , base_url )
2024-03-15 18:59:00 +08:00
2024-03-14 19:45:29 +08:00
2023-12-28 13:50:13 +08:00
class QWenChat ( Base ) :
2024-03-28 19:15:16 +08:00
def __init__ ( self , key , model_name = Generation . Models . qwen_turbo , * * kwargs ) :
2024-01-22 19:51:38 +08:00
import dashscope
dashscope . api_key = key
self . model_name = model_name
2023-12-25 19:05:59 +08:00
def chat ( self , system , history , gen_conf ) :
from http import HTTPStatus
2024-03-27 11:33:46 +08:00
if system :
history . insert ( 0 , { " role " : " system " , " content " : system } )
2023-12-25 19:05:59 +08:00
response = Generation . call (
2024-01-22 19:51:38 +08:00
self . model_name ,
2023-12-28 13:50:13 +08:00
messages = history ,
2024-02-19 19:22:17 +08:00
result_format = ' message ' ,
* * gen_conf
2023-12-25 19:05:59 +08:00
)
2024-03-07 16:12:01 +08:00
ans = " "
tk_count = 0
2023-12-25 19:05:59 +08:00
if response . status_code == HTTPStatus . OK :
2024-03-07 16:12:01 +08:00
ans + = response . output . choices [ 0 ] [ ' message ' ] [ ' content ' ]
2024-04-11 10:13:43 +08:00
tk_count + = response . usage . total_tokens
2024-03-07 16:12:01 +08:00
if response . output . choices [ 0 ] . get ( " finish_reason " , " " ) == " length " :
2024-03-27 11:33:46 +08:00
ans + = " ... \n For the content length reason, it stopped, continue? " if is_english (
[ ans ] ) else " ······ \n 由于长度的原因,回答被截断了,要继续吗? "
2024-03-07 16:12:01 +08:00
return ans , tk_count
return " **ERROR**: " + response . message , tk_count
2024-02-08 17:01:01 +08:00
2024-05-16 20:14:53 +08:00
def chat_streamly ( self , system , history , gen_conf ) :
from http import HTTPStatus
if system :
history . insert ( 0 , { " role " : " system " , " content " : system } )
ans = " "
try :
response = Generation . call (
self . model_name ,
messages = history ,
result_format = ' message ' ,
stream = True ,
* * gen_conf
)
tk_count = 0
for resp in response :
if resp . status_code == HTTPStatus . OK :
ans = resp . output . choices [ 0 ] [ ' message ' ] [ ' content ' ]
tk_count = resp . usage . total_tokens
if resp . output . choices [ 0 ] . get ( " finish_reason " , " " ) == " length " :
ans + = " ... \n For the content length reason, it stopped, continue? " if is_english (
[ ans ] ) else " ······ \n 由于长度的原因,回答被截断了,要继续吗? "
yield ans
else :
yield ans + " \n **ERROR**: " + resp . message if str ( resp . message ) . find ( " Access " ) < 0 else " Out of credit. Please set the API key in **settings > Model providers.** "
except Exception as e :
yield ans + " \n **ERROR**: " + str ( e )
yield tk_count
2024-02-08 17:01:01 +08:00
class ZhipuChat ( Base ) :
2024-03-28 19:15:16 +08:00
def __init__ ( self , key , model_name = " glm-3-turbo " , * * kwargs ) :
2024-02-08 17:01:01 +08:00
self . client = ZhipuAI ( api_key = key )
self . model_name = model_name
def chat ( self , system , history , gen_conf ) :
2024-03-27 11:33:46 +08:00
if system :
history . insert ( 0 , { " role " : " system " , " content " : system } )
2024-03-07 16:12:01 +08:00
try :
2024-04-09 09:24:08 +08:00
if " presence_penalty " in gen_conf : del gen_conf [ " presence_penalty " ]
2024-04-09 16:16:10 +08:00
if " frequency_penalty " in gen_conf : del gen_conf [ " frequency_penalty " ]
2024-03-07 16:12:01 +08:00
response = self . client . chat . completions . create (
2024-03-27 17:55:45 +08:00
model = self . model_name ,
2024-03-07 16:12:01 +08:00
messages = history ,
* * gen_conf
)
2024-03-27 17:55:45 +08:00
ans = response . choices [ 0 ] . message . content . strip ( )
if response . choices [ 0 ] . finish_reason == " length " :
2024-03-07 16:12:01 +08:00
ans + = " ... \n For the content length reason, it stopped, continue? " if is_english (
[ ans ] ) else " ······ \n 由于长度的原因,回答被截断了,要继续吗? "
2024-04-15 14:43:44 +08:00
return ans , response . usage . total_tokens
2024-03-07 16:12:01 +08:00
except Exception as e :
2024-03-12 11:57:08 +08:00
return " **ERROR**: " + str ( e ) , 0
2024-05-16 20:14:53 +08:00
def chat_streamly ( self , system , history , gen_conf ) :
if system :
history . insert ( 0 , { " role " : " system " , " content " : system } )
if " presence_penalty " in gen_conf : del gen_conf [ " presence_penalty " ]
if " frequency_penalty " in gen_conf : del gen_conf [ " frequency_penalty " ]
ans = " "
try :
response = self . client . chat . completions . create (
model = self . model_name ,
messages = history ,
stream = True ,
* * gen_conf
)
tk_count = 0
for resp in response :
if not resp . choices [ 0 ] . delta . content : continue
delta = resp . choices [ 0 ] . delta . content
ans + = delta
2024-05-17 17:07:33 +08:00
if resp . choices [ 0 ] . finish_reason == " length " :
2024-05-16 20:14:53 +08:00
ans + = " ... \n For the content length reason, it stopped, continue? " if is_english (
[ ans ] ) else " ······ \n 由于长度的原因,回答被截断了,要继续吗? "
2024-05-17 17:07:33 +08:00
tk_count = resp . usage . total_tokens
if resp . choices [ 0 ] . finish_reason == " stop " : tk_count = resp . usage . total_tokens
2024-05-16 20:14:53 +08:00
yield ans
except Exception as e :
yield ans + " \n **ERROR**: " + str ( e )
yield tk_count
2024-03-27 11:33:46 +08:00
2024-04-08 19:20:57 +08:00
class OllamaChat ( Base ) :
def __init__ ( self , key , model_name , * * kwargs ) :
self . client = Client ( host = kwargs [ " base_url " ] )
self . model_name = model_name
def chat ( self , system , history , gen_conf ) :
if system :
history . insert ( 0 , { " role " : " system " , " content " : system } )
try :
2024-04-30 11:04:14 +08:00
options = { }
if " temperature " in gen_conf : options [ " temperature " ] = gen_conf [ " temperature " ]
if " max_tokens " in gen_conf : options [ " num_predict " ] = gen_conf [ " max_tokens " ]
if " top_p " in gen_conf : options [ " top_k " ] = gen_conf [ " top_p " ]
if " presence_penalty " in gen_conf : options [ " presence_penalty " ] = gen_conf [ " presence_penalty " ]
if " frequency_penalty " in gen_conf : options [ " frequency_penalty " ] = gen_conf [ " frequency_penalty " ]
2024-04-08 19:20:57 +08:00
response = self . client . chat (
model = self . model_name ,
messages = history ,
options = options
)
ans = response [ " message " ] [ " content " ] . strip ( )
2024-04-22 15:13:01 +08:00
return ans , response [ " eval_count " ] + response . get ( " prompt_eval_count " , 0 )
2024-04-08 19:20:57 +08:00
except Exception as e :
return " **ERROR**: " + str ( e ) , 0
2024-05-16 20:14:53 +08:00
def chat_streamly ( self , system , history , gen_conf ) :
if system :
history . insert ( 0 , { " role " : " system " , " content " : system } )
options = { }
if " temperature " in gen_conf : options [ " temperature " ] = gen_conf [ " temperature " ]
if " max_tokens " in gen_conf : options [ " num_predict " ] = gen_conf [ " max_tokens " ]
if " top_p " in gen_conf : options [ " top_k " ] = gen_conf [ " top_p " ]
if " presence_penalty " in gen_conf : options [ " presence_penalty " ] = gen_conf [ " presence_penalty " ]
if " frequency_penalty " in gen_conf : options [ " frequency_penalty " ] = gen_conf [ " frequency_penalty " ]
ans = " "
try :
response = self . client . chat (
model = self . model_name ,
messages = history ,
stream = True ,
options = options
)
for resp in response :
if resp [ " done " ] :
2024-05-17 12:07:00 +08:00
yield resp . get ( " prompt_eval_count " , 0 ) + resp . get ( " eval_count " , 0 )
ans + = resp [ " message " ] [ " content " ]
2024-05-16 20:14:53 +08:00
yield ans
except Exception as e :
yield ans + " \n **ERROR**: " + str ( e )
yield 0
2024-05-20 12:23:51 +08:00
class LocalLLM ( Base ) :
class RPCProxy :
def __init__ ( self , host , port ) :
self . host = host
self . port = int ( port )
self . __conn ( )
def __conn ( self ) :
from multiprocessing . connection import Client
self . _connection = Client (
( self . host , self . port ) , authkey = b ' infiniflow-token4kevinhu ' )
def __getattr__ ( self , name ) :
import pickle
def do_rpc ( * args , * * kwargs ) :
for _ in range ( 3 ) :
try :
self . _connection . send (
pickle . dumps ( ( name , args , kwargs ) ) )
return pickle . loads ( self . _connection . recv ( ) )
except Exception as e :
self . __conn ( )
raise Exception ( " RPC connection lost! " )
return do_rpc
def __init__ ( self , key , model_name = " glm-3-turbo " ) :
self . client = LocalLLM . RPCProxy ( " 127.0.0.1 " , 7860 )
def chat ( self , system , history , gen_conf ) :
if system :
history . insert ( 0 , { " role " : " system " , " content " : system } )
try :
ans = self . client . chat (
history ,
gen_conf
)
return ans , num_tokens_from_string ( ans )
except Exception as e :
2024-05-20 12:40:59 +08:00
return " **ERROR**: " + str ( e ) , 0
def chat_streamly ( self , system , history , gen_conf ) :
if system :
history . insert ( 0 , { " role " : " system " , " content " : system } )
token_count = 0
answer = " "
try :
for ans in self . client . chat_streamly ( history , gen_conf ) :
answer + = ans
token_count + = 1
yield answer
except Exception as e :
yield answer + " \n **ERROR**: " + str ( e )
yield token_count