feat(reasoner): reasoner add main class (#41)

This commit is contained in:
Donghai 2023-12-20 19:03:27 +08:00 committed by GitHub
parent 0c82076604
commit 261cd27e72
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 282 additions and 53 deletions

View File

@ -305,6 +305,11 @@
<artifactId>QLExpress</artifactId>
<version>3.3.2</version>
</dependency>
<dependency>
<groupId>commons-cli</groupId>
<artifactId>commons-cli</artifactId>
<version>1.6.0</version>
</dependency>
</dependencies>
</dependencyManagement>

View File

@ -114,6 +114,10 @@
<groupId>com.antgroup.openspg.reasoner</groupId>
<artifactId>reasoner-common</artifactId>
</dependency>
<dependency>
<groupId>commons-cli</groupId>
<artifactId>commons-cli</artifactId>
</dependency>
</dependencies>
</project>

View File

@ -57,7 +57,7 @@ public class KGReasonerLocalRunner {
return doRun(task);
} catch (Throwable e) {
log.error("KGReasonerLocalRunner,error", e);
return new LocalReasonerResult("KGReasonerLocalRunner,error " + e.getMessage());
return new LocalReasonerResult("KGReasonerLocalRunner,error:" + e.getMessage());
}
}
@ -179,11 +179,23 @@ public class KGReasonerLocalRunner {
String graphLoadClass = task.getGraphLoadClass();
MemGraphState memGraphState = new MemGraphState();
AbstractLocalGraphLoader graphLoader;
try {
graphLoader =
(AbstractLocalGraphLoader) Class.forName(graphLoadClass).getConstructor().newInstance();
} catch (Exception e) {
throw new RuntimeException("can not create graph loader from name " + graphLoadClass, e);
if (StringUtils.isEmpty(task.getGraphStateInitString())) {
try {
graphLoader =
(AbstractLocalGraphLoader) Class.forName(graphLoadClass).getConstructor().newInstance();
} catch (Exception e) {
throw new RuntimeException("can not create graph loader from name " + graphLoadClass, e);
}
} else {
try {
graphLoader =
(AbstractLocalGraphLoader)
Class.forName(graphLoadClass)
.getConstructor(String.class)
.newInstance(task.getGraphStateInitString());
} catch (Exception e) {
throw new RuntimeException("can not create graph loader from name " + graphLoadClass, e);
}
}
graphLoader.setGraphState(memGraphState);
graphLoader.load();

View File

@ -14,26 +14,214 @@
package com.antgroup.openspg.reasoner.runner.local;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.TypeReference;
import com.antgroup.openspg.reasoner.catalog.impl.KgSchemaConnectionInfo;
import com.antgroup.openspg.reasoner.runner.local.model.LocalReasonerResult;
import com.antgroup.openspg.reasoner.runner.local.model.LocalReasonerTask;
import com.opencsv.CSVWriter;
import java.io.FileWriter;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.Base64;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.StandardOpenOption;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.CommandLineParser;
import org.apache.commons.cli.DefaultParser;
import org.apache.commons.cli.HelpFormatter;
import org.apache.commons.cli.Option;
import org.apache.commons.cli.Options;
import org.apache.commons.cli.ParseException;
import org.apache.commons.lang3.StringUtils;
@Slf4j
public class LocalRunnerMain {
/** result */
public static LocalReasonerResult result = null;
/** KGReasoner main */
public static void main(String[] args) {
String taskInfoJson = new String(Base64.getDecoder().decode(args[0]), StandardCharsets.UTF_8);
LocalReasonerTask task = JSON.parseObject(taskInfoJson, LocalReasonerTask.class);
LocalReasonerTask task = parseArgs(args);
if (null == task) {
System.exit(1);
}
KGReasonerLocalRunner runner = new KGReasonerLocalRunner();
result = runner.run(task);
if (null != result) {
log.info(result.toString());
LocalReasonerResult result = runner.run(task);
if (null == result) {
log.error("local runner return null");
return;
}
if (StringUtils.isNotEmpty(result.getErrMsg())) {
log.error(result.getErrMsg());
}
if (StringUtils.isNotEmpty(task.getOutputFile())) {
writeOutputFile(result, task.getOutputFile());
}
}
private static void writeOutputFile(LocalReasonerResult result, String file) {
Path path = Paths.get(file);
try {
if (Files.notExists(path.getParent())) {
Files.createDirectories(path.getParent());
}
if (Files.exists(path)) {
Files.delete(path);
}
} catch (IOException e) {
log.error("write result file error, file=" + file, e);
return;
}
if (StringUtils.isNotEmpty(result.getErrMsg())) {
writeFile(path, result.getErrMsg());
} else if (result.isGraphResult()) {
// write graph result
writeFile(path, result.toString());
} else {
// write csv
writeCsv(path, result.getColumns(), result.getRows());
}
}
private static void writeCsv(Path path, List<String> columns, List<Object[]> rows) {
List<String[]> allLines = new ArrayList<>(rows.size() + 1);
allLines.add(columns.toArray(new String[] {}));
for (Object[] rowObj : rows) {
String[] row = new String[rowObj.length];
for (int i = 0; i < rowObj.length; ++i) {
if (null != rowObj[i]) {
row[i] = String.valueOf(rowObj[i]);
} else {
row[i] = null;
}
}
allLines.add(row);
}
CSVWriter csvWriter;
try {
csvWriter = new CSVWriter(new FileWriter(path.toString()));
csvWriter.writeAll(allLines);
csvWriter.close();
} catch (IOException e) {
log.error("csvwriter error, file=" + path, e);
}
}
private static void writeFile(Path path, String content) {
try {
Files.write(path, content.getBytes(StandardCharsets.UTF_8), StandardOpenOption.CREATE);
} catch (IOException e) {
log.error("write result file error, file=" + path, e);
}
}
private static LocalReasonerTask parseArgs(String[] args) {
Options options = new Options();
Option optDsl = new Option("q", "query", true, "query dsl string");
optDsl.setRequired(true);
options.addOption(optDsl);
Option optOutputFile = new Option("o", "output", true, "output file");
optOutputFile.setRequired(false);
options.addOption(optOutputFile);
Option optSchemaUri = new Option("s", "schema_uri", true, "provide schema uri");
optSchemaUri.setRequired(true);
options.addOption(optSchemaUri);
Option optSchemaToken = new Option("st", "schema_token", true, "provide schema token");
optSchemaToken.setRequired(true);
options.addOption(optSchemaToken);
Option optGraphStateClass =
new Option("g", "graph_state_class", true, "graph state class name");
optGraphStateClass.setRequired(true);
options.addOption(optGraphStateClass);
Option optGraphStateUrl = new Option("gs", "graph_state_url", true, "graph state url");
optGraphStateUrl.setRequired(false);
options.addOption(optGraphStateUrl);
Option optStartIdList = new Option("start", "start_id_list", true, "start id json list");
optStartIdList.setRequired(true);
options.addOption(optStartIdList);
Option optParamsJson =
new Option("params", "param_map_json_str", true, "parameter map json string");
optParamsJson.setRequired(false);
options.addOption(optParamsJson);
CommandLineParser parser = new DefaultParser();
HelpFormatter formatter = new HelpFormatter();
CommandLine cmd;
String dsl;
String outputFile;
String schemaUri;
String schemaToken;
String graphStateClass;
String graphStateUrl;
List<List<String>> startIdList;
Map<String, Object> params = null;
try {
cmd = parser.parse(options, args);
dsl = cmd.getOptionValue("q");
if (StringUtils.isEmpty(dsl)) {
throw new ParseException("please provide query dsl!");
}
outputFile = cmd.getOptionValue("o");
if (StringUtils.isEmpty(outputFile)) {
outputFile = null;
}
schemaUri = cmd.getOptionValue("s");
if (StringUtils.isEmpty(schemaUri)) {
throw new ParseException("please provide openspg schema uri!");
}
schemaToken = cmd.getOptionValue("st");
if (StringUtils.isEmpty(schemaToken)) {
throw new ParseException("please provide openspg schema api token!");
}
graphStateClass = cmd.getOptionValue("g");
if (StringUtils.isEmpty(graphStateClass)) {
throw new ParseException("please provide graph state class name!");
}
graphStateUrl = cmd.getOptionValue("gs");
if (StringUtils.isEmpty(graphStateUrl)) {
graphStateUrl = null;
}
String startIdListJson = cmd.getOptionValue("start");
if (StringUtils.isEmpty(startIdListJson)) {
throw new ParseException("please provide start id");
}
startIdList = JSON.parseObject(startIdListJson, new TypeReference<List<List<String>>>() {});
String paramsJson = cmd.getOptionValue("params");
if (StringUtils.isNotEmpty(paramsJson)) {
params = new HashMap<>(JSON.parseObject(paramsJson));
}
} catch (ParseException e) {
log.error(e.getMessage());
formatter.printHelp("ReasonerLocalRunner", options);
return null;
}
LocalReasonerTask task = new LocalReasonerTask();
task.setId(UUID.randomUUID().toString());
task.setDsl(dsl);
task.setOutputFile(outputFile);
task.setConnInfo(new KgSchemaConnectionInfo(schemaUri, schemaToken));
task.setGraphLoadClass(graphStateClass);
task.setGraphStateInitString(graphStateUrl);
task.setStartIdList(new ArrayList<>());
task.addStartId(startIdList);
task.setParams(params);
return task;
}
}

View File

@ -31,24 +31,30 @@ import scala.Tuple2;
@Data
public class LocalReasonerTask implements Serializable {
/** task id */
private static final long serialVersionUID = 8591924774057455987L;
/** task id */
private String id = "";
/** output file name */
private String outputFile = null;
/** Choose between dsl or dslDagList */
private String dsl = null;
private List<PhysicalOperator<LocalRDG>> dslDagList = null;
private LocalReasonerSession session = null;
/** pass catalog to runner or provide schema connection info */
/** pass catalog to runner or provide schema connection info or provide schema string */
private Catalog catalog = null;
private KgSchemaConnectionInfo connInfo = null;
private String schemaString = null;
/** Choose between graphLoadClass or graphState */
private String graphLoadClass = null;
private String graphStateInitString = null;
private GraphState<IVertexId> graphState = null;
/** start id from input */
@ -73,4 +79,11 @@ public class LocalReasonerTask implements Serializable {
/** execution information recorder, for debug */
private IExecutionRecorder executionRecorder = null;
/** add start id */
public void addStartId(List<List<String>> startIdList) {
for (List<String> item : startIdList) {
this.startIdList.add(new Tuple2<>(item.get(0), item.get(1)));
}
}
}

View File

@ -23,12 +23,14 @@ import com.antgroup.openspg.reasoner.runner.ConfigKey;
import com.antgroup.openspg.reasoner.runner.local.model.LocalReasonerResult;
import com.antgroup.openspg.reasoner.runner.local.model.LocalReasonerTask;
import com.antgroup.openspg.reasoner.util.Convert2ScalaUtil;
import com.antgroup.openspg.reasoner.utils.RunnerUtil;
import com.antgroup.openspg.reasoner.utils.SimpleObjSerde;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import java.util.Base64;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import org.junit.Assert;
import org.junit.Test;
import scala.Tuple2;
@ -427,7 +429,8 @@ public class LocalRunnerTest {
+ "Rule {\n"
+ "}\n"
+ "Action {\n"
+ "\tget(s.id, s.is_trans_raise_more_after_down, s.cur_trans_multiple, s.last_trans_multiple, s.last_last_month_num, s.last_month_num, s.cur_month_num)\n"
+ "\tget(s.id, s.is_trans_raise_more_after_down, s.cur_trans_multiple, s.last_trans_multiple, s"
+ ".last_last_month_num, s.last_month_num, s.cur_month_num)\n"
+ "}\n";
LocalReasonerTask task = new LocalReasonerTask();
task.setDsl(dsl);
@ -447,9 +450,8 @@ public class LocalRunnerTest {
task.setStartIdList(Lists.newArrayList(new Tuple2<>("1", "CustFundKG.Account")));
LocalRunnerMain.main(
new String[] {new String(Base64.getEncoder().encode(JSON.toJSONBytes(task)))});
LocalReasonerResult result = LocalRunnerMain.result;
KGReasonerLocalRunner runner = new KGReasonerLocalRunner();
LocalReasonerResult result = runner.run(task);
System.out.println("##########################");
System.out.println(result);
System.out.println("##########################");
@ -514,9 +516,8 @@ public class LocalRunnerTest {
task.setStartIdList(Lists.newArrayList(new Tuple2<>("保险产品", "InsProduct.Product")));
LocalRunnerMain.main(
new String[] {new String(Base64.getEncoder().encode(JSON.toJSONBytes(task)))});
LocalReasonerResult result = LocalRunnerMain.result;
KGReasonerLocalRunner runner = new KGReasonerLocalRunner();
LocalReasonerResult result = runner.run(task);
System.out.println("##########################");
System.out.println(result);
System.out.println("##########################");
@ -578,9 +579,8 @@ public class LocalRunnerTest {
task.setStartIdList(Lists.newArrayList(new Tuple2<>("1", "CustFundKG.Account")));
LocalRunnerMain.main(
new String[] {new String(Base64.getEncoder().encode(JSON.toJSONBytes(task)))});
LocalReasonerResult result = LocalRunnerMain.result;
KGReasonerLocalRunner runner = new KGReasonerLocalRunner();
LocalReasonerResult result = runner.run(task);
System.out.println("##########################");
System.out.println(result);
System.out.println("##########################");
@ -653,9 +653,8 @@ public class LocalRunnerTest {
task.setStartIdList(Lists.newArrayList(new Tuple2<>("1", "CustFundKG.Account")));
LocalRunnerMain.main(
new String[] {new String(Base64.getEncoder().encode(JSON.toJSONBytes(task)))});
LocalReasonerResult result = LocalRunnerMain.result;
KGReasonerLocalRunner runner = new KGReasonerLocalRunner();
LocalReasonerResult result = runner.run(task);
System.out.println("##########################");
System.out.println(result);
System.out.println("##########################");
@ -741,9 +740,8 @@ public class LocalRunnerTest {
task.setStartIdList(Lists.newArrayList(new Tuple2<>("black_app_1", "Pkg")));
LocalRunnerMain.main(
new String[] {new String(Base64.getEncoder().encode(JSON.toJSONBytes(task)))});
LocalReasonerResult result = LocalRunnerMain.result;
KGReasonerLocalRunner runner = new KGReasonerLocalRunner();
LocalReasonerResult result = runner.run(task);
System.out.println("##########################");
System.out.println(result);
System.out.println("##########################");
@ -784,9 +782,8 @@ public class LocalRunnerTest {
task.getParams().put(ConfigKey.KG_REASONER_CATALOG, SimpleObjSerde.ser(catalog));
task.setStartIdList(Lists.newArrayList(new Tuple2<>("user1", "ABM.User")));
LocalRunnerMain.main(
new String[] {new String(Base64.getEncoder().encode(JSON.toJSONBytes(task)))});
LocalReasonerResult result = LocalRunnerMain.result;
KGReasonerLocalRunner runner = new KGReasonerLocalRunner();
LocalReasonerResult result = runner.run(task);
System.out.println("##########################");
System.out.println(result);
System.out.println("##########################");
@ -875,9 +872,8 @@ public class LocalRunnerTest {
task.setStartIdList(Lists.newArrayList(new Tuple2<>("S", "CustFundKG.Account")));
LocalRunnerMain.main(
new String[] {new String(Base64.getEncoder().encode(JSON.toJSONBytes(task)))});
LocalReasonerResult result = LocalRunnerMain.result;
KGReasonerLocalRunner runner = new KGReasonerLocalRunner();
LocalReasonerResult result = runner.run(task);
System.out.println("##########################");
System.out.println(result);
System.out.println("##########################");
@ -896,11 +892,8 @@ public class LocalRunnerTest {
+ "Action {\n"
+ " get(s.id)\n"
+ "}";
LocalReasonerTask task = new LocalReasonerTask();
task.setDsl(nearbyDsl);
task.setGraphLoadClass(
"com.antgroup.openspg.reasoner.runner.local.loader.TestSpatioTemporalGraphLoader");
Map<String, Object> params = new HashMap<>();
Map<String, scala.collection.immutable.Set<String>> schema = new HashMap<>();
schema.put(
"PE.JiuZhi",
@ -911,17 +904,31 @@ public class LocalRunnerTest {
Convert2ScalaUtil.toScalaImmutableSet(Sets.newHashSet()));
Catalog catalog = new PropertyGraphCatalog(Convert2ScalaUtil.toScalaImmutableMap(schema));
catalog.init();
task.getParams().put(ConfigKey.KG_REASONER_CATALOG, SimpleObjSerde.ser(catalog));
task.getParams().put(Constants.START_ALIAS, "s");
task.setStartIdList(Lists.newArrayList(new Tuple2<>("MOCK1", "PE.JiuZhi")));
params.put(ConfigKey.KG_REASONER_CATALOG, SimpleObjSerde.ser(catalog));
params.put(Constants.START_ALIAS, "s");
String outputFile = "/tmp/local/runner/" + UUID.randomUUID() + ".csv";
LocalRunnerMain.main(
new String[] {new String(Base64.getEncoder().encode(JSON.toJSONBytes(task)))});
LocalReasonerResult result = LocalRunnerMain.result;
System.out.println("##########################");
System.out.println(result);
System.out.println("##########################");
Assert.assertEquals(1, result.getRows().size());
new String[] {
"-q",
nearbyDsl,
"-o",
outputFile,
"-g",
"com.antgroup.openspg.reasoner.runner.local.loader.TestSpatioTemporalGraphLoader",
"-s",
"s",
"-st",
"st",
"-start",
"[[\"MOCK1\",\"PE.JiuZhi\"]]",
"-params",
JSON.toJSONString(params)
});
List<String[]> rst = RunnerUtil.loadCsvFile(outputFile);
Assert.assertEquals(2, rst.size());
clear();
}