fix(reasoner): fix __path__ format (#175)

This commit is contained in:
wenchengyao 2024-03-25 10:57:52 +08:00 committed by GitHub
parent e6c963bd6f
commit aba1ea1cca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 237 additions and 3 deletions

View File

@ -17,6 +17,10 @@ import com.google.common.collect.Sets;
import java.util.Set;
public class Constants {
public static final String CONTEXT_TYPE = "__type__";
public static final String CONTEXT_ALIAS = "__alias__";
public static final String CONTEXT_LABEL = "__label__";
/** edge from id key */
public static final String EDGE_FROM_ID_KEY = "__from_id__";

View File

@ -15,12 +15,12 @@ package com.antgroup.openspg.reasoner.lube.logical.validate.semantic
import com.antgroup.openspg.reasoner.lube.block.Block
import com.antgroup.openspg.reasoner.lube.logical.planning.LogicalPlannerContext
import com.antgroup.openspg.reasoner.lube.logical.validate.semantic.rules.{ConceptExplain, NodeIdTransform, SpatioTemporalExplain}
import com.antgroup.openspg.reasoner.lube.logical.validate.semantic.rules.{ConceptExplain, PathExplain, SpatioTemporalExplain}
object SemanticExplainer {
var SEMANTIC_EXPLAINS: Seq[Explain] =
Seq(ConceptExplain, SpatioTemporalExplain)
Seq(ConceptExplain, SpatioTemporalExplain, PathExplain)
def explain(input: Block, optRuleList: Seq[Explain])(implicit
context: LogicalPlannerContext): Block = {

View File

@ -0,0 +1,75 @@
/*
* Copyright 2023 OpenSPG Authors
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
* in compliance with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License
* is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
* or implied.
*/
package com.antgroup.openspg.reasoner.lube.logical.validate.semantic.rules
import com.antgroup.openspg.reasoner.common.constants.Constants
import com.antgroup.openspg.reasoner.lube.block.{Block, MatchBlock, TableResultBlock}
import com.antgroup.openspg.reasoner.lube.common.graph.{IREdge, IRNode, IRPath}
import com.antgroup.openspg.reasoner.lube.logical.planning.LogicalPlannerContext
import com.antgroup.openspg.reasoner.lube.logical.validate.semantic.Explain
import scala.collection.mutable.ListBuffer
object PathExplain extends Explain {
override def explain(implicit context: LogicalPlannerContext): PartialFunction[Block, Block] = {
case tableResultBlock@TableResultBlock(dependencies, selectList, asList) =>
if (selectList.fields.isEmpty) {
tableResultBlock
} else {
val pathNodes = ListBuffer[String]()
val pathEdges = ListBuffer[String]()
val newSelectFields = selectList.fields.map {
case path@IRPath(_, elements) =>
val newPathField = elements.map {
case node@IRNode(name, fields) =>
pathNodes.+=(name)
node.copy(fields = fields + Constants.PROPERTY_JSON_KEY)
case edge@IREdge(name, fields) =>
pathEdges.+=(name)
edge.copy(fields = fields + Constants.PROPERTY_JSON_KEY)
case other => other
}
path.copy(elements = newPathField)
case other => other
}
val newSelectList = selectList.copy(orderedFields = newSelectFields)
val newTableResultBlock = TableResultBlock(dependencies, newSelectList, asList)
newTableResultBlock.rewriteTopDown(explainMatch(pathNodes, pathEdges))
}
}
private def explainMatch(pathNodes: ListBuffer[String],
pathEdges: ListBuffer[String]): PartialFunction[Block, Block] = {
case matchBlock@MatchBlock(dependencies, patterns) =>
if (patterns.isEmpty) {
matchBlock
} else {
val newPatterns = patterns.map {
p =>
val pattern = p._2.graphPattern
val newProperties = pattern.properties.map {
case (key, value) =>
if (pathNodes.contains(key) || pathEdges.contains(key)) {
(key, value + Constants.PROPERTY_JSON_KEY)
} else {
(key, value)
}
}
val newPath = p._2.copy(graphPattern = pattern.copy(properties = newProperties))
(p._1, newPath)
}
MatchBlock(dependencies, newPatterns)
}
}
}

View File

@ -13,6 +13,8 @@
package com.antgroup.openspg.reasoner.runner.local.main.transitive;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONArray;
import com.antgroup.openspg.reasoner.common.constants.Constants;
import com.antgroup.openspg.reasoner.common.graph.edge.IEdge;
import com.antgroup.openspg.reasoner.common.graph.property.IProperty;
@ -124,6 +126,33 @@ public class KgReasonerTransitiveTest {
Assert.assertEquals(result.getRows().get(0)[1], "C5");
}
@Test
public void testTransitiveWithPathLongestWithPath() {
String dsl =
"GraphStructure {\n"
+ " A [RelatedParty, __start__='true']\n"
+ " B [RelatedParty]\n"
+ " A->B [holdShare] repeat(1,10) as e\n"
+ "}\n"
+ "Rule {\n"
+ " R1(\"只保留最长的路径\"): group(A).keep_longest_path(e)\n"
+ "}\n"
+ "Action {\n"
+ " get(A.id,B.id,__path__) \n"
+ "}";
LocalReasonerResult result = doProcess(dsl);
// check result
Assert.assertEquals(1, result.getRows().size());
Assert.assertEquals(3, result.getRows().get(0).length);
Assert.assertEquals(result.getRows().get(0)[0], "P1");
Assert.assertTrue("C2,C5".contains(result.getRows().get(0)[1].toString()));
// check path format
JSONArray path = JSON.parseArray(result.getRows().get(0)[2].toString());
Assert.assertEquals(path.size(), 9);
Assert.assertEquals(path.getJSONObject(0).get(Constants.CONTEXT_TYPE), "vertex");
Assert.assertEquals(path.getJSONObject(5).get(Constants.CONTEXT_TYPE), "edge");
}
@Test
public void testTransitiveWithPathLongestWithTailRule() {
String dsl =
@ -306,6 +335,38 @@ public class KgReasonerTransitiveTest {
Assert.assertEquals(result.getRows().get(0)[1], "C5");
}
@Test
public void testTransitiveWithRule2WithPath() {
String dsl =
"GraphStructure {\n"
+ " A [RelatedParty, __start__='true']\n"
+ " B,C [RelatedParty]\n"
+ " B->C [trans] repeat(1,10) as e\n"
+ " A->B [trans] as f\n"
+ "}\n"
+ "Rule {\n"
+ "R1(\"要求转账logId一致\"): e.edges().constraint((pre,cur) => cur.logId == f.logId)"
+ "R2(\"时间大于第一个\"): e.edges().constraint((pre,cur) => cur.payDate > f.payDate)"
+ "R11(\"要求前一个时间小于后一个时间\"): e.edges().constraint((pre,cur) => cur.payDate > pre.payDate)"
+ "R3(\"只保留最长的路径\"): group(A).keep_longest_path(e)\n"
+ "}\n"
+ "Action {\n"
+ " get(A.id,C.id,__path__) \n"
+ "}";
LocalReasonerResult result = doProcess(dsl);
// check result
Assert.assertEquals(1, result.getRows().size());
Assert.assertEquals(3, result.getRows().get(0).length);
Assert.assertEquals(result.getRows().get(0)[0], "P1");
Assert.assertEquals(result.getRows().get(0)[1], "C5");
// check path format
JSONArray path = JSON.parseArray(result.getRows().get(0)[2].toString());
Assert.assertEquals(path.size(), 9);
Assert.assertEquals(path.getJSONObject(0).get("entityType"), "PERSON");
Assert.assertEquals(path.getJSONObject(4).get("name"), "C6");
Assert.assertEquals(path.getJSONObject(8).get("amount"), 5);
}
@Test
public void testTransitive1() {
String dsl =

View File

@ -49,7 +49,7 @@ public class SelectRowImpl implements Serializable {
Object selectValue;
KgType fieldType;
if (var instanceof PathVar) {
selectValue = getSelectValue(null, Constants.GET_PATH_KEY, context);
selectValue = RunnerUtil.getPathInfo(path);
fieldType = KTString$.MODULE$;
} else {
PropertyVar propertyVar = (PropertyVar) var;

View File

@ -15,11 +15,13 @@ package com.antgroup.openspg.reasoner.utils;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.serializer.SerializerFeature;
import com.antgroup.openspg.reasoner.common.constants.Constants;
import com.antgroup.openspg.reasoner.common.exception.NotImplementedException;
import com.antgroup.openspg.reasoner.common.graph.edge.Direction;
import com.antgroup.openspg.reasoner.common.graph.edge.IEdge;
import com.antgroup.openspg.reasoner.common.graph.edge.SPO;
import com.antgroup.openspg.reasoner.common.graph.edge.impl.Edge;
import com.antgroup.openspg.reasoner.common.graph.edge.impl.OptionalEdge;
import com.antgroup.openspg.reasoner.common.graph.edge.impl.PathEdge;
import com.antgroup.openspg.reasoner.common.graph.property.IProperty;
@ -28,6 +30,7 @@ import com.antgroup.openspg.reasoner.common.graph.vertex.IVertex;
import com.antgroup.openspg.reasoner.common.graph.vertex.IVertexId;
import com.antgroup.openspg.reasoner.common.graph.vertex.impl.MirrorVertex;
import com.antgroup.openspg.reasoner.common.graph.vertex.impl.NoneVertex;
import com.antgroup.openspg.reasoner.common.graph.vertex.impl.Vertex;
import com.antgroup.openspg.reasoner.common.utils.CombinationIterator;
import com.antgroup.openspg.reasoner.kggraph.KgGraph;
import com.antgroup.openspg.reasoner.kggraph.impl.KgGraphImpl;
@ -312,6 +315,97 @@ public class RunnerUtil {
return context;
}
/**
* KgGraph 2 PathInfo in flat format
*
* @param kgGraph
* @return
*/
public static String getPathInfo(KgGraph<IVertexId> kgGraph) {
List<Map<String, Object>> context = new ArrayList<>();
if (null == kgGraph) {
return JSON.toJSONString(
context,
SerializerFeature.PrettyFormat,
SerializerFeature.DisableCircularReferenceDetect,
SerializerFeature.SortField);
}
for (String alias : kgGraph.getVertexAlias()) {
List<IVertex<IVertexId, IProperty>> vertexList = kgGraph.getVertex(alias);
if (CollectionUtils.isEmpty(vertexList)) {
continue;
}
Map<String, Object> vc = vertexContext(vertexList.get(0));
vc.put(Constants.CONTEXT_TYPE, "vertex");
vc.put(Constants.CONTEXT_ALIAS, alias);
context.add(vc);
}
for (String alias : kgGraph.getEdgeAlias()) {
List<IEdge<IVertexId, IProperty>> edgeList = kgGraph.getEdge(alias);
if (CollectionUtils.isEmpty(edgeList)) {
continue;
}
IEdge<IVertexId, IProperty> edge = edgeList.get(0);
if (null == edge) {
continue;
}
if (edge instanceof PathEdge) {
flattenPathEdgeContext(
(PathEdge<IVertexId, IProperty, IProperty>) edge, null, kgGraph, context);
} else {
Map<String, Object> eMap = getEdgePropertyMap(edge, null, kgGraph, alias);
context.add(eMap);
}
}
return JSON.toJSONString(
context,
SerializerFeature.PrettyFormat,
SerializerFeature.DisableCircularReferenceDetect,
SerializerFeature.SortField);
}
public static void flattenPathEdgeContext(
PathEdge<IVertexId, IProperty, IProperty> edge,
String edgeType,
KgGraph<IVertexId> kgGraph,
List<Map<String, Object>> context) {
List<Vertex<IVertexId, IProperty>> vertexList = edge.getVertexList();
if (CollectionUtils.isNotEmpty(vertexList)) {
for (Vertex<IVertexId, IProperty> v : vertexList) {
Map<String, Object> vc = vertexContext(v);
vc.put(Constants.CONTEXT_TYPE, "vertex");
context.add(vc);
}
}
List<Edge<IVertexId, IProperty>> edgeList = edge.getEdgeList();
if (CollectionUtils.isNotEmpty(edgeList)) {
for (Edge<IVertexId, IProperty> e : edgeList) {
context.add(getEdgePropertyMap(e, edgeType, kgGraph, null));
}
}
}
public static Map<String, Object> getEdgePropertyMap(
IEdge<IVertexId, IProperty> edge, String edgeType, KgGraph<IVertexId> kgGraph, String alias) {
Map<String, Object> edgeProperty = new HashMap<>();
if (edge instanceof OptionalEdge) {
edgeProperty.put(Constants.CONTEXT_LABEL, edgeType);
IProperty property = edge.getValue();
if (null != property) {
for (String key : property.getKeySet()) {
edgeProperty.put(key, property.get(key));
}
}
edgeProperty.put(Constants.OPTIONAL_EDGE_FLAG, true);
} else {
edgeProperty.putAll(edgeContext(edge, edgeType, kgGraph));
}
edgeProperty.put(Constants.CONTEXT_ALIAS, alias);
edgeProperty.put(Constants.CONTEXT_TYPE, "edge");
return edgeProperty;
}
/** get vertex context in alias */
public static Map<String, Object> vertexContext(
IVertex<IVertexId, IProperty> vertex, String alias) {