From 261cd27e720274c69bbd48c6d8930da2efba3fc6 Mon Sep 17 00:00:00 2001 From: Donghai Date: Wed, 20 Dec 2023 19:03:27 +0800 Subject: [PATCH] feat(reasoner): reasoner add main class (#41) --- pom.xml | 5 + reasoner/runner/local-runner/pom.xml | 4 + .../runner/local/KGReasonerLocalRunner.java | 24 +- .../runner/local/LocalRunnerMain.java | 206 +++++++++++++++++- .../runner/local/model/LocalReasonerTask.java | 17 +- .../runner/local/LocalRunnerTest.java | 79 ++++--- 6 files changed, 282 insertions(+), 53 deletions(-) diff --git a/pom.xml b/pom.xml index f51b085d..98bfa33f 100644 --- a/pom.xml +++ b/pom.xml @@ -305,6 +305,11 @@ QLExpress 3.3.2 + + commons-cli + commons-cli + 1.6.0 + diff --git a/reasoner/runner/local-runner/pom.xml b/reasoner/runner/local-runner/pom.xml index 15e4f8b4..b3f937a5 100644 --- a/reasoner/runner/local-runner/pom.xml +++ b/reasoner/runner/local-runner/pom.xml @@ -114,6 +114,10 @@ com.antgroup.openspg.reasoner reasoner-common + + commons-cli + commons-cli + diff --git a/reasoner/runner/local-runner/src/main/java/com/antgroup/openspg/reasoner/runner/local/KGReasonerLocalRunner.java b/reasoner/runner/local-runner/src/main/java/com/antgroup/openspg/reasoner/runner/local/KGReasonerLocalRunner.java index 61c3d15c..8847d212 100644 --- a/reasoner/runner/local-runner/src/main/java/com/antgroup/openspg/reasoner/runner/local/KGReasonerLocalRunner.java +++ b/reasoner/runner/local-runner/src/main/java/com/antgroup/openspg/reasoner/runner/local/KGReasonerLocalRunner.java @@ -57,7 +57,7 @@ public class KGReasonerLocalRunner { return doRun(task); } catch (Throwable e) { log.error("KGReasonerLocalRunner,error", e); - return new LocalReasonerResult("KGReasonerLocalRunner,error " + e.getMessage()); + return new LocalReasonerResult("KGReasonerLocalRunner,error:" + e.getMessage()); } } @@ -179,11 +179,23 @@ public class KGReasonerLocalRunner { String graphLoadClass = task.getGraphLoadClass(); MemGraphState memGraphState = new MemGraphState(); AbstractLocalGraphLoader graphLoader; - try { - graphLoader = - (AbstractLocalGraphLoader) Class.forName(graphLoadClass).getConstructor().newInstance(); - } catch (Exception e) { - throw new RuntimeException("can not create graph loader from name " + graphLoadClass, e); + if (StringUtils.isEmpty(task.getGraphStateInitString())) { + try { + graphLoader = + (AbstractLocalGraphLoader) Class.forName(graphLoadClass).getConstructor().newInstance(); + } catch (Exception e) { + throw new RuntimeException("can not create graph loader from name " + graphLoadClass, e); + } + } else { + try { + graphLoader = + (AbstractLocalGraphLoader) + Class.forName(graphLoadClass) + .getConstructor(String.class) + .newInstance(task.getGraphStateInitString()); + } catch (Exception e) { + throw new RuntimeException("can not create graph loader from name " + graphLoadClass, e); + } } graphLoader.setGraphState(memGraphState); graphLoader.load(); diff --git a/reasoner/runner/local-runner/src/main/java/com/antgroup/openspg/reasoner/runner/local/LocalRunnerMain.java b/reasoner/runner/local-runner/src/main/java/com/antgroup/openspg/reasoner/runner/local/LocalRunnerMain.java index 1728fd42..dbf32c8b 100644 --- a/reasoner/runner/local-runner/src/main/java/com/antgroup/openspg/reasoner/runner/local/LocalRunnerMain.java +++ b/reasoner/runner/local-runner/src/main/java/com/antgroup/openspg/reasoner/runner/local/LocalRunnerMain.java @@ -14,26 +14,214 @@ package com.antgroup.openspg.reasoner.runner.local; import com.alibaba.fastjson.JSON; +import com.alibaba.fastjson.TypeReference; +import com.antgroup.openspg.reasoner.catalog.impl.KgSchemaConnectionInfo; import com.antgroup.openspg.reasoner.runner.local.model.LocalReasonerResult; import com.antgroup.openspg.reasoner.runner.local.model.LocalReasonerTask; +import com.opencsv.CSVWriter; +import java.io.FileWriter; +import java.io.IOException; import java.nio.charset.StandardCharsets; -import java.util.Base64; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.nio.file.StandardOpenOption; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.UUID; import lombok.extern.slf4j.Slf4j; +import org.apache.commons.cli.CommandLine; +import org.apache.commons.cli.CommandLineParser; +import org.apache.commons.cli.DefaultParser; +import org.apache.commons.cli.HelpFormatter; +import org.apache.commons.cli.Option; +import org.apache.commons.cli.Options; +import org.apache.commons.cli.ParseException; +import org.apache.commons.lang3.StringUtils; @Slf4j public class LocalRunnerMain { - /** result */ - public static LocalReasonerResult result = null; - /** KGReasoner main */ public static void main(String[] args) { - String taskInfoJson = new String(Base64.getDecoder().decode(args[0]), StandardCharsets.UTF_8); - LocalReasonerTask task = JSON.parseObject(taskInfoJson, LocalReasonerTask.class); + LocalReasonerTask task = parseArgs(args); + if (null == task) { + System.exit(1); + } KGReasonerLocalRunner runner = new KGReasonerLocalRunner(); - result = runner.run(task); - if (null != result) { - log.info(result.toString()); + LocalReasonerResult result = runner.run(task); + if (null == result) { + log.error("local runner return null"); + return; + } + if (StringUtils.isNotEmpty(result.getErrMsg())) { + log.error(result.getErrMsg()); + } + if (StringUtils.isNotEmpty(task.getOutputFile())) { + writeOutputFile(result, task.getOutputFile()); } } + + private static void writeOutputFile(LocalReasonerResult result, String file) { + Path path = Paths.get(file); + try { + if (Files.notExists(path.getParent())) { + Files.createDirectories(path.getParent()); + } + if (Files.exists(path)) { + Files.delete(path); + } + } catch (IOException e) { + log.error("write result file error, file=" + file, e); + return; + } + + if (StringUtils.isNotEmpty(result.getErrMsg())) { + writeFile(path, result.getErrMsg()); + } else if (result.isGraphResult()) { + // write graph result + writeFile(path, result.toString()); + } else { + // write csv + writeCsv(path, result.getColumns(), result.getRows()); + } + } + + private static void writeCsv(Path path, List columns, List rows) { + List allLines = new ArrayList<>(rows.size() + 1); + allLines.add(columns.toArray(new String[] {})); + for (Object[] rowObj : rows) { + String[] row = new String[rowObj.length]; + for (int i = 0; i < rowObj.length; ++i) { + if (null != rowObj[i]) { + row[i] = String.valueOf(rowObj[i]); + } else { + row[i] = null; + } + } + allLines.add(row); + } + + CSVWriter csvWriter; + try { + csvWriter = new CSVWriter(new FileWriter(path.toString())); + csvWriter.writeAll(allLines); + csvWriter.close(); + } catch (IOException e) { + log.error("csvwriter error, file=" + path, e); + } + } + + private static void writeFile(Path path, String content) { + try { + Files.write(path, content.getBytes(StandardCharsets.UTF_8), StandardOpenOption.CREATE); + } catch (IOException e) { + log.error("write result file error, file=" + path, e); + } + } + + private static LocalReasonerTask parseArgs(String[] args) { + Options options = new Options(); + + Option optDsl = new Option("q", "query", true, "query dsl string"); + optDsl.setRequired(true); + options.addOption(optDsl); + + Option optOutputFile = new Option("o", "output", true, "output file"); + optOutputFile.setRequired(false); + options.addOption(optOutputFile); + + Option optSchemaUri = new Option("s", "schema_uri", true, "provide schema uri"); + optSchemaUri.setRequired(true); + options.addOption(optSchemaUri); + + Option optSchemaToken = new Option("st", "schema_token", true, "provide schema token"); + optSchemaToken.setRequired(true); + options.addOption(optSchemaToken); + + Option optGraphStateClass = + new Option("g", "graph_state_class", true, "graph state class name"); + optGraphStateClass.setRequired(true); + options.addOption(optGraphStateClass); + + Option optGraphStateUrl = new Option("gs", "graph_state_url", true, "graph state url"); + optGraphStateUrl.setRequired(false); + options.addOption(optGraphStateUrl); + + Option optStartIdList = new Option("start", "start_id_list", true, "start id json list"); + optStartIdList.setRequired(true); + options.addOption(optStartIdList); + + Option optParamsJson = + new Option("params", "param_map_json_str", true, "parameter map json string"); + optParamsJson.setRequired(false); + options.addOption(optParamsJson); + + CommandLineParser parser = new DefaultParser(); + HelpFormatter formatter = new HelpFormatter(); + CommandLine cmd; + + String dsl; + String outputFile; + String schemaUri; + String schemaToken; + String graphStateClass; + String graphStateUrl; + List> startIdList; + Map params = null; + try { + cmd = parser.parse(options, args); + dsl = cmd.getOptionValue("q"); + if (StringUtils.isEmpty(dsl)) { + throw new ParseException("please provide query dsl!"); + } + outputFile = cmd.getOptionValue("o"); + if (StringUtils.isEmpty(outputFile)) { + outputFile = null; + } + schemaUri = cmd.getOptionValue("s"); + if (StringUtils.isEmpty(schemaUri)) { + throw new ParseException("please provide openspg schema uri!"); + } + schemaToken = cmd.getOptionValue("st"); + if (StringUtils.isEmpty(schemaToken)) { + throw new ParseException("please provide openspg schema api token!"); + } + graphStateClass = cmd.getOptionValue("g"); + if (StringUtils.isEmpty(graphStateClass)) { + throw new ParseException("please provide graph state class name!"); + } + graphStateUrl = cmd.getOptionValue("gs"); + if (StringUtils.isEmpty(graphStateUrl)) { + graphStateUrl = null; + } + String startIdListJson = cmd.getOptionValue("start"); + if (StringUtils.isEmpty(startIdListJson)) { + throw new ParseException("please provide start id"); + } + startIdList = JSON.parseObject(startIdListJson, new TypeReference>>() {}); + String paramsJson = cmd.getOptionValue("params"); + if (StringUtils.isNotEmpty(paramsJson)) { + params = new HashMap<>(JSON.parseObject(paramsJson)); + } + } catch (ParseException e) { + log.error(e.getMessage()); + formatter.printHelp("ReasonerLocalRunner", options); + return null; + } + + LocalReasonerTask task = new LocalReasonerTask(); + task.setId(UUID.randomUUID().toString()); + task.setDsl(dsl); + task.setOutputFile(outputFile); + task.setConnInfo(new KgSchemaConnectionInfo(schemaUri, schemaToken)); + task.setGraphLoadClass(graphStateClass); + task.setGraphStateInitString(graphStateUrl); + task.setStartIdList(new ArrayList<>()); + task.addStartId(startIdList); + task.setParams(params); + return task; + } } diff --git a/reasoner/runner/local-runner/src/main/java/com/antgroup/openspg/reasoner/runner/local/model/LocalReasonerTask.java b/reasoner/runner/local-runner/src/main/java/com/antgroup/openspg/reasoner/runner/local/model/LocalReasonerTask.java index 5fc67901..bb9c9d0f 100644 --- a/reasoner/runner/local-runner/src/main/java/com/antgroup/openspg/reasoner/runner/local/model/LocalReasonerTask.java +++ b/reasoner/runner/local-runner/src/main/java/com/antgroup/openspg/reasoner/runner/local/model/LocalReasonerTask.java @@ -31,24 +31,30 @@ import scala.Tuple2; @Data public class LocalReasonerTask implements Serializable { - /** task id */ private static final long serialVersionUID = 8591924774057455987L; + /** task id */ private String id = ""; + + /** output file name */ + private String outputFile = null; + /** Choose between dsl or dslDagList */ private String dsl = null; private List> dslDagList = null; private LocalReasonerSession session = null; - /** pass catalog to runner or provide schema connection info */ + /** pass catalog to runner or provide schema connection info or provide schema string */ private Catalog catalog = null; private KgSchemaConnectionInfo connInfo = null; + private String schemaString = null; /** Choose between graphLoadClass or graphState */ private String graphLoadClass = null; + private String graphStateInitString = null; private GraphState graphState = null; /** start id from input */ @@ -73,4 +79,11 @@ public class LocalReasonerTask implements Serializable { /** execution information recorder, for debug */ private IExecutionRecorder executionRecorder = null; + + /** add start id */ + public void addStartId(List> startIdList) { + for (List item : startIdList) { + this.startIdList.add(new Tuple2<>(item.get(0), item.get(1))); + } + } } diff --git a/reasoner/runner/local-runner/src/test/java/com/antgroup/openspg/reasoner/runner/local/LocalRunnerTest.java b/reasoner/runner/local-runner/src/test/java/com/antgroup/openspg/reasoner/runner/local/LocalRunnerTest.java index 62b2872d..68b67d41 100644 --- a/reasoner/runner/local-runner/src/test/java/com/antgroup/openspg/reasoner/runner/local/LocalRunnerTest.java +++ b/reasoner/runner/local-runner/src/test/java/com/antgroup/openspg/reasoner/runner/local/LocalRunnerTest.java @@ -23,12 +23,14 @@ import com.antgroup.openspg.reasoner.runner.ConfigKey; import com.antgroup.openspg.reasoner.runner.local.model.LocalReasonerResult; import com.antgroup.openspg.reasoner.runner.local.model.LocalReasonerTask; import com.antgroup.openspg.reasoner.util.Convert2ScalaUtil; +import com.antgroup.openspg.reasoner.utils.RunnerUtil; import com.antgroup.openspg.reasoner.utils.SimpleObjSerde; import com.google.common.collect.Lists; import com.google.common.collect.Sets; -import java.util.Base64; import java.util.HashMap; +import java.util.List; import java.util.Map; +import java.util.UUID; import org.junit.Assert; import org.junit.Test; import scala.Tuple2; @@ -427,7 +429,8 @@ public class LocalRunnerTest { + "Rule {\n" + "}\n" + "Action {\n" - + "\tget(s.id, s.is_trans_raise_more_after_down, s.cur_trans_multiple, s.last_trans_multiple, s.last_last_month_num, s.last_month_num, s.cur_month_num)\n" + + "\tget(s.id, s.is_trans_raise_more_after_down, s.cur_trans_multiple, s.last_trans_multiple, s" + + ".last_last_month_num, s.last_month_num, s.cur_month_num)\n" + "}\n"; LocalReasonerTask task = new LocalReasonerTask(); task.setDsl(dsl); @@ -447,9 +450,8 @@ public class LocalRunnerTest { task.setStartIdList(Lists.newArrayList(new Tuple2<>("1", "CustFundKG.Account"))); - LocalRunnerMain.main( - new String[] {new String(Base64.getEncoder().encode(JSON.toJSONBytes(task)))}); - LocalReasonerResult result = LocalRunnerMain.result; + KGReasonerLocalRunner runner = new KGReasonerLocalRunner(); + LocalReasonerResult result = runner.run(task); System.out.println("##########################"); System.out.println(result); System.out.println("##########################"); @@ -514,9 +516,8 @@ public class LocalRunnerTest { task.setStartIdList(Lists.newArrayList(new Tuple2<>("保险产品", "InsProduct.Product"))); - LocalRunnerMain.main( - new String[] {new String(Base64.getEncoder().encode(JSON.toJSONBytes(task)))}); - LocalReasonerResult result = LocalRunnerMain.result; + KGReasonerLocalRunner runner = new KGReasonerLocalRunner(); + LocalReasonerResult result = runner.run(task); System.out.println("##########################"); System.out.println(result); System.out.println("##########################"); @@ -578,9 +579,8 @@ public class LocalRunnerTest { task.setStartIdList(Lists.newArrayList(new Tuple2<>("1", "CustFundKG.Account"))); - LocalRunnerMain.main( - new String[] {new String(Base64.getEncoder().encode(JSON.toJSONBytes(task)))}); - LocalReasonerResult result = LocalRunnerMain.result; + KGReasonerLocalRunner runner = new KGReasonerLocalRunner(); + LocalReasonerResult result = runner.run(task); System.out.println("##########################"); System.out.println(result); System.out.println("##########################"); @@ -653,9 +653,8 @@ public class LocalRunnerTest { task.setStartIdList(Lists.newArrayList(new Tuple2<>("1", "CustFundKG.Account"))); - LocalRunnerMain.main( - new String[] {new String(Base64.getEncoder().encode(JSON.toJSONBytes(task)))}); - LocalReasonerResult result = LocalRunnerMain.result; + KGReasonerLocalRunner runner = new KGReasonerLocalRunner(); + LocalReasonerResult result = runner.run(task); System.out.println("##########################"); System.out.println(result); System.out.println("##########################"); @@ -741,9 +740,8 @@ public class LocalRunnerTest { task.setStartIdList(Lists.newArrayList(new Tuple2<>("black_app_1", "Pkg"))); - LocalRunnerMain.main( - new String[] {new String(Base64.getEncoder().encode(JSON.toJSONBytes(task)))}); - LocalReasonerResult result = LocalRunnerMain.result; + KGReasonerLocalRunner runner = new KGReasonerLocalRunner(); + LocalReasonerResult result = runner.run(task); System.out.println("##########################"); System.out.println(result); System.out.println("##########################"); @@ -784,9 +782,8 @@ public class LocalRunnerTest { task.getParams().put(ConfigKey.KG_REASONER_CATALOG, SimpleObjSerde.ser(catalog)); task.setStartIdList(Lists.newArrayList(new Tuple2<>("user1", "ABM.User"))); - LocalRunnerMain.main( - new String[] {new String(Base64.getEncoder().encode(JSON.toJSONBytes(task)))}); - LocalReasonerResult result = LocalRunnerMain.result; + KGReasonerLocalRunner runner = new KGReasonerLocalRunner(); + LocalReasonerResult result = runner.run(task); System.out.println("##########################"); System.out.println(result); System.out.println("##########################"); @@ -875,9 +872,8 @@ public class LocalRunnerTest { task.setStartIdList(Lists.newArrayList(new Tuple2<>("S", "CustFundKG.Account"))); - LocalRunnerMain.main( - new String[] {new String(Base64.getEncoder().encode(JSON.toJSONBytes(task)))}); - LocalReasonerResult result = LocalRunnerMain.result; + KGReasonerLocalRunner runner = new KGReasonerLocalRunner(); + LocalReasonerResult result = runner.run(task); System.out.println("##########################"); System.out.println(result); System.out.println("##########################"); @@ -896,11 +892,8 @@ public class LocalRunnerTest { + "Action {\n" + " get(s.id)\n" + "}"; - LocalReasonerTask task = new LocalReasonerTask(); - task.setDsl(nearbyDsl); - task.setGraphLoadClass( - "com.antgroup.openspg.reasoner.runner.local.loader.TestSpatioTemporalGraphLoader"); + Map params = new HashMap<>(); Map> schema = new HashMap<>(); schema.put( "PE.JiuZhi", @@ -911,17 +904,31 @@ public class LocalRunnerTest { Convert2ScalaUtil.toScalaImmutableSet(Sets.newHashSet())); Catalog catalog = new PropertyGraphCatalog(Convert2ScalaUtil.toScalaImmutableMap(schema)); catalog.init(); - task.getParams().put(ConfigKey.KG_REASONER_CATALOG, SimpleObjSerde.ser(catalog)); - task.getParams().put(Constants.START_ALIAS, "s"); - task.setStartIdList(Lists.newArrayList(new Tuple2<>("MOCK1", "PE.JiuZhi"))); + + params.put(ConfigKey.KG_REASONER_CATALOG, SimpleObjSerde.ser(catalog)); + params.put(Constants.START_ALIAS, "s"); + + String outputFile = "/tmp/local/runner/" + UUID.randomUUID() + ".csv"; LocalRunnerMain.main( - new String[] {new String(Base64.getEncoder().encode(JSON.toJSONBytes(task)))}); - LocalReasonerResult result = LocalRunnerMain.result; - System.out.println("##########################"); - System.out.println(result); - System.out.println("##########################"); - Assert.assertEquals(1, result.getRows().size()); + new String[] { + "-q", + nearbyDsl, + "-o", + outputFile, + "-g", + "com.antgroup.openspg.reasoner.runner.local.loader.TestSpatioTemporalGraphLoader", + "-s", + "s", + "-st", + "st", + "-start", + "[[\"MOCK1\",\"PE.JiuZhi\"]]", + "-params", + JSON.toJSONString(params) + }); + List rst = RunnerUtil.loadCsvFile(outputFile); + Assert.assertEquals(2, rst.size()); clear(); }