mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-06-26 22:19:57 +00:00
Fix: AzureChat model code (#8426)
### What problem does this PR solve? - Simplify AzureChat constructor by passing base_url directly - Clean up spacing and formatting in chat_model.py - Remove redundant parentheses and improve code consistency - #8423 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
parent
4760e317d5
commit
244d8a47b9
@ -157,9 +157,9 @@ class Base(ABC):
|
||||
tk_count = 0
|
||||
hist = deepcopy(history)
|
||||
# Implement exponential backoff retry strategy
|
||||
for attempt in range(self.max_retries+1):
|
||||
for attempt in range(self.max_retries + 1):
|
||||
history = hist
|
||||
for _ in range(self.max_rounds*2):
|
||||
for _ in range(self.max_rounds * 2):
|
||||
try:
|
||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=self.tools, **gen_conf)
|
||||
tk_count += self.total_token_count(response)
|
||||
@ -185,7 +185,6 @@ class Base(ABC):
|
||||
except Exception as e:
|
||||
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)})
|
||||
|
||||
|
||||
except Exception as e:
|
||||
e = self._exceptions(e, attempt)
|
||||
if e:
|
||||
@ -198,7 +197,7 @@ class Base(ABC):
|
||||
gen_conf = self._clean_conf(gen_conf)
|
||||
|
||||
# Implement exponential backoff retry strategy
|
||||
for attempt in range(self.max_retries+1):
|
||||
for attempt in range(self.max_retries + 1):
|
||||
try:
|
||||
return self._chat(history, gen_conf)
|
||||
except Exception as e:
|
||||
@ -232,9 +231,9 @@ class Base(ABC):
|
||||
total_tokens = 0
|
||||
hist = deepcopy(history)
|
||||
# Implement exponential backoff retry strategy
|
||||
for attempt in range(self.max_retries+1):
|
||||
for attempt in range(self.max_retries + 1):
|
||||
history = hist
|
||||
for _ in range(self.max_rounds*2):
|
||||
for _ in range(self.max_rounds * 2):
|
||||
reasoning_start = False
|
||||
try:
|
||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, **gen_conf)
|
||||
@ -453,11 +452,11 @@ class DeepSeekChat(Base):
|
||||
|
||||
|
||||
class AzureChat(Base):
|
||||
def __init__(self, key, model_name, **kwargs):
|
||||
def __init__(self, key, model_name, base_url, **kwargs):
|
||||
api_key = json.loads(key).get("api_key", "")
|
||||
api_version = json.loads(key).get("api_version", "2024-02-01")
|
||||
super().__init__(key, model_name, kwargs["base_url"], **kwargs)
|
||||
self.client = AzureOpenAI(api_key=api_key, azure_endpoint=kwargs["base_url"], api_version=api_version)
|
||||
super().__init__(key, model_name, base_url, **kwargs)
|
||||
self.client = AzureOpenAI(api_key=api_key, azure_endpoint=base_url, api_version=api_version)
|
||||
self.model_name = model_name
|
||||
|
||||
|
||||
@ -925,10 +924,10 @@ class LocalAIChat(Base):
|
||||
|
||||
|
||||
class LocalLLM(Base):
|
||||
|
||||
def __init__(self, key, model_name, base_url=None, **kwargs):
|
||||
super().__init__(key, model_name, base_url=base_url, **kwargs)
|
||||
from jina import Client
|
||||
|
||||
self.client = Client(port=12345, protocol="grpc", asyncio=True)
|
||||
|
||||
def _prepare_prompt(self, system, history, gen_conf):
|
||||
@ -985,13 +984,7 @@ class VolcEngineChat(Base):
|
||||
|
||||
|
||||
class MiniMaxChat(Base):
|
||||
def __init__(
|
||||
self,
|
||||
key,
|
||||
model_name,
|
||||
base_url="https://api.minimax.chat/v1/text/chatcompletion_v2",
|
||||
**kwargs
|
||||
):
|
||||
def __init__(self, key, model_name, base_url="https://api.minimax.chat/v1/text/chatcompletion_v2", **kwargs):
|
||||
super().__init__(key, model_name, base_url=base_url, **kwargs)
|
||||
|
||||
if not base_url:
|
||||
@ -1223,6 +1216,7 @@ class GeminiChat(Base):
|
||||
|
||||
def _chat(self, history, gen_conf):
|
||||
from google.generativeai.types import content_types
|
||||
|
||||
system = history[0]["content"] if history and history[0]["role"] == "system" else ""
|
||||
hist = []
|
||||
for item in history:
|
||||
@ -1880,4 +1874,4 @@ class GPUStackChat(Base):
|
||||
if not base_url:
|
||||
raise ValueError("Local llm url cannot be None")
|
||||
base_url = urljoin(base_url, "v1")
|
||||
super().__init__(key, model_name, base_url, **kwargs)
|
||||
super().__init__(key, model_name, base_url, **kwargs)
|
||||
|
Loading…
x
Reference in New Issue
Block a user