mirror of
https://github.com/OpenSPG/openspg.git
synced 2025-12-04 19:21:00 +00:00
fix(reaonser): bugfix in type infer (#150)
Co-authored-by: Donghai <donghai.ydh@antgroup.com>
This commit is contained in:
parent
bf57b3319f
commit
e4602264a5
@ -112,7 +112,7 @@ object ExprUtil {
|
||||
}
|
||||
case FunctionExpr(name, funcArgs) =>
|
||||
val types = funcArgs.map(getTargetType(_, referVars, udfRepo))
|
||||
name.toLowerCase(Locale.getDefault) match {
|
||||
name match {
|
||||
case "rule_value" => types(1)
|
||||
case "cast_type" | "Cast" =>
|
||||
funcArgs(1).asInstanceOf[VString].value.toLowerCase(Locale.getDefault) match {
|
||||
|
||||
@ -197,6 +197,6 @@ public class KgReasonerAliasSetKFilmTest {
|
||||
Assert.assertEquals("1", result.get(0)[0]);
|
||||
Assert.assertEquals("2", result.get(0)[1]);
|
||||
Assert.assertEquals("3", result.get(0)[2]);
|
||||
Assert.assertEquals("700.0", result.get(0)[3]);
|
||||
Assert.assertEquals("700", result.get(0)[3]);
|
||||
}
|
||||
}
|
||||
|
||||
@ -36,8 +36,7 @@ import com.antgroup.openspg.reasoner.rdg.common.groupProcess.AggIfOpProcessBaseG
|
||||
import com.antgroup.openspg.reasoner.rdg.common.groupProcess.AggOpProcessBaseGroupProcess;
|
||||
import com.antgroup.openspg.reasoner.rdg.common.groupProcess.BaseGroupProcess;
|
||||
import com.antgroup.openspg.reasoner.rdg.common.groupProcess.ParsedAggEle;
|
||||
import com.antgroup.openspg.reasoner.udf.model.BaseUdaf;
|
||||
import com.antgroup.openspg.reasoner.udf.model.UdafMeta;
|
||||
import com.antgroup.openspg.reasoner.udf.model.LazyUdaf;
|
||||
import com.antgroup.openspg.reasoner.udf.rule.RuleRunner;
|
||||
import com.antgroup.openspg.reasoner.utils.RunnerUtil;
|
||||
import com.google.common.collect.Lists;
|
||||
@ -150,11 +149,10 @@ public class KgGraphAggregateImpl implements Serializable {
|
||||
PropertyVar var = (PropertyVar) aggInfo.getVar();
|
||||
|
||||
// 进行聚合计算
|
||||
UdafMeta udafMeta = aggInfo.getUdafMeta();
|
||||
Object[] udafInitParams = aggInfo.getUdfInitParams();
|
||||
LazyUdaf udafMeta = aggInfo.getLazyUdaf();
|
||||
List<String> ruleList = aggInfo.getRuleList();
|
||||
List<KgGraph<IVertexId>> valueFilteredList = getValueFilteredList(values, ruleList);
|
||||
Object aggValue = doAggregation(valueFilteredList, udafMeta, udafInitParams, aggInfo);
|
||||
Object aggValue = doAggregation(valueFilteredList, udafMeta, aggInfo);
|
||||
String targetPropertyName = var.field().name();
|
||||
propertyMap.put(targetPropertyName, aggValue);
|
||||
}
|
||||
@ -229,11 +227,10 @@ public class KgGraphAggregateImpl implements Serializable {
|
||||
Var var = aggInfo.getVar();
|
||||
|
||||
// 进行聚合计算
|
||||
UdafMeta udafMeta = aggInfo.getUdafMeta();
|
||||
Object[] udafInitParams = aggInfo.getUdfInitParams();
|
||||
LazyUdaf udaf = aggInfo.getLazyUdaf();
|
||||
List<String> ruleList = aggInfo.getRuleList();
|
||||
List<KgGraph<IVertexId>> valueFilteredList = getValueFilteredList(values, ruleList);
|
||||
Object aggValue = doAggregation(valueFilteredList, udafMeta, udafInitParams, aggInfo);
|
||||
Object aggValue = doAggregation(valueFilteredList, udaf, aggInfo);
|
||||
|
||||
// 聚合结果赋值
|
||||
if (var instanceof NodeVar) {
|
||||
@ -333,21 +330,9 @@ public class KgGraphAggregateImpl implements Serializable {
|
||||
return newEdge;
|
||||
}
|
||||
|
||||
private void updateUdafDataFromProperty(BaseUdaf udaf, IProperty property, String propertyName) {
|
||||
if (property.isKeyExist(propertyName)) {
|
||||
udaf.update(property.get(propertyName));
|
||||
}
|
||||
}
|
||||
|
||||
private Object doAggregation(
|
||||
List<KgGraph<IVertexId>> valueFilteredList,
|
||||
UdafMeta udafMeta,
|
||||
Object[] udafInitParams,
|
||||
BaseGroupProcess aggInfo) {
|
||||
BaseUdaf udaf = udafMeta.createAggregateFunction();
|
||||
if (null != udafInitParams) {
|
||||
udaf.initialize(udafInitParams);
|
||||
}
|
||||
List<KgGraph<IVertexId>> valueFilteredList, LazyUdaf udaf, BaseGroupProcess aggInfo) {
|
||||
udaf.reset();
|
||||
ParsedAggEle parsedAggEle;
|
||||
Set<String> aliasList = aggInfo.getExprUseAliasSet();
|
||||
if (aliasList.size() <= 1) {
|
||||
@ -383,7 +368,9 @@ public class KgGraphAggregateImpl implements Serializable {
|
||||
vertexList.forEach(udaf::update);
|
||||
} else {
|
||||
vertexList.forEach(
|
||||
v -> updateUdafDataFromProperty(udaf, v.getValue(), finalSourcePropertyName));
|
||||
v ->
|
||||
RunnerUtil.updateUdafDataFromProperty(
|
||||
udaf, v.getValue(), finalSourcePropertyName));
|
||||
}
|
||||
} else {
|
||||
List<IEdge<IVertexId, IProperty>> edgeList =
|
||||
@ -402,7 +389,9 @@ public class KgGraphAggregateImpl implements Serializable {
|
||||
edgeList.forEach(udaf::update);
|
||||
} else {
|
||||
edgeList.forEach(
|
||||
e -> updateUdafDataFromProperty(udaf, e.getValue(), finalSourcePropertyName));
|
||||
e ->
|
||||
RunnerUtil.updateUdafDataFromProperty(
|
||||
udaf, e.getValue(), finalSourcePropertyName));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -26,6 +26,7 @@ import com.antgroup.openspg.reasoner.lube.logical.PropertyVar;
|
||||
import com.antgroup.openspg.reasoner.lube.logical.Var;
|
||||
import com.antgroup.openspg.reasoner.lube.utils.ExprUtils;
|
||||
import com.antgroup.openspg.reasoner.udf.UdfMngFactory;
|
||||
import com.antgroup.openspg.reasoner.udf.model.LazyUdaf;
|
||||
import com.antgroup.openspg.reasoner.udf.model.UdafMeta;
|
||||
import com.antgroup.openspg.reasoner.udf.rule.RuleRunner;
|
||||
import com.antgroup.openspg.reasoner.warehouse.utils.WareHouseUtils;
|
||||
@ -38,8 +39,7 @@ import scala.collection.JavaConversions;
|
||||
|
||||
public abstract class BaseGroupProcess implements Serializable {
|
||||
protected Var var;
|
||||
protected UdafMeta udafMeta;
|
||||
protected Object[] udfInitParams;
|
||||
protected LazyUdaf lazyUdaf;
|
||||
protected List<String> ruleList;
|
||||
protected Aggregator aggOp;
|
||||
protected String taskId;
|
||||
@ -59,8 +59,7 @@ public abstract class BaseGroupProcess implements Serializable {
|
||||
this.var = var;
|
||||
this.aggOp = aggregator;
|
||||
this.ruleList = parseRuleList();
|
||||
this.udfInitParams = parseUdfInitParams();
|
||||
this.udafMeta = parseUdafMeta();
|
||||
this.lazyUdaf = createLazyUdafMeta();
|
||||
|
||||
this.exprUseAliasSet = parseExprUseAliasSet();
|
||||
this.exprRuleString = parseExprRuleList();
|
||||
@ -112,6 +111,15 @@ public abstract class BaseGroupProcess implements Serializable {
|
||||
return udfInitParams;
|
||||
}
|
||||
|
||||
public LazyUdaf createLazyUdafMeta() {
|
||||
String udafName = getUdafStrName(getAggOpSet());
|
||||
return new LazyUdaf(udafName, parseUdfInitParams());
|
||||
}
|
||||
|
||||
public LazyUdaf getLazyUdaf() {
|
||||
return lazyUdaf;
|
||||
}
|
||||
|
||||
protected UdafMeta parseUdafMeta() {
|
||||
String udafName = getUdafStrName(getAggOpSet());
|
||||
UdafMeta udafMeta = UdfMngFactory.getUdfMng().getUdafMeta(udafName, KTString$.MODULE$);
|
||||
@ -183,24 +191,6 @@ public abstract class BaseGroupProcess implements Serializable {
|
||||
return var;
|
||||
}
|
||||
|
||||
/**
|
||||
* getter
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
public UdafMeta getUdafMeta() {
|
||||
return udafMeta;
|
||||
}
|
||||
|
||||
/**
|
||||
* getter
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
public Object[] getUdfInitParams() {
|
||||
return udfInitParams;
|
||||
}
|
||||
|
||||
/**
|
||||
* getter
|
||||
*
|
||||
|
||||
@ -59,6 +59,7 @@ import com.antgroup.openspg.reasoner.runner.ConfigKey;
|
||||
import com.antgroup.openspg.reasoner.session.KGReasonerSession;
|
||||
import com.antgroup.openspg.reasoner.udf.UdfMng;
|
||||
import com.antgroup.openspg.reasoner.udf.UdfMngFactory;
|
||||
import com.antgroup.openspg.reasoner.udf.model.LazyUdaf;
|
||||
import com.antgroup.openspg.reasoner.udf.model.UdtfMeta;
|
||||
import com.antgroup.openspg.reasoner.udf.rule.RuleRunner;
|
||||
import com.antgroup.openspg.reasoner.util.Convert2ScalaUtil;
|
||||
@ -1239,4 +1240,11 @@ public class RunnerUtil {
|
||||
return JavaConversions.mapAsJavaMap(
|
||||
KgGraphSchema.getEdgeDirectionDiff(leftSchema, rightSchema));
|
||||
}
|
||||
|
||||
public static void updateUdafDataFromProperty(
|
||||
LazyUdaf udaf, IProperty property, String propertyName) {
|
||||
if (property.isKeyExist(propertyName)) {
|
||||
udaf.update(property.get(propertyName));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -54,6 +54,14 @@ public interface UdfMng {
|
||||
*/
|
||||
UdafMeta getUdafMeta(String name, KgType rowDataType);
|
||||
|
||||
/**
|
||||
* query UDAF mete list from name
|
||||
*
|
||||
* @param name
|
||||
* @return
|
||||
*/
|
||||
List<UdafMeta> getUdafMetas(String name);
|
||||
|
||||
/**
|
||||
* Query UDTF meta information
|
||||
*
|
||||
|
||||
@ -242,6 +242,15 @@ public class UdfMngImpl implements UdfMng {
|
||||
return getMeta(name, Lists.newArrayList(rowDataTypes), this.udafMetaMap);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<UdafMeta> getUdafMetas(String name) {
|
||||
Map<String, UdafMeta> subMetaMap = this.udafMetaMap.get(UdfName.from(name));
|
||||
if (null == subMetaMap) {
|
||||
return null;
|
||||
}
|
||||
return Lists.newArrayList(subMetaMap.values());
|
||||
}
|
||||
|
||||
@Override
|
||||
public UdtfMeta getUdtfMeta(String name, List<KgType> rowDataTypes) {
|
||||
return getMeta(name, rowDataTypes, this.udtfMetaMap);
|
||||
|
||||
@ -0,0 +1,95 @@
|
||||
/*
|
||||
* Copyright 2023 OpenSPG Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||
* in compliance with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
* is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
* or implied.
|
||||
*/
|
||||
|
||||
package com.antgroup.openspg.reasoner.udf.model;
|
||||
|
||||
import com.antgroup.openspg.reasoner.common.types.KTObject$;
|
||||
import com.antgroup.openspg.reasoner.common.types.KTString$;
|
||||
import com.antgroup.openspg.reasoner.common.types.KgType;
|
||||
import com.antgroup.openspg.reasoner.udf.UdfMngFactory;
|
||||
import com.antgroup.openspg.reasoner.udf.utils.UdfUtils;
|
||||
import java.util.List;
|
||||
import org.apache.commons.collections4.CollectionUtils;
|
||||
|
||||
public class LazyUdaf {
|
||||
|
||||
protected final String name;
|
||||
protected final Object[] udfInitParams;
|
||||
|
||||
protected UdafMeta udafMeta = null;
|
||||
protected BaseUdaf baseUdaf = null;
|
||||
|
||||
public LazyUdaf(String name, Object[] udfInitParams) {
|
||||
this.name = name;
|
||||
this.udfInitParams = udfInitParams;
|
||||
List<UdafMeta> udafMetas = UdfMngFactory.getUdfMng().getUdafMetas(this.name);
|
||||
if (CollectionUtils.isEmpty(udafMetas)) {
|
||||
throw new RuntimeException("unsupported aggregator function, type=" + this.name);
|
||||
} else if (1 == udafMetas.size()) {
|
||||
this.udafMeta = udafMetas.get(0);
|
||||
}
|
||||
}
|
||||
|
||||
public String getName() {
|
||||
return name;
|
||||
}
|
||||
|
||||
public Object[] getUdfInitParams() {
|
||||
return udfInitParams;
|
||||
}
|
||||
|
||||
public UdafMeta getUdafMeta() {
|
||||
return udafMeta;
|
||||
}
|
||||
|
||||
public BaseUdaf getBaseUdaf() {
|
||||
if (null == baseUdaf) {
|
||||
createBaseUdaf(KTString$.MODULE$);
|
||||
}
|
||||
return baseUdaf;
|
||||
}
|
||||
|
||||
public void reset() {
|
||||
this.baseUdaf = null;
|
||||
}
|
||||
|
||||
public void update(Object row) {
|
||||
if (null == this.baseUdaf) {
|
||||
KgType inputParamType;
|
||||
try {
|
||||
List<KgType> inputParamTypeList = UdfUtils.getParamTypeList(row);
|
||||
inputParamType = inputParamTypeList.get(0);
|
||||
} catch (Throwable e) {
|
||||
inputParamType = KTObject$.MODULE$;
|
||||
}
|
||||
createBaseUdaf(inputParamType);
|
||||
}
|
||||
this.baseUdaf.update(row);
|
||||
}
|
||||
|
||||
private void createBaseUdaf(KgType kgType) {
|
||||
if (null == this.udafMeta) {
|
||||
this.udafMeta = UdfMngFactory.getUdfMng().getUdafMeta(this.name, kgType);
|
||||
}
|
||||
if (null == this.baseUdaf) {
|
||||
this.baseUdaf = this.udafMeta.createAggregateFunction();
|
||||
if (null != this.udfInitParams) {
|
||||
this.baseUdaf.initialize(this.udfInitParams);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public Object evaluate() {
|
||||
return this.getBaseUdaf().evaluate();
|
||||
}
|
||||
}
|
||||
@ -13,9 +13,7 @@
|
||||
|
||||
package com.antgroup.openspg.reasoner.udf.model;
|
||||
|
||||
import com.antgroup.openspg.reasoner.common.Utils;
|
||||
import com.antgroup.openspg.reasoner.common.exception.UnsupportedOperationException;
|
||||
import com.antgroup.openspg.reasoner.common.types.KTObject$;
|
||||
import com.antgroup.openspg.reasoner.common.types.KgType;
|
||||
import com.antgroup.openspg.reasoner.udf.utils.UdfUtils;
|
||||
import java.util.ArrayList;
|
||||
@ -92,7 +90,7 @@ public class RuntimeUdfMeta {
|
||||
*/
|
||||
public Object invoke(Object... args) {
|
||||
IUdfMeta udfMeta = null;
|
||||
List<KgType> inputParamTypeList = getParamTypeList(args);
|
||||
List<KgType> inputParamTypeList = UdfUtils.getParamTypeList(args);
|
||||
List<List<KgType>> groupByParamNumList = this.udfParamQueryMap.get(inputParamTypeList.size());
|
||||
if (null != groupByParamNumList) {
|
||||
for (List<KgType> udfParamTypeList : groupByParamNumList) {
|
||||
@ -109,32 +107,6 @@ public class RuntimeUdfMeta {
|
||||
return udfMeta.invoke(args);
|
||||
}
|
||||
|
||||
private List<KgType> getParamTypeList(Object... args) {
|
||||
List<KgType> kgTypeList = new ArrayList<>(args.length);
|
||||
for (Object arg : args) {
|
||||
if (null == arg) {
|
||||
kgTypeList.add(KTObject$.MODULE$);
|
||||
continue;
|
||||
}
|
||||
String className = arg.getClass().getName();
|
||||
if ("java.util.ArrayList".equals(className)) {
|
||||
ArrayList list = (ArrayList) arg;
|
||||
String memberType;
|
||||
if ((list.isEmpty() || list.get(0) == null)) {
|
||||
memberType = null;
|
||||
} else {
|
||||
memberType = list.get(0).getClass().getName();
|
||||
if (!memberType.startsWith("java.lang.")) {
|
||||
memberType = "java.lang.Object";
|
||||
}
|
||||
}
|
||||
className = "java.util.List<" + memberType + ">";
|
||||
}
|
||||
kgTypeList.add(Utils.javaType2KgType(className));
|
||||
}
|
||||
return kgTypeList;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
StringBuilder sb = new StringBuilder("RuntimeUdfMeta{");
|
||||
|
||||
@ -13,6 +13,7 @@
|
||||
|
||||
package com.antgroup.openspg.reasoner.udf.utils;
|
||||
|
||||
import com.antgroup.openspg.reasoner.common.Utils;
|
||||
import com.antgroup.openspg.reasoner.common.exception.UnsupportedOperationException;
|
||||
import com.antgroup.openspg.reasoner.common.types.KTArray;
|
||||
import com.antgroup.openspg.reasoner.common.types.KTList;
|
||||
@ -165,4 +166,30 @@ public class UdfUtils {
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
public static List<KgType> getParamTypeList(Object... args) {
|
||||
List<KgType> kgTypeList = new ArrayList<>(args.length);
|
||||
for (Object arg : args) {
|
||||
if (null == arg) {
|
||||
kgTypeList.add(KTObject$.MODULE$);
|
||||
continue;
|
||||
}
|
||||
String className = arg.getClass().getName();
|
||||
if ("java.util.ArrayList".equals(className)) {
|
||||
ArrayList list = (ArrayList) arg;
|
||||
String memberType;
|
||||
if ((list.isEmpty() || list.get(0) == null)) {
|
||||
memberType = null;
|
||||
} else {
|
||||
memberType = list.get(0).getClass().getName();
|
||||
if (!memberType.startsWith("java.lang.")) {
|
||||
memberType = "java.lang.Object";
|
||||
}
|
||||
}
|
||||
className = "java.util.List<" + memberType + ">";
|
||||
}
|
||||
kgTypeList.add(Utils.javaType2KgType(className));
|
||||
}
|
||||
return kgTypeList;
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user