diff --git a/agent/component/agent_with_tools.py b/agent/component/agent_with_tools.py index b3470d6aa..979b636af 100644 --- a/agent/component/agent_with_tools.py +++ b/agent/component/agent_with_tools.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import json import logging import os import re @@ -29,7 +30,7 @@ from api.db.services.tenant_llm_service import TenantLLMService from api.db.services.mcp_server_service import MCPServerService from common.connection_utils import timeout from rag.prompts.generator import next_step, COMPLETE_TASK, analyze_task, \ - citation_prompt, reflect, rank_memories, kb_prompt, citation_plus, full_question, message_fit_in + citation_prompt, reflect, rank_memories, kb_prompt, citation_plus, full_question, message_fit_in, structured_output_prompt from common.mcp_tool_call_conn import MCPToolCallSession, mcp_tool_metadata_to_openai_tool from agent.component.llm import LLMParam, LLM @@ -137,6 +138,29 @@ class Agent(LLM, ToolBase): res.update(cpn.get_input_form()) return res + def _get_output_schema(self): + try: + cand = self._param.outputs.get("structured") + except Exception: + return None + + if isinstance(cand, dict): + if isinstance(cand.get("properties"), dict) and len(cand["properties"]) > 0: + return cand + for k in ("schema", "structured"): + if isinstance(cand.get(k), dict) and isinstance(cand[k].get("properties"), dict) and len(cand[k]["properties"]) > 0: + return cand[k] + + return None + + def _force_format_to_schema(self, text: str, schema_prompt: str) -> str: + fmt_msgs = [ + {"role": "system", "content": schema_prompt + "\nIMPORTANT: Output ONLY valid JSON. No markdown, no extra text."}, + {"role": "user", "content": text}, + ] + _, fmt_msgs = message_fit_in(fmt_msgs, int(self.chat_mdl.max_length * 0.97)) + return self._generate(fmt_msgs) + @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 20*60))) def _invoke(self, **kwargs): if self.check_if_canceled("Agent processing"): @@ -160,17 +184,22 @@ class Agent(LLM, ToolBase): return LLM._invoke(self, **kwargs) prompt, msg, user_defined_prompt = self._prepare_prompt_variables() + output_schema = self._get_output_schema() + schema_prompt = "" + if output_schema: + schema = json.dumps(output_schema, ensure_ascii=False, indent=2) + schema_prompt = structured_output_prompt(schema) downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else [] ex = self.exception_handler() - if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not (ex and ex["goto"]): + if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not (ex and ex["goto"]) and not output_schema: self.set_output("content", partial(self.stream_output_with_tools, prompt, msg, user_defined_prompt)) return _, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97)) use_tools = [] ans = "" - for delta_ans, tk in self._react_with_tools_streamly(prompt, msg, use_tools, user_defined_prompt): + for delta_ans, tk in self._react_with_tools_streamly(prompt, msg, use_tools, user_defined_prompt,schema_prompt=schema_prompt): if self.check_if_canceled("Agent processing"): return ans += delta_ans @@ -183,6 +212,28 @@ class Agent(LLM, ToolBase): self.set_output("_ERROR", ans) return + if output_schema: + error = "" + for _ in range(self._param.max_retries + 1): + try: + def clean_formated_answer(ans: str) -> str: + ans = re.sub(r"^.*", "", ans, flags=re.DOTALL) + ans = re.sub(r"^.*```json", "", ans, flags=re.DOTALL) + return re.sub(r"```\n*$", "", ans, flags=re.DOTALL) + obj = json_repair.loads(clean_formated_answer(ans)) + self.set_output("structured", obj) + if use_tools: + self.set_output("use_tools", use_tools) + return obj + except Exception: + error = "The answer cannot be parsed as JSON" + ans = self._force_format_to_schema(ans, schema_prompt) + if ans.find("**ERROR**") >= 0: + continue + + self.set_output("_ERROR", error) + return + self.set_output("content", ans) if use_tools: self.set_output("use_tools", use_tools) @@ -219,7 +270,7 @@ class Agent(LLM, ToolBase): ]): yield delta_ans - def _react_with_tools_streamly(self, prompt, history: list[dict], use_tools, user_defined_prompt={}): + def _react_with_tools_streamly(self, prompt, history: list[dict], use_tools, user_defined_prompt={}, schema_prompt: str = ""): token_count = 0 tool_metas = self.tool_meta hist = deepcopy(history) @@ -256,9 +307,13 @@ class Agent(LLM, ToolBase): def complete(): nonlocal hist need2cite = self._param.cite and self._canvas.get_reference()["chunks"] and self._id.find("-->") < 0 + if schema_prompt: + need2cite = False cited = False - if hist[0]["role"] == "system" and need2cite: - if len(hist) < 7: + if hist and hist[0]["role"] == "system": + if schema_prompt: + hist[0]["content"] += "\n" + schema_prompt + if need2cite and len(hist) < 7: hist[0]["content"] += citation_prompt() cited = True yield "", token_count @@ -369,7 +424,7 @@ Respond immediately with your final comprehensive answer. """ for k in self._param.outputs.keys(): self._param.outputs[k]["value"] = None - + for k, cpn in self.tools.items(): if hasattr(cpn, "reset") and callable(cpn.reset): cpn.reset() diff --git a/agent/component/llm.py b/agent/component/llm.py index 807bbc288..0f5317676 100644 --- a/agent/component/llm.py +++ b/agent/component/llm.py @@ -222,7 +222,7 @@ class LLM(ComponentBase): output_structure = self._param.outputs['structured'] except Exception: pass - if output_structure and isinstance(output_structure, dict) and output_structure.get("properties"): + if output_structure and isinstance(output_structure, dict) and output_structure.get("properties") and len(output_structure["properties"]) > 0: schema=json.dumps(output_structure, ensure_ascii=False, indent=2) prompt += structured_output_prompt(schema) for _ in range(self._param.max_retries+1):