ragflow/graphrag/utils.py
Yongteng Lei 56cd576876
Refa: revise the implementation of LightRAG and enable response caching (#9828)
### What problem does this PR solve?

This revision performed a comprehensive check on LightRAG to ensure the
correctness of its implementation. It **did not involve** Entity
Resolution and Community Reports Generation. There is an example using
default entity types and the General chunking method, which shows good
results in both time and effectiveness. Moreover, response caching is
enabled for resuming failed tasks.


[The-Necklace.pdf](https://github.com/user-attachments/files/22042432/The-Necklace.pdf)

After:


![img_v3_02pk_177dbc6a-e7cc-4732-b202-ad4682d171fg](https://github.com/user-attachments/assets/5ef1d93a-9109-4fe9-8a7b-a65add16f82b)


```bash
Begin at:
Fri, 29 Aug 2025 16:48:03 GMT
Duration:
222.31 s
Progress:
16:48:04 Task has been received.
16:48:06 Page(1~7): Start to parse.
16:48:06 Page(1~7): OCR started
16:48:08 Page(1~7): OCR finished (1.89s)
16:48:11 Page(1~7): Layout analysis (3.72s)
16:48:11 Page(1~7): Table analysis (0.00s)
16:48:11 Page(1~7): Text merged (0.00s)
16:48:11 Page(1~7): Finish parsing.
16:48:12 Page(1~7): Generate 7 chunks
16:48:12 Page(1~7): Embedding chunks (0.29s)
16:48:12 Page(1~7): Indexing done (0.04s). Task done (7.84s)
16:48:17 Start processing for f421fb06849e11f0bdd32724b93a52b2: She had no dresses, no je...
16:48:17 Start processing for f421fb06849e11f0bdd32724b93a52b2: Her husband, already half...
16:48:17 Start processing for f421fb06849e11f0bdd32724b93a52b2: And this life lasted ten ...
16:48:17 Start processing for f421fb06849e11f0bdd32724b93a52b2: Then she asked, hesitatin...
16:49:30 Completed processing for f421fb06849e11f0bdd32724b93a52b2: She had no dresses, no je... after 1 gleanings, 21985 tokens.
16:49:30 Entities extraction of chunk 3 1/7 done, 12 nodes, 13 edges, 21985 tokens.
16:49:40 Completed processing for f421fb06849e11f0bdd32724b93a52b2: Finally, she replied, hes... after 1 gleanings, 22584 tokens.
16:49:40 Entities extraction of chunk 5 2/7 done, 19 nodes, 19 edges, 22584 tokens.
16:50:02 Completed processing for f421fb06849e11f0bdd32724b93a52b2: Then she asked, hesitatin... after 1 gleanings, 24610 tokens.
16:50:02 Entities extraction of chunk 0 3/7 done, 16 nodes, 28 edges, 24610 tokens.
16:50:03 Completed processing for f421fb06849e11f0bdd32724b93a52b2: And this life lasted ten ... after 1 gleanings, 24031 tokens.
16:50:04 Entities extraction of chunk 1 4/7 done, 24 nodes, 22 edges, 24031 tokens.
16:50:14 Completed processing for f421fb06849e11f0bdd32724b93a52b2: So they begged the jewell... after 1 gleanings, 24635 tokens.
16:50:14 Entities extraction of chunk 6 5/7 done, 27 nodes, 26 edges, 24635 tokens.
16:50:29 Completed processing for f421fb06849e11f0bdd32724b93a52b2: Her husband, already half... after 1 gleanings, 25758 tokens.
16:50:29 Entities extraction of chunk 2 6/7 done, 25 nodes, 35 edges, 25758 tokens.
16:51:35 Completed processing for f421fb06849e11f0bdd32724b93a52b2: The Necklace By Guy de Ma... after 1 gleanings, 27491 tokens.
16:51:35 Entities extraction of chunk 4 7/7 done, 39 nodes, 37 edges, 27491 tokens.
16:51:35 Entities and relationships extraction done, 147 nodes, 177 edges, 171094 tokens, 198.58s.
16:51:35 Entities merging done, 0.01s.
16:51:35 Relationships merging done, 0.01s.
16:51:35 ignored 7 relations due to missing entities.
16:51:35 generated subgraph for doc f421fb06849e11f0bdd32724b93a52b2 in 198.68 seconds.
16:51:35 run_graphrag f421fb06849e11f0bdd32724b93a52b2 graphrag_task_lock acquired
16:51:35 set_graph removed 0 nodes and 0 edges from index in 0.00s.
16:51:35 Get embedding of nodes: 9/147
16:51:35 Get embedding of nodes: 109/147
16:51:37 Get embedding of edges: 9/170
16:51:37 Get embedding of edges: 109/170
16:51:40 set_graph converted graph change to 319 chunks in 4.21s.
16:51:40 Insert chunks: 4/319
16:51:40 Insert chunks: 104/319
16:51:40 Insert chunks: 204/319
16:51:40 Insert chunks: 304/319
16:51:40 set_graph added/updated 147 nodes and 170 edges from index in 0.53s.
16:51:40 merging subgraph for doc f421fb06849e11f0bdd32724b93a52b2 into the global graph done in 4.79 seconds.
16:51:40 Knowledge Graph done (204.29s)
```

Before:


![img_v3_02pk_63370edf-ecee-4ee8-8ac8-69c8d2c712fg](https://github.com/user-attachments/assets/1162eb0f-68c2-4de5-abe0-cdfa168f71de)

```bash
Begin at:
Fri, 29 Aug 2025 17:00:47 GMT
processDuration:
173.38 s
Progress:
17:00:49 Task has been received.
17:00:51 Page(1~7): Start to parse.
17:00:51 Page(1~7): OCR started
17:00:53 Page(1~7): OCR finished (1.82s)
17:00:57 Page(1~7): Layout analysis (3.64s)
17:00:57 Page(1~7): Table analysis (0.00s)
17:00:57 Page(1~7): Text merged (0.00s)
17:00:57 Page(1~7): Finish parsing.
17:00:57 Page(1~7): Generate 7 chunks
17:00:57 Page(1~7): Embedding chunks (0.31s)
17:00:57 Page(1~7): Indexing done (0.03s). Task done (7.88s)
17:00:57 created task graphrag
17:01:00 Task has been received.
17:02:17 Entities extraction of chunk 1 1/7 done, 9 nodes, 9 edges, 10654 tokens.
17:02:31 Entities extraction of chunk 2 2/7 done, 12 nodes, 13 edges, 11066 tokens.
17:02:33 Entities extraction of chunk 4 3/7 done, 9 nodes, 10 edges, 10433 tokens.
17:02:42 Entities extraction of chunk 5 4/7 done, 11 nodes, 14 edges, 11290 tokens.
17:02:52 Entities extraction of chunk 6 5/7 done, 13 nodes, 15 edges, 11039 tokens.
17:02:55 Entities extraction of chunk 3 6/7 done, 14 nodes, 13 edges, 11466 tokens.
17:03:32 Entities extraction of chunk 0 7/7 done, 19 nodes, 18 edges, 13107 tokens.
17:03:32 Entities and relationships extraction done, 71 nodes, 89 edges, 79055 tokens, 149.66s.
17:03:32 Entities merging done, 0.01s.
17:03:32 Relationships merging done, 0.01s.
17:03:32 ignored 1 relations due to missing entities.
17:03:32 generated subgraph for doc b1d9d3b6848711f0aacd7ddc0714c4d3 in 149.69 seconds.
17:03:32 run_graphrag b1d9d3b6848711f0aacd7ddc0714c4d3 graphrag_task_lock acquired
17:03:32 set_graph removed 0 nodes and 0 edges from index in 0.00s.
17:03:32 Get embedding of nodes: 9/71
17:03:33 Get embedding of edges: 9/88
17:03:34 set_graph converted graph change to 161 chunks in 2.27s.
17:03:34 Insert chunks: 4/161
17:03:34 Insert chunks: 104/161
17:03:34 set_graph added/updated 71 nodes and 88 edges from index in 0.28s.
17:03:34 merging subgraph for doc b1d9d3b6848711f0aacd7ddc0714c4d3 into the global graph done in 2.60 seconds.
17:03:34 Knowledge Graph done (153.18s)

```

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
- [x] Refactoring
- [x] Performance Improvement
2025-08-29 17:58:36 +08:00

629 lines
23 KiB
Python

# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""
Reference:
- [graphrag](https://github.com/microsoft/graphrag)
- [LightRag](https://github.com/HKUDS/LightRAG)
"""
import dataclasses
import html
import json
import logging
import os
import re
import time
from collections import defaultdict
from hashlib import md5
from typing import Any, Callable, Set, Tuple
import networkx as nx
import numpy as np
import trio
import xxhash
from networkx.readwrite import json_graph
from api import settings
from api.utils import get_uuid
from api.utils.api_utils import timeout
from rag.nlp import rag_tokenizer, search
from rag.utils.doc_store_conn import OrderByExpr
from rag.utils.redis_conn import REDIS_CONN
GRAPH_FIELD_SEP = "<SEP>"
ErrorHandlerFn = Callable[[BaseException | None, str | None, dict | None], None]
chat_limiter = trio.CapacityLimiter(int(os.environ.get("MAX_CONCURRENT_CHATS", 10)))
@dataclasses.dataclass
class GraphChange:
removed_nodes: Set[str] = dataclasses.field(default_factory=set)
added_updated_nodes: Set[str] = dataclasses.field(default_factory=set)
removed_edges: Set[Tuple[str, str]] = dataclasses.field(default_factory=set)
added_updated_edges: Set[Tuple[str, str]] = dataclasses.field(default_factory=set)
def perform_variable_replacements(input: str, history: list[dict] | None = None, variables: dict | None = None) -> str:
"""Perform variable replacements on the input string and in a chat log."""
if history is None:
history = []
if variables is None:
variables = {}
result = input
def replace_all(input: str) -> str:
result = input
for k, v in variables.items():
result = result.replace(f"{{{k}}}", str(v))
return result
result = replace_all(result)
for i, entry in enumerate(history):
if entry.get("role") == "system":
entry["content"] = replace_all(entry.get("content") or "")
return result
def clean_str(input: Any) -> str:
"""Clean an input string by removing HTML escapes, control characters, and other unwanted characters."""
# If we get non-string input, just give it back
if not isinstance(input, str):
return input
result = html.unescape(input.strip())
# https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python
return re.sub(r"[\"\x00-\x1f\x7f-\x9f]", "", result)
def dict_has_keys_with_types(data: dict, expected_fields: list[tuple[str, type]]) -> bool:
"""Return True if the given dictionary has the given keys with the given types."""
for field, field_type in expected_fields:
if field not in data:
return False
value = data[field]
if not isinstance(value, field_type):
return False
return True
def get_llm_cache(llmnm, txt, history, genconf):
hasher = xxhash.xxh64()
hasher.update(str(llmnm).encode("utf-8"))
hasher.update(str(txt).encode("utf-8"))
hasher.update(str(history).encode("utf-8"))
hasher.update(str(genconf).encode("utf-8"))
k = hasher.hexdigest()
bin = REDIS_CONN.get(k)
if not bin:
return None
return bin
def set_llm_cache(llmnm, txt, v, history, genconf):
hasher = xxhash.xxh64()
hasher.update(str(llmnm).encode("utf-8"))
hasher.update(str(txt).encode("utf-8"))
hasher.update(str(history).encode("utf-8"))
hasher.update(str(genconf).encode("utf-8"))
k = hasher.hexdigest()
REDIS_CONN.set(k, v.encode("utf-8"), 24 * 3600)
def get_embed_cache(llmnm, txt):
hasher = xxhash.xxh64()
hasher.update(str(llmnm).encode("utf-8"))
hasher.update(str(txt).encode("utf-8"))
k = hasher.hexdigest()
bin = REDIS_CONN.get(k)
if not bin:
return
return np.array(json.loads(bin))
def set_embed_cache(llmnm, txt, arr):
hasher = xxhash.xxh64()
hasher.update(str(llmnm).encode("utf-8"))
hasher.update(str(txt).encode("utf-8"))
k = hasher.hexdigest()
arr = json.dumps(arr.tolist() if isinstance(arr, np.ndarray) else arr)
REDIS_CONN.set(k, arr.encode("utf-8"), 24 * 3600)
def get_tags_from_cache(kb_ids):
hasher = xxhash.xxh64()
hasher.update(str(kb_ids).encode("utf-8"))
k = hasher.hexdigest()
bin = REDIS_CONN.get(k)
if not bin:
return
return bin
def set_tags_to_cache(kb_ids, tags):
hasher = xxhash.xxh64()
hasher.update(str(kb_ids).encode("utf-8"))
k = hasher.hexdigest()
REDIS_CONN.set(k, json.dumps(tags).encode("utf-8"), 600)
def tidy_graph(graph: nx.Graph, callback, check_attribute: bool = True):
"""
Ensure all nodes and edges in the graph have some essential attribute.
"""
def is_valid_item(node_attrs: dict) -> bool:
valid_node = True
for attr in ["description", "source_id"]:
if attr not in node_attrs:
valid_node = False
break
return valid_node
if check_attribute:
purged_nodes = []
for node, node_attrs in graph.nodes(data=True):
if not is_valid_item(node_attrs):
purged_nodes.append(node)
for node in purged_nodes:
graph.remove_node(node)
if purged_nodes and callback:
callback(msg=f"Purged {len(purged_nodes)} nodes from graph due to missing essential attributes.")
purged_edges = []
for source, target, attr in graph.edges(data=True):
if check_attribute:
if not is_valid_item(attr):
purged_edges.append((source, target))
if "keywords" not in attr:
attr["keywords"] = []
for source, target in purged_edges:
graph.remove_edge(source, target)
if purged_edges and callback:
callback(msg=f"Purged {len(purged_edges)} edges from graph due to missing essential attributes.")
def get_from_to(node1, node2):
if node1 < node2:
return (node1, node2)
else:
return (node2, node1)
def graph_merge(g1: nx.Graph, g2: nx.Graph, change: GraphChange):
"""Merge graph g2 into g1 in place."""
for node_name, attr in g2.nodes(data=True):
change.added_updated_nodes.add(node_name)
if not g1.has_node(node_name):
g1.add_node(node_name, **attr)
continue
node = g1.nodes[node_name]
node["description"] += GRAPH_FIELD_SEP + attr["description"]
# A node's source_id indicates which chunks it came from.
node["source_id"] += attr["source_id"]
for source, target, attr in g2.edges(data=True):
change.added_updated_edges.add(get_from_to(source, target))
edge = g1.get_edge_data(source, target)
if edge is None:
g1.add_edge(source, target, **attr)
continue
edge["weight"] += attr.get("weight", 0)
edge["description"] += GRAPH_FIELD_SEP + attr["description"]
edge["keywords"] += attr["keywords"]
# A edge's source_id indicates which chunks it came from.
edge["source_id"] += attr["source_id"]
for node_degree in g1.degree:
g1.nodes[str(node_degree[0])]["rank"] = int(node_degree[1])
# A graph's source_id indicates which documents it came from.
if "source_id" not in g1.graph:
g1.graph["source_id"] = []
g1.graph["source_id"] += g2.graph.get("source_id", [])
return g1
def compute_args_hash(*args):
return md5(str(args).encode()).hexdigest()
def handle_single_entity_extraction(
record_attributes: list[str],
chunk_key: str,
):
if len(record_attributes) < 4 or record_attributes[0] != '"entity"':
return None
# add this record as a node in the G
entity_name = clean_str(record_attributes[1].upper())
if not entity_name.strip():
return None
entity_type = clean_str(record_attributes[2].upper())
entity_description = clean_str(record_attributes[3])
entity_source_id = chunk_key
return dict(
entity_name=entity_name.upper(),
entity_type=entity_type.upper(),
description=entity_description,
source_id=entity_source_id,
)
def handle_single_relationship_extraction(record_attributes: list[str], chunk_key: str):
if len(record_attributes) < 5 or record_attributes[0] != '"relationship"':
return None
# add this record as edge
source = clean_str(record_attributes[1].upper())
target = clean_str(record_attributes[2].upper())
edge_description = clean_str(record_attributes[3])
edge_keywords = clean_str(record_attributes[4])
edge_source_id = chunk_key
weight = float(record_attributes[-1]) if is_float_regex(record_attributes[-1]) else 1.0
pair = sorted([source.upper(), target.upper()])
return dict(
src_id=pair[0],
tgt_id=pair[1],
weight=weight,
description=edge_description,
keywords=edge_keywords,
source_id=edge_source_id,
metadata={"created_at": time.time()},
)
def pack_user_ass_to_openai_messages(*args: str):
roles = ["user", "assistant"]
return [{"role": roles[i % 2], "content": content} for i, content in enumerate(args)]
def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str]:
"""Split a string by multiple markers"""
if not markers:
return [content]
results = re.split("|".join(re.escape(marker) for marker in markers), content)
return [r.strip() for r in results if r.strip()]
def is_float_regex(value):
return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value))
def chunk_id(chunk):
return xxhash.xxh64((chunk["content_with_weight"] + chunk["kb_id"]).encode("utf-8")).hexdigest()
async def graph_node_to_chunk(kb_id, embd_mdl, ent_name, meta, chunks):
global chat_limiter
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
chunk = {
"id": get_uuid(),
"important_kwd": [ent_name],
"title_tks": rag_tokenizer.tokenize(ent_name),
"entity_kwd": ent_name,
"knowledge_graph_kwd": "entity",
"entity_type_kwd": meta["entity_type"],
"content_with_weight": json.dumps(meta, ensure_ascii=False),
"content_ltks": rag_tokenizer.tokenize(meta["description"]),
"source_id": meta["source_id"],
"kb_id": kb_id,
"available_int": 0,
}
chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"])
ebd = get_embed_cache(embd_mdl.llm_name, ent_name)
if ebd is None:
async with chat_limiter:
with trio.fail_after(3 if enable_timeout_assertion else 30000000):
ebd, _ = await trio.to_thread.run_sync(lambda: embd_mdl.encode([ent_name]))
ebd = ebd[0]
set_embed_cache(embd_mdl.llm_name, ent_name, ebd)
assert ebd is not None
chunk["q_%d_vec" % len(ebd)] = ebd
chunks.append(chunk)
@timeout(3, 3)
def get_relation(tenant_id, kb_id, from_ent_name, to_ent_name, size=1):
ents = from_ent_name
if isinstance(ents, str):
ents = [from_ent_name]
if isinstance(to_ent_name, str):
to_ent_name = [to_ent_name]
ents.extend(to_ent_name)
ents = list(set(ents))
conds = {"fields": ["content_with_weight"], "size": size, "from_entity_kwd": ents, "to_entity_kwd": ents, "knowledge_graph_kwd": ["relation"]}
res = []
es_res = settings.retrievaler.search(conds, search.index_name(tenant_id), [kb_id] if isinstance(kb_id, str) else kb_id)
for id in es_res.ids:
try:
if size == 1:
return json.loads(es_res.field[id]["content_with_weight"])
res.append(json.loads(es_res.field[id]["content_with_weight"]))
except Exception:
continue
return res
async def graph_edge_to_chunk(kb_id, embd_mdl, from_ent_name, to_ent_name, meta, chunks):
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
chunk = {
"id": get_uuid(),
"from_entity_kwd": from_ent_name,
"to_entity_kwd": to_ent_name,
"knowledge_graph_kwd": "relation",
"content_with_weight": json.dumps(meta, ensure_ascii=False),
"content_ltks": rag_tokenizer.tokenize(meta["description"]),
"important_kwd": meta["keywords"],
"source_id": meta["source_id"],
"weight_int": int(meta["weight"]),
"kb_id": kb_id,
"available_int": 0,
}
chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"])
txt = f"{from_ent_name}->{to_ent_name}"
ebd = get_embed_cache(embd_mdl.llm_name, txt)
if ebd is None:
async with chat_limiter:
with trio.fail_after(3 if enable_timeout_assertion else 300000000):
ebd, _ = await trio.to_thread.run_sync(lambda: embd_mdl.encode([txt + f": {meta['description']}"]))
ebd = ebd[0]
set_embed_cache(embd_mdl.llm_name, txt, ebd)
assert ebd is not None
chunk["q_%d_vec" % len(ebd)] = ebd
chunks.append(chunk)
async def does_graph_contains(tenant_id, kb_id, doc_id):
# Get doc_ids of graph
fields = ["source_id"]
condition = {
"knowledge_graph_kwd": ["graph"],
"removed_kwd": "N",
}
res = await trio.to_thread.run_sync(lambda: settings.docStoreConn.search(fields, [], condition, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [kb_id]))
fields2 = settings.docStoreConn.getFields(res, fields)
graph_doc_ids = set()
for chunk_id in fields2.keys():
graph_doc_ids = set(fields2[chunk_id]["source_id"])
return doc_id in graph_doc_ids
async def get_graph_doc_ids(tenant_id, kb_id) -> list[str]:
conds = {"fields": ["source_id"], "removed_kwd": "N", "size": 1, "knowledge_graph_kwd": ["graph"]}
res = await trio.to_thread.run_sync(lambda: settings.retrievaler.search(conds, search.index_name(tenant_id), [kb_id]))
doc_ids = []
if res.total == 0:
return doc_ids
for id in res.ids:
doc_ids = res.field[id]["source_id"]
return doc_ids
async def get_graph(tenant_id, kb_id, exclude_rebuild=None):
conds = {"fields": ["content_with_weight", "removed_kwd", "source_id"], "size": 1, "knowledge_graph_kwd": ["graph"]}
res = await trio.to_thread.run_sync(settings.retrievaler.search, conds, search.index_name(tenant_id), [kb_id])
if not res.total == 0:
for id in res.ids:
try:
if res.field[id]["removed_kwd"] == "N":
g = json_graph.node_link_graph(json.loads(res.field[id]["content_with_weight"]), edges="edges")
if "source_id" not in g.graph:
g.graph["source_id"] = res.field[id]["source_id"]
else:
g = await rebuild_graph(tenant_id, kb_id, exclude_rebuild)
return g
except Exception:
continue
result = None
return result
async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, change: GraphChange, callback):
global chat_limiter
start = trio.current_time()
await trio.to_thread.run_sync(settings.docStoreConn.delete, {"knowledge_graph_kwd": ["graph", "subgraph"]}, search.index_name(tenant_id), kb_id)
if change.removed_nodes:
await trio.to_thread.run_sync(settings.docStoreConn.delete, {"knowledge_graph_kwd": ["entity"], "entity_kwd": sorted(change.removed_nodes)}, search.index_name(tenant_id), kb_id)
if change.removed_edges:
async def del_edges(from_node, to_node):
async with chat_limiter:
await trio.to_thread.run_sync(
settings.docStoreConn.delete, {"knowledge_graph_kwd": ["relation"], "from_entity_kwd": from_node, "to_entity_kwd": to_node}, search.index_name(tenant_id), kb_id
)
async with trio.open_nursery() as nursery:
for from_node, to_node in change.removed_edges:
nursery.start_soon(del_edges, from_node, to_node)
now = trio.current_time()
if callback:
callback(msg=f"set_graph removed {len(change.removed_nodes)} nodes and {len(change.removed_edges)} edges from index in {now - start:.2f}s.")
start = now
chunks = [
{
"id": get_uuid(),
"content_with_weight": json.dumps(nx.node_link_data(graph, edges="edges"), ensure_ascii=False),
"knowledge_graph_kwd": "graph",
"kb_id": kb_id,
"source_id": graph.graph.get("source_id", []),
"available_int": 0,
"removed_kwd": "N",
}
]
# generate updated subgraphs
for source in graph.graph["source_id"]:
subgraph = graph.subgraph([n for n in graph.nodes if source in graph.nodes[n]["source_id"]]).copy()
subgraph.graph["source_id"] = [source]
for n in subgraph.nodes:
subgraph.nodes[n]["source_id"] = [source]
chunks.append(
{
"id": get_uuid(),
"content_with_weight": json.dumps(nx.node_link_data(subgraph, edges="edges"), ensure_ascii=False),
"knowledge_graph_kwd": "subgraph",
"kb_id": kb_id,
"source_id": [source],
"available_int": 0,
"removed_kwd": "N",
}
)
async with trio.open_nursery() as nursery:
for ii, node in enumerate(change.added_updated_nodes):
node_attrs = graph.nodes[node]
nursery.start_soon(graph_node_to_chunk, kb_id, embd_mdl, node, node_attrs, chunks)
if ii % 100 == 9 and callback:
callback(msg=f"Get embedding of nodes: {ii}/{len(change.added_updated_nodes)}")
async with trio.open_nursery() as nursery:
for ii, (from_node, to_node) in enumerate(change.added_updated_edges):
edge_attrs = graph.get_edge_data(from_node, to_node)
if not edge_attrs:
# added_updated_edges could record a non-existing edge if both from_node and to_node participate in nodes merging.
continue
nursery.start_soon(graph_edge_to_chunk, kb_id, embd_mdl, from_node, to_node, edge_attrs, chunks)
if ii % 100 == 9 and callback:
callback(msg=f"Get embedding of edges: {ii}/{len(change.added_updated_edges)}")
now = trio.current_time()
if callback:
callback(msg=f"set_graph converted graph change to {len(chunks)} chunks in {now - start:.2f}s.")
start = now
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
es_bulk_size = 4
for b in range(0, len(chunks), es_bulk_size):
with trio.fail_after(3 if enable_timeout_assertion else 30000000):
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b : b + es_bulk_size], search.index_name(tenant_id), kb_id))
if b % 100 == es_bulk_size and callback:
callback(msg=f"Insert chunks: {b}/{len(chunks)}")
if doc_store_result:
error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
raise Exception(error_message)
now = trio.current_time()
if callback:
callback(msg=f"set_graph added/updated {len(change.added_updated_nodes)} nodes and {len(change.added_updated_edges)} edges from index in {now - start:.2f}s.")
def is_continuous_subsequence(subseq, seq):
def find_all_indexes(tup, value):
indexes = []
start = 0
while True:
try:
index = tup.index(value, start)
indexes.append(index)
start = index + 1
except ValueError:
break
return indexes
index_list = find_all_indexes(seq, subseq[0])
for idx in index_list:
if idx != len(seq) - 1:
if seq[idx + 1] == subseq[-1]:
return True
return False
def merge_tuples(list1, list2):
result = []
for tup in list1:
last_element = tup[-1]
if last_element in tup[:-1]:
result.append(tup)
else:
matching_tuples = [t for t in list2 if t[0] == last_element]
already_match_flag = 0
for match in matching_tuples:
matchh = (match[1], match[0])
if is_continuous_subsequence(match, tup) or is_continuous_subsequence(matchh, tup):
continue
already_match_flag = 1
merged_tuple = tup + match[1:]
result.append(merged_tuple)
if not already_match_flag:
result.append(tup)
return result
async def get_entity_type2sampels(idxnms, kb_ids: list):
es_res = await trio.to_thread.run_sync(lambda: settings.retrievaler.search({"knowledge_graph_kwd": "ty2ents", "kb_id": kb_ids, "size": 10000, "fields": ["content_with_weight"]}, idxnms, kb_ids))
res = defaultdict(list)
for id in es_res.ids:
smp = es_res.field[id].get("content_with_weight")
if not smp:
continue
try:
smp = json.loads(smp)
except Exception as e:
logging.exception(e)
for ty, ents in smp.items():
res[ty].extend(ents)
return res
def flat_uniq_list(arr, key):
res = []
for a in arr:
a = a[key]
if isinstance(a, list):
res.extend(a)
else:
res.append(a)
return list(set(res))
async def rebuild_graph(tenant_id, kb_id, exclude_rebuild=None):
graph = nx.Graph()
flds = ["knowledge_graph_kwd", "content_with_weight", "source_id"]
bs = 256
for i in range(0, 1024 * bs, bs):
es_res = await trio.to_thread.run_sync(
lambda: settings.docStoreConn.search(flds, [], {"kb_id": kb_id, "knowledge_graph_kwd": ["subgraph"]}, [], OrderByExpr(), i, bs, search.index_name(tenant_id), [kb_id])
)
# tot = settings.docStoreConn.getTotal(es_res)
es_res = settings.docStoreConn.getFields(es_res, flds)
if len(es_res) == 0:
break
for id, d in es_res.items():
assert d["knowledge_graph_kwd"] == "subgraph"
if isinstance(exclude_rebuild, list):
if sum([n in d["source_id"] for n in exclude_rebuild]):
continue
elif exclude_rebuild in d["source_id"]:
continue
next_graph = json_graph.node_link_graph(json.loads(d["content_with_weight"]), edges="edges")
merged_graph = nx.compose(graph, next_graph)
merged_source = {n: graph.nodes[n]["source_id"] + next_graph.nodes[n]["source_id"] for n in graph.nodes & next_graph.nodes}
nx.set_node_attributes(merged_graph, merged_source, "source_id")
if "source_id" in graph.graph:
merged_graph.graph["source_id"] = graph.graph["source_id"] + next_graph.graph["source_id"]
else:
merged_graph.graph["source_id"] = next_graph.graph["source_id"]
graph = merged_graph
if len(graph.nodes) == 0:
return None
graph.graph["source_id"] = sorted(graph.graph["source_id"])
return graph