ragflow/agent/component/categorize.py
Yongteng Lei 9213568692
Feat: add mechanism to check cancellation in Agent (#10766)
### What problem does this PR solve?

Add mechanism to check cancellation in Agent.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-11-11 17:36:48 +08:00

149 lines
5.4 KiB
Python

#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import os
import re
from abc import ABC
from common.constants import LLMType
from api.db.services.llm_service import LLMBundle
from agent.component.llm import LLMParam, LLM
from common.connection_utils import timeout
from rag.llm.chat_model import ERROR_PREFIX
class CategorizeParam(LLMParam):
"""
Define the categorize component parameters.
"""
def __init__(self):
super().__init__()
self.category_description = {}
self.query = "sys.query"
self.message_history_window_size = 1
self.update_prompt()
def check(self):
self.check_positive_integer(self.message_history_window_size, "[Categorize] Message window size > 0")
self.check_empty(self.category_description, "[Categorize] Category examples")
for k, v in self.category_description.items():
if not k:
raise ValueError("[Categorize] Category name can not be empty!")
if not v.get("to"):
raise ValueError(f"[Categorize] 'To' of category {k} can not be empty!")
def get_input_form(self) -> dict[str, dict]:
return {
"query": {
"type": "line",
"name": "Query"
}
}
def update_prompt(self):
cate_lines = []
for c, desc in self.category_description.items():
for line in desc.get("examples", []):
if not line:
continue
cate_lines.append("USER: \"" + re.sub(r"\n", " ", line, flags=re.DOTALL) + "\""+c)
descriptions = []
for c, desc in self.category_description.items():
if desc.get("description"):
descriptions.append(
"\n------\nCategory: {}\nDescription: {}".format(c, desc["description"]))
self.sys_prompt = """
You are an advanced classification system that categorizes user questions into specific types. Analyze the input question and classify it into ONE of the following categories:
{}
Here's description of each category:
- {}
---- Instructions ----
- Consider both explicit mentions and implied context
- Prioritize the most specific applicable category
- Return only the category name without explanations
- Use "Other" only when no other category fits
""".format(
"\n - ".join(list(self.category_description.keys())),
"\n".join(descriptions)
)
if cate_lines:
self.sys_prompt += """
---- Examples ----
{}
""".format("\n".join(cate_lines))
class Categorize(LLM, ABC):
component_name = "Categorize"
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
def _invoke(self, **kwargs):
if self.check_if_canceled("Categorize processing"):
return
msg = self._canvas.get_history(self._param.message_history_window_size)
if not msg:
msg = [{"role": "user", "content": ""}]
if kwargs.get("sys.query"):
msg[-1]["content"] = kwargs["sys.query"]
self.set_input_value("sys.query", kwargs["sys.query"])
else:
msg[-1]["content"] = self._canvas.get_variable_value(self._param.query)
self.set_input_value(self._param.query, msg[-1]["content"])
self._param.update_prompt()
chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id)
user_prompt = """
---- Real Data ----
{}
""".format(" | ".join(["{}: \"{}\"".format(c["role"].upper(), re.sub(r"\n", "", c["content"], flags=re.DOTALL)) for c in msg]))
if self.check_if_canceled("Categorize processing"):
return
ans = chat_mdl.chat(self._param.sys_prompt, [{"role": "user", "content": user_prompt}], self._param.gen_conf())
logging.info(f"input: {user_prompt}, answer: {str(ans)}")
if ERROR_PREFIX in ans:
raise Exception(ans)
if self.check_if_canceled("Categorize processing"):
return
# Count the number of times each category appears in the answer.
category_counts = {}
for c in self._param.category_description.keys():
count = ans.lower().count(c.lower())
category_counts[c] = count
cpn_ids = list(self._param.category_description.items())[-1][1]["to"]
max_category = list(self._param.category_description.keys())[0]
if any(category_counts.values()):
max_category = max(category_counts.items(), key=lambda x: x[1])[0]
cpn_ids = self._param.category_description[max_category]["to"]
self.set_output("category_name", max_category)
self.set_output("_next", cpn_ids)
def thoughts(self) -> str:
return "Which should it falls into {}? ...".format(",".join([f"`{c}`" for c, _ in self._param.category_description.items()]))