diff --git a/agentic_reasoning/deep_research.py b/agentic_reasoning/deep_research.py index cc5fdb91d..6976e9190 100644 --- a/agentic_reasoning/deep_research.py +++ b/agentic_reasoning/deep_research.py @@ -36,132 +36,188 @@ class DeepResearcher: self._kb_retrieve = kb_retrieve self._kg_retrieve = kg_retrieve + @staticmethod + def _remove_query_tags(text): + """Remove query tags from text""" + pattern = re.escape(BEGIN_SEARCH_QUERY) + r"(.*?)" + re.escape(END_SEARCH_QUERY) + return re.sub(pattern, "", text) + + @staticmethod + def _remove_result_tags(text): + """Remove result tags from text""" + pattern = re.escape(BEGIN_SEARCH_RESULT) + r"(.*?)" + re.escape(END_SEARCH_RESULT) + return re.sub(pattern, "", text) + + def _generate_reasoning(self, msg_history): + """Generate reasoning steps""" + query_think = "" + if msg_history[-1]["role"] != "user": + msg_history.append({"role": "user", "content": "Continues reasoning with the new information.\n"}) + else: + msg_history[-1]["content"] += "\n\nContinues reasoning with the new information.\n" + + for ans in self.chat_mdl.chat_streamly(REASON_PROMPT, msg_history, {"temperature": 0.7}): + ans = re.sub(r".*", "", ans, flags=re.DOTALL) + if not ans: + continue + query_think = ans + yield query_think + return query_think + + def _extract_search_queries(self, query_think, question, step_index): + """Extract search queries from thinking""" + queries = extract_between(query_think, BEGIN_SEARCH_QUERY, END_SEARCH_QUERY) + if not queries and step_index == 0: + # If this is the first step and no queries are found, use the original question as the query + queries = [question] + return queries + + def _truncate_previous_reasoning(self, all_reasoning_steps): + """Truncate previous reasoning steps to maintain a reasonable length""" + truncated_prev_reasoning = "" + for i, step in enumerate(all_reasoning_steps): + truncated_prev_reasoning += f"Step {i + 1}: {step}\n\n" + + prev_steps = truncated_prev_reasoning.split('\n\n') + if len(prev_steps) <= 5: + truncated_prev_reasoning = '\n\n'.join(prev_steps) + else: + truncated_prev_reasoning = '' + for i, step in enumerate(prev_steps): + if i == 0 or i >= len(prev_steps) - 4 or BEGIN_SEARCH_QUERY in step or BEGIN_SEARCH_RESULT in step: + truncated_prev_reasoning += step + '\n\n' + else: + if truncated_prev_reasoning[-len('\n\n...\n\n'):] != '\n\n...\n\n': + truncated_prev_reasoning += '...\n\n' + + return truncated_prev_reasoning.strip('\n') + + def _retrieve_information(self, search_query): + """Retrieve information from different sources""" + # 1. Knowledge base retrieval + kbinfos = self._kb_retrieve(question=search_query) if self._kb_retrieve else {"chunks": [], "doc_aggs": []} + + # 2. Web retrieval (if Tavily API is configured) + if self.prompt_config.get("tavily_api_key"): + tav = Tavily(self.prompt_config["tavily_api_key"]) + tav_res = tav.retrieve_chunks(search_query) + kbinfos["chunks"].extend(tav_res["chunks"]) + kbinfos["doc_aggs"].extend(tav_res["doc_aggs"]) + + # 3. Knowledge graph retrieval (if configured) + if self.prompt_config.get("use_kg") and self._kg_retrieve: + ck = self._kg_retrieve(question=search_query) + if ck["content_with_weight"]: + kbinfos["chunks"].insert(0, ck) + + return kbinfos + + def _update_chunk_info(self, chunk_info, kbinfos): + """Update chunk information for citations""" + if not chunk_info["chunks"]: + # If this is the first retrieval, use the retrieval results directly + for k in chunk_info.keys(): + chunk_info[k] = kbinfos[k] + else: + # Merge newly retrieved information, avoiding duplicates + cids = [c["chunk_id"] for c in chunk_info["chunks"]] + for c in kbinfos["chunks"]: + if c["chunk_id"] not in cids: + chunk_info["chunks"].append(c) + + dids = [d["doc_id"] for d in chunk_info["doc_aggs"]] + for d in kbinfos["doc_aggs"]: + if d["doc_id"] not in dids: + chunk_info["doc_aggs"].append(d) + + def _extract_relevant_info(self, truncated_prev_reasoning, search_query, kbinfos): + """Extract and summarize relevant information""" + summary_think = "" + for ans in self.chat_mdl.chat_streamly( + RELEVANT_EXTRACTION_PROMPT.format( + prev_reasoning=truncated_prev_reasoning, + search_query=search_query, + document="\n".join(kb_prompt(kbinfos, 4096)) + ), + [{"role": "user", + "content": f'Now you should analyze each web page and find helpful information based on the current search query "{search_query}" and previous reasoning steps.'}], + {"temperature": 0.7}): + ans = re.sub(r".*", "", ans, flags=re.DOTALL) + if not ans: + continue + summary_think = ans + yield summary_think + + return summary_think + def thinking(self, chunk_info: dict, question: str): - def rm_query_tags(line): - pattern = re.escape(BEGIN_SEARCH_QUERY) + r"(.*?)" + re.escape(END_SEARCH_QUERY) - return re.sub(pattern, "", line) - - def rm_result_tags(line): - pattern = re.escape(BEGIN_SEARCH_RESULT) + r"(.*?)" + re.escape(END_SEARCH_RESULT) - return re.sub(pattern, "", line) - executed_search_queries = [] - msg_hisotry = [{"role": "user", "content": f'Question:\"{question}\"\n'}] + msg_history = [{"role": "user", "content": f'Question:\"{question}\"\n'}] all_reasoning_steps = [] think = "" - for ii in range(MAX_SEARCH_LIMIT + 1): - if ii == MAX_SEARCH_LIMIT - 1: + + for step_index in range(MAX_SEARCH_LIMIT + 1): + # Check if the maximum search limit has been reached + if step_index == MAX_SEARCH_LIMIT - 1: summary_think = f"\n{BEGIN_SEARCH_RESULT}\nThe maximum search limit is exceeded. You are not allowed to search.\n{END_SEARCH_RESULT}\n" yield {"answer": think + summary_think + "", "reference": {}, "audio_binary": None} all_reasoning_steps.append(summary_think) - msg_hisotry.append({"role": "assistant", "content": summary_think}) + msg_history.append({"role": "assistant", "content": summary_think}) break + # Step 1: Generate reasoning query_think = "" - if msg_hisotry[-1]["role"] != "user": - msg_hisotry.append({"role": "user", "content": "Continues reasoning with the new information.\n"}) - else: - msg_hisotry[-1]["content"] += "\n\nContinues reasoning with the new information.\n" - for ans in self.chat_mdl.chat_streamly(REASON_PROMPT, msg_hisotry, {"temperature": 0.7}): - ans = re.sub(r".*", "", ans, flags=re.DOTALL) - if not ans: - continue + for ans in self._generate_reasoning(msg_history): query_think = ans - yield {"answer": think + rm_query_tags(query_think) + "", "reference": {}, "audio_binary": None} + yield {"answer": think + self._remove_query_tags(query_think) + "", "reference": {}, "audio_binary": None} - think += rm_query_tags(query_think) + think += self._remove_query_tags(query_think) all_reasoning_steps.append(query_think) - queries = extract_between(query_think, BEGIN_SEARCH_QUERY, END_SEARCH_QUERY) - if not queries: - if ii > 0: - break - queries = [question] + + # Step 2: Extract search queries + queries = self._extract_search_queries(query_think, question, step_index) + if not queries and step_index > 0: + # If not the first step and no queries, end the search process + break + # Process each search query for search_query in queries: - logging.info(f"[THINK]Query: {ii}. {search_query}") - msg_hisotry.append({"role": "assistant", "content": search_query}) - think += f"\n\n> {ii +1}. {search_query}\n\n" + logging.info(f"[THINK]Query: {step_index}. {search_query}") + msg_history.append({"role": "assistant", "content": search_query}) + think += f"\n\n> {step_index + 1}. {search_query}\n\n" yield {"answer": think + "", "reference": {}, "audio_binary": None} - summary_think = "" - # The search query has been searched in previous steps. + # Check if the query has already been executed if search_query in executed_search_queries: summary_think = f"\n{BEGIN_SEARCH_RESULT}\nYou have searched this query. Please refer to previous results.\n{END_SEARCH_RESULT}\n" yield {"answer": think + summary_think + "", "reference": {}, "audio_binary": None} all_reasoning_steps.append(summary_think) - msg_hisotry.append({"role": "user", "content": summary_think}) + msg_history.append({"role": "user", "content": summary_think}) think += summary_think continue - - truncated_prev_reasoning = "" - for i, step in enumerate(all_reasoning_steps): - truncated_prev_reasoning += f"Step {i + 1}: {step}\n\n" - - prev_steps = truncated_prev_reasoning.split('\n\n') - if len(prev_steps) <= 5: - truncated_prev_reasoning = '\n\n'.join(prev_steps) - else: - truncated_prev_reasoning = '' - for i, step in enumerate(prev_steps): - if i == 0 or i >= len(prev_steps) - 4 or BEGIN_SEARCH_QUERY in step or BEGIN_SEARCH_RESULT in step: - truncated_prev_reasoning += step + '\n\n' - else: - if truncated_prev_reasoning[-len('\n\n...\n\n'):] != '\n\n...\n\n': - truncated_prev_reasoning += '...\n\n' - truncated_prev_reasoning = truncated_prev_reasoning.strip('\n') - - # Retrieval procedure: - # 1. KB search - # 2. Web search (optional) - # 3. KG search (optional) - kbinfos = self._kb_retrieve(question=search_query) if self._kb_retrieve else {"chunks": [], "doc_aggs": []} - - if self.prompt_config.get("tavily_api_key"): - tav = Tavily(self.prompt_config["tavily_api_key"]) - tav_res = tav.retrieve_chunks(search_query) - kbinfos["chunks"].extend(tav_res["chunks"]) - kbinfos["doc_aggs"].extend(tav_res["doc_aggs"]) - if self.prompt_config.get("use_kg") and self._kg_retrieve: - ck = self._kg_retrieve(question=search_query) - if ck["content_with_weight"]: - kbinfos["chunks"].insert(0, ck) - - # Merge chunk info for citations - if not chunk_info["chunks"]: - for k in chunk_info.keys(): - chunk_info[k] = kbinfos[k] - else: - cids = [c["chunk_id"] for c in chunk_info["chunks"]] - for c in kbinfos["chunks"]: - if c["chunk_id"] in cids: - continue - chunk_info["chunks"].append(c) - dids = [d["doc_id"] for d in chunk_info["doc_aggs"]] - for d in kbinfos["doc_aggs"]: - if d["doc_id"] in dids: - continue - chunk_info["doc_aggs"].append(d) - + + executed_search_queries.append(search_query) + + # Step 3: Truncate previous reasoning steps + truncated_prev_reasoning = self._truncate_previous_reasoning(all_reasoning_steps) + + # Step 4: Retrieve information + kbinfos = self._retrieve_information(search_query) + + # Step 5: Update chunk information + self._update_chunk_info(chunk_info, kbinfos) + + # Step 6: Extract relevant information think += "\n\n" - for ans in self.chat_mdl.chat_streamly( - RELEVANT_EXTRACTION_PROMPT.format( - prev_reasoning=truncated_prev_reasoning, - search_query=search_query, - document="\n".join(kb_prompt(kbinfos, 4096)) - ), - [{"role": "user", - "content": f'Now you should analyze each web page and find helpful information based on the current search query "{search_query}" and previous reasoning steps.'}], - {"temperature": 0.7}): - ans = re.sub(r".*", "", ans, flags=re.DOTALL) - if not ans: - continue + summary_think = "" + for ans in self._extract_relevant_info(truncated_prev_reasoning, search_query, kbinfos): summary_think = ans - yield {"answer": think + rm_result_tags(summary_think) + "", "reference": {}, "audio_binary": None} + yield {"answer": think + self._remove_result_tags(summary_think) + "", "reference": {}, "audio_binary": None} all_reasoning_steps.append(summary_think) - msg_hisotry.append( + msg_history.append( {"role": "user", "content": f"\n\n{BEGIN_SEARCH_RESULT}{summary_think}{END_SEARCH_RESULT}\n\n"}) - think += rm_result_tags(summary_think) - logging.info(f"[THINK]Summary: {ii}. {summary_think}") + think += self._remove_result_tags(summary_think) + logging.info(f"[THINK]Summary: {step_index}. {summary_think}") yield think + ""