fix(reasoner): fix pattern schema extra (#351)

This commit is contained in:
royzhao 2024-09-05 14:06:54 +08:00 committed by GitHub
parent 5822cbc7c3
commit 579be8bfae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 64 additions and 5 deletions

View File

@ -131,7 +131,13 @@ object PatternOps {
val typeSet = conn.relTypes
val targetSet = pattern.getNode(conn.target).typeNames
val spoSet = sourceSet.flatMap(s =>
typeSet.flatMap(p => targetSet.map(o => new SPO(s, p, o).toString)))
typeSet.flatMap(p => targetSet.map(o => {
if (conn.direction == Direction.OUT) {
new SPO(s, p, o).toString
} else {
new SPO(o, p, s).toString
}
})))
typeMap.put(conn.alias, spoSet)
})
}

View File

@ -1077,4 +1077,38 @@ public class KgReasonerDependencyFilmTest {
Assert.assertEquals("root", result.get(0)[0]);
Assert.assertEquals("100", result.get(0)[1]);
}
@Test
public void test21() {
FileMutex.runTestWithMutex(this::doTest21);
}
private void doTest21() {
String dsl =
"\n"
+ "GraphStructure {\n"
+ " A [FilmStar] \n"
+ " B [Film,__start__='true'] \n"
+ " C [Film] \n"
+ "\n"
+ " B->A [starOfFilm, __optional__='true'] as BA \n"
+ " C->A [starOfFilm, __optional__='true'] as CA\n"
+ "\n"
+ "}\n"
+ "Rule {\n"
+ " R2: exist(BA) and BA.joinTs >=100\n"
+ " R3: exist(CA) and CA.joinTs >=100\n"
+ "}\n"
+ "Action {\n"
+ " get(\n"
+ " A.id,\n"
+ " B.id,\n"
+ " BA.joinTs,\n"
+ " C.id,\n"
+ " CA.joinTs\n"
+ " ) \n"
+ "}";
List<String[]> result = runTestResult(dsl);
Assert.assertEquals(11, result.size());
}
}

View File

@ -270,7 +270,7 @@ public class KgReasonerTransitiveTest {
+ " A->B [holdShare] repeat(1,10) as e\n"
+ "}\n"
+ "Rule {\n"
+ "totalRate = e.edges().reduce((x,y) => y.rate * x, 1)"
+ "totalRate = e.edges().reduce((x,y) => y.rate * x, 1.0)\n"
+ " R1(\"只保留最长的路径\"): group(A).keep_longest_path(e)\n"
+ "}\n"
+ "Action {\n"

View File

@ -470,7 +470,7 @@ public class TransitiveOptionalTest {
dSet.add(row[3]);
}
Assert.assertTrue(dSet.contains("D_333"));
Assert.assertTrue(dSet.contains(null));
Assert.assertTrue(dSet.contains("null"));
dataGraphStr =
"Graph {\n"

View File

@ -28,6 +28,7 @@ import com.antgroup.openspg.reasoner.lube.logical.Var;
import com.antgroup.openspg.reasoner.udf.rule.RuleRunner;
import com.antgroup.openspg.reasoner.utils.RunnerUtil;
import java.io.Serializable;
import java.math.BigDecimal;
import java.util.List;
import java.util.Map;
@ -60,8 +61,6 @@ public class SelectRowImpl implements Serializable {
if (null == selectValue) {
row[i] = null;
} else if (this.forceOutputString) {
row[i] = String.valueOf(selectValue);
} else {
FieldType type = FieldType.fromKgType(fieldType);
if (FieldType.STRING.equals(type)) {
@ -88,10 +87,30 @@ public class SelectRowImpl implements Serializable {
row[i] = selectValue;
}
}
row[i] = convert2OutputFormat(row[i]);
}
return row;
}
private Object convert2OutputFormat(Object obj) {
if (!this.forceOutputString) {
return obj;
}
if (obj == null) {
return "null";
}
if (obj instanceof Double || obj instanceof Float) {
// 将Double或Float转换为BigDecimal
BigDecimal bd = new BigDecimal(obj.toString());
// 使用toPlainString以避免科学计数法
return bd.toPlainString();
} else {
// 对于其他类型直接调用toString方法
return obj.toString();
}
}
public static Object getSelectValue(
String alias, String propertyName, Map<String, Object> context) {
if (Constants.PROPERTY_JSON_KEY.equals(propertyName)) {