Merge branch 'master' into knext_0112

This commit is contained in:
Qu 2024-01-19 17:57:05 +08:00
commit dc0232ef28
9 changed files with 1124 additions and 9 deletions

27
.github/workflows/cla.yml vendored Normal file
View File

@ -0,0 +1,27 @@
name: "CLA Assistant"
on:
issue_comment:
types: [ created ]
pull_request_target:
types: [ opened,closed,synchronize ]
jobs:
CLAssistant:
runs-on: ubuntu-latest
steps:
- name: "CLA Assistant"
if: (github.event.comment.body == 'recheck' || github.event.comment.body == 'I have read the CLA Document and I hereby sign the CLA') || github.event_name == 'pull_request_target'
uses: contributor-assistant/github-action@v2.3.0
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
# the below token should have repo scope and must be manually added by you in the repository's secret
PERSONAL_ACCESS_TOKEN: ${{ secrets.PERSONAL_ACCESS_TOKEN }}
with:
path-to-signatures: 'signatures/version1/cla.json'
path-to-document: 'https://github.com/OpenSPG/cla-assistant/blob/master/CLA.md' # e.g. a CLA or a DCO document
allowlist: test,bot*
remote-organization-name: OpenSPG
remote-repository-name: cla-assistant
lock-pullrequest-aftermerge: True

View File

@ -17,6 +17,7 @@ import com.antgroup.openspg.reasoner.common.types.KTBoolean$;
import com.antgroup.openspg.reasoner.common.types.KTDouble$;
import com.antgroup.openspg.reasoner.common.types.KTInteger$;
import com.antgroup.openspg.reasoner.common.types.KTLong$;
import com.antgroup.openspg.reasoner.common.types.KTObject$;
import com.antgroup.openspg.reasoner.common.types.KTString$;
import com.antgroup.openspg.reasoner.common.types.KgType;
import java.io.Serializable;
@ -28,7 +29,7 @@ public enum FieldType implements Serializable {
DOUBLE(KTDouble$.MODULE$),
BOOLEAN(KTBoolean$.MODULE$),
UNKNOWN(KTString$.MODULE$),
;
OBJECT(KTObject$.MODULE$);
private final KgType kgType;

View File

@ -16,7 +16,9 @@ package com.antgroup.openspg.reasoner.parser.utils
import com.antgroup.openspg.reasoner.common.Utils
import com.antgroup.openspg.reasoner.common.table.Field
import com.antgroup.openspg.reasoner.common.types.KTString
import com.antgroup.openspg.reasoner.lube.block.{Block, TableResultBlock}
import com.antgroup.openspg.reasoner.lube.block.{Block, MatchBlock, TableResultBlock}
import scala.collection.JavaConverters._
import scala.collection.mutable
object ParserUtils {
@ -34,4 +36,22 @@ object ParserUtils {
}
}
def getAllEntityName(javaBlockList: java.util.List[Block]): java.util.Set[String] = {
val blockList: List[Block] = javaBlockList.asScala.toList
val entityNames = new mutable.HashSet[String]()
if (blockList != null && blockList.nonEmpty) {
blockList.foreach(block => {
block.transform[Unit] {
case (MatchBlock(_, patterns), _) =>
patterns.values
.map(_.graphPattern.nodes)
.flatMap(_.values)
.foreach(node => entityNames ++= node.typeNames)
case _ =>
}
})
}
entityNames.toSet.asJava
}
}

View File

@ -15,9 +15,11 @@ package com.antgroup.openspg.reasoner.parser.utils
import com.antgroup.openspg.reasoner.common.constants.Constants
import com.antgroup.openspg.reasoner.common.table.FieldType
import com.antgroup.openspg.reasoner.lube.block.Block
import com.antgroup.openspg.reasoner.parser.OpenSPGDslParser
import org.scalatest.funspec.AnyFunSpec
import org.scalatest.matchers.should.Matchers.{convertToAnyShouldWrapper, equal}
import scala.collection.JavaConverters._
class ParserUtilsTest extends AnyFunSpec {
@ -93,4 +95,36 @@ class ParserUtilsTest extends AnyFunSpec {
FieldType.STRING,
FieldType.STRING))
}
it("test getAllEntityName") {
// scalastyle:off
val dsl =
"""
|Define (s:User)-[p:belongTo]->(o:Crowd/`Men`) {
| GraphStructure {
| (evt:TradeEvent)-[pr:relateUser]->(s:User)
| }
| Rule{
| R1: s.sex == '男'
| R2: evt.statPriod in ['', '']
| DayliyAmount = group(s).if(evt.statPriod=='日').sum(evt.amount)
| MonthAmount = group(s).if(evt.statPriod=='月').sum(evt.amount)
| R3: DayliyAmount > 300
| R4: MonthAmount < 500
| R5: (R3 and R1) and (not(R4 and R1))
| }
|}
|GraphStructure {
| (a:Crowd/`Men`)
|}
|Rule {
|}
|Action {
| get(a.id)
|}
|""".stripMargin
val parser = new OpenSPGDslParser()
val blockList: List[Block] = parser.parseMultipleStatement(dsl)
val entityNameSet: Set[String] = ParserUtils.getAllEntityName(blockList.asJava).asScala.toSet
entityNameSet should equal(Set.apply("User", "TradeEvent", "Crowd/Men"))
}
}

