mirror of
https://github.com/OpenSPG/KAG.git
synced 2025-06-27 03:20:08 +00:00
fix(builder): fix std and llm_client (#497)
* 修复bug:#278,#495,ollama未自测 * 解决代码规范问题 --------- Co-authored-by: e <ling.liu@chinacreator.com>
This commit is contained in:
parent
af7c8fe0ac
commit
7d9bbc74e2
@ -129,12 +129,29 @@ class OpenIEEntitystandardizationdPrompt(PromptABC):
|
||||
entities_with_offical_name = set()
|
||||
merged = []
|
||||
entities = kwargs.get("named_entities", [])
|
||||
|
||||
# The caller does not have a unified input data structure, so multiple structures are treated uniformly here
|
||||
if "entities" in entities:
|
||||
entities = entities["entities"]
|
||||
if isinstance(entities, dict):
|
||||
_entities = []
|
||||
for category in entities:
|
||||
_e = entities[category]
|
||||
if isinstance(_e, list):
|
||||
for _e2 in _e:
|
||||
_entities.append({"name": _e2, "category": category})
|
||||
elif isinstance(_e, str):
|
||||
_entities.append({"name": _e, "category": category})
|
||||
else:
|
||||
pass
|
||||
|
||||
for entity in standardized_entity:
|
||||
merged.append(entity)
|
||||
entities_with_offical_name.add(entity["name"])
|
||||
# in case llm ignores some entities
|
||||
for entity in entities:
|
||||
if entity["name"] not in entities_with_offical_name:
|
||||
# Ignore entities without a name attribute
|
||||
if "name" in entity and entity["name"] not in entities_with_offical_name:
|
||||
entity["official_name"] = entity["name"]
|
||||
merged.append(entity)
|
||||
return merged
|
||||
|
@ -103,6 +103,7 @@ class OllamaClient(LLMClient):
|
||||
messages=messages,
|
||||
stream=self.stream,
|
||||
tools=tools,
|
||||
max_tokens=self.max_tokens,
|
||||
)
|
||||
if not self.stream:
|
||||
# reasoning_content = getattr(
|
||||
@ -179,6 +180,7 @@ class OllamaClient(LLMClient):
|
||||
messages=messages,
|
||||
stream=self.stream,
|
||||
tools=tools,
|
||||
max_tokens=self.max_tokens,
|
||||
)
|
||||
if not self.stream:
|
||||
# reasoning_content = getattr(
|
||||
|
@ -76,6 +76,7 @@ class OpenAIClient(LLMClient):
|
||||
logger.debug(
|
||||
f"Initialize OpenAIClient with rate limit {max_rate} every {time_period}s"
|
||||
)
|
||||
logger.info(f"OpenAIClient max_tokens={self.max_tokens}")
|
||||
|
||||
def __call__(self, prompt: str = "", image_url: str = None, **kwargs):
|
||||
"""
|
||||
@ -117,6 +118,7 @@ class OpenAIClient(LLMClient):
|
||||
temperature=self.temperature,
|
||||
timeout=self.timeout,
|
||||
tools=tools,
|
||||
max_tokens=self.max_tokens,
|
||||
)
|
||||
if not self.stream:
|
||||
# reasoning_content = getattr(
|
||||
@ -207,6 +209,7 @@ class OpenAIClient(LLMClient):
|
||||
temperature=self.temperature,
|
||||
timeout=self.timeout,
|
||||
tools=tools,
|
||||
max_tokens=self.max_tokens,
|
||||
)
|
||||
if not self.stream:
|
||||
# reasoning_content = getattr(
|
||||
@ -355,6 +358,7 @@ class AzureOpenAIClient(LLMClient):
|
||||
stream=self.stream,
|
||||
temperature=self.temperature,
|
||||
timeout=self.timeout,
|
||||
max_tokens=self.max_tokens,
|
||||
)
|
||||
rsp = response.choices[0].message.content
|
||||
tool_calls = response.choices[0].message.tool_calls
|
||||
@ -400,6 +404,7 @@ class AzureOpenAIClient(LLMClient):
|
||||
stream=self.stream,
|
||||
temperature=self.temperature,
|
||||
timeout=self.timeout,
|
||||
max_tokens=self.max_tokens,
|
||||
)
|
||||
rsp = response.choices[0].message.content
|
||||
tool_calls = response.choices[0].message.tool_calls
|
||||
|
@ -38,6 +38,7 @@ class LLMClient(Registrable):
|
||||
super().__init__(**kwargs)
|
||||
self.limiter = RATE_LIMITER_MANGER.get_rate_limiter(name, max_rate, time_period)
|
||||
self.enable_check = kwargs.get("enable_check", True)
|
||||
self.max_tokens = kwargs.get("max_tokens", 32768)
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
@ -101,6 +102,11 @@ class LLMClient(Registrable):
|
||||
_end = res.rfind("```")
|
||||
_start = res.find("```json")
|
||||
if _end != -1 and _start != -1:
|
||||
if _end == _start:
|
||||
logger.error(
|
||||
f"response is not intact, please set max_tokens. res={res}"
|
||||
)
|
||||
return res
|
||||
json_str = res[_start + len("```json") : _end].strip()
|
||||
else:
|
||||
json_str = res
|
||||
@ -133,6 +139,11 @@ class LLMClient(Registrable):
|
||||
_end = res.rfind("```")
|
||||
_start = res.find("```json")
|
||||
if _end != -1 and _start != -1:
|
||||
if _end == _start:
|
||||
logger.error(
|
||||
f"response is not intact, please set max_tokens. res={res}"
|
||||
)
|
||||
return res
|
||||
json_str = res[_start + len("```json") : _end].strip()
|
||||
else:
|
||||
json_str = res
|
||||
|
Loading…
x
Reference in New Issue
Block a user