mirror of
https://github.com/OpenSPG/openspg.git
synced 2025-12-02 02:01:00 +00:00
feat(reasoner): reasoner add main class (#41)
This commit is contained in:
parent
0c82076604
commit
261cd27e72
5
pom.xml
5
pom.xml
@ -305,6 +305,11 @@
|
||||
<artifactId>QLExpress</artifactId>
|
||||
<version>3.3.2</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>commons-cli</groupId>
|
||||
<artifactId>commons-cli</artifactId>
|
||||
<version>1.6.0</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
</dependencyManagement>
|
||||
|
||||
|
||||
@ -114,6 +114,10 @@
|
||||
<groupId>com.antgroup.openspg.reasoner</groupId>
|
||||
<artifactId>reasoner-common</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>commons-cli</groupId>
|
||||
<artifactId>commons-cli</artifactId>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
</project>
|
||||
|
||||
@ -57,7 +57,7 @@ public class KGReasonerLocalRunner {
|
||||
return doRun(task);
|
||||
} catch (Throwable e) {
|
||||
log.error("KGReasonerLocalRunner,error", e);
|
||||
return new LocalReasonerResult("KGReasonerLocalRunner,error " + e.getMessage());
|
||||
return new LocalReasonerResult("KGReasonerLocalRunner,error:" + e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
@ -179,11 +179,23 @@ public class KGReasonerLocalRunner {
|
||||
String graphLoadClass = task.getGraphLoadClass();
|
||||
MemGraphState memGraphState = new MemGraphState();
|
||||
AbstractLocalGraphLoader graphLoader;
|
||||
try {
|
||||
graphLoader =
|
||||
(AbstractLocalGraphLoader) Class.forName(graphLoadClass).getConstructor().newInstance();
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("can not create graph loader from name " + graphLoadClass, e);
|
||||
if (StringUtils.isEmpty(task.getGraphStateInitString())) {
|
||||
try {
|
||||
graphLoader =
|
||||
(AbstractLocalGraphLoader) Class.forName(graphLoadClass).getConstructor().newInstance();
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("can not create graph loader from name " + graphLoadClass, e);
|
||||
}
|
||||
} else {
|
||||
try {
|
||||
graphLoader =
|
||||
(AbstractLocalGraphLoader)
|
||||
Class.forName(graphLoadClass)
|
||||
.getConstructor(String.class)
|
||||
.newInstance(task.getGraphStateInitString());
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("can not create graph loader from name " + graphLoadClass, e);
|
||||
}
|
||||
}
|
||||
graphLoader.setGraphState(memGraphState);
|
||||
graphLoader.load();
|
||||
|
||||
@ -14,26 +14,214 @@
|
||||
package com.antgroup.openspg.reasoner.runner.local;
|
||||
|
||||
import com.alibaba.fastjson.JSON;
|
||||
import com.alibaba.fastjson.TypeReference;
|
||||
import com.antgroup.openspg.reasoner.catalog.impl.KgSchemaConnectionInfo;
|
||||
import com.antgroup.openspg.reasoner.runner.local.model.LocalReasonerResult;
|
||||
import com.antgroup.openspg.reasoner.runner.local.model.LocalReasonerTask;
|
||||
import com.opencsv.CSVWriter;
|
||||
import java.io.FileWriter;
|
||||
import java.io.IOException;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.Base64;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.nio.file.Paths;
|
||||
import java.nio.file.StandardOpenOption;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.UUID;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.cli.CommandLine;
|
||||
import org.apache.commons.cli.CommandLineParser;
|
||||
import org.apache.commons.cli.DefaultParser;
|
||||
import org.apache.commons.cli.HelpFormatter;
|
||||
import org.apache.commons.cli.Option;
|
||||
import org.apache.commons.cli.Options;
|
||||
import org.apache.commons.cli.ParseException;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
@Slf4j
|
||||
public class LocalRunnerMain {
|
||||
|
||||
/** result */
|
||||
public static LocalReasonerResult result = null;
|
||||
|
||||
/** KGReasoner main */
|
||||
public static void main(String[] args) {
|
||||
String taskInfoJson = new String(Base64.getDecoder().decode(args[0]), StandardCharsets.UTF_8);
|
||||
LocalReasonerTask task = JSON.parseObject(taskInfoJson, LocalReasonerTask.class);
|
||||
LocalReasonerTask task = parseArgs(args);
|
||||
if (null == task) {
|
||||
System.exit(1);
|
||||
}
|
||||
KGReasonerLocalRunner runner = new KGReasonerLocalRunner();
|
||||
result = runner.run(task);
|
||||
if (null != result) {
|
||||
log.info(result.toString());
|
||||
LocalReasonerResult result = runner.run(task);
|
||||
if (null == result) {
|
||||
log.error("local runner return null");
|
||||
return;
|
||||
}
|
||||
if (StringUtils.isNotEmpty(result.getErrMsg())) {
|
||||
log.error(result.getErrMsg());
|
||||
}
|
||||
if (StringUtils.isNotEmpty(task.getOutputFile())) {
|
||||
writeOutputFile(result, task.getOutputFile());
|
||||
}
|
||||
}
|
||||
|
||||
private static void writeOutputFile(LocalReasonerResult result, String file) {
|
||||
Path path = Paths.get(file);
|
||||
try {
|
||||
if (Files.notExists(path.getParent())) {
|
||||
Files.createDirectories(path.getParent());
|
||||
}
|
||||
if (Files.exists(path)) {
|
||||
Files.delete(path);
|
||||
}
|
||||
} catch (IOException e) {
|
||||
log.error("write result file error, file=" + file, e);
|
||||
return;
|
||||
}
|
||||
|
||||
if (StringUtils.isNotEmpty(result.getErrMsg())) {
|
||||
writeFile(path, result.getErrMsg());
|
||||
} else if (result.isGraphResult()) {
|
||||
// write graph result
|
||||
writeFile(path, result.toString());
|
||||
} else {
|
||||
// write csv
|
||||
writeCsv(path, result.getColumns(), result.getRows());
|
||||
}
|
||||
}
|
||||
|
||||
private static void writeCsv(Path path, List<String> columns, List<Object[]> rows) {
|
||||
List<String[]> allLines = new ArrayList<>(rows.size() + 1);
|
||||
allLines.add(columns.toArray(new String[] {}));
|
||||
for (Object[] rowObj : rows) {
|
||||
String[] row = new String[rowObj.length];
|
||||
for (int i = 0; i < rowObj.length; ++i) {
|
||||
if (null != rowObj[i]) {
|
||||
row[i] = String.valueOf(rowObj[i]);
|
||||
} else {
|
||||
row[i] = null;
|
||||
}
|
||||
}
|
||||
allLines.add(row);
|
||||
}
|
||||
|
||||
CSVWriter csvWriter;
|
||||
try {
|
||||
csvWriter = new CSVWriter(new FileWriter(path.toString()));
|
||||
csvWriter.writeAll(allLines);
|
||||
csvWriter.close();
|
||||
} catch (IOException e) {
|
||||
log.error("csvwriter error, file=" + path, e);
|
||||
}
|
||||
}
|
||||
|
||||
private static void writeFile(Path path, String content) {
|
||||
try {
|
||||
Files.write(path, content.getBytes(StandardCharsets.UTF_8), StandardOpenOption.CREATE);
|
||||
} catch (IOException e) {
|
||||
log.error("write result file error, file=" + path, e);
|
||||
}
|
||||
}
|
||||
|
||||
private static LocalReasonerTask parseArgs(String[] args) {
|
||||
Options options = new Options();
|
||||
|
||||
Option optDsl = new Option("q", "query", true, "query dsl string");
|
||||
optDsl.setRequired(true);
|
||||
options.addOption(optDsl);
|
||||
|
||||
Option optOutputFile = new Option("o", "output", true, "output file");
|
||||
optOutputFile.setRequired(false);
|
||||
options.addOption(optOutputFile);
|
||||
|
||||
Option optSchemaUri = new Option("s", "schema_uri", true, "provide schema uri");
|
||||
optSchemaUri.setRequired(true);
|
||||
options.addOption(optSchemaUri);
|
||||
|
||||
Option optSchemaToken = new Option("st", "schema_token", true, "provide schema token");
|
||||
optSchemaToken.setRequired(true);
|
||||
options.addOption(optSchemaToken);
|
||||
|
||||
Option optGraphStateClass =
|
||||
new Option("g", "graph_state_class", true, "graph state class name");
|
||||
optGraphStateClass.setRequired(true);
|
||||
options.addOption(optGraphStateClass);
|
||||
|
||||
Option optGraphStateUrl = new Option("gs", "graph_state_url", true, "graph state url");
|
||||
optGraphStateUrl.setRequired(false);
|
||||
options.addOption(optGraphStateUrl);
|
||||
|
||||
Option optStartIdList = new Option("start", "start_id_list", true, "start id json list");
|
||||
optStartIdList.setRequired(true);
|
||||
options.addOption(optStartIdList);
|
||||
|
||||
Option optParamsJson =
|
||||
new Option("params", "param_map_json_str", true, "parameter map json string");
|
||||
optParamsJson.setRequired(false);
|
||||
options.addOption(optParamsJson);
|
||||
|
||||
CommandLineParser parser = new DefaultParser();
|
||||
HelpFormatter formatter = new HelpFormatter();
|
||||
CommandLine cmd;
|
||||
|
||||
String dsl;
|
||||
String outputFile;
|
||||
String schemaUri;
|
||||
String schemaToken;
|
||||
String graphStateClass;
|
||||
String graphStateUrl;
|
||||
List<List<String>> startIdList;
|
||||
Map<String, Object> params = null;
|
||||
try {
|
||||
cmd = parser.parse(options, args);
|
||||
dsl = cmd.getOptionValue("q");
|
||||
if (StringUtils.isEmpty(dsl)) {
|
||||
throw new ParseException("please provide query dsl!");
|
||||
}
|
||||
outputFile = cmd.getOptionValue("o");
|
||||
if (StringUtils.isEmpty(outputFile)) {
|
||||
outputFile = null;
|
||||
}
|
||||
schemaUri = cmd.getOptionValue("s");
|
||||
if (StringUtils.isEmpty(schemaUri)) {
|
||||
throw new ParseException("please provide openspg schema uri!");
|
||||
}
|
||||
schemaToken = cmd.getOptionValue("st");
|
||||
if (StringUtils.isEmpty(schemaToken)) {
|
||||
throw new ParseException("please provide openspg schema api token!");
|
||||
}
|
||||
graphStateClass = cmd.getOptionValue("g");
|
||||
if (StringUtils.isEmpty(graphStateClass)) {
|
||||
throw new ParseException("please provide graph state class name!");
|
||||
}
|
||||
graphStateUrl = cmd.getOptionValue("gs");
|
||||
if (StringUtils.isEmpty(graphStateUrl)) {
|
||||
graphStateUrl = null;
|
||||
}
|
||||
String startIdListJson = cmd.getOptionValue("start");
|
||||
if (StringUtils.isEmpty(startIdListJson)) {
|
||||
throw new ParseException("please provide start id");
|
||||
}
|
||||
startIdList = JSON.parseObject(startIdListJson, new TypeReference<List<List<String>>>() {});
|
||||
String paramsJson = cmd.getOptionValue("params");
|
||||
if (StringUtils.isNotEmpty(paramsJson)) {
|
||||
params = new HashMap<>(JSON.parseObject(paramsJson));
|
||||
}
|
||||
} catch (ParseException e) {
|
||||
log.error(e.getMessage());
|
||||
formatter.printHelp("ReasonerLocalRunner", options);
|
||||
return null;
|
||||
}
|
||||
|
||||
LocalReasonerTask task = new LocalReasonerTask();
|
||||
task.setId(UUID.randomUUID().toString());
|
||||
task.setDsl(dsl);
|
||||
task.setOutputFile(outputFile);
|
||||
task.setConnInfo(new KgSchemaConnectionInfo(schemaUri, schemaToken));
|
||||
task.setGraphLoadClass(graphStateClass);
|
||||
task.setGraphStateInitString(graphStateUrl);
|
||||
task.setStartIdList(new ArrayList<>());
|
||||
task.addStartId(startIdList);
|
||||
task.setParams(params);
|
||||
return task;
|
||||
}
|
||||
}
|
||||
|
||||
@ -31,24 +31,30 @@ import scala.Tuple2;
|
||||
|
||||
@Data
|
||||
public class LocalReasonerTask implements Serializable {
|
||||
/** task id */
|
||||
private static final long serialVersionUID = 8591924774057455987L;
|
||||
|
||||
/** task id */
|
||||
private String id = "";
|
||||
|
||||
/** output file name */
|
||||
private String outputFile = null;
|
||||
|
||||
/** Choose between dsl or dslDagList */
|
||||
private String dsl = null;
|
||||
|
||||
private List<PhysicalOperator<LocalRDG>> dslDagList = null;
|
||||
private LocalReasonerSession session = null;
|
||||
|
||||
/** pass catalog to runner or provide schema connection info */
|
||||
/** pass catalog to runner or provide schema connection info or provide schema string */
|
||||
private Catalog catalog = null;
|
||||
|
||||
private KgSchemaConnectionInfo connInfo = null;
|
||||
private String schemaString = null;
|
||||
|
||||
/** Choose between graphLoadClass or graphState */
|
||||
private String graphLoadClass = null;
|
||||
|
||||
private String graphStateInitString = null;
|
||||
private GraphState<IVertexId> graphState = null;
|
||||
|
||||
/** start id from input */
|
||||
@ -73,4 +79,11 @@ public class LocalReasonerTask implements Serializable {
|
||||
|
||||
/** execution information recorder, for debug */
|
||||
private IExecutionRecorder executionRecorder = null;
|
||||
|
||||
/** add start id */
|
||||
public void addStartId(List<List<String>> startIdList) {
|
||||
for (List<String> item : startIdList) {
|
||||
this.startIdList.add(new Tuple2<>(item.get(0), item.get(1)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -23,12 +23,14 @@ import com.antgroup.openspg.reasoner.runner.ConfigKey;
|
||||
import com.antgroup.openspg.reasoner.runner.local.model.LocalReasonerResult;
|
||||
import com.antgroup.openspg.reasoner.runner.local.model.LocalReasonerTask;
|
||||
import com.antgroup.openspg.reasoner.util.Convert2ScalaUtil;
|
||||
import com.antgroup.openspg.reasoner.utils.RunnerUtil;
|
||||
import com.antgroup.openspg.reasoner.utils.SimpleObjSerde;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.google.common.collect.Sets;
|
||||
import java.util.Base64;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.UUID;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
import scala.Tuple2;
|
||||
@ -427,7 +429,8 @@ public class LocalRunnerTest {
|
||||
+ "Rule {\n"
|
||||
+ "}\n"
|
||||
+ "Action {\n"
|
||||
+ "\tget(s.id, s.is_trans_raise_more_after_down, s.cur_trans_multiple, s.last_trans_multiple, s.last_last_month_num, s.last_month_num, s.cur_month_num)\n"
|
||||
+ "\tget(s.id, s.is_trans_raise_more_after_down, s.cur_trans_multiple, s.last_trans_multiple, s"
|
||||
+ ".last_last_month_num, s.last_month_num, s.cur_month_num)\n"
|
||||
+ "}\n";
|
||||
LocalReasonerTask task = new LocalReasonerTask();
|
||||
task.setDsl(dsl);
|
||||
@ -447,9 +450,8 @@ public class LocalRunnerTest {
|
||||
|
||||
task.setStartIdList(Lists.newArrayList(new Tuple2<>("1", "CustFundKG.Account")));
|
||||
|
||||
LocalRunnerMain.main(
|
||||
new String[] {new String(Base64.getEncoder().encode(JSON.toJSONBytes(task)))});
|
||||
LocalReasonerResult result = LocalRunnerMain.result;
|
||||
KGReasonerLocalRunner runner = new KGReasonerLocalRunner();
|
||||
LocalReasonerResult result = runner.run(task);
|
||||
System.out.println("##########################");
|
||||
System.out.println(result);
|
||||
System.out.println("##########################");
|
||||
@ -514,9 +516,8 @@ public class LocalRunnerTest {
|
||||
|
||||
task.setStartIdList(Lists.newArrayList(new Tuple2<>("保险产品", "InsProduct.Product")));
|
||||
|
||||
LocalRunnerMain.main(
|
||||
new String[] {new String(Base64.getEncoder().encode(JSON.toJSONBytes(task)))});
|
||||
LocalReasonerResult result = LocalRunnerMain.result;
|
||||
KGReasonerLocalRunner runner = new KGReasonerLocalRunner();
|
||||
LocalReasonerResult result = runner.run(task);
|
||||
System.out.println("##########################");
|
||||
System.out.println(result);
|
||||
System.out.println("##########################");
|
||||
@ -578,9 +579,8 @@ public class LocalRunnerTest {
|
||||
|
||||
task.setStartIdList(Lists.newArrayList(new Tuple2<>("1", "CustFundKG.Account")));
|
||||
|
||||
LocalRunnerMain.main(
|
||||
new String[] {new String(Base64.getEncoder().encode(JSON.toJSONBytes(task)))});
|
||||
LocalReasonerResult result = LocalRunnerMain.result;
|
||||
KGReasonerLocalRunner runner = new KGReasonerLocalRunner();
|
||||
LocalReasonerResult result = runner.run(task);
|
||||
System.out.println("##########################");
|
||||
System.out.println(result);
|
||||
System.out.println("##########################");
|
||||
@ -653,9 +653,8 @@ public class LocalRunnerTest {
|
||||
|
||||
task.setStartIdList(Lists.newArrayList(new Tuple2<>("1", "CustFundKG.Account")));
|
||||
|
||||
LocalRunnerMain.main(
|
||||
new String[] {new String(Base64.getEncoder().encode(JSON.toJSONBytes(task)))});
|
||||
LocalReasonerResult result = LocalRunnerMain.result;
|
||||
KGReasonerLocalRunner runner = new KGReasonerLocalRunner();
|
||||
LocalReasonerResult result = runner.run(task);
|
||||
System.out.println("##########################");
|
||||
System.out.println(result);
|
||||
System.out.println("##########################");
|
||||
@ -741,9 +740,8 @@ public class LocalRunnerTest {
|
||||
|
||||
task.setStartIdList(Lists.newArrayList(new Tuple2<>("black_app_1", "Pkg")));
|
||||
|
||||
LocalRunnerMain.main(
|
||||
new String[] {new String(Base64.getEncoder().encode(JSON.toJSONBytes(task)))});
|
||||
LocalReasonerResult result = LocalRunnerMain.result;
|
||||
KGReasonerLocalRunner runner = new KGReasonerLocalRunner();
|
||||
LocalReasonerResult result = runner.run(task);
|
||||
System.out.println("##########################");
|
||||
System.out.println(result);
|
||||
System.out.println("##########################");
|
||||
@ -784,9 +782,8 @@ public class LocalRunnerTest {
|
||||
task.getParams().put(ConfigKey.KG_REASONER_CATALOG, SimpleObjSerde.ser(catalog));
|
||||
task.setStartIdList(Lists.newArrayList(new Tuple2<>("user1", "ABM.User")));
|
||||
|
||||
LocalRunnerMain.main(
|
||||
new String[] {new String(Base64.getEncoder().encode(JSON.toJSONBytes(task)))});
|
||||
LocalReasonerResult result = LocalRunnerMain.result;
|
||||
KGReasonerLocalRunner runner = new KGReasonerLocalRunner();
|
||||
LocalReasonerResult result = runner.run(task);
|
||||
System.out.println("##########################");
|
||||
System.out.println(result);
|
||||
System.out.println("##########################");
|
||||
@ -875,9 +872,8 @@ public class LocalRunnerTest {
|
||||
|
||||
task.setStartIdList(Lists.newArrayList(new Tuple2<>("S", "CustFundKG.Account")));
|
||||
|
||||
LocalRunnerMain.main(
|
||||
new String[] {new String(Base64.getEncoder().encode(JSON.toJSONBytes(task)))});
|
||||
LocalReasonerResult result = LocalRunnerMain.result;
|
||||
KGReasonerLocalRunner runner = new KGReasonerLocalRunner();
|
||||
LocalReasonerResult result = runner.run(task);
|
||||
System.out.println("##########################");
|
||||
System.out.println(result);
|
||||
System.out.println("##########################");
|
||||
@ -896,11 +892,8 @@ public class LocalRunnerTest {
|
||||
+ "Action {\n"
|
||||
+ " get(s.id)\n"
|
||||
+ "}";
|
||||
LocalReasonerTask task = new LocalReasonerTask();
|
||||
task.setDsl(nearbyDsl);
|
||||
task.setGraphLoadClass(
|
||||
"com.antgroup.openspg.reasoner.runner.local.loader.TestSpatioTemporalGraphLoader");
|
||||
|
||||
Map<String, Object> params = new HashMap<>();
|
||||
Map<String, scala.collection.immutable.Set<String>> schema = new HashMap<>();
|
||||
schema.put(
|
||||
"PE.JiuZhi",
|
||||
@ -911,17 +904,31 @@ public class LocalRunnerTest {
|
||||
Convert2ScalaUtil.toScalaImmutableSet(Sets.newHashSet()));
|
||||
Catalog catalog = new PropertyGraphCatalog(Convert2ScalaUtil.toScalaImmutableMap(schema));
|
||||
catalog.init();
|
||||
task.getParams().put(ConfigKey.KG_REASONER_CATALOG, SimpleObjSerde.ser(catalog));
|
||||
task.getParams().put(Constants.START_ALIAS, "s");
|
||||
task.setStartIdList(Lists.newArrayList(new Tuple2<>("MOCK1", "PE.JiuZhi")));
|
||||
|
||||
params.put(ConfigKey.KG_REASONER_CATALOG, SimpleObjSerde.ser(catalog));
|
||||
params.put(Constants.START_ALIAS, "s");
|
||||
|
||||
String outputFile = "/tmp/local/runner/" + UUID.randomUUID() + ".csv";
|
||||
|
||||
LocalRunnerMain.main(
|
||||
new String[] {new String(Base64.getEncoder().encode(JSON.toJSONBytes(task)))});
|
||||
LocalReasonerResult result = LocalRunnerMain.result;
|
||||
System.out.println("##########################");
|
||||
System.out.println(result);
|
||||
System.out.println("##########################");
|
||||
Assert.assertEquals(1, result.getRows().size());
|
||||
new String[] {
|
||||
"-q",
|
||||
nearbyDsl,
|
||||
"-o",
|
||||
outputFile,
|
||||
"-g",
|
||||
"com.antgroup.openspg.reasoner.runner.local.loader.TestSpatioTemporalGraphLoader",
|
||||
"-s",
|
||||
"s",
|
||||
"-st",
|
||||
"st",
|
||||
"-start",
|
||||
"[[\"MOCK1\",\"PE.JiuZhi\"]]",
|
||||
"-params",
|
||||
JSON.toJSONString(params)
|
||||
});
|
||||
List<String[]> rst = RunnerUtil.loadCsvFile(outputFile);
|
||||
Assert.assertEquals(2, rst.size());
|
||||
clear();
|
||||
}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user