fix(reasoner): bugfix of transform ListOpExpr (#328)

Co-authored-by: peilong <peilong.zpl@antgroup.com>
This commit is contained in:
wenchengyao 2024-07-23 10:18:08 +08:00 committed by GitHub
parent 2f18a19769
commit a7592a639d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 111 additions and 46 deletions

View File

@ -127,7 +127,7 @@ public class Utils {
case "[Ljava.util.Date;":
return new KTArray(KTDate$.MODULE$);
default:
throw new RuntimeException("unsupported type " + typeName);
return KTObject$.MODULE$;
}
}

View File

@ -82,9 +82,6 @@ abstract class Catalog() extends Serializable {
*/
def getConnection(typeName: String): Set[AbstractConnection] = {
val finalType = LabelTypeUtils.getMetaType(typeName)
if (!connections.contains(finalType)) {
throw ConnectionNotFoundException(s"$finalType not found.", null)
}
connections.getOrElse(finalType, mutable.Set.empty).toSet
}

View File

@ -225,7 +225,13 @@ final case class PathOpExpr(name: PathOpSet, pathName: Ref) extends TypeValidate
case class AggOpExpr(name: AggregatorOpSet, aggEleExpr: Expr) extends Aggregator {
override def withNewChildren(newChildren: Array[Expr]): Expr = AggOpExpr(name, newChildren.head)
override def children: Array[Expr] = Array.apply(aggEleExpr)
override def children: Array[Expr] = {
name match {
case AggUdf(name, funcArgs) => Array.apply(aggEleExpr) ++ funcArgs
case _ => Array.apply(aggEleExpr)
}
}
}
/**

View File

@ -120,6 +120,10 @@ object ExprUtils {
} else {
List.apply(IRVariable(refName))
}
case FunctionExpr(_, funcArgs) =>
funcArgs.map(arg => {
getAllInputFieldInRule(arg, nodesAlias, edgeAlias)
}).filter(_.nonEmpty).flatten
case ListOpExpr(name, _) =>
name match {
case constraint: Constraint =>

View File

@ -40,19 +40,31 @@ object ExprUtil {
* @param rule
* @return
*/
def getReferProperties(rule: Expr): List[Tuple2[String, String]] = {
def getReferProperties(rule: Expr): List[(String, String)] = {
def transformHelper(expr: Expr): List[(String, String)] =
expr.transform[List[(String, String)]] {
case (Ref(name), _) => List((null, name))
case (UnaryOpExpr(GetField(name), Ref(alias)), _) => List((alias, name))
case (BinaryOpExpr(_, Ref(left), Ref(right)), _) => List((null, left), (null, right))
case (_, tupleList) => tupleList.flatten
}
if (rule == null) {
List.empty
} else {
rule.transform[List[Tuple2[String, String]]] {
case (Ref(name), _) => List.apply((null, name))
case (UnaryOpExpr(GetField(name), Ref(alis)), _) => List.apply((alis, name))
case (BinaryOpExpr(_, Ref(left), Ref(right)), _) =>
List.apply((null, left), (null, right))
rule.transform[List[(String, String)]] {
case (Ref(name), _) => List((null, name))
case (UnaryOpExpr(GetField(name), Ref(alias)), _) => List((alias, name))
case (BinaryOpExpr(_, Ref(left), Ref(right)), _) => List((null, left), (null, right))
case (ListOpExpr(name, _), tupleList) =>
name match {
case constraint: Constraint => transformHelper(constraint.reduceFunc)
case compute: Reduce => transformHelper(compute.reduceFunc)
case _ => tupleList.flatten
}
case (_, tupleList) => tupleList.flatten
}
}
}
def needResolved(rule: Expr): Boolean = {
@ -61,14 +73,25 @@ object ExprUtil {
def transExpr(rule: Expr, replaceVar: Map[String, PropertyVar]): Expr = {
def rewriter: PartialFunction[Expr, Expr] = { case Ref(refName) =>
def rewriter: PartialFunction[Expr, Expr] = {
case Ref(refName) =>
if (replaceVar.contains(refName)) {
val propertyVar = replaceVar(refName)
UnaryOpExpr(GetField(propertyVar.field.name), Ref(propertyVar.name))
} else {
Ref(refName)
}
case ListOpExpr(name, opInput) =>
name match {
case constraint: Constraint =>
val reduceFunc = BottomUp(rewriter).transform(constraint.reduceFunc)
ListOpExpr(Constraint(constraint.pre, constraint.cur, reduceFunc), opInput)
case compute: Reduce =>
val reduceFunc = BottomUp(rewriter).transform(compute.reduceFunc)
val initValue = BottomUp(rewriter).transform(compute.initValue)
ListOpExpr(Reduce(compute.ele, compute.res, reduceFunc, initValue), opInput)
case _ => ListOpExpr(name, opInput)
}
}
BottomUp(rewriter).transform(rule)

View File

@ -19,7 +19,7 @@ import com.antgroup.openspg.reasoner.common.exception.UnsupportedOperationExcept
import com.antgroup.openspg.reasoner.common.types.KTObject
import com.antgroup.openspg.reasoner.lube.catalog.struct.Field
import com.antgroup.openspg.reasoner.lube.common.expr.Expr
import com.antgroup.openspg.reasoner.lube.common.graph.{IREdge, IRNode}
import com.antgroup.openspg.reasoner.lube.common.graph.{IREdge, IRNode, IRVariable}
import com.antgroup.openspg.reasoner.lube.logical._
import com.antgroup.openspg.reasoner.lube.utils.ExprUtils
@ -41,6 +41,9 @@ final case class Project(in: LogicalOperator, expr: Map[Var, Expr], solved: Solv
case IREdge(name, fields) =>
val edge = EdgeVar(name, fields.map(new Field(_, KTObject, true)))
fieldsMap.put(name, edge.merge(fieldsMap.get(name)))
case IRVariable(name) =>
val v = Variable(new Field(name, KTObject, true))
fieldsMap.put(name, v)
case _ => throw UnsupportedOperationException(s"unsupported $expr")
}
}
@ -78,8 +81,8 @@ final case class Project(in: LogicalOperator, expr: Map[Var, Expr], solved: Solv
fieldsMap.values.toList
}
override def withNewChildren(newChildren: Array[LogicalOperator]): LogicalOperator = {
this.copy(in = newChildren.head)
}
}

View File

@ -598,4 +598,21 @@ public class RuleRunnerTest {
Object rst = RuleRunner.getInstance().executeExpression(context, rules, "");
Assert.assertEquals(rst, 0.15);
}
@Test
public void testRepeatReduce3() {
Expr e =
ruleExprParser.parse(
"e.nodes().reduce((res, ele) => concat(res, \"#\", Cast(ele.age - A.age, 'String')), '')");
Expr2QlexpressTransformer transformer =
new Expr2QlexpressTransformer(RuleRunner::convertPropertyName);
List<String> rules =
Lists.newArrayList(JavaConversions.asJavaCollection(transformer.transform(e)));
Assert.assertEquals(
"repeat_reduce(e.nodes, \"\", 'res', 'ele', 'concat(res,\"#\",cast_type(ele.age - A.age,\"String\"))', context_capturer([\"A.age\"],[A.age]))",
rules.get(0));
Map<String, Object> context = getRepeatTestContext();
Object rst = RuleRunner.getInstance().executeExpression(context, rules, "");
Assert.assertEquals(rst, "#0#1#2");
}
}

View File

@ -17,7 +17,10 @@ import com.antgroup.openspg.reasoner.common.graph.edge.impl.Edge;
import com.antgroup.openspg.reasoner.common.graph.property.IProperty;
import com.antgroup.openspg.reasoner.common.graph.vertex.IVertexId;
import com.antgroup.openspg.reasoner.common.graph.vertex.impl.Vertex;
import com.google.common.base.Splitter;
import com.google.common.collect.Lists;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
public class CommonUtils {
@ -43,4 +46,28 @@ public class CommonUtils {
}
return result;
}
/** get patent context in repeat compute * */
public static Map<String, Object> getParentContext(Object context) {
Map<String, Object> result = new HashMap<>();
if (null == context) {
return result;
}
Map<String, Object> contextMap = (Map<String, Object>) context;
for (Map.Entry<String, Object> entry : contextMap.entrySet()) {
Map<String, Object> curMap = result;
List<String> nameList = Lists.newArrayList(Splitter.on(".").split(entry.getKey()));
for (int i = 0; i < nameList.size(); ++i) {
String name = nameList.get(i);
Map<String, Object> newMap = new HashMap<>();
if (i == nameList.size() - 1) {
curMap.put(name, entry.getValue());
} else {
curMap.putIfAbsent(name, newMap);
}
curMap = newMap;
}
}
return result;
}
}

View File

@ -16,7 +16,6 @@ package com.antgroup.openspg.reasoner.udf.builtin.udf;
import com.antgroup.openspg.reasoner.udf.builtin.CommonUtils;
import com.antgroup.openspg.reasoner.udf.model.UdfDefine;
import com.antgroup.openspg.reasoner.udf.rule.RuleRunner;
import com.google.common.base.Splitter;
import com.google.common.collect.Lists;
import java.util.HashMap;
import java.util.List;
@ -28,7 +27,7 @@ public class RepeatConstraint {
@UdfDefine(name = "repeat_constraint")
public boolean constraint(
List<Object> itemList, String preName, String curName, String express, Object context) {
Map<String, Object> contextMap = getParentContext(context);
Map<String, Object> contextMap = CommonUtils.getParentContext(context);
int processIndex = 1;
if (StringUtils.isEmpty(preName) || !express.contains(preName)) {
processIndex = 0;
@ -51,27 +50,4 @@ public class RepeatConstraint {
public boolean constraint(List<Object> itemList, String preName, String curName, String express) {
return constraint(itemList, preName, curName, express, null);
}
private Map<String, Object> getParentContext(Object context) {
Map<String, Object> result = new HashMap<>();
if (null == context) {
return result;
}
Map<String, Object> contextMap = (Map<String, Object>) context;
for (Map.Entry<String, Object> entry : contextMap.entrySet()) {
Map<String, Object> curMap = result;
List<String> nameList = Lists.newArrayList(Splitter.on(".").split(entry.getKey()));
for (int i = 0; i < nameList.size(); ++i) {
String name = nameList.get(i);
Map<String, Object> newMap = new HashMap<>();
if (i == nameList.size() - 1) {
curMap.put(name, entry.getValue());
} else {
curMap.putIfAbsent(name, newMap);
}
curMap = newMap;
}
}
return result;
}
}

View File

@ -25,19 +25,31 @@ public class RepeatReduce {
@UdfDefine(name = "repeat_reduce")
public Object reduce(
List<Object> itemList, Object defaultValue, String preName, String curName, String express) {
List<Object> itemList,
Object defaultValue,
String preName,
String curName,
String express,
Object context) {
Object preValue = defaultValue;
Map<String, Object> contextMap = CommonUtils.getParentContext(context);
for (int i = 0; i < itemList.size(); ++i) {
Object cur = itemList.get(i);
Map<String, Object> context = new HashMap<>();
context.put(preName, preValue);
context.put(curName, CommonUtils.getRepeatItemContext(cur));
Map<String, Object> subContext = new HashMap<>(contextMap);
subContext.put(preName, preValue);
subContext.put(curName, CommonUtils.getRepeatItemContext(cur));
preValue =
RuleRunner.getInstance().executeExpression(context, Lists.newArrayList(express), "");
RuleRunner.getInstance().executeExpression(subContext, Lists.newArrayList(express), "");
}
return preValue;
}
@UdfDefine(name = "repeat_reduce")
public Object reduce(
List<Object> itemList, Object defaultValue, String preName, String curName, String express) {
return reduce(itemList, defaultValue, preName, curName, express, null);
}
@UdfDefine(name = "repeat_reduce")
public Object reduce(
List<Object> itemList,