run pre-commit

This commit is contained in:
DavIvek 2025-07-08 20:21:20 +02:00
parent 4438897b6b
commit 08eb68b8ed

View File

@ -73,8 +73,12 @@ class MemgraphStorage(BaseGraphStorage):
# Create index for base nodes on entity_id if it doesn't exist
try:
workspace_label = self._get_workspace_label()
await session.run(f"""CREATE INDEX ON :{workspace_label}(entity_id)""")
logger.info(f"Created index on :{workspace_label}(entity_id) in Memgraph.")
await session.run(
f"""CREATE INDEX ON :{workspace_label}(entity_id)"""
)
logger.info(
f"Created index on :{workspace_label}(entity_id) in Memgraph."
)
except Exception as e:
# Index may already exist, which is not an error
logger.warning(
@ -112,7 +116,9 @@ class MemgraphStorage(BaseGraphStorage):
Exception: If there is an error checking the node existence.
"""
if self._driver is None:
raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.")
raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first."
)
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
@ -122,7 +128,9 @@ class MemgraphStorage(BaseGraphStorage):
result = await session.run(query, entity_id=node_id)
single_result = await result.single()
await result.consume() # Ensure result is fully consumed
return single_result["node_exists"] if single_result is not None else False
return (
single_result["node_exists"] if single_result is not None else False
)
except Exception as e:
logger.error(f"Error checking node existence for {node_id}: {str(e)}")
await result.consume() # Ensure the result is consumed even on error
@ -143,7 +151,9 @@ class MemgraphStorage(BaseGraphStorage):
Exception: If there is an error checking the edge existence.
"""
if self._driver is None:
raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.")
raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first."
)
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
@ -153,10 +163,16 @@ class MemgraphStorage(BaseGraphStorage):
f"MATCH (a:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(b:`{workspace_label}` {{entity_id: $target_entity_id}}) "
"RETURN COUNT(r) > 0 AS edgeExists"
)
result = await session.run(query, source_entity_id=source_node_id, target_entity_id=target_node_id) # type: ignore
result = await session.run(
query,
source_entity_id=source_node_id,
target_entity_id=target_node_id,
) # type: ignore
single_result = await result.single()
await result.consume() # Ensure result is fully consumed
return single_result["edgeExists"] if single_result is not None else False
return (
single_result["edgeExists"] if single_result is not None else False
)
except Exception as e:
logger.error(
f"Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}"
@ -178,13 +194,17 @@ class MemgraphStorage(BaseGraphStorage):
Exception: If there is an error executing the query
"""
if self._driver is None:
raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.")
raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first."
)
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
try:
workspace_label = self._get_workspace_label()
query = f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN n"
query = (
f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN n"
)
result = await session.run(query, entity_id=node_id)
try:
records = await result.fetch(
@ -228,7 +248,9 @@ class MemgraphStorage(BaseGraphStorage):
Exception: If there is an error executing the query
"""
if self._driver is None:
raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.")
raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first."
)
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
@ -265,7 +287,9 @@ class MemgraphStorage(BaseGraphStorage):
Exception: If there is an error executing the query
"""
if self._driver is None:
raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.")
raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first."
)
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
@ -302,7 +326,9 @@ class MemgraphStorage(BaseGraphStorage):
Exception: If there is an error executing the query
"""
if self._driver is None:
raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.")
raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first."
)
try:
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
@ -366,7 +392,9 @@ class MemgraphStorage(BaseGraphStorage):
Exception: If there is an error executing the query
"""
if self._driver is None:
raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.")
raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first."
)
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
@ -414,7 +442,9 @@ class MemgraphStorage(BaseGraphStorage):
node_data: Dictionary of node properties
"""
if self._driver is None:
raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.")
raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first."
)
properties = node_data
entity_type = properties["entity_type"]
if "entity_id" not in properties:
@ -423,14 +453,13 @@ class MemgraphStorage(BaseGraphStorage):
try:
async with self._driver.session(database=self._DATABASE) as session:
workspace_label = self._get_workspace_label()
async def execute_upsert(tx: AsyncManagedTransaction):
query = (
f"""
query = f"""
MERGE (n:`{workspace_label}` {{entity_id: $entity_id}})
SET n += $properties
SET n:`{entity_type}`
"""
)
result = await tx.run(
query, entity_id=node_id, properties=properties
)
@ -458,7 +487,9 @@ class MemgraphStorage(BaseGraphStorage):
Exception: If there is an error executing the query
"""
if self._driver is None:
raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.")
raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first."
)
try:
edge_properties = edge_data
async with self._driver.session(database=self._DATABASE) as session:
@ -499,7 +530,9 @@ class MemgraphStorage(BaseGraphStorage):
Exception: If there is an error executing the query
"""
if self._driver is None:
raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.")
raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first."
)
async def _do_delete(tx: AsyncManagedTransaction):
workspace_label = self._get_workspace_label()
@ -525,7 +558,9 @@ class MemgraphStorage(BaseGraphStorage):
nodes: List of node labels to be deleted
"""
if self._driver is None:
raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.")
raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first."
)
for node in nodes:
await self.delete_node(node)
@ -539,7 +574,9 @@ class MemgraphStorage(BaseGraphStorage):
Exception: If there is an error executing the query
"""
if self._driver is None:
raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.")
raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first."
)
for source, target in edges:
async def _do_delete_edge(tx: AsyncManagedTransaction):
@ -575,17 +612,23 @@ class MemgraphStorage(BaseGraphStorage):
Exception: If there is an error executing the query
"""
if self._driver is None:
raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.")
raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first."
)
try:
async with self._driver.session(database=self._DATABASE) as session:
workspace_label = self._get_workspace_label()
query = f"MATCH (n:`{workspace_label}`) DETACH DELETE n"
result = await session.run(query)
await result.consume()
logger.info(f"Dropped workspace {workspace_label} from Memgraph database {self._DATABASE}")
logger.info(
f"Dropped workspace {workspace_label} from Memgraph database {self._DATABASE}"
)
return {"status": "success", "message": "workspace data dropped"}
except Exception as e:
logger.error(f"Error dropping workspace {workspace_label} from Memgraph database {self._DATABASE}: {e}")
logger.error(
f"Error dropping workspace {workspace_label} from Memgraph database {self._DATABASE}: {e}"
)
return {"status": "error", "message": str(e)}
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
@ -599,7 +642,9 @@ class MemgraphStorage(BaseGraphStorage):
int: Sum of the degrees of both nodes
"""
if self._driver is None:
raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.")
raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first."
)
src_degree = await self.node_degree(src_id)
trg_degree = await self.node_degree(tgt_id)
@ -621,7 +666,9 @@ class MemgraphStorage(BaseGraphStorage):
An empty list if no matching nodes are found.
"""
if self._driver is None:
raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.")
raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first."
)
workspace_label = self._get_workspace_label()
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
@ -653,7 +700,9 @@ class MemgraphStorage(BaseGraphStorage):
An empty list if no matching edges are found.
"""
if self._driver is None:
raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.")
raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first."
)
workspace_label = self._get_workspace_label()
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
@ -701,7 +750,9 @@ class MemgraphStorage(BaseGraphStorage):
Exception: If there is an error executing the query
"""
if self._driver is None:
raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.")
raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first."
)
result = KnowledgeGraph()
seen_nodes = set()
@ -761,7 +812,7 @@ class MemgraphStorage(BaseGraphStorage):
else:
bfs_query = f"""
MATCH (start:`{workspace_label}`)
MATCH (start:`{workspace_label}`)
WHERE start.entity_id = $entity_id
WITH start
CALL {{
@ -774,7 +825,7 @@ class MemgraphStorage(BaseGraphStorage):
RETURN all_nodes, all_rels
}}
WITH all_nodes AS nodes, all_rels AS relationships, size(all_nodes) AS total_nodes
WITH
WITH
CASE
WHEN total_nodes <= {max_nodes} THEN nodes
ELSE nodes[0..{max_nodes}]
@ -782,7 +833,7 @@ class MemgraphStorage(BaseGraphStorage):
relationships,
total_nodes,
total_nodes > {max_nodes} AS is_truncated
RETURN
RETURN
[node IN limited_nodes | {{node: node}}] AS node_info,
relationships,
total_nodes,