fix(reaonser): bugfix in type infer (#150)

Co-authored-by: Donghai <donghai.ydh@antgroup.com>
This commit is contained in:
FishJoy 2024-03-11 16:52:16 +08:00 committed by GitHub
parent bf57b3319f
commit e4602264a5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 175 additions and 77 deletions

View File

@ -112,7 +112,7 @@ object ExprUtil {
}
case FunctionExpr(name, funcArgs) =>
val types = funcArgs.map(getTargetType(_, referVars, udfRepo))
name.toLowerCase(Locale.getDefault) match {
name match {
case "rule_value" => types(1)
case "cast_type" | "Cast" =>
funcArgs(1).asInstanceOf[VString].value.toLowerCase(Locale.getDefault) match {

View File

@ -197,6 +197,6 @@ public class KgReasonerAliasSetKFilmTest {
Assert.assertEquals("1", result.get(0)[0]);
Assert.assertEquals("2", result.get(0)[1]);
Assert.assertEquals("3", result.get(0)[2]);
Assert.assertEquals("700.0", result.get(0)[3]);
Assert.assertEquals("700", result.get(0)[3]);
}
}

View File

@ -36,8 +36,7 @@ import com.antgroup.openspg.reasoner.rdg.common.groupProcess.AggIfOpProcessBaseG
import com.antgroup.openspg.reasoner.rdg.common.groupProcess.AggOpProcessBaseGroupProcess;
import com.antgroup.openspg.reasoner.rdg.common.groupProcess.BaseGroupProcess;
import com.antgroup.openspg.reasoner.rdg.common.groupProcess.ParsedAggEle;
import com.antgroup.openspg.reasoner.udf.model.BaseUdaf;
import com.antgroup.openspg.reasoner.udf.model.UdafMeta;
import com.antgroup.openspg.reasoner.udf.model.LazyUdaf;
import com.antgroup.openspg.reasoner.udf.rule.RuleRunner;
import com.antgroup.openspg.reasoner.utils.RunnerUtil;
import com.google.common.collect.Lists;
@ -150,11 +149,10 @@ public class KgGraphAggregateImpl implements Serializable {
PropertyVar var = (PropertyVar) aggInfo.getVar();
// 进行聚合计算
UdafMeta udafMeta = aggInfo.getUdafMeta();
Object[] udafInitParams = aggInfo.getUdfInitParams();
LazyUdaf udafMeta = aggInfo.getLazyUdaf();
List<String> ruleList = aggInfo.getRuleList();
List<KgGraph<IVertexId>> valueFilteredList = getValueFilteredList(values, ruleList);
Object aggValue = doAggregation(valueFilteredList, udafMeta, udafInitParams, aggInfo);
Object aggValue = doAggregation(valueFilteredList, udafMeta, aggInfo);
String targetPropertyName = var.field().name();
propertyMap.put(targetPropertyName, aggValue);
}
@ -229,11 +227,10 @@ public class KgGraphAggregateImpl implements Serializable {
Var var = aggInfo.getVar();
// 进行聚合计算
UdafMeta udafMeta = aggInfo.getUdafMeta();
Object[] udafInitParams = aggInfo.getUdfInitParams();
LazyUdaf udaf = aggInfo.getLazyUdaf();
List<String> ruleList = aggInfo.getRuleList();
List<KgGraph<IVertexId>> valueFilteredList = getValueFilteredList(values, ruleList);
Object aggValue = doAggregation(valueFilteredList, udafMeta, udafInitParams, aggInfo);
Object aggValue = doAggregation(valueFilteredList, udaf, aggInfo);
// 聚合结果赋值
if (var instanceof NodeVar) {
@ -333,21 +330,9 @@ public class KgGraphAggregateImpl implements Serializable {
return newEdge;
}
private void updateUdafDataFromProperty(BaseUdaf udaf, IProperty property, String propertyName) {
if (property.isKeyExist(propertyName)) {
udaf.update(property.get(propertyName));
}
}
private Object doAggregation(
List<KgGraph<IVertexId>> valueFilteredList,
UdafMeta udafMeta,
Object[] udafInitParams,
BaseGroupProcess aggInfo) {
BaseUdaf udaf = udafMeta.createAggregateFunction();
if (null != udafInitParams) {
udaf.initialize(udafInitParams);
}
List<KgGraph<IVertexId>> valueFilteredList, LazyUdaf udaf, BaseGroupProcess aggInfo) {
udaf.reset();
ParsedAggEle parsedAggEle;
Set<String> aliasList = aggInfo.getExprUseAliasSet();
if (aliasList.size() <= 1) {
@ -383,7 +368,9 @@ public class KgGraphAggregateImpl implements Serializable {
vertexList.forEach(udaf::update);
} else {
vertexList.forEach(
v -> updateUdafDataFromProperty(udaf, v.getValue(), finalSourcePropertyName));
v ->
RunnerUtil.updateUdafDataFromProperty(
udaf, v.getValue(), finalSourcePropertyName));
}
} else {
List<IEdge<IVertexId, IProperty>> edgeList =
@ -402,7 +389,9 @@ public class KgGraphAggregateImpl implements Serializable {
edgeList.forEach(udaf::update);
} else {
edgeList.forEach(
e -> updateUdafDataFromProperty(udaf, e.getValue(), finalSourcePropertyName));
e ->
RunnerUtil.updateUdafDataFromProperty(
udaf, e.getValue(), finalSourcePropertyName));
}
}
}

View File

@ -26,6 +26,7 @@ import com.antgroup.openspg.reasoner.lube.logical.PropertyVar;
import com.antgroup.openspg.reasoner.lube.logical.Var;
import com.antgroup.openspg.reasoner.lube.utils.ExprUtils;
import com.antgroup.openspg.reasoner.udf.UdfMngFactory;
import com.antgroup.openspg.reasoner.udf.model.LazyUdaf;
import com.antgroup.openspg.reasoner.udf.model.UdafMeta;
import com.antgroup.openspg.reasoner.udf.rule.RuleRunner;
import com.antgroup.openspg.reasoner.warehouse.utils.WareHouseUtils;
@ -38,8 +39,7 @@ import scala.collection.JavaConversions;
public abstract class BaseGroupProcess implements Serializable {
protected Var var;
protected UdafMeta udafMeta;
protected Object[] udfInitParams;
protected LazyUdaf lazyUdaf;
protected List<String> ruleList;
protected Aggregator aggOp;
protected String taskId;
@ -59,8 +59,7 @@ public abstract class BaseGroupProcess implements Serializable {
this.var = var;
this.aggOp = aggregator;
this.ruleList = parseRuleList();
this.udfInitParams = parseUdfInitParams();
this.udafMeta = parseUdafMeta();
this.lazyUdaf = createLazyUdafMeta();
this.exprUseAliasSet = parseExprUseAliasSet();
this.exprRuleString = parseExprRuleList();
@ -112,6 +111,15 @@ public abstract class BaseGroupProcess implements Serializable {
return udfInitParams;
}
public LazyUdaf createLazyUdafMeta() {
String udafName = getUdafStrName(getAggOpSet());
return new LazyUdaf(udafName, parseUdfInitParams());
}
public LazyUdaf getLazyUdaf() {
return lazyUdaf;
}
protected UdafMeta parseUdafMeta() {
String udafName = getUdafStrName(getAggOpSet());
UdafMeta udafMeta = UdfMngFactory.getUdfMng().getUdafMeta(udafName, KTString$.MODULE$);
@ -183,24 +191,6 @@ public abstract class BaseGroupProcess implements Serializable {
return var;
}
/**
* getter
*
* @return
*/
public UdafMeta getUdafMeta() {
return udafMeta;
}
/**
* getter
*
* @return
*/
public Object[] getUdfInitParams() {
return udfInitParams;
}
/**
* getter
*

View File

@ -59,6 +59,7 @@ import com.antgroup.openspg.reasoner.runner.ConfigKey;
import com.antgroup.openspg.reasoner.session.KGReasonerSession;
import com.antgroup.openspg.reasoner.udf.UdfMng;
import com.antgroup.openspg.reasoner.udf.UdfMngFactory;
import com.antgroup.openspg.reasoner.udf.model.LazyUdaf;
import com.antgroup.openspg.reasoner.udf.model.UdtfMeta;
import com.antgroup.openspg.reasoner.udf.rule.RuleRunner;
import com.antgroup.openspg.reasoner.util.Convert2ScalaUtil;
@ -1239,4 +1240,11 @@ public class RunnerUtil {
return JavaConversions.mapAsJavaMap(
KgGraphSchema.getEdgeDirectionDiff(leftSchema, rightSchema));
}
public static void updateUdafDataFromProperty(
LazyUdaf udaf, IProperty property, String propertyName) {
if (property.isKeyExist(propertyName)) {
udaf.update(property.get(propertyName));
}
}
}

View File

@ -54,6 +54,14 @@ public interface UdfMng {
*/
UdafMeta getUdafMeta(String name, KgType rowDataType);
/**
* query UDAF mete list from name
*
* @param name
* @return
*/
List<UdafMeta> getUdafMetas(String name);
/**
* Query UDTF meta information
*

View File

@ -242,6 +242,15 @@ public class UdfMngImpl implements UdfMng {
return getMeta(name, Lists.newArrayList(rowDataTypes), this.udafMetaMap);
}
@Override
public List<UdafMeta> getUdafMetas(String name) {
Map<String, UdafMeta> subMetaMap = this.udafMetaMap.get(UdfName.from(name));
if (null == subMetaMap) {
return null;
}
return Lists.newArrayList(subMetaMap.values());
}
@Override
public UdtfMeta getUdtfMeta(String name, List<KgType> rowDataTypes) {
return getMeta(name, rowDataTypes, this.udtfMetaMap);

View File

@ -0,0 +1,95 @@
/*
* Copyright 2023 OpenSPG Authors
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
* in compliance with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License
* is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
* or implied.
*/
package com.antgroup.openspg.reasoner.udf.model;
import com.antgroup.openspg.reasoner.common.types.KTObject$;
import com.antgroup.openspg.reasoner.common.types.KTString$;
import com.antgroup.openspg.reasoner.common.types.KgType;
import com.antgroup.openspg.reasoner.udf.UdfMngFactory;
import com.antgroup.openspg.reasoner.udf.utils.UdfUtils;
import java.util.List;
import org.apache.commons.collections4.CollectionUtils;
public class LazyUdaf {
protected final String name;
protected final Object[] udfInitParams;
protected UdafMeta udafMeta = null;
protected BaseUdaf baseUdaf = null;
public LazyUdaf(String name, Object[] udfInitParams) {
this.name = name;
this.udfInitParams = udfInitParams;
List<UdafMeta> udafMetas = UdfMngFactory.getUdfMng().getUdafMetas(this.name);
if (CollectionUtils.isEmpty(udafMetas)) {
throw new RuntimeException("unsupported aggregator function, type=" + this.name);
} else if (1 == udafMetas.size()) {
this.udafMeta = udafMetas.get(0);
}
}
public String getName() {
return name;
}
public Object[] getUdfInitParams() {
return udfInitParams;
}
public UdafMeta getUdafMeta() {
return udafMeta;
}
public BaseUdaf getBaseUdaf() {
if (null == baseUdaf) {
createBaseUdaf(KTString$.MODULE$);
}
return baseUdaf;
}
public void reset() {
this.baseUdaf = null;
}
public void update(Object row) {
if (null == this.baseUdaf) {
KgType inputParamType;
try {
List<KgType> inputParamTypeList = UdfUtils.getParamTypeList(row);
inputParamType = inputParamTypeList.get(0);
} catch (Throwable e) {
inputParamType = KTObject$.MODULE$;
}
createBaseUdaf(inputParamType);
}
this.baseUdaf.update(row);
}
private void createBaseUdaf(KgType kgType) {
if (null == this.udafMeta) {
this.udafMeta = UdfMngFactory.getUdfMng().getUdafMeta(this.name, kgType);
}
if (null == this.baseUdaf) {
this.baseUdaf = this.udafMeta.createAggregateFunction();
if (null != this.udfInitParams) {
this.baseUdaf.initialize(this.udfInitParams);
}
}
}
public Object evaluate() {
return this.getBaseUdaf().evaluate();
}
}

View File

@ -13,9 +13,7 @@
package com.antgroup.openspg.reasoner.udf.model;
import com.antgroup.openspg.reasoner.common.Utils;
import com.antgroup.openspg.reasoner.common.exception.UnsupportedOperationException;
import com.antgroup.openspg.reasoner.common.types.KTObject$;
import com.antgroup.openspg.reasoner.common.types.KgType;
import com.antgroup.openspg.reasoner.udf.utils.UdfUtils;
import java.util.ArrayList;
@ -92,7 +90,7 @@ public class RuntimeUdfMeta {
*/
public Object invoke(Object... args) {
IUdfMeta udfMeta = null;
List<KgType> inputParamTypeList = getParamTypeList(args);
List<KgType> inputParamTypeList = UdfUtils.getParamTypeList(args);
List<List<KgType>> groupByParamNumList = this.udfParamQueryMap.get(inputParamTypeList.size());
if (null != groupByParamNumList) {
for (List<KgType> udfParamTypeList : groupByParamNumList) {
@ -109,32 +107,6 @@ public class RuntimeUdfMeta {
return udfMeta.invoke(args);
}
private List<KgType> getParamTypeList(Object... args) {
List<KgType> kgTypeList = new ArrayList<>(args.length);
for (Object arg : args) {
if (null == arg) {
kgTypeList.add(KTObject$.MODULE$);
continue;
}
String className = arg.getClass().getName();
if ("java.util.ArrayList".equals(className)) {
ArrayList list = (ArrayList) arg;
String memberType;
if ((list.isEmpty() || list.get(0) == null)) {
memberType = null;
} else {
memberType = list.get(0).getClass().getName();
if (!memberType.startsWith("java.lang.")) {
memberType = "java.lang.Object";
}
}
className = "java.util.List<" + memberType + ">";
}
kgTypeList.add(Utils.javaType2KgType(className));
}
return kgTypeList;
}
@Override
public String toString() {
StringBuilder sb = new StringBuilder("RuntimeUdfMeta{");

View File

@ -13,6 +13,7 @@
package com.antgroup.openspg.reasoner.udf.utils;
import com.antgroup.openspg.reasoner.common.Utils;
import com.antgroup.openspg.reasoner.common.exception.UnsupportedOperationException;
import com.antgroup.openspg.reasoner.common.types.KTArray;
import com.antgroup.openspg.reasoner.common.types.KTList;
@ -165,4 +166,30 @@ public class UdfUtils {
}
return 0;
}
public static List<KgType> getParamTypeList(Object... args) {
List<KgType> kgTypeList = new ArrayList<>(args.length);
for (Object arg : args) {
if (null == arg) {
kgTypeList.add(KTObject$.MODULE$);
continue;
}
String className = arg.getClass().getName();
if ("java.util.ArrayList".equals(className)) {
ArrayList list = (ArrayList) arg;
String memberType;
if ((list.isEmpty() || list.get(0) == null)) {
memberType = null;
} else {
memberType = list.get(0).getClass().getName();
if (!memberType.startsWith("java.lang.")) {
memberType = "java.lang.Object";
}
}
className = "java.util.List<" + memberType + ">";
}
kgTypeList.add(Utils.javaType2KgType(className));
}
return kgTypeList;
}
}