ragflow/agent/canvas.py
Billy Bao 534fa60b2a
Fix: Agent.reset() argument wrong #10463 & Unable to converse with agent through Python API. #10415 (#10472)
### What problem does this PR solve?
Fix: Agent.reset() argument wrong #10463 & Unable to converse with agent
through Python API. #10415

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-10-10 20:44:05 +08:00

513 lines
20 KiB
Python

#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import base64
import json
import logging
import re
import time
from concurrent.futures import ThreadPoolExecutor
from copy import deepcopy
from functools import partial
from typing import Any, Union, Tuple
from agent.component import component_class
from agent.component.base import ComponentBase
from api.db.services.file_service import FileService
from api.utils import get_uuid, hash_str2int
from rag.prompts.generator import chunks_format
from rag.utils.redis_conn import REDIS_CONN
class Graph:
"""
dsl = {
"components": {
"begin": {
"obj":{
"component_name": "Begin",
"params": {},
},
"downstream": ["answer_0"],
"upstream": [],
},
"retrieval_0": {
"obj": {
"component_name": "Retrieval",
"params": {}
},
"downstream": ["generate_0"],
"upstream": ["answer_0"],
},
"generate_0": {
"obj": {
"component_name": "Generate",
"params": {}
},
"downstream": ["answer_0"],
"upstream": ["retrieval_0"],
}
},
"history": [],
"path": ["begin"],
"retrieval": {"chunks": [], "doc_aggs": []},
"globals": {
"sys.query": "",
"sys.user_id": tenant_id,
"sys.conversation_turns": 0,
"sys.files": []
}
}
"""
def __init__(self, dsl: str, tenant_id=None, task_id=None):
self.path = []
self.components = {}
self.error = ""
self.dsl = json.loads(dsl)
self._tenant_id = tenant_id
self.task_id = task_id if task_id else get_uuid()
self.load()
def load(self):
self.components = self.dsl["components"]
cpn_nms = set([])
for k, cpn in self.components.items():
cpn_nms.add(cpn["obj"]["component_name"])
for k, cpn in self.components.items():
cpn_nms.add(cpn["obj"]["component_name"])
param = component_class(cpn["obj"]["component_name"] + "Param")()
param.update(cpn["obj"]["params"])
try:
param.check()
except Exception as e:
raise ValueError(self.get_component_name(k) + f": {e}")
cpn["obj"] = component_class(cpn["obj"]["component_name"])(self, k, param)
self.path = self.dsl["path"]
def __str__(self):
self.dsl["path"] = self.path
self.dsl["task_id"] = self.task_id
dsl = {
"components": {}
}
for k in self.dsl.keys():
if k in ["components"]:
continue
dsl[k] = deepcopy(self.dsl[k])
for k, cpn in self.components.items():
if k not in dsl["components"]:
dsl["components"][k] = {}
for c in cpn.keys():
if c == "obj":
dsl["components"][k][c] = json.loads(str(cpn["obj"]))
continue
dsl["components"][k][c] = deepcopy(cpn[c])
return json.dumps(dsl, ensure_ascii=False)
def reset(self):
self.path = []
for k, cpn in self.components.items():
self.components[k]["obj"].reset()
try:
REDIS_CONN.delete(f"{self.task_id}-logs")
except Exception as e:
logging.exception(e)
def get_component_name(self, cid):
for n in self.dsl.get("graph", {}).get("nodes", []):
if cid == n["id"]:
return n["data"]["name"]
return ""
def run(self, **kwargs):
raise NotImplementedError()
def get_component(self, cpn_id) -> Union[None, dict[str, Any]]:
return self.components.get(cpn_id)
def get_component_obj(self, cpn_id) -> ComponentBase:
return self.components.get(cpn_id)["obj"]
def get_component_type(self, cpn_id) -> str:
return self.components.get(cpn_id)["obj"].component_name
def get_component_input_form(self, cpn_id) -> dict:
return self.components.get(cpn_id)["obj"].get_input_form()
def get_tenant_id(self):
return self._tenant_id
def get_variable_value(self, exp: str) -> Any:
exp = exp.strip("{").strip("}").strip(" ").strip("{").strip("}")
if exp.find("@") < 0:
return self.globals[exp]
cpn_id, var_nm = exp.split("@")
cpn = self.get_component(cpn_id)
if not cpn:
raise Exception(f"Can't find variable: '{cpn_id}@{var_nm}'")
return cpn["obj"].output(var_nm)
class Canvas(Graph):
def __init__(self, dsl: str, tenant_id=None, task_id=None):
self.globals = {
"sys.query": "",
"sys.user_id": tenant_id,
"sys.conversation_turns": 0,
"sys.files": []
}
super().__init__(dsl, tenant_id, task_id)
def load(self):
super().load()
self.history = self.dsl["history"]
if "globals" in self.dsl:
self.globals = self.dsl["globals"]
else:
self.globals = {
"sys.query": "",
"sys.user_id": "",
"sys.conversation_turns": 0,
"sys.files": []
}
self.retrieval = self.dsl["retrieval"]
self.memory = self.dsl.get("memory", [])
def __str__(self):
self.dsl["history"] = self.history
self.dsl["retrieval"] = self.retrieval
self.dsl["memory"] = self.memory
return super().__str__()
def reset(self, mem=False):
super().reset()
if not mem:
self.history = []
self.retrieval = []
self.memory = []
for k in self.globals.keys():
if isinstance(self.globals[k], str):
self.globals[k] = ""
elif isinstance(self.globals[k], int):
self.globals[k] = 0
elif isinstance(self.globals[k], float):
self.globals[k] = 0
elif isinstance(self.globals[k], list):
self.globals[k] = []
elif isinstance(self.globals[k], dict):
self.globals[k] = {}
else:
self.globals[k] = None
def run(self, **kwargs):
st = time.perf_counter()
self.message_id = get_uuid()
created_at = int(time.time())
self.add_user_input(kwargs.get("query"))
for k, cpn in self.components.items():
self.components[k]["obj"].reset(True)
for k in kwargs.keys():
if k in ["query", "user_id", "files"] and kwargs[k]:
if k == "files":
self.globals[f"sys.{k}"] = self.get_files(kwargs[k])
else:
self.globals[f"sys.{k}"] = kwargs[k]
if not self.globals["sys.conversation_turns"] :
self.globals["sys.conversation_turns"] = 0
self.globals["sys.conversation_turns"] += 1
def decorate(event, dt):
nonlocal created_at
return {
"event": event,
#"conversation_id": "f3cc152b-24b0-4258-a1a1-7d5e9fc8a115",
"message_id": self.message_id,
"created_at": created_at,
"task_id": self.task_id,
"data": dt
}
if not self.path or self.path[-1].lower().find("userfillup") < 0:
self.path.append("begin")
self.retrieval.append({"chunks": [], "doc_aggs": []})
yield decorate("workflow_started", {"inputs": kwargs.get("inputs")})
self.retrieval.append({"chunks": {}, "doc_aggs": {}})
def _run_batch(f, t):
with ThreadPoolExecutor(max_workers=5) as executor:
thr = []
for i in range(f, t):
cpn = self.get_component_obj(self.path[i])
if cpn.component_name.lower() in ["begin", "userfillup"]:
thr.append(executor.submit(cpn.invoke, inputs=kwargs.get("inputs", {})))
else:
thr.append(executor.submit(cpn.invoke, **cpn.get_input()))
for t in thr:
t.result()
def _node_finished(cpn_obj):
return decorate("node_finished",{
"inputs": cpn_obj.get_input_values(),
"outputs": cpn_obj.output(),
"component_id": cpn_obj._id,
"component_name": self.get_component_name(cpn_obj._id),
"component_type": self.get_component_type(cpn_obj._id),
"error": cpn_obj.error(),
"elapsed_time": time.perf_counter() - cpn_obj.output("_created_time"),
"created_at": cpn_obj.output("_created_time"),
})
self.error = ""
idx = len(self.path) - 1
partials = []
while idx < len(self.path):
to = len(self.path)
for i in range(idx, to):
yield decorate("node_started", {
"inputs": None, "created_at": int(time.time()),
"component_id": self.path[i],
"component_name": self.get_component_name(self.path[i]),
"component_type": self.get_component_type(self.path[i]),
"thoughts": self.get_component_thoughts(self.path[i])
})
_run_batch(idx, to)
# post processing of components invocation
for i in range(idx, to):
cpn = self.get_component(self.path[i])
cpn_obj = self.get_component_obj(self.path[i])
if cpn_obj.component_name.lower() == "message":
if isinstance(cpn_obj.output("content"), partial):
_m = ""
for m in cpn_obj.output("content")():
if not m:
continue
if m == "<think>":
yield decorate("message", {"content": "", "start_to_think": True})
elif m == "</think>":
yield decorate("message", {"content": "", "end_to_think": True})
else:
yield decorate("message", {"content": m})
_m += m
cpn_obj.set_output("content", _m)
cite = re.search(r"\[ID:[ 0-9]+\]", _m)
else:
yield decorate("message", {"content": cpn_obj.output("content")})
cite = re.search(r"\[ID:[ 0-9]+\]", cpn_obj.output("content"))
yield decorate("message_end", {"reference": self.get_reference() if cite else None})
while partials:
_cpn_obj = self.get_component_obj(partials[0])
if isinstance(_cpn_obj.output("content"), partial):
break
yield _node_finished(_cpn_obj)
partials.pop(0)
other_branch = False
if cpn_obj.error():
ex = cpn_obj.exception_handler()
if ex and ex["goto"]:
self.path.extend(ex["goto"])
other_branch = True
elif ex and ex["default_value"]:
yield decorate("message", {"content": ex["default_value"]})
yield decorate("message_end", {})
else:
self.error = cpn_obj.error()
if cpn_obj.component_name.lower() != "iteration":
if isinstance(cpn_obj.output("content"), partial):
if self.error:
cpn_obj.set_output("content", None)
yield _node_finished(cpn_obj)
else:
partials.append(self.path[i])
else:
yield _node_finished(cpn_obj)
def _append_path(cpn_id):
nonlocal other_branch
if other_branch:
return
if self.path[-1] == cpn_id:
return
self.path.append(cpn_id)
def _extend_path(cpn_ids):
nonlocal other_branch
if other_branch:
return
for cpn_id in cpn_ids:
_append_path(cpn_id)
if cpn_obj.component_name.lower() == "iterationitem" and cpn_obj.end():
iter = cpn_obj.get_parent()
yield _node_finished(iter)
_extend_path(self.get_component(cpn["parent_id"])["downstream"])
elif cpn_obj.component_name.lower() in ["categorize", "switch"]:
_extend_path(cpn_obj.output("_next"))
elif cpn_obj.component_name.lower() == "iteration":
_append_path(cpn_obj.get_start())
elif not cpn["downstream"] and cpn_obj.get_parent():
_append_path(cpn_obj.get_parent().get_start())
else:
_extend_path(cpn["downstream"])
if self.error:
logging.error(f"Runtime Error: {self.error}")
break
idx = to
if any([self.get_component_obj(c).component_name.lower() == "userfillup" for c in self.path[idx:]]):
path = [c for c in self.path[idx:] if self.get_component(c)["obj"].component_name.lower() == "userfillup"]
path.extend([c for c in self.path[idx:] if self.get_component(c)["obj"].component_name.lower() != "userfillup"])
another_inputs = {}
tips = ""
for c in path:
o = self.get_component_obj(c)
if o.component_name.lower() == "userfillup":
another_inputs.update(o.get_input_elements())
if o.get_param("enable_tips"):
tips = o.get_param("tips")
self.path = path
yield decorate("user_inputs", {"inputs": another_inputs, "tips": tips})
return
self.path = self.path[:idx]
if not self.error:
yield decorate("workflow_finished",
{
"inputs": kwargs.get("inputs"),
"outputs": self.get_component_obj(self.path[-1]).output(),
"elapsed_time": time.perf_counter() - st,
"created_at": st,
})
self.history.append(("assistant", self.get_component_obj(self.path[-1]).output()))
def is_reff(self, exp: str) -> bool:
exp = exp.strip("{").strip("}")
if exp.find("@") < 0:
return exp in self.globals
arr = exp.split("@")
if len(arr) != 2:
return False
if self.get_component(arr[0]) is None:
return False
return True
def get_history(self, window_size):
convs = []
if window_size <= 0:
return convs
for role, obj in self.history[window_size * -2:]:
if isinstance(obj, dict):
convs.append({"role": role, "content": obj.get("content", "")})
else:
convs.append({"role": role, "content": str(obj)})
return convs
def add_user_input(self, question):
self.history.append(("user", question))
def get_prologue(self):
return self.components["begin"]["obj"]._param.prologue
def get_mode(self):
return self.components["begin"]["obj"]._param.mode
def set_global_param(self, **kwargs):
self.globals.update(kwargs)
def get_preset_param(self):
return self.components["begin"]["obj"]._param.inputs
def get_component_input_elements(self, cpnnm):
return self.components[cpnnm]["obj"].get_input_elements()
def get_files(self, files: Union[None, list[dict]]) -> list[str]:
if not files:
return []
def image_to_base64(file):
return "data:{};base64,{}".format(file["mime_type"],
base64.b64encode(FileService.get_blob(file["created_by"], file["id"])).decode("utf-8"))
exe = ThreadPoolExecutor(max_workers=5)
threads = []
for file in files:
if file["mime_type"].find("image") >=0:
threads.append(exe.submit(image_to_base64, file))
continue
threads.append(exe.submit(FileService.parse, file["name"], FileService.get_blob(file["created_by"], file["id"]), True, file["created_by"]))
return [th.result() for th in threads]
def tool_use_callback(self, agent_id: str, func_name: str, params: dict, result: Any, elapsed_time=None):
agent_ids = agent_id.split("-->")
agent_name = self.get_component_name(agent_ids[0])
path = agent_name if len(agent_ids) < 2 else agent_name+"-->"+"-->".join(agent_ids[1:])
try:
bin = REDIS_CONN.get(f"{self.task_id}-{self.message_id}-logs")
if bin:
obj = json.loads(bin.encode("utf-8"))
if obj[-1]["component_id"] == agent_ids[0]:
obj[-1]["trace"].append({"path": path, "tool_name": func_name, "arguments": params, "result": result, "elapsed_time": elapsed_time})
else:
obj.append({
"component_id": agent_ids[0],
"trace": [{"path": path, "tool_name": func_name, "arguments": params, "result": result, "elapsed_time": elapsed_time}]
})
else:
obj = [{
"component_id": agent_ids[0],
"trace": [{"path": path, "tool_name": func_name, "arguments": params, "result": result, "elapsed_time": elapsed_time}]
}]
REDIS_CONN.set_obj(f"{self.task_id}-{self.message_id}-logs", obj, 60*10)
except Exception as e:
logging.exception(e)
def add_reference(self, chunks: list[object], doc_infos: list[object]):
if not self.retrieval:
self.retrieval = [{"chunks": {}, "doc_aggs": {}}]
r = self.retrieval[-1]
for ck in chunks_format({"chunks": chunks}):
cid = hash_str2int(ck["id"], 500)
# cid = uuid.uuid5(uuid.NAMESPACE_DNS, ck["id"])
if cid not in r:
r["chunks"][cid] = ck
for doc in doc_infos:
if doc["doc_name"] not in r:
r["doc_aggs"][doc["doc_name"]] = doc
def get_reference(self):
if not self.retrieval:
return {"chunks": {}, "doc_aggs": {}}
return self.retrieval[-1]
def add_memory(self, user:str, assist:str, summ: str):
self.memory.append((user, assist, summ))
def get_memory(self) -> list[Tuple]:
return self.memory
def get_component_thoughts(self, cpn_id) -> str:
return self.components.get(cpn_id)["obj"].thoughts()