fix(reasoner): support std prop during type inference && fix some bugs in NodeIdToEdgeProperty (#159)

This commit is contained in:
FishJoy 2024-03-15 15:28:20 +08:00 committed by GitHub
parent 3b214401ec
commit 9eac15c4de
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 134 additions and 1866 deletions

View File

@ -38,7 +38,7 @@
<properties> <properties>
<antlr4.version>4.8</antlr4.version> <antlr4.version>4.8</antlr4.version>
<dropwizard.metrics.version>4.2.9</dropwizard.metrics.version> <dropwizard.metrics.version>4.2.9</dropwizard.metrics.version>
<geotools.version>27.0</geotools.version> <geotools.version>28.0</geotools.version>
<google.s2.version>2.0.0</google.s2.version> <google.s2.version>2.0.0</google.s2.version>
<groovy.version>3.0.9</groovy.version> <groovy.version>3.0.9</groovy.version>
<gson.version>2.10.1</gson.version> <gson.version>2.10.1</gson.version>

View File

@ -15,7 +15,7 @@ package com.antgroup.openspg.reasoner.catalog.impl
import scala.language.implicitConversions import scala.language.implicitConversions
import com.antgroup.openspg.core.schema.model.`type`.{BaseSPGType, BasicType, ConceptType, EntityType, EventType, SPGTypeEnum, StandardType} import com.antgroup.openspg.core.schema.model.`type`._
import com.antgroup.openspg.reasoner.catalog.impl.struct.PropertyMeta import com.antgroup.openspg.reasoner.catalog.impl.struct.PropertyMeta
import com.antgroup.openspg.reasoner.common.exception.KGValueException import com.antgroup.openspg.reasoner.common.exception.KGValueException
import com.antgroup.openspg.reasoner.common.types._ import com.antgroup.openspg.reasoner.common.types._
@ -30,9 +30,15 @@ object PropertySchemaOps {
case "CONCEPT" => case "CONCEPT" =>
KTConcept(propertySchema.getPropRange.getRangeEntityName) KTConcept(propertySchema.getPropRange.getRangeEntityName)
case "STANDARD" => case "STANDARD" =>
KTStd(propertySchema.getPropRange.getRangeEntityName, propertySchema.isSpreadable) KTStd(
propertySchema.getPropRange.getRangeEntityName,
toKgType(propertySchema.getPropRange.getAttrRangeTypeEnum),
propertySchema.isSpreadable)
case "PROPERTY" => case "PROPERTY" =>
KTStd(propertySchema.getPropRange.getRangeEntityName, propertySchema.isSpreadable) KTStd(
propertySchema.getPropRange.getRangeEntityName,
toKgType(propertySchema.getPropRange.getAttrRangeTypeEnum),
propertySchema.isSpreadable)
case "ENTITY" => case "ENTITY" =>
KTAdvanced(propertySchema.getPropRange.getRangeEntityName) KTAdvanced(propertySchema.getPropRange.getRangeEntityName)
case _ => throw KGValueException(s"unsupported type: ${propertySchema.getCategory}") case _ => throw KGValueException(s"unsupported type: ${propertySchema.getCategory}")
@ -49,7 +55,8 @@ object PropertySchemaOps {
// todo // todo
KTAdvanced(eventType.getName) KTAdvanced(eventType.getName)
case standardType: StandardType => case standardType: StandardType =>
KTStd(spgType.getName, standardType.getSpreadable) // todo basicType support
KTStd(spgType.getName, null, standardType.getSpreadable)
case basicType: BasicType => case basicType: BasicType =>
toKgType(basicType.getBasicType.name()) toKgType(basicType.getBasicType.name())
case _ => case _ =>
@ -57,7 +64,7 @@ object PropertySchemaOps {
} }
} }
private def toKgType(basicType: String): KgType = { private def toKgType(basicType: String): BasicKgType = {
basicType.toUpperCase() match { basicType.toUpperCase() match {
case "INTEGER" => KTLong case "INTEGER" => KTLong
case "LONG" => KTLong case "LONG" => KTLong

View File

@ -19,53 +19,57 @@ trait KgType {
def isNullable: Boolean = false def isNullable: Boolean = false
} }
case object KTString extends KgType trait BasicKgType extends KgType
case object KTCharacter extends KgType
case object KTInteger extends KgType trait AdvancedKgType extends KgType
case object KTLong extends KgType
case object KTDouble extends KgType case object KTString extends BasicKgType
case object KTCharacter extends BasicKgType
case object KTInteger extends BasicKgType
case object KTLong extends BasicKgType
case object KTDouble extends BasicKgType
// corresponding to java object // corresponding to java object
case object KTObject extends KgType case object KTObject extends BasicKgType
case object KTDate extends KgType case object KTDate extends BasicKgType
case object KTBoolean extends KgType case object KTBoolean extends BasicKgType
case object KTParameter extends KgType
/** /**
* list type * list type
* @param elementType element type * @param elementType element type
*/ */
final case class KTList(elementType: KgType) extends KgType final case class KTList(elementType: KgType) extends AdvancedKgType
/** /**
* array type * array type
* @param elementType element type * @param elementType element type
*/ */
final case class KTArray(elementType: KgType) extends KgType final case class KTArray(elementType: KgType) extends AdvancedKgType
/** /**
* Standard entity in Knowledge Graph. * Standard entity in Knowledge Graph.
* @param label entity type name * @param label entity type name
* @param spreadable is spreadable. * @param spreadable is spreadable.
*/ */
final case class KTStd(label: String, spreadable: Boolean) extends KgType final case class KTStd(label: String, basicType: BasicKgType, spreadable: Boolean)
extends AdvancedKgType
/** /**
* Meta concept in Knowledge Graph. * Meta concept in Knowledge Graph.
* @param label meta concept name * @param label meta concept name
*/ */
final case class KTConcept(label: String) extends KgType final case class KTConcept(label: String) extends AdvancedKgType
/** /**
* Custom semantic type, which linked to entity in Knowledge Graph. * Custom semantic type, which linked to entity in Knowledge Graph.
* @param label entity type TODO add link function * @param label entity type TODO add link function
*/ */
final case class KTAdvanced(label: String) extends KgType final case class KTAdvanced(label: String) extends AdvancedKgType
/** /**
* multi version property, default version number unit is ms * multi version property, default version number unit is ms
* @param elementType * @param elementType
*/ */
final case class KTMultiVersion(elementType: KgType) extends KgType final case class KTMultiVersion(elementType: KgType) extends AdvancedKgType
object KgType { object KgType {

View File

@ -91,7 +91,16 @@ object ExprUtil {
} else { } else {
KTObject KTObject
} }
case UnaryOpExpr(GetField(name), Ref(alis)) => referVars(IRProperty(alis, name)) case UnaryOpExpr(GetField(name), Ref(alis)) =>
val kgType = referVars(IRProperty(alis, name))
if (kgType.isInstanceOf[BasicKgType]) {
kgType
} else {
kgType match {
case KTStd(_, basicType, _) => basicType
case _ => KTObject
}
}
case BinaryOpExpr(name, l, r) => case BinaryOpExpr(name, l, r) =>
name match { name match {
case BAnd | BEqual | BNotEqual | BGreaterThan | BNotGreaterThan | BSmallerThan | case BAnd | BEqual | BNotEqual | BGreaterThan | BNotGreaterThan | BSmallerThan |

View File

@ -20,7 +20,7 @@ import com.antgroup.openspg.reasoner.common.constants.Constants
import com.antgroup.openspg.reasoner.common.exception.UnsupportedOperationException import com.antgroup.openspg.reasoner.common.exception.UnsupportedOperationException
import com.antgroup.openspg.reasoner.common.graph.edge import com.antgroup.openspg.reasoner.common.graph.edge
import com.antgroup.openspg.reasoner.common.types.KTString import com.antgroup.openspg.reasoner.common.types.KTString
import com.antgroup.openspg.reasoner.lube.block.{AddPredicate, DDLOp} import com.antgroup.openspg.reasoner.lube.block.{AddPredicate, AddProperty, AddVertex, DDLOp}
import com.antgroup.openspg.reasoner.lube.catalog.struct.Field import com.antgroup.openspg.reasoner.lube.catalog.struct.Field
import com.antgroup.openspg.reasoner.lube.common.expr.Expr import com.antgroup.openspg.reasoner.lube.common.expr.Expr
import com.antgroup.openspg.reasoner.lube.common.graph.{IRField, IRNode, IRProperty, IRVariable} import com.antgroup.openspg.reasoner.lube.common.graph.{IRField, IRNode, IRProperty, IRVariable}
@ -85,8 +85,10 @@ object NodeIdToEdgeProperty extends Rule {
private def genField(direction: edge.Direction, fieldName: String): String = { private def genField(direction: edge.Direction, fieldName: String): String = {
(direction, fieldName) match { (direction, fieldName) match {
case (edge.Direction.OUT, Constants.NODE_ID_KEY) => Constants.EDGE_TO_ID_KEY case (edge.Direction.OUT | edge.Direction.BOTH, Constants.NODE_ID_KEY) =>
case (edge.Direction.OUT, Constants.CONTEXT_LABEL) => Constants.EDGE_TO_ID_TYPE_KEY Constants.EDGE_TO_ID_KEY
case (edge.Direction.OUT | edge.Direction.BOTH, Constants.CONTEXT_LABEL) =>
Constants.EDGE_TO_ID_TYPE_KEY
case (edge.Direction.IN, Constants.NODE_ID_KEY) => Constants.EDGE_FROM_ID_KEY case (edge.Direction.IN, Constants.NODE_ID_KEY) => Constants.EDGE_FROM_ID_KEY
case (edge.Direction.IN, Constants.CONTEXT_LABEL) => Constants.EDGE_FROM_ID_TYPE_KEY case (edge.Direction.IN, Constants.CONTEXT_LABEL) => Constants.EDGE_FROM_ID_TYPE_KEY
case (_, _) => case (_, _) =>
@ -124,10 +126,19 @@ object NodeIdToEdgeProperty extends Rule {
private def projectUpdate(project: Project, map: Map[String, Object]): Project = { private def projectUpdate(project: Project, map: Map[String, Object]): Project = {
val exprMap = new mutable.HashMap[Var, Expr]() val exprMap = new mutable.HashMap[Var, Expr]()
for (expr <- project.expr) { for (expr <- project.expr) {
val input = ExprUtils.getAllInputFieldInRule( exprMap.put(
expr._2, expr._1,
project.solved.getNodeAliasSet, exprRewrite(expr._2, project.solved.getNodeAliasSet, project.solved.getEdgeAliasSet, map))
project.solved.getEdgeAliasSet) }
project.copy(expr = exprMap.toMap)
}
private def exprRewrite(
expr: Expr,
nodes: Set[String],
edges: Set[String],
map: Map[String, Object]): Expr = {
val input = ExprUtils.getAllInputFieldInRule(expr, nodes, edges)
val replaceVar = new mutable.HashMap[IRField, IRField] val replaceVar = new mutable.HashMap[IRField, IRField]
for (irField <- input) { for (irField <- input) {
if (irField.isInstanceOf[IRNode] && map.contains(irField.name)) { if (irField.isInstanceOf[IRNode] && map.contains(irField.name)) {
@ -143,12 +154,28 @@ object NodeIdToEdgeProperty extends Rule {
} }
} }
if (replaceVar.isEmpty) { if (replaceVar.isEmpty) {
exprMap.+=(expr) expr
} else { } else {
exprMap.put(expr._1, ExprUtils.renameVariableInExpr(expr._2, replaceVar.toMap)) ExprUtils.renameVariableInExpr(expr, replaceVar.toMap)
} }
} }
project.copy(expr = exprMap.toMap)
private def ddlUpdate(ddl: DDL, map: Map[String, Object]): DDL = {
val nodes = ddl.solved.getNodeAliasSet
val edges = ddl.solved.getEdgeAliasSet
val newOps = new mutable.HashSet[DDLOp]()
for (ddlOp <- ddl.ddlOp) {
ddlOp match {
case ddlOp: AddProperty => newOps.add(ddlOp)
case AddVertex(s, props) =>
val newProps = props.map(p => (p._1, exprRewrite(p._2, nodes, edges, map)))
newOps.add(AddVertex(s, newProps))
case AddPredicate(predicate) =>
val newProps = predicate.fields.map(p => (p._1, exprRewrite(p._2, nodes, edges, map)))
newOps.add(AddPredicate(predicate.copy(fields = newProps)))
}
}
ddl.copy(ddlOp = newOps.toSet)
} }
private def selectUpdate(select: Select, map: Map[String, Object]): Select = { private def selectUpdate(select: Select, map: Map[String, Object]): Select = {
@ -172,40 +199,6 @@ object NodeIdToEdgeProperty extends Rule {
select.copy(fields = newFields.toList) select.copy(fields = newFields.toList)
} }
private def ddlUpdate(ddl: DDL, map: Map[String, Object]): DDL = {
val ddlOps = new mutable.HashSet[DDLOp]()
for (ddlOp <- ddl.ddlOp) {
ddlOp match {
case AddPredicate(predicate) =>
val newFields = new mutable.HashMap[String, Expr]()
for (field <- predicate.fields) {
val input = ExprUtils.getAllInputFieldInRule(field._2, null, null)
val replaceVar = new mutable.HashMap[IRField, IRProperty]
for (irField <- input) {
if (irField.isInstanceOf[IRNode] && map.contains(irField.name)) {
for (propName <- irField.asInstanceOf[IRNode].fields) {
if (NODE_DEFAULT_PROPS.contains(propName)) {
val edgeInfo = map(irField.name).asInstanceOf[Connection]
replaceVar.put(
IRProperty(irField.name, propName),
IRProperty(edgeInfo.alias, genField(edgeInfo.direction, propName)))
}
}
}
}
if (replaceVar.isEmpty) {
newFields.put(field._1, field._2)
} else {
newFields.put(field._1, ExprUtils.renameVariableInExpr(field._2, replaceVar.toMap))
}
}
ddlOps.add(AddPredicate(predicate.copy(fields = newFields.toMap)))
case _ => ddlOps.add(ddlOp)
}
}
ddl.copy(ddlOp = ddlOps.toSet)
}
private def targetConnection(expandInto: ExpandInto): Connection = { private def targetConnection(expandInto: ExpandInto): Connection = {
val alias = expandInto.pattern.root.alias val alias = expandInto.pattern.root.alias
val edgeAlias = expandInto.transform[Connection] { val edgeAlias = expandInto.transform[Connection] {
@ -217,6 +210,12 @@ object NodeIdToEdgeProperty extends Rule {
} else { } else {
targetConnection(alias, expandInto.pattern) targetConnection(alias, expandInto.pattern)
} }
case (linkedExpand: LinkedExpand, list) =>
if (!list.isEmpty && list.head != null) {
list.head
} else {
targetConnection(alias, linkedExpand.edgePattern)
}
case (_, list) => case (_, list) =>
if (list.isEmpty) { if (list.isEmpty) {
null null

View File

@ -39,8 +39,8 @@ object Pure extends SimpleRule {
val projects = select.refFields.map((_, Directly)).toMap val projects = select.refFields.map((_, Directly)).toMap
select.withNewChildren(Array.apply(Project(in, projects, in.solved))) select.withNewChildren(Array.apply(Project(in, projects, in.solved)))
case project @ Project(in, _, _) => case project @ Project(in, _, _) =>
if (in.isInstanceOf[Project] || in.isInstanceOf[ExpandInto] || in if (in.isInstanceOf[Project] || in.isInstanceOf[ExpandInto] || in.isInstanceOf[LinkedExpand]
.isInstanceOf[PatternScan] || in.isInstanceOf[BinaryLogicalOperator]) { || in.isInstanceOf[PatternScan] || in.isInstanceOf[BinaryLogicalOperator]) {
project project
} else { } else {
val projectOutput: List[Var] = project.fields val projectOutput: List[Var] = project.fields
@ -71,17 +71,20 @@ object Pure extends SimpleRule {
} }
} }
varMap.values.map(f => { varMap.values
.map(f => {
if (f.isInstanceOf[PathVar]) { if (f.isInstanceOf[PathVar]) {
f f
} else if (!solved.fields.contains(f.name)) { } else if (!solved.fields.contains(f.name)) {
throw InvalidRefVariable(s"can not find $f") throw InvalidRefVariable(s"can not find $f")
} else if (solved.fields.get(f.name).get.isInstanceOf[RepeatPathVar]) { } else if (solved.fields.get(f.name).get.isInstanceOf[RepeatPathVar]) {
f.intersect(solved.fields.get(f.name).get.asInstanceOf[RepeatPathVar].pathVar.elements(1)) f.intersect(
solved.fields.get(f.name).get.asInstanceOf[RepeatPathVar].pathVar.elements(1))
} else { } else {
f.intersect(solved.fields.get(f.name).get) f.intersect(solved.fields.get(f.name).get)
} }
}).toList })
.toList
} }
override def direction: Direction = Down override def direction: Direction = Down

File diff suppressed because it is too large Load Diff

View File

@ -13,7 +13,7 @@
package com.antgroup.openspg.reasoner.lube.logical package com.antgroup.openspg.reasoner.lube.logical
import com.antgroup.openspg.reasoner.common.types.{KgType, KTBoolean, KTDouble, KTInteger, KTLong, KTString} import com.antgroup.openspg.reasoner.common.types._
import com.antgroup.openspg.reasoner.lube.catalog.struct.Field import com.antgroup.openspg.reasoner.lube.catalog.struct.Field
import com.antgroup.openspg.reasoner.lube.common.expr.{Expr, Ref, UnaryOpExpr} import com.antgroup.openspg.reasoner.lube.common.expr.{Expr, Ref, UnaryOpExpr}
import com.antgroup.openspg.reasoner.lube.common.graph.{IRField, IRProperty, IRVariable} import com.antgroup.openspg.reasoner.lube.common.graph.{IRField, IRProperty, IRVariable}
@ -76,4 +76,23 @@ class ExprUtilTests extends AnyFunSpec {
ExprUtil.getTargetType(rule, map, udfRepo) should equal(KTLong) ExprUtil.getTargetType(rule, map, udfRepo) should equal(KTLong)
} }
it("test advanced type") {
val parser = new RuleExprParser()
val udfRepo = UdfMngFactory.getUdfMng
val map = Map
.apply(IRProperty("e", "eventTime") -> KTStd("STD.Timestamp", KTLong, false))
.asInstanceOf[Map[IRField, KgType]]
val r1 = ProjectRule(
IRVariable("eventDay"),
parser.parse("from_unix_time_ms(e.eventTime, 'yyyyMMdd')"))
val r2 = ProjectRule(
IRVariable("timeInDay"),
parser.parse("from_unix_time_ms(e.eventTime, 'yyyyMMdd') in ['20240311', '20240312']"))
ExprUtil.getTargetType(r1, map, udfRepo) should equal(KTString)
ExprUtil.getTargetType(r2, map, udfRepo) should equal(KTBoolean)
}
} }

View File

@ -63,101 +63,6 @@ class OptimizerTests extends AnyFunSpec {
} }
} }
it("testEdgeToProperty") {
val schema = ResourceLoader.loadResourceFile("TuringSchema.json")
val catalog = new JSONGraphCatalog(schema)
val dsl =
"""
|Define (user:TuringCore.AlipayUser)-[bt:belongTo]->(tc:`TuringCrowd`/`通勤用户`) {
| GraphStructure {
| (user) -[pwl:workLoc]-> (aa1:CKG.AdministrativeArea)
| (te:TuringCore.TravelEvent) -[ptler:traveler]-> (user)
| (te) -[ptm:travelMode]-> (tm:TuringCore.TravelMode)
| (te) -[pte:travelEndpoint]-> (aa1:CKG.AdministrativeArea)
| }
| Rule {
| R1('常驻地在杭州'): aa1.id == '中国-浙江省-杭州市'
| R2('工作日上班时间通勤用户'): dayOfWeek(te.eventTime) in [1, 2, 3, 4, 5]
| and hourOfDay(te.eventTime) in [6, 7, 8, 9, 10, 17, 18, 19, 20, 21]
| R3('公交地铁'): tm.id in ['bus', 'subway']
| tmCount('出行次数') = group(user).count(te.id)
| R4('出行次数大于3次'): tmCount >= 3
| }
|}
|""".stripMargin
catalog.init()
val parser = new OpenSPGDslParser()
val block = parser.parse(dsl)
implicit val context: LogicalPlannerContext =
LogicalPlannerContext(catalog, parser, Map.empty)
val logicalPlan = LogicalPlanner.plan(block).head
val finalOp = LogicalOptimizer.optimize(
logicalPlan,
Seq.apply(FilterPushDown, EdgeToProperty, SolvedModelPure))
val qlTransformer = new Expr2QlexpressTransformer()
finalOp.findExactlyOne { case ExpandInto(_, _, pattern) =>
qlTransformer.transform(pattern.getNode("te").rule).head should equal(
"((te.travelMode in [\"bus\",\"subway\"]) && ((dayOfWeek(te.eventTime) in [1,2,3,4,5]) && (hourOfDay(te.eventTime) in [6,7,8,9,10,17,18,19,20,21]))) && (te.travelEndpoint == \"中国-浙江省-杭州市\")")
}
finalOp.findExactlyOne { case PatternScan(_, pattern) =>
qlTransformer.transform(pattern.getNode("user").rule).head should equal(
"user.workLoc == \"中国-浙江省-杭州市\"")
}
finalOp.findExactlyOne { case Start(_, _, _, solved) =>
solved.alias2Types.keys.toSet should equal(Set.apply("user", "te", "ptler"))
solved.fields("user").asInstanceOf[NodeVar].fields.map(_.name) should equal(
Set.apply("workLoc"))
solved.fields("te").asInstanceOf[NodeVar].fields.map(_.name) should equal(
Set.apply("eventTime", "id", "travelMode", "travelEndpoint"))
}
}
it("concept to property") {
val schema = ResourceLoader.loadResourceFile("TuringSchema.json")
val catalog = new JSONGraphCatalog(schema)
catalog.init()
val start = Start(
catalog.getGraph(Catalog.defaultGraphName),
null,
Set.empty,
SolvedModel(
Map.empty,
Map.apply(
("te", NodeVar("te", Set.apply(new Field("eventTime", KTString, true)))),
("user", NodeVar("user", Set.apply())),
("tm", NodeVar("tm", Set.apply(new Field("id", KTString, true))))),
Map.empty))
val r1 = LogicRule(
"R1",
"xx",
BinaryOpExpr(BEqual, UnaryOpExpr(GetField("eventTime"), Ref("te")), VString("1")))
val r2 = LogicRule(
"R2",
"xx",
BinaryOpExpr(BEqual, UnaryOpExpr(GetField("id"), Ref("tm")), VString("bus")))
val patternElementMap = Map.apply(
("te", PatternElement("te", Set.apply("TuringCore.TravelEvent"), r1)),
("tm", PatternElement("tm", Set.apply("TuringCore.TravelMode"), r2)))
val edges: Map[String, Set[Connection]] = Map.apply((
"te",
Set.apply(
new PatternConnection("ptm", "te", Set.apply("travelMode"), "tm", Direction.OUT, null))))
val expand = ExpandInto(
start,
patternElementMap("te"),
PartialGraphPattern("te", patternElementMap, edges))
val logicalOp =
ExpandInto(expand, patternElementMap("tm"), NodePattern(patternElementMap("tm")))
val finalOp = BottomUp[LogicalOperator](EdgeToProperty.rule(null)).transform(logicalOp)
val qlTransformer = new Expr2QlexpressTransformer()
finalOp.findExactlyOne { case ExpandInto(_, _, pattern) =>
pattern.root.alias should equal("te")
qlTransformer.transform(pattern.getNode("te").rule).head should equal(
"(te.travelMode == \"bus\") && (te.eventTime == \"1\")")
}
}
it("expandInto pure") { it("expandInto pure") {
val dsl = val dsl =
""" """