View File

@ -126,6 +126,10 @@
<groupId>com.opencsv</groupId>
<artifactId>opencsv</artifactId>
</dependency>
<dependency>
<groupId>com.aliyun.odps</groupId>
<artifactId>odps-sdk-core</artifactId>
</dependency>
<!-- parquet start -->
<dependency>

View File

@ -13,29 +13,45 @@
package com.antgroup.openspg.reasoner.io;
import com.aliyun.odps.Odps;
import com.aliyun.odps.account.Account;
import com.aliyun.odps.account.AliyunAccount;
import com.aliyun.odps.tunnel.TableTunnel;
import com.aliyun.odps.tunnel.TableTunnel.DownloadSession;
import com.aliyun.odps.tunnel.TableTunnel.UploadSession;
import com.aliyun.odps.tunnel.TunnelException;
import com.antgroup.openspg.reasoner.common.exception.HiveException;
import com.antgroup.openspg.reasoner.common.exception.IllegalArgumentException;
import com.antgroup.openspg.reasoner.common.exception.NotImplementedException;
import com.antgroup.openspg.reasoner.common.exception.OdpsException;
import com.antgroup.openspg.reasoner.io.hive.HiveUtils;
import com.antgroup.openspg.reasoner.io.hive.HiveWriter;
import com.antgroup.openspg.reasoner.io.hive.HiveWriterSession;
import com.antgroup.openspg.reasoner.io.model.AbstractTableInfo;
import com.antgroup.openspg.reasoner.io.model.HiveTableInfo;
import com.antgroup.openspg.reasoner.io.model.OdpsTableInfo;
import com.antgroup.openspg.reasoner.io.odps.OdpsReader;
import com.antgroup.openspg.reasoner.io.odps.OdpsUtils;
import com.antgroup.openspg.reasoner.io.odps.OdpsWriter;
import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.common.cache.RemovalListener;
import com.google.common.collect.Lists;
import java.io.IOException;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import scala.Tuple2;
@Slf4j
public class IoFactory {
private static final Logger log = LoggerFactory.getLogger(IoFactory.class);
private static final Map<AbstractTableInfo, Tuple2<String, Object>> SESSION_MAP = new HashMap<>();
@ -47,7 +63,13 @@ public class IoFactory {
}
String sessionId = null;
// create session
if (tableInfo instanceof HiveTableInfo) {
if (tableInfo instanceof OdpsTableInfo) {
OdpsTableInfo odpsTableInfo = (OdpsTableInfo) tableInfo;
UploadSession uploadSession = OdpsUtils.createUploadSession(odpsTableInfo);
sessionId = uploadSession.getId();
SESSION_MAP.put(odpsTableInfo, new Tuple2<>(sessionId, uploadSession));
odpsTableInfo.setUploadSessionId(sessionId);
} else if (tableInfo instanceof HiveTableInfo) {
HiveTableInfo hiveTableInfo = (HiveTableInfo) tableInfo;
HiveWriterSession hiveWriterSession = HiveUtils.createHiveWriterSession(hiveTableInfo);
sessionId = hiveWriterSession.getSessionId();
@ -66,7 +88,22 @@ public class IoFactory {
String sessionId, int index, int parallel, AbstractTableInfo tableInfo) {
String cacheKey = getCacheKey(sessionId, index);
ITableWriter resultWriter;
if (tableInfo instanceof HiveTableInfo) {
if (tableInfo instanceof OdpsTableInfo) {
resultWriter = TABLE_WRITER_CACHE.getIfPresent(cacheKey);
if (null != resultWriter) {
return resultWriter;
}
synchronized (TABLE_WRITER_CACHE) {
resultWriter = TABLE_WRITER_CACHE.getIfPresent(cacheKey);
if (null != resultWriter) {
return resultWriter;
}
OdpsWriter odpsWriter = new OdpsWriter();
odpsWriter.open(index, parallel, tableInfo);
TABLE_WRITER_CACHE.put(cacheKey, odpsWriter);
}
return TABLE_WRITER_CACHE.getIfPresent(cacheKey);
} else if (tableInfo instanceof HiveTableInfo) {
resultWriter = TABLE_WRITER_CACHE.getIfPresent(cacheKey);
if (null != resultWriter) {
return resultWriter;
@ -111,7 +148,14 @@ public class IoFactory {
log.info("commitWriterSession,sessionId=" + sessionId);
if (session instanceof HiveWriterSession) {
if (session instanceof UploadSession) {
UploadSession uploadSession = (UploadSession) session;
try {
uploadSession.commit();
} catch (TunnelException | IOException e) {
throw new OdpsException("commit session error", e);
}
} else if (session instanceof HiveWriterSession) {
HiveWriterSession hiveWriterSession = (HiveWriterSession) session;
hiveWriterSession.commit();
}
@ -150,7 +194,20 @@ public class IoFactory {
throw new IllegalArgumentException(
"tableInfoList", "emptyList", "please input table info list", null);
}
if (tableInfoList.get(0) instanceof HiveTableInfo) {
if (tableInfoList.get(0) instanceof OdpsTableInfo) {
Map<OdpsTableInfo, DownloadSession> downloadSessionMap = new HashMap<>();
for (AbstractTableInfo tableInfo : tableInfoList) {
OdpsTableInfo odpsTableInfo = (OdpsTableInfo) tableInfo;
try {
downloadSessionMap.put(odpsTableInfo, DOWNLOAD_SESSION_CACHE.get(odpsTableInfo));
} catch (ExecutionException e) {
throw new OdpsException("create odps download session error", e);
}
}
OdpsReader odpsReader = new OdpsReader(downloadSessionMap);
odpsReader.init(index, parallel, nowRound, allRound, tableInfoList);
return odpsReader;
} else if (tableInfoList.get(0) instanceof HiveTableInfo) {
if (allRound > 1) {
throw new HiveException("hive reader not support multiple round read", null);
}
@ -197,4 +254,34 @@ public class IoFactory {
notification.getValue().close();
})
.build();
private static final LoadingCache<OdpsTableInfo, DownloadSession> DOWNLOAD_SESSION_CACHE =
CacheBuilder.newBuilder()
.maximumSize(2000)
.expireAfterAccess(3, TimeUnit.HOURS)
.expireAfterWrite(6, TimeUnit.HOURS)
.build(
new CacheLoader<OdpsTableInfo, DownloadSession>() {
@Override
public DownloadSession load(OdpsTableInfo odpsTableInfo) throws Exception {
log.info("create_download_session,=" + odpsTableInfo);
Account account =
new AliyunAccount(odpsTableInfo.getAccessID(), odpsTableInfo.getAccessKey());
Odps odps = new Odps(account);
odps.setEndpoint(odpsTableInfo.getEndPoint());
odps.setDefaultProject(odpsTableInfo.getProject());
TableTunnel tunnel = new TableTunnel(odps);
if (StringUtils.isNotEmpty(odpsTableInfo.getTunnelEndPoint())) {
tunnel.setEndpoint(odpsTableInfo.getTunnelEndPoint());
}
DownloadSession downloadSession =
OdpsUtils.tryCreateDownloadSession(tunnel, odpsTableInfo);
if (null == downloadSession) {
throw new OdpsException("get_download_session_error", null);
}
return downloadSession;
}
});
}

View File

@ -0,0 +1,204 @@
/*
* Copyright 2023 OpenSPG Authors
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
* in compliance with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License
* is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
* or implied.
*/
package com.antgroup.openspg.reasoner.io.odps;
import com.aliyun.odps.Column;
import com.aliyun.odps.data.Record;
import com.aliyun.odps.data.RecordReader;
import com.aliyun.odps.tunnel.TableTunnel.DownloadSession;
import com.antgroup.openspg.reasoner.common.exception.OdpsException;
import com.antgroup.openspg.reasoner.common.table.Field;
import com.antgroup.openspg.reasoner.io.ITableReader;
import com.antgroup.openspg.reasoner.io.model.AbstractTableInfo;
import com.antgroup.openspg.reasoner.io.model.OdpsTableInfo;
import com.antgroup.openspg.reasoner.io.model.ReadRange;
import java.io.IOException;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.CollectionUtils;
@Slf4j
public class OdpsReader implements ITableReader {
protected final int MAX_ODPS_READER_COUNT = 500 * 10000;
protected int index;
protected long readCount;
protected transient Map<OdpsTableInfo, ReadRange> tableReadRangeMap = new HashMap<>();
protected final Map<OdpsTableInfo, DownloadSession> downloadSessionMap;
public OdpsReader(Map<OdpsTableInfo, DownloadSession> downloadSessionMap) {
this.downloadSessionMap = downloadSessionMap;
}
private DownloadSession getDownloadSession(OdpsTableInfo odpsTableInfo) {
return this.downloadSessionMap.get(odpsTableInfo);
}
/** open odps reader */
@Override
public void init(
int index, int parallel, int nowRound, int allRound, List<AbstractTableInfo> tableInfoList) {
this.index = index;
Map<OdpsTableInfo, Long> tableCountMap = new HashMap<>();
for (AbstractTableInfo tableInfo : tableInfoList) {
OdpsTableInfo odpsTableInfo = (OdpsTableInfo) tableInfo;
long count = getDownloadSession(odpsTableInfo).getRecordCount();
tableCountMap.put(odpsTableInfo, count);
}
this.tableReadRangeMap =
OdpsUtils.getReadRange(parallel, index, allRound, nowRound, tableCountMap);
// init iterator
this.nowReadTableIt = this.tableReadRangeMap.entrySet().iterator();
this.nowReadRange = null;
this.readCount = 0L;
}
/** close odps reader */
@Override
public void close() {
log.info("close odps reader, index=" + this.index + ", readCount=" + this.readCount);
}
private Iterator<Map.Entry<OdpsTableInfo, ReadRange>> nowReadTableIt;
private OdpsTableInfo nowOdpsTableInfo = null;
private RecordReader nowRecordReader = null;
private Map<String, Integer> columnName2ResultIndexMap = null;
private ReadRange nowReadRange = null;
private long nowReadCount = 0;
@Override
public boolean hasNext() {
if (nowReaderHasNext()) {
return true;
}
return nowReadTableIt.hasNext();
}
@Override
public Object[] next() {
this.readCount++;
if (nowReaderHasNext()) {
return readRecord();
}
Map.Entry<OdpsTableInfo, ReadRange> entry = this.nowReadTableIt.next();
this.nowOdpsTableInfo = entry.getKey();
this.nowReadRange = entry.getValue();
this.nowReadCount = 0;
this.nowRecordReader =
OdpsUtils.tryOpenRecordReader(
getDownloadSession(this.nowOdpsTableInfo),
this.nowReadRange.getStart(),
this.nowReadRange.getEnd());
this.initColumnName2ResultIndexMap();
return readRecord();
}
private boolean nowReaderHasNext() {
return null != this.nowReadRange && this.nowReadCount < this.nowReadRange.getCount();
}
private void initColumnName2ResultIndexMap() {
if (CollectionUtils.isEmpty(this.nowOdpsTableInfo.getColumns())) {
columnName2ResultIndexMap = null;
return;
}
columnName2ResultIndexMap = new HashMap<>();
int resultSize = this.nowOdpsTableInfo.getColumns().size();
for (int i = 0; i < resultSize; ++i) {
Field field = this.nowOdpsTableInfo.getColumns().get(i);
columnName2ResultIndexMap.put(field.getName(), i);
}
}
private Object[] readRecord() {
Record record;
try {
record = this.nowRecordReader.read();
nowReadCount++;
if (nowReadCount > MAX_ODPS_READER_COUNT) {
// reset reader when read a lot of datas
this.nowRecordReader =
OdpsUtils.tryOpenRecordReader(
getDownloadSession(this.nowOdpsTableInfo),
this.nowReadRange.getStart(),
this.nowReadRange.getEnd());
this.initColumnName2ResultIndexMap();
}
} catch (IOException e) {
throw new OdpsException("read odps record error", e);
}
Column[] columns = record.getColumns();
// convert type
Object[] result =
new Object
[null == this.columnName2ResultIndexMap
? columns.length
: this.columnName2ResultIndexMap.size()];
for (int i = 0; i < columns.length; ++i) {
Column column = columns[i];
int resultIndex = i;
if (null != this.columnName2ResultIndexMap) {
Integer integer = this.columnName2ResultIndexMap.get(column.getName());
if (null == integer) {
continue;
}
resultIndex = integer;
}
switch (column.getTypeInfo().getOdpsType()) {
case STRING:
case VARCHAR:
case CHAR:
result[resultIndex] = record.getString(i);
break;
case FLOAT:
result[resultIndex] = record.getDouble(i).floatValue();
break;
case DOUBLE:
result[resultIndex] = record.getDouble(i);
break;
case INT:
result[resultIndex] = record.getBigint(i).intValue();
break;
case SMALLINT:
result[resultIndex] = record.getBigint(i).shortValue();
break;
case TINYINT:
result[resultIndex] = record.getBigint(i).byteValue();
break;
case BIGINT:
result[resultIndex] = record.getBigint(i);
break;
case BOOLEAN:
result[resultIndex] = record.getBoolean(i);
break;
default:
result[resultIndex] = record.get(i);
break;
}
}
return result;
}
}

View File

@ -0,0 +1,587 @@
/*
* Copyright 2023 OpenSPG Authors
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
* in compliance with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License
* is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
* or implied.
*/
package com.antgroup.openspg.reasoner.io.odps;
import com.alibaba.fastjson.JSON;
import com.aliyun.odps.Column;
import com.aliyun.odps.Odps;
import com.aliyun.odps.OdpsType;
import com.aliyun.odps.PartitionSpec;
import com.aliyun.odps.ReloadException;
import com.aliyun.odps.Table;
import com.aliyun.odps.TableSchema;
import com.aliyun.odps.account.Account;
import com.aliyun.odps.account.AliyunAccount;
import com.aliyun.odps.tunnel.TableTunnel;
import com.aliyun.odps.tunnel.TableTunnel.DownloadSession;
import com.aliyun.odps.tunnel.TableTunnel.UploadSession;
import com.aliyun.odps.tunnel.TunnelException;
import com.aliyun.odps.tunnel.io.CompressOption;
import com.aliyun.odps.tunnel.io.TunnelBufferedWriter;
import com.aliyun.odps.tunnel.io.TunnelRecordReader;
import com.antgroup.openspg.reasoner.common.exception.OdpsException;
import com.antgroup.openspg.reasoner.common.types.KTBoolean$;
import com.antgroup.openspg.reasoner.common.types.KTDouble$;
import com.antgroup.openspg.reasoner.common.types.KTInteger$;
import com.antgroup.openspg.reasoner.common.types.KTLong$;
import com.antgroup.openspg.reasoner.common.types.KTString$;
import com.antgroup.openspg.reasoner.common.types.KgType;
import com.antgroup.openspg.reasoner.io.model.OdpsTableInfo;
import com.antgroup.openspg.reasoner.io.model.ReadRange;
import com.antgroup.openspg.reasoner.lube.catalog.struct.Field;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
@Slf4j
public class OdpsUtils {
private static final int MAX_TRY_TIMES = 10;
private static final int ODPS_WAIT_MS = 5000;
private static final int ODPS_FLOW_EXCEEDED_WAIT_MS = 60 * 1000;
/** query table schema from odps */
public static TableSchema getTableSchema(OdpsTableInfo odpsTableInfo) {
Odps odps = getODPSInstance(odpsTableInfo);
int tryTimes = MAX_TRY_TIMES;
while (tryTimes-- > 0) {
TableSchema schema;
try {
schema = odps.tables().get(odpsTableInfo.getTable()).getSchema();
return schema;
} catch (ReloadException e) {
if (e.getMessage().contains("Table not found")) {
return null;
}
log.error("get_table_schema_error,table_info=" + JSON.toJSONString(odpsTableInfo), e);
if (e.getMessage().contains("time out")) {
continue;
}
throw new OdpsException("get_table_schema_error", e);
}
}
throw new OdpsException("get_table_schema_error, reach max retry times", null);
}
/** create table schema form odps table info */
public static TableSchema createSchema(OdpsTableInfo odpsTableInfo) {
TableSchema schema = new TableSchema();
for (Field field : odpsTableInfo.getLubeColumns()) {
schema.addColumn(new Column(validColumnName(field.name()), toOdpsType(field.kgType())));
}
for (String name : odpsTableInfo.getPartition().keySet()) {
schema.addPartitionColumn(new Column(validColumnName(name), OdpsType.STRING));
}
return schema;
}
/** check schema is match */
public static boolean isSchemaMatch(OdpsTableInfo odpsTableInfo, TableSchema realSchema) {
TableSchema schemaNeeded = createSchema(odpsTableInfo);
List<String> columnsNeeded =
schemaNeeded.getPartitionColumns().stream()
.map(com.aliyun.odps.Column::getName)
.sorted(String::compareTo)
.collect(Collectors.toList());
List<String> columnsReal =
realSchema.getPartitionColumns().stream()
.map(com.aliyun.odps.Column::getName)
.sorted(String::compareTo)
.collect(Collectors.toList());
if (!columnsNeeded.equals(columnsReal)) {
log.error(
"odps_partition_columns_not_match, need_columns="
+ JSON.toJSONString(columnsNeeded)
+ ",real_columns="
+ JSON.toJSONString(columnsReal));
return false;
}
columnsNeeded =
schemaNeeded.getColumns().stream()
.map(com.aliyun.odps.Column::getName)
.sorted(String::compareTo)
.collect(Collectors.toList());
columnsReal =
realSchema.getColumns().stream()
.map(com.aliyun.odps.Column::getName)
.sorted(String::compareTo)
.collect(Collectors.toList());
if (!columnsNeeded.equals(columnsReal)) {
log.error(
"odps_type_not_match, need_columns="
+ JSON.toJSONString(columnsNeeded)
+ ",real_columns="
+ JSON.toJSONString(columnsReal));
return false;
}
List<OdpsType> neededType =
schemaNeeded.getColumns().stream()
.map(column -> column.getTypeInfo().getOdpsType())
.sorted(Enum::compareTo)
.collect(Collectors.toList());
List<OdpsType> realType =
realSchema.getColumns().stream()
.map(column -> column.getTypeInfo().getOdpsType())
.sorted(Enum::compareTo)
.collect(Collectors.toList());
/*
Comment it out and wait for the type inference to complete
if (!neededType.equals(realType)) {
log.error(
"odps_type_not_match, need_type="
+ JSON.toJSONString(neededType)
+ ",real_type="
+ JSON.toJSONString(realType));
return false;
}
*/
return true;
}
/** create odps table */
public static void createTable(OdpsTableInfo odpsTableInfo) {
Odps odps = getODPSInstance(odpsTableInfo);
TableSchema schema = createSchema(odpsTableInfo);
log.info("start create table=" + odpsTableInfo.getProject() + "." + odpsTableInfo.getTable());
try {
odps.tables()
.createTableWithLifeCycle(
odpsTableInfo.getProject(), odpsTableInfo.getTable(), schema, null, true, 365L);
log.info(
"create table success," + odpsTableInfo.getProject() + "." + odpsTableInfo.getTable());
} catch (com.aliyun.odps.OdpsException e) {
log.error("create table error", e);
throw new OdpsException("create_table_error", e);
}
}
/** convert KgType to odps type */
public static OdpsType toOdpsType(KgType kgType) {
if (KTString$.MODULE$.equals(kgType)) {
return OdpsType.STRING;
} else if (KTLong$.MODULE$.equals(kgType) || KTInteger$.MODULE$.equals(kgType)) {
return OdpsType.BIGINT;
} else if (KTDouble$.MODULE$.equals(kgType)) {
return OdpsType.DOUBLE;
} else if (KTBoolean$.MODULE$.equals(kgType)) {
return OdpsType.BOOLEAN;
} else {
throw new OdpsException("unsupported column type, " + kgType, null);
}
}
/** change column name to odps valid name */
public static String validColumnName(String columnName) {
return columnName.replaceAll("\\.", "_").toLowerCase();
}
/** create odps instance */
public static Odps getODPSInstance(OdpsTableInfo odpsTableInfo) {
Account account = new AliyunAccount(odpsTableInfo.getAccessID(), odpsTableInfo.getAccessKey());
Odps odps = new Odps(account);
odps.setEndpoint(odpsTableInfo.getEndPoint());
odps.setDefaultProject(odpsTableInfo.getProject());
return odps;
}
/** get PartitionSpec from OdpsTableInfo */
public static PartitionSpec getOdpsPartitionSpec(OdpsTableInfo odpsTableInfo) {
if (null == odpsTableInfo.getPartition() || odpsTableInfo.getPartition().isEmpty()) {
return null;
}
PartitionSpec partitionSpec = new PartitionSpec();
for (Map.Entry<String, String> entry : odpsTableInfo.getPartition().entrySet()) {
partitionSpec.set(entry.getKey(), entry.getValue());
}
return partitionSpec;
}
/** get upload session */
public static UploadSession tryGetUploadSession(
OdpsTableInfo odpsTableInfo, String id, int index, int parallel) {
Odps odps = getODPSInstance(odpsTableInfo);
TableTunnel tunnel = new TableTunnel(odps);
if (!StringUtils.isEmpty(odpsTableInfo.getTunnelEndPoint())) {
log.info("set odps tunnel endpoint=" + odpsTableInfo.getTunnelEndPoint());
tunnel.setEndpoint(odpsTableInfo.getTunnelEndPoint());
}
int maxTryTimes = MAX_TRY_TIMES;
while (--maxTryTimes >= 0) {
try {
return tunnel.getUploadSession(
odpsTableInfo.getProject(),
odpsTableInfo.getTable(),
getOdpsPartitionSpec(odpsTableInfo),
id,
parallel,
index);
} catch (Throwable e) {
log.error("create upload session error", e);
}
waitMs(ODPS_WAIT_MS);
}
throw new OdpsException("create upload session failed", null);
}
/** create upload session */
public static UploadSession tryCreateUploadSession(OdpsTableInfo odpsTableInfo) {
Odps odps = getODPSInstance(odpsTableInfo);
TableTunnel tunnel = new TableTunnel(odps);
if (!StringUtils.isEmpty(odpsTableInfo.getTunnelEndPoint())) {
log.info("set odps tunnel endpoint=" + odpsTableInfo.getTunnelEndPoint());
tunnel.setEndpoint(odpsTableInfo.getTunnelEndPoint());
}
int maxTryTimes = MAX_TRY_TIMES;
while (--maxTryTimes >= 0) {
try {
PartitionSpec partitionSpec = getOdpsPartitionSpec(odpsTableInfo);
if (null == partitionSpec) {
return tunnel.createUploadSession(odpsTableInfo.getProject(), odpsTableInfo.getTable());
} else {
return tunnel.createUploadSession(
odpsTableInfo.getProject(), odpsTableInfo.getTable(), partitionSpec);
}
} catch (Throwable e) {
log.error("create upload session error", e);
}
waitMs(ODPS_WAIT_MS);
}
throw new OdpsException("create upload session failed", null);
}
/** create record writer */
public static TunnelBufferedWriter tryCreateBufferRecordWriter(UploadSession uploadSession) {
int maxTryTimes = MAX_TRY_TIMES;
while (--maxTryTimes >= 0) {
try {
CompressOption option = new CompressOption();
return new TunnelBufferedWriter(uploadSession, option);
} catch (Throwable e) {
log.error("create buffer writer error", e);
}
waitMs(ODPS_WAIT_MS);
}
throw new OdpsException("create buffer writer error", null);
}
/** create partition for table */
public static void createPartition(OdpsTableInfo odpsTableInfo) {
Odps odps = getODPSInstance(odpsTableInfo);
PartitionSpec partitionSpec = getOdpsPartitionSpec(odpsTableInfo);
if (null == partitionSpec) {
return;
}
try {
Table t = odps.tables().get(odpsTableInfo.getTable());
if (!t.hasPartition(partitionSpec)) {
t.createPartition(partitionSpec);
}
} catch (com.aliyun.odps.OdpsException e) {
if (e.getMessage().contains("Partition already exists")) {
// partition already exists, do not throw error
// com.aliyun.odps.OdpsException: Catalog Service Failed, ErrorCode: 103,
// Error Message: ODPS-0110061: Failed to run ddltask - AlreadyExistsException(message:
// Partition already exists, existed values: ["$partition"])
return;
}
throw new OdpsException("create_partition_error", e);
}
}
/** create download session */
public static DownloadSession tryCreateDownloadSession(
TableTunnel tunnel, OdpsTableInfo odpsTableInfo) {
PartitionSpec partition = getOdpsPartitionSpec(odpsTableInfo);
int maxTryTimes = MAX_TRY_TIMES;
Throwable lastError = null;
while (--maxTryTimes >= 0) {
try {
DownloadSession downloadSession;
if (null != partition) {
downloadSession =
tunnel.createDownloadSession(
odpsTableInfo.getProject(), odpsTableInfo.getTable(), partition);
} else {
downloadSession =
tunnel.createDownloadSession(odpsTableInfo.getProject(), odpsTableInfo.getTable());
}
return downloadSession;
} catch (TunnelException e) {
if ("NoSuchPartition".equals(e.getErrorCode())) {
// continue
log.info(
"table="
+ odpsTableInfo.getProject()
+ "."
+ odpsTableInfo.getTable()
+ ", partition="
+ partition
+ ", not exist");
return null;
} else if ("InvalidPartitionSpec".equals(e.getErrorCode())) {
// if this table is not a partition table, we create download session without
// PartitionSpec
log.info(
"table="
+ odpsTableInfo.getProject()
+ "."
+ odpsTableInfo.getTable()
+ ", partition="
+ partition
+ ", InvalidPartitionSpec");
partition = null;
continue;
} else if ("FlowExceeded".equals(e.getErrorCode())) {
log.warn("create_download_session, flow exceeded");
// flow exceeded, continue
// --maxTryTimes;
waitMs(ODPS_WAIT_MS);
continue;
}
log.error(
"create_download_session_error, table="
+ odpsTableInfo.getProject()
+ "."
+ odpsTableInfo.getTable()
+ ", partition="
+ partition,
e);
} catch (Throwable e) {
log.error(
"create_download_session_error, table="
+ odpsTableInfo.getProject()
+ "."
+ odpsTableInfo.getTable()
+ ", partition="
+ partition,
e);
lastError = e;
}
waitMs(ODPS_WAIT_MS);
}
throw new OdpsException(
"create_download_session_failed, time_out, table="
+ odpsTableInfo.getProject()
+ "."
+ odpsTableInfo.getTable()
+ ", partition="
+ partition,
lastError);
}
/** open record reader */
public static TunnelRecordReader tryOpenRecordReader(
TableTunnel.DownloadSession downloadSession, long start, long count) {
TunnelRecordReader recordReader;
int maxTryTimes = MAX_TRY_TIMES;
while (--maxTryTimes >= 0) {
try {
recordReader = downloadSession.openRecordReader(start, count);
return recordReader;
} catch (Exception e) {
if (e instanceof com.aliyun.odps.OdpsException) {
com.aliyun.odps.OdpsException oe = (com.aliyun.odps.OdpsException) e;
if ("FlowExceeded".equals(oe.getErrorCode())) {
log.warn("open_record_reader, flow exceeded");
--maxTryTimes;
waitMs(ODPS_FLOW_EXCEEDED_WAIT_MS);
continue;
}
}
log.error("open_record_reader_error", e);
}
waitMs(ODPS_WAIT_MS);
}
return null;
}
/** must call on driver */
public static UploadSession createUploadSession(OdpsTableInfo odpsTableInfo) {
// check odps table is exist
TableSchema schema = getTableSchema(odpsTableInfo);
if (null == schema) {
if (odpsTableInfo.getProject().endsWith("_dev")) {
createTable(odpsTableInfo);
schema = getTableSchema(odpsTableInfo);
if (null == schema) {
throw new OdpsException("create table error", null);
}
} else {
// table not exist
throw new OdpsException(
"table not exist, project="
+ odpsTableInfo.getProject()
+ ",table="
+ odpsTableInfo.getTable(),
null);
}
}
// check table schema is match
if (!isSchemaMatch(odpsTableInfo, schema)) {
throw new OdpsException(
"table "
+ odpsTableInfo.getProject()
+ "."
+ odpsTableInfo.getTable()
+ ",schema not match",
null);
}
// create partition and upload session
createPartition(odpsTableInfo);
return tryCreateUploadSession(odpsTableInfo);
}
/** delete odps table */
public static void dropOdpsTable(OdpsTableInfo odpsTableInfo) throws Exception {
Odps odps = getODPSInstance(odpsTableInfo);
odps.tables().delete(odpsTableInfo.getProject(), odpsTableInfo.getTable());
log.info("dropOdpsTable," + odpsTableInfo.getTableInfoKeyString());
}
/** get read range */
public static Map<OdpsTableInfo, ReadRange> getReadRange(
int parallel, int index, int allRound, int nowRound, Map<OdpsTableInfo, Long> tableCountMap) {
List<Pair<OdpsTableInfo, Long>> tableCountList = new ArrayList<>();
long allCount = 0;
for (OdpsTableInfo tableInfo : tableCountMap.keySet()) {
tableCountList.add(new ImmutablePair<>(tableInfo, tableCountMap.get(tableInfo)));
allCount += tableCountMap.get(tableInfo);
}
tableCountList.sort(
Comparator.comparingLong((Pair<OdpsTableInfo, Long> o) -> o.getRight())
.thenComparing(Pair::getLeft));
Map<OdpsTableInfo, ReadRange> result = new HashMap<>();
ReadRange loadRange = getReadRange(parallel, index, allRound, nowRound, allCount, 1);
long offset1 = 0;
long offset2 = 0;
long loadedCount = 0;
for (Pair<OdpsTableInfo, Long> partitionInfo : tableCountList) {
offset1 = offset2;
offset2 += partitionInfo.getRight();
long start = 0;
if (loadedCount <= 0) {
if (loadRange.getStart() >= offset1 && loadRange.getStart() < offset2) {
start = loadRange.getStart() - offset1;
} else {
continue;
}
}
if (loadRange.getEnd() <= offset2) {
long end = start + (loadRange.getCount() - loadedCount);
if (end == start) {
continue;
}
result.put(partitionInfo.getLeft(), new ReadRange(start, end));
loadedCount += end - start;
break;
} else {
long end = offset2 - offset1;
if (end == start) {
continue;
}
result.put(partitionInfo.getLeft(), new ReadRange(start, end));
loadedCount += end - start;
}
}
result =
result.entrySet().stream()
.filter(tableReadRange -> tableReadRange.getValue().getCount() > 0)
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
return result;
}
private static ReadRange getReadRange(
int parallel, int index, int allRound, int nowRound, long count, long minLoadSize) {
long modSize = count % parallel;
long share = count / parallel;
long start;
if (index < modSize) {
share += 1;
start = share * index;
} else {
start = modSize + share * index;
}
long end = start + share;
if (share > 0 && share < minLoadSize) {
// allocate to last node
start = 0;
if (index == parallel - 1) {
end = count;
} else {
end = 0;
}
}
if (allRound > 1) {
long roundCount = end - start;
long roundShare = roundCount / allRound;
if (roundShare > 0 && roundShare < minLoadSize) {
// allocate to last node
if (nowRound == allRound - 1) {
end = start + roundCount;
} else {
end = start;
}
} else {
long roundModSize = roundCount % allRound;
if (nowRound < roundModSize) {
roundShare += 1;
start += nowRound * roundShare;
} else {
start += roundModSize + roundShare * nowRound;
}
end = start + roundShare;
}
}
return new ReadRange(start, end);
}
/** wait ms */
public static void waitMs(long ms) {
try {
Thread.sleep(ms);
} catch (InterruptedException e) {
log.warn("sleep_error", e);
}
}
}

View File

@ -0,0 +1,151 @@
/*
* Copyright 2023 OpenSPG Authors
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
* in compliance with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License
* is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
* or implied.
*/
package com.antgroup.openspg.reasoner.io.odps;
import com.aliyun.odps.data.Record;
import com.aliyun.odps.tunnel.TableTunnel.UploadSession;
import com.aliyun.odps.tunnel.io.TunnelBufferedWriter;
import com.antgroup.openspg.reasoner.common.exception.OdpsException;
import com.antgroup.openspg.reasoner.io.ITableWriter;
import com.antgroup.openspg.reasoner.io.model.AbstractTableInfo;
import com.antgroup.openspg.reasoner.io.model.OdpsTableInfo;
import java.io.IOException;
import java.util.Arrays;
import lombok.extern.slf4j.Slf4j;
@Slf4j
public class OdpsWriter implements ITableWriter {
private int taskIndex;
private OdpsTableInfo odpsTableInfo;
private transient UploadSession uploadSession = null;
private transient TunnelBufferedWriter recordWriter = null;
private long writeCount = 0L;
private static final int MAX_TRY_WRITE_TIMES = 5;
/** write buffer size */
private static final long WRITER_BUFFER_SIZE = 32 * 1024 * 1024;
/** reset writer when count to 10M */
private static final long WRITER_RESET_COUNT = 1000 * 10000;
/**
* init odps writer
*
* <p>The odps writer will not commit the result himself, You must ensure data commit by yourself.
*
* <p>for example:
*
* <p>// create upload session on driver UploadSession session =
* OdpsUtils.getUploadSession(tableInfo);
*
* <p>// set session id, makesure that the writer on each worker is under the same session id
* tableInfo.setUploadSessionId(session.getId());
*
* <p>// on worker, get writer and write data ...(code on worker)
*
* <p>// on driver, commit session session.commit();
*/
public void open(int taskIndex, int parallel, AbstractTableInfo tableInfo) {
this.taskIndex = taskIndex;
this.odpsTableInfo = (OdpsTableInfo) tableInfo;
log.info("openOdpsWriter,index=" + this.taskIndex + ",odpsTableInfo=" + this.odpsTableInfo);
this.uploadSession =
OdpsUtils.tryGetUploadSession(
this.odpsTableInfo, this.odpsTableInfo.getUploadSessionId(), taskIndex, parallel);
resetWriter();
}
/** write record */
@Override
public void write(Object[] data) {
long c = this.writeCount++;
if (1 == c % 10000) {
log.info(
"index="
+ this.taskIndex
+ ",write_odps_record["
+ Arrays.toString(data)
+ "], write_count="
+ c);
}
Record record = uploadSession.newRecord();
record.set(data);
// try five times at most
int maxTryTimes = MAX_TRY_WRITE_TIMES;
while (maxTryTimes-- > 0) {
try {
synchronized (this) {
recordWriter.write(record);
}
break;
} catch (IOException e) {
if (e.getLocalizedMessage().contains("MalformedDataStream")) {
log.error("write_odps_get_io_exception", e);
// io exception, reset
resetWriter();
continue;
}
throw new OdpsException("write_odps_record_error", e);
}
}
}
/** close writer */
@Override
public void close() {
closeWriter();
}
@Override
public long writeCount() {
return this.writeCount;
}
private void resetWriter() {
closeWriter();
recordWriter = OdpsUtils.tryCreateBufferRecordWriter(this.uploadSession);
recordWriter.setBufferSize(WRITER_BUFFER_SIZE);
}
private void closeWriter() {
if (null != recordWriter) {
try {
log.info(
"odps_writer_close, index="
+ this.taskIndex
+ ", info="
+ odpsTableInfo
+ ", odps_write_count="
+ writeCount);
recordWriter.close();
} catch (IOException e) {
if (e.getLocalizedMessage().contains("MalformedDataStream")) {
log.error("close_writer_MalformedDataStream", e);
return;
}
log.error("close_writer_error", e);
throw new OdpsException("close_writer_error", e);
} finally {
recordWriter = null;
writeCount = 0L;
}
}
}
}