LightRAG/lightrag/utils.py

745 lines
24 KiB
Python
Raw Normal View History

2025-02-15 22:37:12 +01:00
from __future__ import annotations
2024-10-10 15:02:30 +08:00
import asyncio
import html
2024-10-31 14:31:26 +08:00
import io
import csv
2024-10-10 15:02:30 +08:00
import json
import logging
import os
import re
from dataclasses import dataclass
from functools import wraps
from hashlib import md5
2025-02-15 22:37:12 +01:00
from typing import Any, Callable
2024-10-20 23:08:26 +08:00
import xml.etree.ElementTree as ET
2025-02-06 22:56:17 +03:00
import bs4
2024-10-10 15:02:30 +08:00
import numpy as np
import tiktoken
from lightrag.prompt import PROMPTS
2025-01-27 23:21:34 +08:00
VERBOSE_DEBUG = False
def verbose_debug(msg: str, *args, **kwargs):
"""Function for outputting detailed debug information.
When VERBOSE_DEBUG=True, outputs the complete message.
When VERBOSE_DEBUG=False, outputs only the first 30 characters.
"""
if VERBOSE_DEBUG:
logger.debug(msg, *args, **kwargs)
def set_verbose_debug(enabled: bool):
"""Enable or disable verbose debug output"""
global VERBOSE_DEBUG
VERBOSE_DEBUG = enabled
2024-12-10 09:01:21 +08:00
class UnlimitedSemaphore:
"""A context manager that allows unlimited access."""
async def __aenter__(self):
pass
async def __aexit__(self, exc_type, exc, tb):
pass
2024-10-10 15:02:30 +08:00
ENCODER = None
2025-01-16 12:52:37 +08:00
statistic_data = {"llm_call": 0, "llm_cache": 0, "embed_call": 0}
2024-10-10 15:02:30 +08:00
logger = logging.getLogger("lightrag")
2025-01-16 12:52:37 +08:00
# Set httpx logging level to WARNING
logging.getLogger("httpx").setLevel(logging.WARNING)
2024-10-10 15:02:30 +08:00
def set_logger(log_file: str):
logger.setLevel(logging.DEBUG)
2025-01-16 12:58:15 +08:00
file_handler = logging.FileHandler(log_file, encoding="utf-8")
2024-10-10 15:02:30 +08:00
file_handler.setLevel(logging.DEBUG)
formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
2024-10-10 15:02:30 +08:00
file_handler.setFormatter(formatter)
if not logger.handlers:
logger.addHandler(file_handler)
2024-10-10 15:02:30 +08:00
@dataclass
class EmbeddingFunc:
embedding_dim: int
max_token_size: int
func: callable
2025-02-01 22:07:12 +08:00
# concurrent_limit: int = 16
2024-12-10 09:01:21 +08:00
2024-10-10 15:02:30 +08:00
async def __call__(self, *args, **kwargs) -> np.ndarray:
return await self.func(*args, **kwargs)
2025-02-06 22:56:17 +03:00
@dataclass
class ReasoningResponse:
2025-02-15 00:10:37 +01:00
reasoning_content: str | None
2025-02-06 22:56:17 +03:00
response_content: str
tag: str
2025-02-15 22:37:12 +01:00
def locate_json_string_body_from_string(content: str) -> str | None:
2024-10-10 15:02:30 +08:00
"""Locate the JSON string body from a string"""
2024-11-25 13:29:55 +08:00
try:
maybe_json_str = re.search(r"{.*}", content, re.DOTALL)
if maybe_json_str is not None:
maybe_json_str = maybe_json_str.group(0)
maybe_json_str = maybe_json_str.replace("\\n", "")
maybe_json_str = maybe_json_str.replace("\n", "")
maybe_json_str = maybe_json_str.replace("'", '"')
# json.loads(maybe_json_str) # don't check here, cannot validate schema after all
2024-11-25 13:29:55 +08:00
return maybe_json_str
2024-11-25 13:40:38 +08:00
except Exception:
pass
2024-11-25 13:29:55 +08:00
# try:
# content = (
# content.replace(kw_prompt[:-1], "")
# .replace("user", "")
# .replace("model", "")
# .strip()
# )
2024-11-25 13:40:38 +08:00
# maybe_json_str = "{" + content.split("{")[1].split("}")[0] + "}"
2024-11-25 13:29:55 +08:00
# json.loads(maybe_json_str)
2024-11-25 13:40:38 +08:00
2024-10-10 15:02:30 +08:00
return None
2025-02-09 11:46:01 +01:00
def convert_response_to_json(response: str) -> dict[str, Any]:
2024-10-10 15:02:30 +08:00
json_str = locate_json_string_body_from_string(response)
assert json_str is not None, f"Unable to parse JSON from response: {response}"
try:
data = json.loads(json_str)
return data
except json.JSONDecodeError as e:
logger.error(f"Failed to parse JSON: {json_str}")
raise e from None
2025-02-15 00:10:37 +01:00
def compute_args_hash(*args: Any, cache_type: str | None = None) -> str:
2025-01-24 18:59:24 +08:00
"""Compute a hash for the given arguments.
Args:
*args: Arguments to hash
cache_type: Type of cache (e.g., 'keywords', 'query', 'extract')
2025-01-24 18:59:24 +08:00
Returns:
str: Hash string
"""
import hashlib
# Convert all arguments to strings and join them
args_str = "".join([str(arg) for arg in args])
if cache_type:
args_str = f"{cache_type}:{args_str}"
# Compute MD5 hash
return hashlib.md5(args_str.encode()).hexdigest()
2024-10-10 15:02:30 +08:00
def compute_mdhash_id(content: str, prefix: str = "") -> str:
"""
Compute a unique ID for a given content string.
The ID is a combination of the given prefix and the MD5 hash of the content string.
"""
2024-10-10 15:02:30 +08:00
return prefix + md5(content.encode()).hexdigest()
def limit_async_func_call(max_size: int):
"""Add restriction of maximum concurrent async calls using asyncio.Semaphore"""
2024-10-10 15:02:30 +08:00
def final_decro(func):
sem = asyncio.Semaphore(max_size)
2024-10-10 15:02:30 +08:00
@wraps(func)
async def wait_func(*args, **kwargs):
async with sem:
result = await func(*args, **kwargs)
return result
2024-10-10 15:02:30 +08:00
return wait_func
return final_decro
2024-10-10 15:02:30 +08:00
def wrap_embedding_func_with_attrs(**kwargs):
"""Wrap a function with attributes"""
def final_decro(func) -> EmbeddingFunc:
new_func = EmbeddingFunc(**kwargs, func=func)
return new_func
return final_decro
2024-10-10 15:02:30 +08:00
def load_json(file_name):
if not os.path.exists(file_name):
return None
2024-10-11 11:24:42 +08:00
with open(file_name, encoding="utf-8") as f:
2024-10-10 15:02:30 +08:00
return json.load(f)
2024-10-10 15:02:30 +08:00
def write_json(json_obj, file_name):
2024-10-11 11:24:42 +08:00
with open(file_name, "w", encoding="utf-8") as f:
2024-10-10 15:02:30 +08:00
json.dump(json_obj, f, indent=2, ensure_ascii=False)
2024-10-10 15:02:30 +08:00
def encode_string_by_tiktoken(content: str, model_name: str = "gpt-4o"):
global ENCODER
if ENCODER is None:
ENCODER = tiktoken.encoding_for_model(model_name)
tokens = ENCODER.encode(content)
return tokens
def decode_tokens_by_tiktoken(tokens: list[int], model_name: str = "gpt-4o"):
global ENCODER
if ENCODER is None:
ENCODER = tiktoken.encoding_for_model(model_name)
content = ENCODER.decode(tokens)
return content
2024-10-10 15:02:30 +08:00
def pack_user_ass_to_openai_messages(*args: str):
roles = ["user", "assistant"]
return [
{"role": roles[i % 2], "content": content} for i, content in enumerate(args)
]
2024-10-10 15:02:30 +08:00
def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str]:
"""Split a string by multiple markers"""
if not markers:
return [content]
results = re.split("|".join(re.escape(marker) for marker in markers), content)
return [r.strip() for r in results if r.strip()]
2024-10-10 15:02:30 +08:00
# Refer the utils functions of the official GraphRAG implementation:
# https://github.com/microsoft/graphrag
def clean_str(input: Any) -> str:
"""Clean an input string by removing HTML escapes, control characters, and other unwanted characters."""
# If we get non-string input, just give it back
if not isinstance(input, str):
return input
result = html.unescape(input.strip())
# https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python
return re.sub(r"[\x00-\x1f\x7f-\x9f]", "", result)
2025-02-15 00:10:37 +01:00
def is_float_regex(value: str) -> bool:
2024-10-10 15:02:30 +08:00
return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value))
2025-02-15 00:10:37 +01:00
def truncate_list_by_token_size(
list_data: list[Any], key: Callable[[Any], str], max_token_size: int
) -> list[int]:
2024-10-10 15:02:30 +08:00
"""Truncate a list of data by token size"""
if max_token_size <= 0:
return []
tokens = 0
for i, data in enumerate(list_data):
tokens += len(encode_string_by_tiktoken(key(data)))
if tokens > max_token_size:
return list_data[:i]
return list_data
2024-11-06 11:18:14 -05:00
2025-02-15 22:37:12 +01:00
def list_of_list_to_csv(data: list[list[str]]) -> str:
2024-10-31 14:31:26 +08:00
output = io.StringIO()
2025-01-27 10:13:06 +01:00
writer = csv.writer(
output,
2025-01-27 23:21:34 +08:00
quoting=csv.QUOTE_ALL, # Quote all fields
escapechar="\\", # Use backslash as escape character
quotechar='"', # Use double quotes
lineterminator="\n", # Explicit line terminator
2025-01-27 10:13:06 +01:00
)
2024-10-31 14:31:26 +08:00
writer.writerows(data)
return output.getvalue()
2024-11-06 11:18:14 -05:00
2025-02-15 22:37:12 +01:00
def csv_string_to_list(csv_string: str) -> list[list[str]]:
# Clean the string by removing NUL characters
2025-01-27 23:21:34 +08:00
cleaned_string = csv_string.replace("\0", "")
output = io.StringIO(cleaned_string)
reader = csv.reader(
output,
2025-01-27 23:21:34 +08:00
quoting=csv.QUOTE_ALL, # Match the writer configuration
escapechar="\\", # Use backslash as escape character
quotechar='"', # Use double quotes
)
2025-01-27 23:21:34 +08:00
try:
return [row for row in reader]
except csv.Error as e:
raise ValueError(f"Failed to parse CSV string: {str(e)}")
finally:
output.close()
2024-10-31 14:31:26 +08:00
2024-10-10 15:02:30 +08:00
def save_data_to_file(data, file_name):
with open(file_name, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=4)
2024-10-20 23:08:26 +08:00
2024-10-25 13:32:25 +05:30
2024-10-20 23:08:26 +08:00
def xml_to_json(xml_file):
try:
tree = ET.parse(xml_file)
root = tree.getroot()
# Print the root element's tag and attributes to confirm the file has been correctly loaded
print(f"Root element: {root.tag}")
print(f"Root attributes: {root.attrib}")
2024-10-25 13:32:25 +05:30
data = {"nodes": [], "edges": []}
2024-10-20 23:08:26 +08:00
# Use namespace
2024-10-25 13:32:25 +05:30
namespace = {"": "http://graphml.graphdrawing.org/xmlns"}
2024-10-20 23:08:26 +08:00
2024-10-25 13:32:25 +05:30
for node in root.findall(".//node", namespace):
2024-10-20 23:08:26 +08:00
node_data = {
2024-10-25 13:32:25 +05:30
"id": node.get("id").strip('"'),
"entity_type": node.find("./data[@key='d0']", namespace).text.strip('"')
if node.find("./data[@key='d0']", namespace) is not None
else "",
"description": node.find("./data[@key='d1']", namespace).text
if node.find("./data[@key='d1']", namespace) is not None
else "",
"source_id": node.find("./data[@key='d2']", namespace).text
if node.find("./data[@key='d2']", namespace) is not None
else "",
2024-10-20 23:08:26 +08:00
}
data["nodes"].append(node_data)
2024-10-25 13:32:25 +05:30
for edge in root.findall(".//edge", namespace):
2024-10-20 23:08:26 +08:00
edge_data = {
2024-10-25 13:32:25 +05:30
"source": edge.get("source").strip('"'),
"target": edge.get("target").strip('"'),
"weight": float(edge.find("./data[@key='d3']", namespace).text)
if edge.find("./data[@key='d3']", namespace) is not None
else 0.0,
"description": edge.find("./data[@key='d4']", namespace).text
if edge.find("./data[@key='d4']", namespace) is not None
else "",
"keywords": edge.find("./data[@key='d5']", namespace).text
if edge.find("./data[@key='d5']", namespace) is not None
else "",
"source_id": edge.find("./data[@key='d6']", namespace).text
if edge.find("./data[@key='d6']", namespace) is not None
else "",
2024-10-20 23:08:26 +08:00
}
data["edges"].append(edge_data)
# Print the number of nodes and edges found
print(f"Found {len(data['nodes'])} nodes and {len(data['edges'])} edges")
return data
except ET.ParseError as e:
print(f"Error parsing XML file: {e}")
return None
except Exception as e:
print(f"An error occurred: {e}")
return None
2024-11-06 11:18:14 -05:00
2025-02-15 00:10:37 +01:00
def process_combine_contexts(hl: str, ll: str):
header = None
2024-10-31 14:31:26 +08:00
list_hl = csv_string_to_list(hl.strip())
list_ll = csv_string_to_list(ll.strip())
2024-11-03 17:53:53 +08:00
if list_hl:
2024-11-06 11:18:14 -05:00
header = list_hl[0]
list_hl = list_hl[1:]
if list_ll:
header = list_ll[0]
list_ll = list_ll[1:]
if header is None:
return ""
2024-11-03 17:53:53 +08:00
if list_hl:
2024-11-06 11:18:14 -05:00
list_hl = [",".join(item[1:]) for item in list_hl if item]
if list_ll:
2024-11-06 11:18:14 -05:00
list_ll = [",".join(item[1:]) for item in list_ll if item]
2024-11-14 15:59:37 +08:00
combined_sources = []
seen = set()
2024-11-14 15:59:37 +08:00
for item in list_hl + list_ll:
if item and item not in seen:
combined_sources.append(item)
seen.add(item)
2024-11-03 17:53:53 +08:00
2024-11-14 15:59:37 +08:00
combined_sources_result = [",\t".join(header)]
2024-11-06 11:18:14 -05:00
2024-11-14 15:59:37 +08:00
for i, item in enumerate(combined_sources, start=1):
combined_sources_result.append(f"{i},\t{item}")
2024-11-14 15:59:37 +08:00
combined_sources_result = "\n".join(combined_sources_result)
return combined_sources_result
async def get_best_cached_response(
hashing_kv,
current_embedding,
similarity_threshold=0.95,
mode="default",
use_llm_check=False,
llm_func=None,
original_prompt=None,
2025-01-24 18:59:24 +08:00
cache_type=None,
2025-02-15 22:37:12 +01:00
) -> str | None:
2025-02-02 04:27:55 +08:00
logger.debug(
f"get_best_cached_response: mode={mode} cache_type={cache_type} use_llm_check={use_llm_check}"
)
mode_cache = await hashing_kv.get_by_id(mode)
if not mode_cache:
return None
best_similarity = -1
best_response = None
best_prompt = None
best_cache_id = None
# Only iterate through cache entries for this mode
for cache_id, cache_data in mode_cache.items():
2025-01-24 18:59:24 +08:00
# Skip if cache_type doesn't match
if cache_type and cache_data.get("cache_type") != cache_type:
continue
if cache_data["embedding"] is None:
continue
# Convert cached embedding list to ndarray
cached_quantized = np.frombuffer(
bytes.fromhex(cache_data["embedding"]), dtype=np.uint8
).reshape(cache_data["embedding_shape"])
cached_embedding = dequantize_embedding(
cached_quantized,
cache_data["embedding_min"],
cache_data["embedding_max"],
)
similarity = cosine_similarity(current_embedding, cached_embedding)
if similarity > best_similarity:
best_similarity = similarity
best_response = cache_data["return"]
best_prompt = cache_data["original_prompt"]
best_cache_id = cache_id
if best_similarity > similarity_threshold:
# If LLM check is enabled and all required parameters are provided
2025-02-13 14:07:36 +08:00
if (
use_llm_check
and llm_func
and original_prompt
and best_prompt
and best_response is not None
):
compare_prompt = PROMPTS["similarity_check"].format(
original_prompt=original_prompt, cached_prompt=best_prompt
)
try:
llm_result = await llm_func(compare_prompt)
llm_result = llm_result.strip()
llm_similarity = float(llm_result)
# Replace vector similarity with LLM similarity score
best_similarity = llm_similarity
if best_similarity < similarity_threshold:
log_data = {
"event": "cache_rejected_by_llm",
"type": cache_type,
"mode": mode,
"original_question": original_prompt[:100] + "..."
if len(original_prompt) > 100
else original_prompt,
"cached_question": best_prompt[:100] + "..."
if len(best_prompt) > 100
else best_prompt,
"similarity_score": round(best_similarity, 4),
"threshold": similarity_threshold,
}
logger.debug(json.dumps(log_data, ensure_ascii=False))
logger.info(f"Cache rejected by LLM(mode:{mode} tpye:{cache_type})")
return None
except Exception as e: # Catch all possible exceptions
logger.warning(f"LLM similarity check failed: {e}")
return None # Return None directly when LLM check fails
prompt_display = (
best_prompt[:50] + "..." if len(best_prompt) > 50 else best_prompt
)
log_data = {
"event": "cache_hit",
"type": cache_type,
"mode": mode,
"similarity": round(best_similarity, 4),
"cache_id": best_cache_id,
"original_prompt": prompt_display,
}
logger.debug(json.dumps(log_data, ensure_ascii=False))
return best_response
return None
def cosine_similarity(v1, v2):
"""Calculate cosine similarity between two vectors"""
dot_product = np.dot(v1, v2)
norm1 = np.linalg.norm(v1)
norm2 = np.linalg.norm(v2)
return dot_product / (norm1 * norm2)
2025-02-15 22:37:32 +01:00
def quantize_embedding(embedding: np.ndarray | list[float], bits: int = 8) -> tuple:
"""Quantize embedding to specified bits"""
# Convert list to numpy array if needed
if isinstance(embedding, list):
embedding = np.array(embedding)
2025-02-01 15:22:40 +08:00
# Calculate min/max values for reconstruction
min_val = embedding.min()
max_val = embedding.max()
# Quantize to 0-255 range
scale = (2**bits - 1) / (max_val - min_val)
quantized = np.round((embedding - min_val) * scale).astype(np.uint8)
return quantized, min_val, max_val
def dequantize_embedding(
quantized: np.ndarray, min_val: float, max_val: float, bits=8
) -> np.ndarray:
"""Restore quantized embedding"""
scale = (max_val - min_val) / (2**bits - 1)
return (quantized * scale + min_val).astype(np.float32)
2024-12-08 10:37:55 +08:00
2025-01-31 15:33:50 +08:00
async def handle_cache(
2025-02-02 01:56:32 +08:00
hashing_kv,
args_hash,
prompt,
mode="default",
cache_type=None,
force_llm_cache=False,
2025-01-31 15:33:50 +08:00
):
2024-12-08 10:37:55 +08:00
"""Generic cache handling function"""
2025-02-02 01:56:32 +08:00
if hashing_kv is None or not (
force_llm_cache or hashing_kv.global_config.get("enable_llm_cache")
):
2024-12-08 10:37:55 +08:00
return None, None, None, None
if mode != "default":
# Get embedding cache configuration
embedding_cache_config = hashing_kv.global_config.get(
"embedding_cache_config",
{"enabled": False, "similarity_threshold": 0.95, "use_llm_check": False},
)
is_embedding_cache_enabled = embedding_cache_config["enabled"]
use_llm_check = embedding_cache_config.get("use_llm_check", False)
quantized = min_val = max_val = None
if is_embedding_cache_enabled:
# Use embedding cache
current_embedding = await hashing_kv.embedding_func([prompt])
2025-02-02 04:27:55 +08:00
llm_model_func = hashing_kv.global_config.get("llm_model_func")
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
best_cached_response = await get_best_cached_response(
hashing_kv,
current_embedding[0],
similarity_threshold=embedding_cache_config["similarity_threshold"],
mode=mode,
use_llm_check=use_llm_check,
llm_func=llm_model_func if use_llm_check else None,
original_prompt=prompt,
cache_type=cache_type,
)
if best_cached_response is not None:
logger.info(f"Embedding cached hit(mode:{mode} type:{cache_type})")
return best_cached_response, None, None, None
else:
# if caching keyword embedding is enabled, return the quantized embedding for saving it latter
logger.info(f"Embedding cached missed(mode:{mode} type:{cache_type})")
return None, quantized, min_val, max_val
# For default mode or is_embedding_cache_enabled is False, use regular cache
# default mode is for extract_entities or naive query
if exists_func(hashing_kv, "get_by_mode_and_id"):
mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {}
2024-12-08 10:37:55 +08:00
else:
mode_cache = await hashing_kv.get_by_id(mode) or {}
if args_hash in mode_cache:
logger.info(f"Non-embedding cached hit(mode:{mode} type:{cache_type})")
return mode_cache[args_hash]["return"], None, None, None
logger.info(f"Non-embedding cached missed(mode:{mode} type:{cache_type})")
return None, None, None, None
2024-12-08 10:37:55 +08:00
@dataclass
class CacheData:
args_hash: str
content: str
prompt: str
2025-02-15 22:37:12 +01:00
quantized: np.ndarray | None = None
min_val: float | None = None
max_val: float | None = None
2024-12-08 10:37:55 +08:00
mode: str = "default"
cache_type: str = "query"
2024-12-08 10:37:55 +08:00
async def save_to_cache(hashing_kv, cache_data: CacheData):
if hashing_kv is None or hasattr(cache_data.content, "__aiter__"):
2024-12-08 10:37:55 +08:00
return
if exists_func(hashing_kv, "get_by_mode_and_id"):
mode_cache = (
await hashing_kv.get_by_mode_and_id(cache_data.mode, cache_data.args_hash)
or {}
)
else:
mode_cache = await hashing_kv.get_by_id(cache_data.mode) or {}
2024-12-08 10:37:55 +08:00
mode_cache[cache_data.args_hash] = {
"return": cache_data.content,
2025-02-01 22:27:49 +08:00
"cache_type": cache_data.cache_type,
2024-12-08 10:37:55 +08:00
"embedding": cache_data.quantized.tobytes().hex()
if cache_data.quantized is not None
else None,
"embedding_shape": cache_data.quantized.shape
if cache_data.quantized is not None
else None,
"embedding_min": cache_data.min_val,
"embedding_max": cache_data.max_val,
"original_prompt": cache_data.prompt,
}
await hashing_kv.upsert({cache_data.mode: mode_cache})
2024-12-09 19:14:27 +08:00
def safe_unicode_decode(content):
# Regular expression to find all Unicode escape sequences of the form \uXXXX
unicode_escape_pattern = re.compile(r"\\u([0-9a-fA-F]{4})")
# Function to replace the Unicode escape with the actual character
def replace_unicode_escape(match):
# Convert the matched hexadecimal value into the actual Unicode character
return chr(int(match.group(1), 16))
# Perform the substitution
decoded_content = unicode_escape_pattern.sub(
replace_unicode_escape, content.decode("utf-8")
)
return decoded_content
def exists_func(obj, func_name: str) -> bool:
"""Check if a function exists in an object or not.
:param obj:
:param func_name:
:return: True / False
"""
if callable(getattr(obj, func_name, None)):
return True
else:
return False
2025-01-24 18:59:24 +08:00
2025-02-15 00:10:37 +01:00
def get_conversation_turns(
conversation_history: list[dict[str, Any]], num_turns: int
) -> str:
2025-01-24 18:59:24 +08:00
"""
Process conversation history to get the specified number of complete turns.
Args:
conversation_history: List of conversation messages in chronological order
num_turns: Number of complete turns to include
Returns:
Formatted string of the conversation history
"""
# Group messages into turns
2025-02-15 00:10:37 +01:00
turns: list[list[dict[str, Any]]] = []
messages: list[dict[str, Any]] = []
2025-01-24 18:59:24 +08:00
# First, filter out keyword extraction messages
for msg in conversation_history:
if msg["role"] == "assistant" and (
msg["content"].startswith('{ "high_level_keywords"')
or msg["content"].startswith("{'high_level_keywords'")
):
continue
messages.append(msg)
# Then process messages in chronological order
i = 0
while i < len(messages) - 1:
msg1 = messages[i]
msg2 = messages[i + 1]
# Check if we have a user-assistant or assistant-user pair
if (msg1["role"] == "user" and msg2["role"] == "assistant") or (
msg1["role"] == "assistant" and msg2["role"] == "user"
):
# Always put user message first in the turn
if msg1["role"] == "assistant":
turn = [msg2, msg1] # user, assistant
else:
turn = [msg1, msg2] # user, assistant
turns.append(turn)
i += 2
2025-01-24 18:59:24 +08:00
# Keep only the most recent num_turns
if len(turns) > num_turns:
turns = turns[-num_turns:]
# Format the turns into a string
2025-02-15 00:10:37 +01:00
formatted_turns: list[str] = []
2025-01-24 18:59:24 +08:00
for turn in turns:
formatted_turns.extend(
[f"user: {turn[0]['content']}", f"assistant: {turn[1]['content']}"]
)
return "\n".join(formatted_turns)
2025-02-06 22:56:17 +03:00
def extract_reasoning(response: str, tag: str) -> ReasoningResponse:
"""Extract the reasoning section and the following section from the LLM response.
Args:
response: LLM response
tag: Tag to extract
Returns:
ReasoningResponse: Reasoning section and following section
"""
soup = bs4.BeautifulSoup(response, "html.parser")
reasoning_section = soup.find(tag)
if reasoning_section is None:
return ReasoningResponse(None, response, tag)
reasoning_content = reasoning_section.get_text().strip()
after_reasoning_section = reasoning_section.next_sibling
if after_reasoning_section is None:
return ReasoningResponse(reasoning_content, "", tag)
after_reasoning_content = after_reasoning_section.get_text().strip()
return ReasoningResponse(reasoning_content, after_reasoning_content, tag)