From e4602264a5ceebb4d3cf7abb439dc2aa76b1672a Mon Sep 17 00:00:00 2001 From: FishJoy Date: Mon, 11 Mar 2024 16:52:16 +0800 Subject: [PATCH] fix(reaonser): bugfix in type infer (#150) Co-authored-by: Donghai --- .../reasoner/lube/logical/ExprUtil.scala | 2 +- .../main/KgReasonerAliasSetKFilmTest.java | 2 +- .../rdg/common/KgGraphAggregateImpl.java | 37 +++----- .../common/groupProcess/BaseGroupProcess.java | 34 +++---- .../openspg/reasoner/utils/RunnerUtil.java | 8 ++ .../antgroup/openspg/reasoner/udf/UdfMng.java | 8 ++ .../openspg/reasoner/udf/impl/UdfMngImpl.java | 9 ++ .../openspg/reasoner/udf/model/LazyUdaf.java | 95 +++++++++++++++++++ .../reasoner/udf/model/RuntimeUdfMeta.java | 30 +----- .../openspg/reasoner/udf/utils/UdfUtils.java | 27 ++++++ 10 files changed, 175 insertions(+), 77 deletions(-) create mode 100644 reasoner/udf/src/main/java/com/antgroup/openspg/reasoner/udf/model/LazyUdaf.java diff --git a/reasoner/lube-logical/src/main/scala/com/antgroup/openspg/reasoner/lube/logical/ExprUtil.scala b/reasoner/lube-logical/src/main/scala/com/antgroup/openspg/reasoner/lube/logical/ExprUtil.scala index 343003f5..9f70869b 100644 --- a/reasoner/lube-logical/src/main/scala/com/antgroup/openspg/reasoner/lube/logical/ExprUtil.scala +++ b/reasoner/lube-logical/src/main/scala/com/antgroup/openspg/reasoner/lube/logical/ExprUtil.scala @@ -112,7 +112,7 @@ object ExprUtil { } case FunctionExpr(name, funcArgs) => val types = funcArgs.map(getTargetType(_, referVars, udfRepo)) - name.toLowerCase(Locale.getDefault) match { + name match { case "rule_value" => types(1) case "cast_type" | "Cast" => funcArgs(1).asInstanceOf[VString].value.toLowerCase(Locale.getDefault) match { diff --git a/reasoner/runner/local-runner/src/test/java/com/antgroup/openspg/reasoner/runner/local/main/KgReasonerAliasSetKFilmTest.java b/reasoner/runner/local-runner/src/test/java/com/antgroup/openspg/reasoner/runner/local/main/KgReasonerAliasSetKFilmTest.java index 443e6624..bfe95b4b 100644 --- a/reasoner/runner/local-runner/src/test/java/com/antgroup/openspg/reasoner/runner/local/main/KgReasonerAliasSetKFilmTest.java +++ b/reasoner/runner/local-runner/src/test/java/com/antgroup/openspg/reasoner/runner/local/main/KgReasonerAliasSetKFilmTest.java @@ -197,6 +197,6 @@ public class KgReasonerAliasSetKFilmTest { Assert.assertEquals("1", result.get(0)[0]); Assert.assertEquals("2", result.get(0)[1]); Assert.assertEquals("3", result.get(0)[2]); - Assert.assertEquals("700.0", result.get(0)[3]); + Assert.assertEquals("700", result.get(0)[3]); } } diff --git a/reasoner/runner/runner-common/src/main/java/com/antgroup/openspg/reasoner/rdg/common/KgGraphAggregateImpl.java b/reasoner/runner/runner-common/src/main/java/com/antgroup/openspg/reasoner/rdg/common/KgGraphAggregateImpl.java index ed2ec729..9ba4bde1 100644 --- a/reasoner/runner/runner-common/src/main/java/com/antgroup/openspg/reasoner/rdg/common/KgGraphAggregateImpl.java +++ b/reasoner/runner/runner-common/src/main/java/com/antgroup/openspg/reasoner/rdg/common/KgGraphAggregateImpl.java @@ -36,8 +36,7 @@ import com.antgroup.openspg.reasoner.rdg.common.groupProcess.AggIfOpProcessBaseG import com.antgroup.openspg.reasoner.rdg.common.groupProcess.AggOpProcessBaseGroupProcess; import com.antgroup.openspg.reasoner.rdg.common.groupProcess.BaseGroupProcess; import com.antgroup.openspg.reasoner.rdg.common.groupProcess.ParsedAggEle; -import com.antgroup.openspg.reasoner.udf.model.BaseUdaf; -import com.antgroup.openspg.reasoner.udf.model.UdafMeta; +import com.antgroup.openspg.reasoner.udf.model.LazyUdaf; import com.antgroup.openspg.reasoner.udf.rule.RuleRunner; import com.antgroup.openspg.reasoner.utils.RunnerUtil; import com.google.common.collect.Lists; @@ -150,11 +149,10 @@ public class KgGraphAggregateImpl implements Serializable { PropertyVar var = (PropertyVar) aggInfo.getVar(); // 进行聚合计算 - UdafMeta udafMeta = aggInfo.getUdafMeta(); - Object[] udafInitParams = aggInfo.getUdfInitParams(); + LazyUdaf udafMeta = aggInfo.getLazyUdaf(); List ruleList = aggInfo.getRuleList(); List> valueFilteredList = getValueFilteredList(values, ruleList); - Object aggValue = doAggregation(valueFilteredList, udafMeta, udafInitParams, aggInfo); + Object aggValue = doAggregation(valueFilteredList, udafMeta, aggInfo); String targetPropertyName = var.field().name(); propertyMap.put(targetPropertyName, aggValue); } @@ -229,11 +227,10 @@ public class KgGraphAggregateImpl implements Serializable { Var var = aggInfo.getVar(); // 进行聚合计算 - UdafMeta udafMeta = aggInfo.getUdafMeta(); - Object[] udafInitParams = aggInfo.getUdfInitParams(); + LazyUdaf udaf = aggInfo.getLazyUdaf(); List ruleList = aggInfo.getRuleList(); List> valueFilteredList = getValueFilteredList(values, ruleList); - Object aggValue = doAggregation(valueFilteredList, udafMeta, udafInitParams, aggInfo); + Object aggValue = doAggregation(valueFilteredList, udaf, aggInfo); // 聚合结果赋值 if (var instanceof NodeVar) { @@ -333,21 +330,9 @@ public class KgGraphAggregateImpl implements Serializable { return newEdge; } - private void updateUdafDataFromProperty(BaseUdaf udaf, IProperty property, String propertyName) { - if (property.isKeyExist(propertyName)) { - udaf.update(property.get(propertyName)); - } - } - private Object doAggregation( - List> valueFilteredList, - UdafMeta udafMeta, - Object[] udafInitParams, - BaseGroupProcess aggInfo) { - BaseUdaf udaf = udafMeta.createAggregateFunction(); - if (null != udafInitParams) { - udaf.initialize(udafInitParams); - } + List> valueFilteredList, LazyUdaf udaf, BaseGroupProcess aggInfo) { + udaf.reset(); ParsedAggEle parsedAggEle; Set aliasList = aggInfo.getExprUseAliasSet(); if (aliasList.size() <= 1) { @@ -383,7 +368,9 @@ public class KgGraphAggregateImpl implements Serializable { vertexList.forEach(udaf::update); } else { vertexList.forEach( - v -> updateUdafDataFromProperty(udaf, v.getValue(), finalSourcePropertyName)); + v -> + RunnerUtil.updateUdafDataFromProperty( + udaf, v.getValue(), finalSourcePropertyName)); } } else { List> edgeList = @@ -402,7 +389,9 @@ public class KgGraphAggregateImpl implements Serializable { edgeList.forEach(udaf::update); } else { edgeList.forEach( - e -> updateUdafDataFromProperty(udaf, e.getValue(), finalSourcePropertyName)); + e -> + RunnerUtil.updateUdafDataFromProperty( + udaf, e.getValue(), finalSourcePropertyName)); } } } diff --git a/reasoner/runner/runner-common/src/main/java/com/antgroup/openspg/reasoner/rdg/common/groupProcess/BaseGroupProcess.java b/reasoner/runner/runner-common/src/main/java/com/antgroup/openspg/reasoner/rdg/common/groupProcess/BaseGroupProcess.java index 564afd51..bed24613 100644 --- a/reasoner/runner/runner-common/src/main/java/com/antgroup/openspg/reasoner/rdg/common/groupProcess/BaseGroupProcess.java +++ b/reasoner/runner/runner-common/src/main/java/com/antgroup/openspg/reasoner/rdg/common/groupProcess/BaseGroupProcess.java @@ -26,6 +26,7 @@ import com.antgroup.openspg.reasoner.lube.logical.PropertyVar; import com.antgroup.openspg.reasoner.lube.logical.Var; import com.antgroup.openspg.reasoner.lube.utils.ExprUtils; import com.antgroup.openspg.reasoner.udf.UdfMngFactory; +import com.antgroup.openspg.reasoner.udf.model.LazyUdaf; import com.antgroup.openspg.reasoner.udf.model.UdafMeta; import com.antgroup.openspg.reasoner.udf.rule.RuleRunner; import com.antgroup.openspg.reasoner.warehouse.utils.WareHouseUtils; @@ -38,8 +39,7 @@ import scala.collection.JavaConversions; public abstract class BaseGroupProcess implements Serializable { protected Var var; - protected UdafMeta udafMeta; - protected Object[] udfInitParams; + protected LazyUdaf lazyUdaf; protected List ruleList; protected Aggregator aggOp; protected String taskId; @@ -59,8 +59,7 @@ public abstract class BaseGroupProcess implements Serializable { this.var = var; this.aggOp = aggregator; this.ruleList = parseRuleList(); - this.udfInitParams = parseUdfInitParams(); - this.udafMeta = parseUdafMeta(); + this.lazyUdaf = createLazyUdafMeta(); this.exprUseAliasSet = parseExprUseAliasSet(); this.exprRuleString = parseExprRuleList(); @@ -112,6 +111,15 @@ public abstract class BaseGroupProcess implements Serializable { return udfInitParams; } + public LazyUdaf createLazyUdafMeta() { + String udafName = getUdafStrName(getAggOpSet()); + return new LazyUdaf(udafName, parseUdfInitParams()); + } + + public LazyUdaf getLazyUdaf() { + return lazyUdaf; + } + protected UdafMeta parseUdafMeta() { String udafName = getUdafStrName(getAggOpSet()); UdafMeta udafMeta = UdfMngFactory.getUdfMng().getUdafMeta(udafName, KTString$.MODULE$); @@ -183,24 +191,6 @@ public abstract class BaseGroupProcess implements Serializable { return var; } - /** - * getter - * - * @return - */ - public UdafMeta getUdafMeta() { - return udafMeta; - } - - /** - * getter - * - * @return - */ - public Object[] getUdfInitParams() { - return udfInitParams; - } - /** * getter * diff --git a/reasoner/runner/runner-common/src/main/java/com/antgroup/openspg/reasoner/utils/RunnerUtil.java b/reasoner/runner/runner-common/src/main/java/com/antgroup/openspg/reasoner/utils/RunnerUtil.java index 8b434f22..d1eaef38 100644 --- a/reasoner/runner/runner-common/src/main/java/com/antgroup/openspg/reasoner/utils/RunnerUtil.java +++ b/reasoner/runner/runner-common/src/main/java/com/antgroup/openspg/reasoner/utils/RunnerUtil.java @@ -59,6 +59,7 @@ import com.antgroup.openspg.reasoner.runner.ConfigKey; import com.antgroup.openspg.reasoner.session.KGReasonerSession; import com.antgroup.openspg.reasoner.udf.UdfMng; import com.antgroup.openspg.reasoner.udf.UdfMngFactory; +import com.antgroup.openspg.reasoner.udf.model.LazyUdaf; import com.antgroup.openspg.reasoner.udf.model.UdtfMeta; import com.antgroup.openspg.reasoner.udf.rule.RuleRunner; import com.antgroup.openspg.reasoner.util.Convert2ScalaUtil; @@ -1239,4 +1240,11 @@ public class RunnerUtil { return JavaConversions.mapAsJavaMap( KgGraphSchema.getEdgeDirectionDiff(leftSchema, rightSchema)); } + + public static void updateUdafDataFromProperty( + LazyUdaf udaf, IProperty property, String propertyName) { + if (property.isKeyExist(propertyName)) { + udaf.update(property.get(propertyName)); + } + } } diff --git a/reasoner/udf/src/main/java/com/antgroup/openspg/reasoner/udf/UdfMng.java b/reasoner/udf/src/main/java/com/antgroup/openspg/reasoner/udf/UdfMng.java index 36fc93c1..4ecdc860 100644 --- a/reasoner/udf/src/main/java/com/antgroup/openspg/reasoner/udf/UdfMng.java +++ b/reasoner/udf/src/main/java/com/antgroup/openspg/reasoner/udf/UdfMng.java @@ -54,6 +54,14 @@ public interface UdfMng { */ UdafMeta getUdafMeta(String name, KgType rowDataType); + /** + * query UDAF mete list from name + * + * @param name + * @return + */ + List getUdafMetas(String name); + /** * Query UDTF meta information * diff --git a/reasoner/udf/src/main/java/com/antgroup/openspg/reasoner/udf/impl/UdfMngImpl.java b/reasoner/udf/src/main/java/com/antgroup/openspg/reasoner/udf/impl/UdfMngImpl.java index 2faafa21..37b47d5b 100644 --- a/reasoner/udf/src/main/java/com/antgroup/openspg/reasoner/udf/impl/UdfMngImpl.java +++ b/reasoner/udf/src/main/java/com/antgroup/openspg/reasoner/udf/impl/UdfMngImpl.java @@ -242,6 +242,15 @@ public class UdfMngImpl implements UdfMng { return getMeta(name, Lists.newArrayList(rowDataTypes), this.udafMetaMap); } + @Override + public List getUdafMetas(String name) { + Map subMetaMap = this.udafMetaMap.get(UdfName.from(name)); + if (null == subMetaMap) { + return null; + } + return Lists.newArrayList(subMetaMap.values()); + } + @Override public UdtfMeta getUdtfMeta(String name, List rowDataTypes) { return getMeta(name, rowDataTypes, this.udtfMetaMap); diff --git a/reasoner/udf/src/main/java/com/antgroup/openspg/reasoner/udf/model/LazyUdaf.java b/reasoner/udf/src/main/java/com/antgroup/openspg/reasoner/udf/model/LazyUdaf.java new file mode 100644 index 00000000..cf054d1d --- /dev/null +++ b/reasoner/udf/src/main/java/com/antgroup/openspg/reasoner/udf/model/LazyUdaf.java @@ -0,0 +1,95 @@ +/* + * Copyright 2023 OpenSPG Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. + */ + +package com.antgroup.openspg.reasoner.udf.model; + +import com.antgroup.openspg.reasoner.common.types.KTObject$; +import com.antgroup.openspg.reasoner.common.types.KTString$; +import com.antgroup.openspg.reasoner.common.types.KgType; +import com.antgroup.openspg.reasoner.udf.UdfMngFactory; +import com.antgroup.openspg.reasoner.udf.utils.UdfUtils; +import java.util.List; +import org.apache.commons.collections4.CollectionUtils; + +public class LazyUdaf { + + protected final String name; + protected final Object[] udfInitParams; + + protected UdafMeta udafMeta = null; + protected BaseUdaf baseUdaf = null; + + public LazyUdaf(String name, Object[] udfInitParams) { + this.name = name; + this.udfInitParams = udfInitParams; + List udafMetas = UdfMngFactory.getUdfMng().getUdafMetas(this.name); + if (CollectionUtils.isEmpty(udafMetas)) { + throw new RuntimeException("unsupported aggregator function, type=" + this.name); + } else if (1 == udafMetas.size()) { + this.udafMeta = udafMetas.get(0); + } + } + + public String getName() { + return name; + } + + public Object[] getUdfInitParams() { + return udfInitParams; + } + + public UdafMeta getUdafMeta() { + return udafMeta; + } + + public BaseUdaf getBaseUdaf() { + if (null == baseUdaf) { + createBaseUdaf(KTString$.MODULE$); + } + return baseUdaf; + } + + public void reset() { + this.baseUdaf = null; + } + + public void update(Object row) { + if (null == this.baseUdaf) { + KgType inputParamType; + try { + List inputParamTypeList = UdfUtils.getParamTypeList(row); + inputParamType = inputParamTypeList.get(0); + } catch (Throwable e) { + inputParamType = KTObject$.MODULE$; + } + createBaseUdaf(inputParamType); + } + this.baseUdaf.update(row); + } + + private void createBaseUdaf(KgType kgType) { + if (null == this.udafMeta) { + this.udafMeta = UdfMngFactory.getUdfMng().getUdafMeta(this.name, kgType); + } + if (null == this.baseUdaf) { + this.baseUdaf = this.udafMeta.createAggregateFunction(); + if (null != this.udfInitParams) { + this.baseUdaf.initialize(this.udfInitParams); + } + } + } + + public Object evaluate() { + return this.getBaseUdaf().evaluate(); + } +} diff --git a/reasoner/udf/src/main/java/com/antgroup/openspg/reasoner/udf/model/RuntimeUdfMeta.java b/reasoner/udf/src/main/java/com/antgroup/openspg/reasoner/udf/model/RuntimeUdfMeta.java index 8a7507df..9ca8bc5d 100644 --- a/reasoner/udf/src/main/java/com/antgroup/openspg/reasoner/udf/model/RuntimeUdfMeta.java +++ b/reasoner/udf/src/main/java/com/antgroup/openspg/reasoner/udf/model/RuntimeUdfMeta.java @@ -13,9 +13,7 @@ package com.antgroup.openspg.reasoner.udf.model; -import com.antgroup.openspg.reasoner.common.Utils; import com.antgroup.openspg.reasoner.common.exception.UnsupportedOperationException; -import com.antgroup.openspg.reasoner.common.types.KTObject$; import com.antgroup.openspg.reasoner.common.types.KgType; import com.antgroup.openspg.reasoner.udf.utils.UdfUtils; import java.util.ArrayList; @@ -92,7 +90,7 @@ public class RuntimeUdfMeta { */ public Object invoke(Object... args) { IUdfMeta udfMeta = null; - List inputParamTypeList = getParamTypeList(args); + List inputParamTypeList = UdfUtils.getParamTypeList(args); List> groupByParamNumList = this.udfParamQueryMap.get(inputParamTypeList.size()); if (null != groupByParamNumList) { for (List udfParamTypeList : groupByParamNumList) { @@ -109,32 +107,6 @@ public class RuntimeUdfMeta { return udfMeta.invoke(args); } - private List getParamTypeList(Object... args) { - List kgTypeList = new ArrayList<>(args.length); - for (Object arg : args) { - if (null == arg) { - kgTypeList.add(KTObject$.MODULE$); - continue; - } - String className = arg.getClass().getName(); - if ("java.util.ArrayList".equals(className)) { - ArrayList list = (ArrayList) arg; - String memberType; - if ((list.isEmpty() || list.get(0) == null)) { - memberType = null; - } else { - memberType = list.get(0).getClass().getName(); - if (!memberType.startsWith("java.lang.")) { - memberType = "java.lang.Object"; - } - } - className = "java.util.List<" + memberType + ">"; - } - kgTypeList.add(Utils.javaType2KgType(className)); - } - return kgTypeList; - } - @Override public String toString() { StringBuilder sb = new StringBuilder("RuntimeUdfMeta{"); diff --git a/reasoner/udf/src/main/java/com/antgroup/openspg/reasoner/udf/utils/UdfUtils.java b/reasoner/udf/src/main/java/com/antgroup/openspg/reasoner/udf/utils/UdfUtils.java index 6932eeb0..7de507be 100644 --- a/reasoner/udf/src/main/java/com/antgroup/openspg/reasoner/udf/utils/UdfUtils.java +++ b/reasoner/udf/src/main/java/com/antgroup/openspg/reasoner/udf/utils/UdfUtils.java @@ -13,6 +13,7 @@ package com.antgroup.openspg.reasoner.udf.utils; +import com.antgroup.openspg.reasoner.common.Utils; import com.antgroup.openspg.reasoner.common.exception.UnsupportedOperationException; import com.antgroup.openspg.reasoner.common.types.KTArray; import com.antgroup.openspg.reasoner.common.types.KTList; @@ -165,4 +166,30 @@ public class UdfUtils { } return 0; } + + public static List getParamTypeList(Object... args) { + List kgTypeList = new ArrayList<>(args.length); + for (Object arg : args) { + if (null == arg) { + kgTypeList.add(KTObject$.MODULE$); + continue; + } + String className = arg.getClass().getName(); + if ("java.util.ArrayList".equals(className)) { + ArrayList list = (ArrayList) arg; + String memberType; + if ((list.isEmpty() || list.get(0) == null)) { + memberType = null; + } else { + memberType = list.get(0).getClass().getName(); + if (!memberType.startsWith("java.lang.")) { + memberType = "java.lang.Object"; + } + } + className = "java.util.List<" + memberType + ">"; + } + kgTypeList.add(Utils.javaType2KgType(className)); + } + return kgTypeList; + } }