mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-06 03:47:41 +00:00
### What problem does this PR solve? Fix invalid COMPONENT_EXEC_TIMEOUT. #10273 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
228 lines
9.5 KiB
Python
228 lines
9.5 KiB
Python
#
|
||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||
#
|
||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
# you may not use this file except in compliance with the License.
|
||
# You may obtain a copy of the License at
|
||
#
|
||
# http://www.apache.org/licenses/LICENSE-2.0
|
||
#
|
||
# Unless required by applicable law or agreed to in writing, software
|
||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
# See the License for the specific language governing permissions and
|
||
# limitations under the License.
|
||
#
|
||
import logging
|
||
import os
|
||
import time
|
||
from abc import ABC
|
||
from tavily import TavilyClient
|
||
from agent.tools.base import ToolParamBase, ToolBase, ToolMeta
|
||
from api.utils.api_utils import timeout
|
||
|
||
|
||
class TavilySearchParam(ToolParamBase):
|
||
"""
|
||
Define the Retrieval component parameters.
|
||
"""
|
||
|
||
def __init__(self):
|
||
self.meta:ToolMeta = {
|
||
"name": "tavily_search",
|
||
"description": """
|
||
Tavily is a search engine optimized for LLMs, aimed at efficient, quick and persistent search results.
|
||
When searching:
|
||
- Start with specific query which should focus on just a single aspect.
|
||
- Number of keywords in query should be less than 5.
|
||
- Broaden search terms if needed
|
||
- Cross-reference information from multiple sources
|
||
""",
|
||
"parameters": {
|
||
"query": {
|
||
"type": "string",
|
||
"description": "The search keywords to execute with Tavily. The keywords should be the most important words/terms(includes synonyms) from the original request.",
|
||
"default": "{sys.query}",
|
||
"required": True
|
||
},
|
||
"topic": {
|
||
"type": "string",
|
||
"description": "default:general. The category of the search.news is useful for retrieving real-time updates, particularly about politics, sports, and major current events covered by mainstream media sources. general is for broader, more general-purpose searches that may include a wide range of sources.",
|
||
"enum": ["general", "news"],
|
||
"default": "general",
|
||
"required": False,
|
||
},
|
||
"include_domains": {
|
||
"type": "array",
|
||
"description": "default:[]. A list of domains only from which the search results can be included.",
|
||
"default": [],
|
||
"items": {
|
||
"type": "string",
|
||
"description": "Domain name that must be included, e.g. www.yahoo.com"
|
||
},
|
||
"required": False
|
||
},
|
||
"exclude_domains": {
|
||
"type": "array",
|
||
"description": "default:[]. A list of domains from which the search results can not be included",
|
||
"default": [],
|
||
"items": {
|
||
"type": "string",
|
||
"description": "Domain name that must be excluded, e.g. www.yahoo.com"
|
||
},
|
||
"required": False
|
||
},
|
||
}
|
||
}
|
||
super().__init__()
|
||
self.api_key = ""
|
||
self.search_depth = "basic" # basic/advanced
|
||
self.max_results = 6
|
||
self.days = 14
|
||
self.include_answer = False
|
||
self.include_raw_content = False
|
||
self.include_images = False
|
||
self.include_image_descriptions = False
|
||
|
||
def check(self):
|
||
self.check_valid_value(self.topic, "Tavily topic: should be in 'general/news'", ["general", "news"])
|
||
self.check_valid_value(self.search_depth, "Tavily search depth should be in 'basic/advanced'", ["basic", "advanced"])
|
||
self.check_positive_integer(self.max_results, "Tavily max result number should be within [1, 20]")
|
||
self.check_positive_integer(self.days, "Tavily days should be greater than 1")
|
||
|
||
def get_input_form(self) -> dict[str, dict]:
|
||
return {
|
||
"query": {
|
||
"name": "Query",
|
||
"type": "line"
|
||
}
|
||
}
|
||
|
||
class TavilySearch(ToolBase, ABC):
|
||
component_name = "TavilySearch"
|
||
|
||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||
def _invoke(self, **kwargs):
|
||
if not kwargs.get("query"):
|
||
self.set_output("formalized_content", "")
|
||
return ""
|
||
|
||
self.tavily_client = TavilyClient(api_key=self._param.api_key)
|
||
last_e = None
|
||
for fld in ["search_depth", "topic", "max_results", "days", "include_answer", "include_raw_content", "include_images", "include_image_descriptions", "include_domains", "exclude_domains"]:
|
||
if fld not in kwargs:
|
||
kwargs[fld] = getattr(self._param, fld)
|
||
for _ in range(self._param.max_retries+1):
|
||
try:
|
||
kwargs["include_images"] = False
|
||
kwargs["include_raw_content"] = False
|
||
res = self.tavily_client.search(**kwargs)
|
||
self._retrieve_chunks(res["results"],
|
||
get_title=lambda r: r["title"],
|
||
get_url=lambda r: r["url"],
|
||
get_content=lambda r: r["raw_content"] if r["raw_content"] else r["content"],
|
||
get_score=lambda r: r["score"])
|
||
self.set_output("json", res["results"])
|
||
return self.output("formalized_content")
|
||
except Exception as e:
|
||
last_e = e
|
||
logging.exception(f"Tavily error: {e}")
|
||
time.sleep(self._param.delay_after_error)
|
||
if last_e:
|
||
self.set_output("_ERROR", str(last_e))
|
||
return f"Tavily error: {last_e}"
|
||
|
||
assert False, self.output()
|
||
|
||
def thoughts(self) -> str:
|
||
return """
|
||
Keywords: {}
|
||
Looking for the most relevant articles.
|
||
""".format(self.get_input().get("query", "-_-!"))
|
||
|
||
|
||
class TavilyExtractParam(ToolParamBase):
|
||
"""
|
||
Define the Retrieval component parameters.
|
||
"""
|
||
|
||
def __init__(self):
|
||
self.meta:ToolMeta = {
|
||
"name": "tavily_extract",
|
||
"description": "Extract web page content from one or more specified URLs using Tavily Extract.",
|
||
"parameters": {
|
||
"urls": {
|
||
"type": "array",
|
||
"description": "The URLs to extract content from.",
|
||
"default": "",
|
||
"items": {
|
||
"type": "string",
|
||
"description": "The URL to extract content from, e.g. www.yahoo.com"
|
||
},
|
||
"required": True
|
||
},
|
||
"extract_depth": {
|
||
"type": "string",
|
||
"description": "The depth of the extraction process. advanced extraction retrieves more data, including tables and embedded content, with higher success but may increase latency.basic extraction costs 1 credit per 5 successful URL extractions, while advanced extraction costs 2 credits per 5 successful URL extractions.",
|
||
"enum": ["basic", "advanced"],
|
||
"default": "basic",
|
||
"required": False,
|
||
},
|
||
"format": {
|
||
"type": "string",
|
||
"description": "The format of the extracted web page content. markdown returns content in markdown format. text returns plain text and may increase latency.",
|
||
"enum": ["markdown", "text"],
|
||
"default": "markdown",
|
||
"required": False,
|
||
}
|
||
}
|
||
}
|
||
super().__init__()
|
||
self.api_key = ""
|
||
self.extract_depth = "basic" # basic/advanced
|
||
self.urls = []
|
||
self.format = "markdown"
|
||
self.include_images = False
|
||
|
||
def check(self):
|
||
self.check_valid_value(self.extract_depth, "Tavily extract depth should be in 'basic/advanced'", ["basic", "advanced"])
|
||
self.check_valid_value(self.format, "Tavily extract format should be in 'markdown/text'", ["markdown", "text"])
|
||
|
||
def get_input_form(self) -> dict[str, dict]:
|
||
return {
|
||
"urls": {
|
||
"name": "URLs",
|
||
"type": "line"
|
||
}
|
||
}
|
||
|
||
class TavilyExtract(ToolBase, ABC):
|
||
component_name = "TavilyExtract"
|
||
|
||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||
def _invoke(self, **kwargs):
|
||
self.tavily_client = TavilyClient(api_key=self._param.api_key)
|
||
last_e = None
|
||
for fld in ["urls", "extract_depth", "format"]:
|
||
if fld not in kwargs:
|
||
kwargs[fld] = getattr(self._param, fld)
|
||
if kwargs.get("urls") and isinstance(kwargs["urls"], str):
|
||
kwargs["urls"] = kwargs["urls"].split(",")
|
||
for _ in range(self._param.max_retries+1):
|
||
try:
|
||
kwargs["include_images"] = False
|
||
res = self.tavily_client.extract(**kwargs)
|
||
self.set_output("json", res["results"])
|
||
return self.output("json")
|
||
except Exception as e:
|
||
last_e = e
|
||
logging.exception(f"Tavily error: {e}")
|
||
if last_e:
|
||
self.set_output("_ERROR", str(last_e))
|
||
return f"Tavily error: {last_e}"
|
||
|
||
assert False, self.output()
|
||
|
||
def thoughts(self) -> str:
|
||
return "Opened {}—pulling out the main text…".format(self.get_input().get("urls", "-_-!"))
|