mirror of
https://github.com/OpenSPG/openspg.git
synced 2025-06-27 03:20:10 +00:00
Merge branch 'master' into knext_0112
This commit is contained in:
commit
dc0232ef28
27
.github/workflows/cla.yml
vendored
Normal file
27
.github/workflows/cla.yml
vendored
Normal file
@ -0,0 +1,27 @@
|
||||
name: "CLA Assistant"
|
||||
on:
|
||||
issue_comment:
|
||||
types: [ created ]
|
||||
pull_request_target:
|
||||
types: [ opened,closed,synchronize ]
|
||||
|
||||
jobs:
|
||||
CLAssistant:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: "CLA Assistant"
|
||||
if: (github.event.comment.body == 'recheck' || github.event.comment.body == 'I have read the CLA Document and I hereby sign the CLA') || github.event_name == 'pull_request_target'
|
||||
uses: contributor-assistant/github-action@v2.3.0
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
# the below token should have repo scope and must be manually added by you in the repository's secret
|
||||
PERSONAL_ACCESS_TOKEN: ${{ secrets.PERSONAL_ACCESS_TOKEN }}
|
||||
with:
|
||||
path-to-signatures: 'signatures/version1/cla.json'
|
||||
path-to-document: 'https://github.com/OpenSPG/cla-assistant/blob/master/CLA.md' # e.g. a CLA or a DCO document
|
||||
allowlist: test,bot*
|
||||
remote-organization-name: OpenSPG
|
||||
remote-repository-name: cla-assistant
|
||||
lock-pullrequest-aftermerge: True
|
||||
|
||||
|
@ -17,6 +17,7 @@ import com.antgroup.openspg.reasoner.common.types.KTBoolean$;
|
||||
import com.antgroup.openspg.reasoner.common.types.KTDouble$;
|
||||
import com.antgroup.openspg.reasoner.common.types.KTInteger$;
|
||||
import com.antgroup.openspg.reasoner.common.types.KTLong$;
|
||||
import com.antgroup.openspg.reasoner.common.types.KTObject$;
|
||||
import com.antgroup.openspg.reasoner.common.types.KTString$;
|
||||
import com.antgroup.openspg.reasoner.common.types.KgType;
|
||||
import java.io.Serializable;
|
||||
@ -28,7 +29,7 @@ public enum FieldType implements Serializable {
|
||||
DOUBLE(KTDouble$.MODULE$),
|
||||
BOOLEAN(KTBoolean$.MODULE$),
|
||||
UNKNOWN(KTString$.MODULE$),
|
||||
;
|
||||
OBJECT(KTObject$.MODULE$);
|
||||
|
||||
private final KgType kgType;
|
||||
|
||||
|
@ -16,7 +16,9 @@ package com.antgroup.openspg.reasoner.parser.utils
|
||||
import com.antgroup.openspg.reasoner.common.Utils
|
||||
import com.antgroup.openspg.reasoner.common.table.Field
|
||||
import com.antgroup.openspg.reasoner.common.types.KTString
|
||||
import com.antgroup.openspg.reasoner.lube.block.{Block, TableResultBlock}
|
||||
import com.antgroup.openspg.reasoner.lube.block.{Block, MatchBlock, TableResultBlock}
|
||||
import scala.collection.JavaConverters._
|
||||
import scala.collection.mutable
|
||||
|
||||
object ParserUtils {
|
||||
|
||||
@ -34,4 +36,22 @@ object ParserUtils {
|
||||
}
|
||||
}
|
||||
|
||||
def getAllEntityName(javaBlockList: java.util.List[Block]): java.util.Set[String] = {
|
||||
val blockList: List[Block] = javaBlockList.asScala.toList
|
||||
val entityNames = new mutable.HashSet[String]()
|
||||
if (blockList != null && blockList.nonEmpty) {
|
||||
blockList.foreach(block => {
|
||||
block.transform[Unit] {
|
||||
case (MatchBlock(_, patterns), _) =>
|
||||
patterns.values
|
||||
.map(_.graphPattern.nodes)
|
||||
.flatMap(_.values)
|
||||
.foreach(node => entityNames ++= node.typeNames)
|
||||
case _ =>
|
||||
}
|
||||
})
|
||||
}
|
||||
entityNames.toSet.asJava
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -15,9 +15,11 @@ package com.antgroup.openspg.reasoner.parser.utils
|
||||
|
||||
import com.antgroup.openspg.reasoner.common.constants.Constants
|
||||
import com.antgroup.openspg.reasoner.common.table.FieldType
|
||||
import com.antgroup.openspg.reasoner.lube.block.Block
|
||||
import com.antgroup.openspg.reasoner.parser.OpenSPGDslParser
|
||||
import org.scalatest.funspec.AnyFunSpec
|
||||
import org.scalatest.matchers.should.Matchers.{convertToAnyShouldWrapper, equal}
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
class ParserUtilsTest extends AnyFunSpec {
|
||||
|
||||
@ -93,4 +95,36 @@ class ParserUtilsTest extends AnyFunSpec {
|
||||
FieldType.STRING,
|
||||
FieldType.STRING))
|
||||
}
|
||||
it("test getAllEntityName") {
|
||||
// scalastyle:off
|
||||
val dsl =
|
||||
"""
|
||||
|Define (s:User)-[p:belongTo]->(o:Crowd/`Men`) {
|
||||
| GraphStructure {
|
||||
| (evt:TradeEvent)-[pr:relateUser]->(s:User)
|
||||
| }
|
||||
| Rule{
|
||||
| R1: s.sex == '男'
|
||||
| R2: evt.statPriod in ['日', '月']
|
||||
| DayliyAmount = group(s).if(evt.statPriod=='日').sum(evt.amount)
|
||||
| MonthAmount = group(s).if(evt.statPriod=='月').sum(evt.amount)
|
||||
| R3: DayliyAmount > 300
|
||||
| R4: MonthAmount < 500
|
||||
| R5: (R3 and R1) and (not(R4 and R1))
|
||||
| }
|
||||
|}
|
||||
|GraphStructure {
|
||||
| (a:Crowd/`Men`)
|
||||
|}
|
||||
|Rule {
|
||||
|}
|
||||
|Action {
|
||||
| get(a.id)
|
||||
|}
|
||||
|""".stripMargin
|
||||
val parser = new OpenSPGDslParser()
|
||||
val blockList: List[Block] = parser.parseMultipleStatement(dsl)
|
||||
val entityNameSet: Set[String] = ParserUtils.getAllEntityName(blockList.asJava).asScala.toSet
|
||||
entityNameSet should equal(Set.apply("User", "TradeEvent", "Crowd/Men"))
|
||||
}
|
||||
}
|
||||
|
@ -126,6 +126,10 @@
|
||||
<groupId>com.opencsv</groupId>
|
||||
<artifactId>opencsv</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.aliyun.odps</groupId>
|
||||
<artifactId>odps-sdk-core</artifactId>
|
||||
</dependency>
|
||||
|
||||
<!-- parquet start -->
|
||||
<dependency>
|
||||
|
@ -13,29 +13,45 @@
|
||||
|
||||
package com.antgroup.openspg.reasoner.io;
|
||||
|
||||
import com.aliyun.odps.Odps;
|
||||
import com.aliyun.odps.account.Account;
|
||||
import com.aliyun.odps.account.AliyunAccount;
|
||||
import com.aliyun.odps.tunnel.TableTunnel;
|
||||
import com.aliyun.odps.tunnel.TableTunnel.DownloadSession;
|
||||
import com.aliyun.odps.tunnel.TableTunnel.UploadSession;
|
||||
import com.aliyun.odps.tunnel.TunnelException;
|
||||
import com.antgroup.openspg.reasoner.common.exception.HiveException;
|
||||
import com.antgroup.openspg.reasoner.common.exception.IllegalArgumentException;
|
||||
import com.antgroup.openspg.reasoner.common.exception.NotImplementedException;
|
||||
import com.antgroup.openspg.reasoner.common.exception.OdpsException;
|
||||
import com.antgroup.openspg.reasoner.io.hive.HiveUtils;
|
||||
import com.antgroup.openspg.reasoner.io.hive.HiveWriter;
|
||||
import com.antgroup.openspg.reasoner.io.hive.HiveWriterSession;
|
||||
import com.antgroup.openspg.reasoner.io.model.AbstractTableInfo;
|
||||
import com.antgroup.openspg.reasoner.io.model.HiveTableInfo;
|
||||
import com.antgroup.openspg.reasoner.io.model.OdpsTableInfo;
|
||||
import com.antgroup.openspg.reasoner.io.odps.OdpsReader;
|
||||
import com.antgroup.openspg.reasoner.io.odps.OdpsUtils;
|
||||
import com.antgroup.openspg.reasoner.io.odps.OdpsWriter;
|
||||
import com.google.common.cache.Cache;
|
||||
import com.google.common.cache.CacheBuilder;
|
||||
import com.google.common.cache.CacheLoader;
|
||||
import com.google.common.cache.LoadingCache;
|
||||
import com.google.common.cache.RemovalListener;
|
||||
import com.google.common.collect.Lists;
|
||||
import java.io.IOException;
|
||||
import java.util.HashMap;
|
||||
import java.util.Iterator;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.ExecutionException;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import scala.Tuple2;
|
||||
|
||||
@Slf4j
|
||||
public class IoFactory {
|
||||
private static final Logger log = LoggerFactory.getLogger(IoFactory.class);
|
||||
|
||||
private static final Map<AbstractTableInfo, Tuple2<String, Object>> SESSION_MAP = new HashMap<>();
|
||||
|
||||
@ -47,7 +63,13 @@ public class IoFactory {
|
||||
}
|
||||
String sessionId = null;
|
||||
// create session
|
||||
if (tableInfo instanceof HiveTableInfo) {
|
||||
if (tableInfo instanceof OdpsTableInfo) {
|
||||
OdpsTableInfo odpsTableInfo = (OdpsTableInfo) tableInfo;
|
||||
UploadSession uploadSession = OdpsUtils.createUploadSession(odpsTableInfo);
|
||||
sessionId = uploadSession.getId();
|
||||
SESSION_MAP.put(odpsTableInfo, new Tuple2<>(sessionId, uploadSession));
|
||||
odpsTableInfo.setUploadSessionId(sessionId);
|
||||
} else if (tableInfo instanceof HiveTableInfo) {
|
||||
HiveTableInfo hiveTableInfo = (HiveTableInfo) tableInfo;
|
||||
HiveWriterSession hiveWriterSession = HiveUtils.createHiveWriterSession(hiveTableInfo);
|
||||
sessionId = hiveWriterSession.getSessionId();
|
||||
@ -66,7 +88,22 @@ public class IoFactory {
|
||||
String sessionId, int index, int parallel, AbstractTableInfo tableInfo) {
|
||||
String cacheKey = getCacheKey(sessionId, index);
|
||||
ITableWriter resultWriter;
|
||||
if (tableInfo instanceof HiveTableInfo) {
|
||||
if (tableInfo instanceof OdpsTableInfo) {
|
||||
resultWriter = TABLE_WRITER_CACHE.getIfPresent(cacheKey);
|
||||
if (null != resultWriter) {
|
||||
return resultWriter;
|
||||
}
|
||||
synchronized (TABLE_WRITER_CACHE) {
|
||||
resultWriter = TABLE_WRITER_CACHE.getIfPresent(cacheKey);
|
||||
if (null != resultWriter) {
|
||||
return resultWriter;
|
||||
}
|
||||
OdpsWriter odpsWriter = new OdpsWriter();
|
||||
odpsWriter.open(index, parallel, tableInfo);
|
||||
TABLE_WRITER_CACHE.put(cacheKey, odpsWriter);
|
||||
}
|
||||
return TABLE_WRITER_CACHE.getIfPresent(cacheKey);
|
||||
} else if (tableInfo instanceof HiveTableInfo) {
|
||||
resultWriter = TABLE_WRITER_CACHE.getIfPresent(cacheKey);
|
||||
if (null != resultWriter) {
|
||||
return resultWriter;
|
||||
@ -111,7 +148,14 @@ public class IoFactory {
|
||||
|
||||
log.info("commitWriterSession,sessionId=" + sessionId);
|
||||
|
||||
if (session instanceof HiveWriterSession) {
|
||||
if (session instanceof UploadSession) {
|
||||
UploadSession uploadSession = (UploadSession) session;
|
||||
try {
|
||||
uploadSession.commit();
|
||||
} catch (TunnelException | IOException e) {
|
||||
throw new OdpsException("commit session error", e);
|
||||
}
|
||||
} else if (session instanceof HiveWriterSession) {
|
||||
HiveWriterSession hiveWriterSession = (HiveWriterSession) session;
|
||||
hiveWriterSession.commit();
|
||||
}
|
||||
@ -150,7 +194,20 @@ public class IoFactory {
|
||||
throw new IllegalArgumentException(
|
||||
"tableInfoList", "emptyList", "please input table info list", null);
|
||||
}
|
||||
if (tableInfoList.get(0) instanceof HiveTableInfo) {
|
||||
if (tableInfoList.get(0) instanceof OdpsTableInfo) {
|
||||
Map<OdpsTableInfo, DownloadSession> downloadSessionMap = new HashMap<>();
|
||||
for (AbstractTableInfo tableInfo : tableInfoList) {
|
||||
OdpsTableInfo odpsTableInfo = (OdpsTableInfo) tableInfo;
|
||||
try {
|
||||
downloadSessionMap.put(odpsTableInfo, DOWNLOAD_SESSION_CACHE.get(odpsTableInfo));
|
||||
} catch (ExecutionException e) {
|
||||
throw new OdpsException("create odps download session error", e);
|
||||
}
|
||||
}
|
||||
OdpsReader odpsReader = new OdpsReader(downloadSessionMap);
|
||||
odpsReader.init(index, parallel, nowRound, allRound, tableInfoList);
|
||||
return odpsReader;
|
||||
} else if (tableInfoList.get(0) instanceof HiveTableInfo) {
|
||||
if (allRound > 1) {
|
||||
throw new HiveException("hive reader not support multiple round read", null);
|
||||
}
|
||||
@ -197,4 +254,34 @@ public class IoFactory {
|
||||
notification.getValue().close();
|
||||
})
|
||||
.build();
|
||||
|
||||
private static final LoadingCache<OdpsTableInfo, DownloadSession> DOWNLOAD_SESSION_CACHE =
|
||||
CacheBuilder.newBuilder()
|
||||
.maximumSize(2000)
|
||||
.expireAfterAccess(3, TimeUnit.HOURS)
|
||||
.expireAfterWrite(6, TimeUnit.HOURS)
|
||||
.build(
|
||||
new CacheLoader<OdpsTableInfo, DownloadSession>() {
|
||||
@Override
|
||||
public DownloadSession load(OdpsTableInfo odpsTableInfo) throws Exception {
|
||||
log.info("create_download_session,=" + odpsTableInfo);
|
||||
Account account =
|
||||
new AliyunAccount(odpsTableInfo.getAccessID(), odpsTableInfo.getAccessKey());
|
||||
Odps odps = new Odps(account);
|
||||
odps.setEndpoint(odpsTableInfo.getEndPoint());
|
||||
odps.setDefaultProject(odpsTableInfo.getProject());
|
||||
|
||||
TableTunnel tunnel = new TableTunnel(odps);
|
||||
if (StringUtils.isNotEmpty(odpsTableInfo.getTunnelEndPoint())) {
|
||||
tunnel.setEndpoint(odpsTableInfo.getTunnelEndPoint());
|
||||
}
|
||||
|
||||
DownloadSession downloadSession =
|
||||
OdpsUtils.tryCreateDownloadSession(tunnel, odpsTableInfo);
|
||||
if (null == downloadSession) {
|
||||
throw new OdpsException("get_download_session_error", null);
|
||||
}
|
||||
return downloadSession;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
@ -0,0 +1,204 @@
|
||||
/*
|
||||
* Copyright 2023 OpenSPG Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||
* in compliance with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
* is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
* or implied.
|
||||
*/
|
||||
|
||||
package com.antgroup.openspg.reasoner.io.odps;
|
||||
|
||||
import com.aliyun.odps.Column;
|
||||
import com.aliyun.odps.data.Record;
|
||||
import com.aliyun.odps.data.RecordReader;
|
||||
import com.aliyun.odps.tunnel.TableTunnel.DownloadSession;
|
||||
import com.antgroup.openspg.reasoner.common.exception.OdpsException;
|
||||
import com.antgroup.openspg.reasoner.common.table.Field;
|
||||
import com.antgroup.openspg.reasoner.io.ITableReader;
|
||||
import com.antgroup.openspg.reasoner.io.model.AbstractTableInfo;
|
||||
import com.antgroup.openspg.reasoner.io.model.OdpsTableInfo;
|
||||
import com.antgroup.openspg.reasoner.io.model.ReadRange;
|
||||
import java.io.IOException;
|
||||
import java.util.HashMap;
|
||||
import java.util.Iterator;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections4.CollectionUtils;
|
||||
|
||||
@Slf4j
|
||||
public class OdpsReader implements ITableReader {
|
||||
|
||||
protected final int MAX_ODPS_READER_COUNT = 500 * 10000;
|
||||
|
||||
protected int index;
|
||||
|
||||
protected long readCount;
|
||||
|
||||
protected transient Map<OdpsTableInfo, ReadRange> tableReadRangeMap = new HashMap<>();
|
||||
|
||||
protected final Map<OdpsTableInfo, DownloadSession> downloadSessionMap;
|
||||
|
||||
public OdpsReader(Map<OdpsTableInfo, DownloadSession> downloadSessionMap) {
|
||||
this.downloadSessionMap = downloadSessionMap;
|
||||
}
|
||||
|
||||
private DownloadSession getDownloadSession(OdpsTableInfo odpsTableInfo) {
|
||||
return this.downloadSessionMap.get(odpsTableInfo);
|
||||
}
|
||||
|
||||
/** open odps reader */
|
||||
@Override
|
||||
public void init(
|
||||
int index, int parallel, int nowRound, int allRound, List<AbstractTableInfo> tableInfoList) {
|
||||
this.index = index;
|
||||
Map<OdpsTableInfo, Long> tableCountMap = new HashMap<>();
|
||||
for (AbstractTableInfo tableInfo : tableInfoList) {
|
||||
OdpsTableInfo odpsTableInfo = (OdpsTableInfo) tableInfo;
|
||||
long count = getDownloadSession(odpsTableInfo).getRecordCount();
|
||||
tableCountMap.put(odpsTableInfo, count);
|
||||
}
|
||||
this.tableReadRangeMap =
|
||||
OdpsUtils.getReadRange(parallel, index, allRound, nowRound, tableCountMap);
|
||||
|
||||
// init iterator
|
||||
this.nowReadTableIt = this.tableReadRangeMap.entrySet().iterator();
|
||||
this.nowReadRange = null;
|
||||
this.readCount = 0L;
|
||||
}
|
||||
|
||||
/** close odps reader */
|
||||
@Override
|
||||
public void close() {
|
||||
log.info("close odps reader, index=" + this.index + ", readCount=" + this.readCount);
|
||||
}
|
||||
|
||||
private Iterator<Map.Entry<OdpsTableInfo, ReadRange>> nowReadTableIt;
|
||||
|
||||
private OdpsTableInfo nowOdpsTableInfo = null;
|
||||
private RecordReader nowRecordReader = null;
|
||||
private Map<String, Integer> columnName2ResultIndexMap = null;
|
||||
private ReadRange nowReadRange = null;
|
||||
private long nowReadCount = 0;
|
||||
|
||||
@Override
|
||||
public boolean hasNext() {
|
||||
if (nowReaderHasNext()) {
|
||||
return true;
|
||||
}
|
||||
return nowReadTableIt.hasNext();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object[] next() {
|
||||
this.readCount++;
|
||||
if (nowReaderHasNext()) {
|
||||
return readRecord();
|
||||
}
|
||||
|
||||
Map.Entry<OdpsTableInfo, ReadRange> entry = this.nowReadTableIt.next();
|
||||
this.nowOdpsTableInfo = entry.getKey();
|
||||
this.nowReadRange = entry.getValue();
|
||||
this.nowReadCount = 0;
|
||||
this.nowRecordReader =
|
||||
OdpsUtils.tryOpenRecordReader(
|
||||
getDownloadSession(this.nowOdpsTableInfo),
|
||||
this.nowReadRange.getStart(),
|
||||
this.nowReadRange.getEnd());
|
||||
this.initColumnName2ResultIndexMap();
|
||||
return readRecord();
|
||||
}
|
||||
|
||||
private boolean nowReaderHasNext() {
|
||||
return null != this.nowReadRange && this.nowReadCount < this.nowReadRange.getCount();
|
||||
}
|
||||
|
||||
private void initColumnName2ResultIndexMap() {
|
||||
if (CollectionUtils.isEmpty(this.nowOdpsTableInfo.getColumns())) {
|
||||
columnName2ResultIndexMap = null;
|
||||
return;
|
||||
}
|
||||
columnName2ResultIndexMap = new HashMap<>();
|
||||
int resultSize = this.nowOdpsTableInfo.getColumns().size();
|
||||
for (int i = 0; i < resultSize; ++i) {
|
||||
Field field = this.nowOdpsTableInfo.getColumns().get(i);
|
||||
columnName2ResultIndexMap.put(field.getName(), i);
|
||||
}
|
||||
}
|
||||
|
||||
private Object[] readRecord() {
|
||||
Record record;
|
||||
try {
|
||||
record = this.nowRecordReader.read();
|
||||
nowReadCount++;
|
||||
if (nowReadCount > MAX_ODPS_READER_COUNT) {
|
||||
// reset reader when read a lot of datas
|
||||
this.nowRecordReader =
|
||||
OdpsUtils.tryOpenRecordReader(
|
||||
getDownloadSession(this.nowOdpsTableInfo),
|
||||
this.nowReadRange.getStart(),
|
||||
this.nowReadRange.getEnd());
|
||||
this.initColumnName2ResultIndexMap();
|
||||
}
|
||||
} catch (IOException e) {
|
||||
throw new OdpsException("read odps record error", e);
|
||||
}
|
||||
|
||||
Column[] columns = record.getColumns();
|
||||
|
||||
// convert type
|
||||
Object[] result =
|
||||
new Object
|
||||
[null == this.columnName2ResultIndexMap
|
||||
? columns.length
|
||||
: this.columnName2ResultIndexMap.size()];
|
||||
for (int i = 0; i < columns.length; ++i) {
|
||||
Column column = columns[i];
|
||||
int resultIndex = i;
|
||||
if (null != this.columnName2ResultIndexMap) {
|
||||
Integer integer = this.columnName2ResultIndexMap.get(column.getName());
|
||||
if (null == integer) {
|
||||
continue;
|
||||
}
|
||||
resultIndex = integer;
|
||||
}
|
||||
switch (column.getTypeInfo().getOdpsType()) {
|
||||
case STRING:
|
||||
case VARCHAR:
|
||||
case CHAR:
|
||||
result[resultIndex] = record.getString(i);
|
||||
break;
|
||||
case FLOAT:
|
||||
result[resultIndex] = record.getDouble(i).floatValue();
|
||||
break;
|
||||
case DOUBLE:
|
||||
result[resultIndex] = record.getDouble(i);
|
||||
break;
|
||||
case INT:
|
||||
result[resultIndex] = record.getBigint(i).intValue();
|
||||
break;
|
||||
case SMALLINT:
|
||||
result[resultIndex] = record.getBigint(i).shortValue();
|
||||
break;
|
||||
case TINYINT:
|
||||
result[resultIndex] = record.getBigint(i).byteValue();
|
||||
break;
|
||||
case BIGINT:
|
||||
result[resultIndex] = record.getBigint(i);
|
||||
break;
|
||||
case BOOLEAN:
|
||||
result[resultIndex] = record.getBoolean(i);
|
||||
break;
|
||||
default:
|
||||
result[resultIndex] = record.get(i);
|
||||
break;
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
@ -0,0 +1,587 @@
|
||||
/*
|
||||
* Copyright 2023 OpenSPG Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||
* in compliance with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
* is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
* or implied.
|
||||
*/
|
||||
|
||||
package com.antgroup.openspg.reasoner.io.odps;
|
||||
|
||||
import com.alibaba.fastjson.JSON;
|
||||
import com.aliyun.odps.Column;
|
||||
import com.aliyun.odps.Odps;
|
||||
import com.aliyun.odps.OdpsType;
|
||||
import com.aliyun.odps.PartitionSpec;
|
||||
import com.aliyun.odps.ReloadException;
|
||||
import com.aliyun.odps.Table;
|
||||
import com.aliyun.odps.TableSchema;
|
||||
import com.aliyun.odps.account.Account;
|
||||
import com.aliyun.odps.account.AliyunAccount;
|
||||
import com.aliyun.odps.tunnel.TableTunnel;
|
||||
import com.aliyun.odps.tunnel.TableTunnel.DownloadSession;
|
||||
import com.aliyun.odps.tunnel.TableTunnel.UploadSession;
|
||||
import com.aliyun.odps.tunnel.TunnelException;
|
||||
import com.aliyun.odps.tunnel.io.CompressOption;
|
||||
import com.aliyun.odps.tunnel.io.TunnelBufferedWriter;
|
||||
import com.aliyun.odps.tunnel.io.TunnelRecordReader;
|
||||
import com.antgroup.openspg.reasoner.common.exception.OdpsException;
|
||||
import com.antgroup.openspg.reasoner.common.types.KTBoolean$;
|
||||
import com.antgroup.openspg.reasoner.common.types.KTDouble$;
|
||||
import com.antgroup.openspg.reasoner.common.types.KTInteger$;
|
||||
import com.antgroup.openspg.reasoner.common.types.KTLong$;
|
||||
import com.antgroup.openspg.reasoner.common.types.KTString$;
|
||||
import com.antgroup.openspg.reasoner.common.types.KgType;
|
||||
import com.antgroup.openspg.reasoner.io.model.OdpsTableInfo;
|
||||
import com.antgroup.openspg.reasoner.io.model.ReadRange;
|
||||
import com.antgroup.openspg.reasoner.lube.catalog.struct.Field;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.commons.lang3.tuple.ImmutablePair;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
|
||||
@Slf4j
|
||||
public class OdpsUtils {
|
||||
|
||||
private static final int MAX_TRY_TIMES = 10;
|
||||
private static final int ODPS_WAIT_MS = 5000;
|
||||
private static final int ODPS_FLOW_EXCEEDED_WAIT_MS = 60 * 1000;
|
||||
|
||||
/** query table schema from odps */
|
||||
public static TableSchema getTableSchema(OdpsTableInfo odpsTableInfo) {
|
||||
Odps odps = getODPSInstance(odpsTableInfo);
|
||||
int tryTimes = MAX_TRY_TIMES;
|
||||
while (tryTimes-- > 0) {
|
||||
TableSchema schema;
|
||||
try {
|
||||
schema = odps.tables().get(odpsTableInfo.getTable()).getSchema();
|
||||
return schema;
|
||||
} catch (ReloadException e) {
|
||||
if (e.getMessage().contains("Table not found")) {
|
||||
return null;
|
||||
}
|
||||
log.error("get_table_schema_error,table_info=" + JSON.toJSONString(odpsTableInfo), e);
|
||||
if (e.getMessage().contains("time out")) {
|
||||
continue;
|
||||
}
|
||||
throw new OdpsException("get_table_schema_error", e);
|
||||
}
|
||||
}
|
||||
throw new OdpsException("get_table_schema_error, reach max retry times", null);
|
||||
}
|
||||
|
||||
/** create table schema form odps table info */
|
||||
public static TableSchema createSchema(OdpsTableInfo odpsTableInfo) {
|
||||
TableSchema schema = new TableSchema();
|
||||
for (Field field : odpsTableInfo.getLubeColumns()) {
|
||||
schema.addColumn(new Column(validColumnName(field.name()), toOdpsType(field.kgType())));
|
||||
}
|
||||
for (String name : odpsTableInfo.getPartition().keySet()) {
|
||||
schema.addPartitionColumn(new Column(validColumnName(name), OdpsType.STRING));
|
||||
}
|
||||
return schema;
|
||||
}
|
||||
|
||||
/** check schema is match */
|
||||
public static boolean isSchemaMatch(OdpsTableInfo odpsTableInfo, TableSchema realSchema) {
|
||||
TableSchema schemaNeeded = createSchema(odpsTableInfo);
|
||||
|
||||
List<String> columnsNeeded =
|
||||
schemaNeeded.getPartitionColumns().stream()
|
||||
.map(com.aliyun.odps.Column::getName)
|
||||
.sorted(String::compareTo)
|
||||
.collect(Collectors.toList());
|
||||
|
||||
List<String> columnsReal =
|
||||
realSchema.getPartitionColumns().stream()
|
||||
.map(com.aliyun.odps.Column::getName)
|
||||
.sorted(String::compareTo)
|
||||
.collect(Collectors.toList());
|
||||
|
||||
if (!columnsNeeded.equals(columnsReal)) {
|
||||
log.error(
|
||||
"odps_partition_columns_not_match, need_columns="
|
||||
+ JSON.toJSONString(columnsNeeded)
|
||||
+ ",real_columns="
|
||||
+ JSON.toJSONString(columnsReal));
|
||||
return false;
|
||||
}
|
||||
|
||||
columnsNeeded =
|
||||
schemaNeeded.getColumns().stream()
|
||||
.map(com.aliyun.odps.Column::getName)
|
||||
.sorted(String::compareTo)
|
||||
.collect(Collectors.toList());
|
||||
|
||||
columnsReal =
|
||||
realSchema.getColumns().stream()
|
||||
.map(com.aliyun.odps.Column::getName)
|
||||
.sorted(String::compareTo)
|
||||
.collect(Collectors.toList());
|
||||
if (!columnsNeeded.equals(columnsReal)) {
|
||||
log.error(
|
||||
"odps_type_not_match, need_columns="
|
||||
+ JSON.toJSONString(columnsNeeded)
|
||||
+ ",real_columns="
|
||||
+ JSON.toJSONString(columnsReal));
|
||||
return false;
|
||||
}
|
||||
|
||||
List<OdpsType> neededType =
|
||||
schemaNeeded.getColumns().stream()
|
||||
.map(column -> column.getTypeInfo().getOdpsType())
|
||||
.sorted(Enum::compareTo)
|
||||
.collect(Collectors.toList());
|
||||
|
||||
List<OdpsType> realType =
|
||||
realSchema.getColumns().stream()
|
||||
.map(column -> column.getTypeInfo().getOdpsType())
|
||||
.sorted(Enum::compareTo)
|
||||
.collect(Collectors.toList());
|
||||
/*
|
||||
Comment it out and wait for the type inference to complete
|
||||
if (!neededType.equals(realType)) {
|
||||
log.error(
|
||||
"odps_type_not_match, need_type="
|
||||
+ JSON.toJSONString(neededType)
|
||||
+ ",real_type="
|
||||
+ JSON.toJSONString(realType));
|
||||
return false;
|
||||
}
|
||||
*/
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/** create odps table */
|
||||
public static void createTable(OdpsTableInfo odpsTableInfo) {
|
||||
Odps odps = getODPSInstance(odpsTableInfo);
|
||||
TableSchema schema = createSchema(odpsTableInfo);
|
||||
log.info("start create table=" + odpsTableInfo.getProject() + "." + odpsTableInfo.getTable());
|
||||
try {
|
||||
odps.tables()
|
||||
.createTableWithLifeCycle(
|
||||
odpsTableInfo.getProject(), odpsTableInfo.getTable(), schema, null, true, 365L);
|
||||
log.info(
|
||||
"create table success," + odpsTableInfo.getProject() + "." + odpsTableInfo.getTable());
|
||||
} catch (com.aliyun.odps.OdpsException e) {
|
||||
log.error("create table error", e);
|
||||
throw new OdpsException("create_table_error", e);
|
||||
}
|
||||
}
|
||||
|
||||
/** convert KgType to odps type */
|
||||
public static OdpsType toOdpsType(KgType kgType) {
|
||||
if (KTString$.MODULE$.equals(kgType)) {
|
||||
return OdpsType.STRING;
|
||||
} else if (KTLong$.MODULE$.equals(kgType) || KTInteger$.MODULE$.equals(kgType)) {
|
||||
return OdpsType.BIGINT;
|
||||
} else if (KTDouble$.MODULE$.equals(kgType)) {
|
||||
return OdpsType.DOUBLE;
|
||||
} else if (KTBoolean$.MODULE$.equals(kgType)) {
|
||||
return OdpsType.BOOLEAN;
|
||||
} else {
|
||||
throw new OdpsException("unsupported column type, " + kgType, null);
|
||||
}
|
||||
}
|
||||
|
||||
/** change column name to odps valid name */
|
||||
public static String validColumnName(String columnName) {
|
||||
return columnName.replaceAll("\\.", "_").toLowerCase();
|
||||
}
|
||||
|
||||
/** create odps instance */
|
||||
public static Odps getODPSInstance(OdpsTableInfo odpsTableInfo) {
|
||||
Account account = new AliyunAccount(odpsTableInfo.getAccessID(), odpsTableInfo.getAccessKey());
|
||||
Odps odps = new Odps(account);
|
||||
odps.setEndpoint(odpsTableInfo.getEndPoint());
|
||||
odps.setDefaultProject(odpsTableInfo.getProject());
|
||||
return odps;
|
||||
}
|
||||
|
||||
/** get PartitionSpec from OdpsTableInfo */
|
||||
public static PartitionSpec getOdpsPartitionSpec(OdpsTableInfo odpsTableInfo) {
|
||||
if (null == odpsTableInfo.getPartition() || odpsTableInfo.getPartition().isEmpty()) {
|
||||
return null;
|
||||
}
|
||||
PartitionSpec partitionSpec = new PartitionSpec();
|
||||
for (Map.Entry<String, String> entry : odpsTableInfo.getPartition().entrySet()) {
|
||||
partitionSpec.set(entry.getKey(), entry.getValue());
|
||||
}
|
||||
return partitionSpec;
|
||||
}
|
||||
|
||||
/** get upload session */
|
||||
public static UploadSession tryGetUploadSession(
|
||||
OdpsTableInfo odpsTableInfo, String id, int index, int parallel) {
|
||||
Odps odps = getODPSInstance(odpsTableInfo);
|
||||
TableTunnel tunnel = new TableTunnel(odps);
|
||||
if (!StringUtils.isEmpty(odpsTableInfo.getTunnelEndPoint())) {
|
||||
log.info("set odps tunnel endpoint=" + odpsTableInfo.getTunnelEndPoint());
|
||||
tunnel.setEndpoint(odpsTableInfo.getTunnelEndPoint());
|
||||
}
|
||||
int maxTryTimes = MAX_TRY_TIMES;
|
||||
while (--maxTryTimes >= 0) {
|
||||
try {
|
||||
return tunnel.getUploadSession(
|
||||
odpsTableInfo.getProject(),
|
||||
odpsTableInfo.getTable(),
|
||||
getOdpsPartitionSpec(odpsTableInfo),
|
||||
id,
|
||||
parallel,
|
||||
index);
|
||||
} catch (Throwable e) {
|
||||
log.error("create upload session error", e);
|
||||
}
|
||||
|
||||
waitMs(ODPS_WAIT_MS);
|
||||
}
|
||||
throw new OdpsException("create upload session failed", null);
|
||||
}
|
||||
|
||||
/** create upload session */
|
||||
public static UploadSession tryCreateUploadSession(OdpsTableInfo odpsTableInfo) {
|
||||
Odps odps = getODPSInstance(odpsTableInfo);
|
||||
TableTunnel tunnel = new TableTunnel(odps);
|
||||
if (!StringUtils.isEmpty(odpsTableInfo.getTunnelEndPoint())) {
|
||||
log.info("set odps tunnel endpoint=" + odpsTableInfo.getTunnelEndPoint());
|
||||
tunnel.setEndpoint(odpsTableInfo.getTunnelEndPoint());
|
||||
}
|
||||
int maxTryTimes = MAX_TRY_TIMES;
|
||||
while (--maxTryTimes >= 0) {
|
||||
try {
|
||||
PartitionSpec partitionSpec = getOdpsPartitionSpec(odpsTableInfo);
|
||||
if (null == partitionSpec) {
|
||||
return tunnel.createUploadSession(odpsTableInfo.getProject(), odpsTableInfo.getTable());
|
||||
} else {
|
||||
return tunnel.createUploadSession(
|
||||
odpsTableInfo.getProject(), odpsTableInfo.getTable(), partitionSpec);
|
||||
}
|
||||
} catch (Throwable e) {
|
||||
log.error("create upload session error", e);
|
||||
}
|
||||
|
||||
waitMs(ODPS_WAIT_MS);
|
||||
}
|
||||
throw new OdpsException("create upload session failed", null);
|
||||
}
|
||||
|
||||
/** create record writer */
|
||||
public static TunnelBufferedWriter tryCreateBufferRecordWriter(UploadSession uploadSession) {
|
||||
int maxTryTimes = MAX_TRY_TIMES;
|
||||
while (--maxTryTimes >= 0) {
|
||||
try {
|
||||
CompressOption option = new CompressOption();
|
||||
return new TunnelBufferedWriter(uploadSession, option);
|
||||
} catch (Throwable e) {
|
||||
log.error("create buffer writer error", e);
|
||||
}
|
||||
waitMs(ODPS_WAIT_MS);
|
||||
}
|
||||
throw new OdpsException("create buffer writer error", null);
|
||||
}
|
||||
|
||||
/** create partition for table */
|
||||
public static void createPartition(OdpsTableInfo odpsTableInfo) {
|
||||
Odps odps = getODPSInstance(odpsTableInfo);
|
||||
PartitionSpec partitionSpec = getOdpsPartitionSpec(odpsTableInfo);
|
||||
if (null == partitionSpec) {
|
||||
return;
|
||||
}
|
||||
try {
|
||||
Table t = odps.tables().get(odpsTableInfo.getTable());
|
||||
if (!t.hasPartition(partitionSpec)) {
|
||||
t.createPartition(partitionSpec);
|
||||
}
|
||||
} catch (com.aliyun.odps.OdpsException e) {
|
||||
if (e.getMessage().contains("Partition already exists")) {
|
||||
// partition already exists, do not throw error
|
||||
// com.aliyun.odps.OdpsException: Catalog Service Failed, ErrorCode: 103,
|
||||
// Error Message: ODPS-0110061: Failed to run ddltask - AlreadyExistsException(message:
|
||||
// Partition already exists, existed values: ["$partition"])
|
||||
return;
|
||||
}
|
||||
throw new OdpsException("create_partition_error", e);
|
||||
}
|
||||
}
|
||||
|
||||
/** create download session */
|
||||
public static DownloadSession tryCreateDownloadSession(
|
||||
TableTunnel tunnel, OdpsTableInfo odpsTableInfo) {
|
||||
PartitionSpec partition = getOdpsPartitionSpec(odpsTableInfo);
|
||||
int maxTryTimes = MAX_TRY_TIMES;
|
||||
Throwable lastError = null;
|
||||
while (--maxTryTimes >= 0) {
|
||||
try {
|
||||
DownloadSession downloadSession;
|
||||
if (null != partition) {
|
||||
downloadSession =
|
||||
tunnel.createDownloadSession(
|
||||
odpsTableInfo.getProject(), odpsTableInfo.getTable(), partition);
|
||||
} else {
|
||||
downloadSession =
|
||||
tunnel.createDownloadSession(odpsTableInfo.getProject(), odpsTableInfo.getTable());
|
||||
}
|
||||
return downloadSession;
|
||||
} catch (TunnelException e) {
|
||||
if ("NoSuchPartition".equals(e.getErrorCode())) {
|
||||
// continue
|
||||
log.info(
|
||||
"table="
|
||||
+ odpsTableInfo.getProject()
|
||||
+ "."
|
||||
+ odpsTableInfo.getTable()
|
||||
+ ", partition="
|
||||
+ partition
|
||||
+ ", not exist");
|
||||
return null;
|
||||
} else if ("InvalidPartitionSpec".equals(e.getErrorCode())) {
|
||||
// if this table is not a partition table, we create download session without
|
||||
// PartitionSpec
|
||||
log.info(
|
||||
"table="
|
||||
+ odpsTableInfo.getProject()
|
||||
+ "."
|
||||
+ odpsTableInfo.getTable()
|
||||
+ ", partition="
|
||||
+ partition
|
||||
+ ", InvalidPartitionSpec");
|
||||
partition = null;
|
||||
continue;
|
||||
} else if ("FlowExceeded".equals(e.getErrorCode())) {
|
||||
log.warn("create_download_session, flow exceeded");
|
||||
// flow exceeded, continue
|
||||
// --maxTryTimes;
|
||||
waitMs(ODPS_WAIT_MS);
|
||||
continue;
|
||||
}
|
||||
log.error(
|
||||
"create_download_session_error, table="
|
||||
+ odpsTableInfo.getProject()
|
||||
+ "."
|
||||
+ odpsTableInfo.getTable()
|
||||
+ ", partition="
|
||||
+ partition,
|
||||
e);
|
||||
} catch (Throwable e) {
|
||||
log.error(
|
||||
"create_download_session_error, table="
|
||||
+ odpsTableInfo.getProject()
|
||||
+ "."
|
||||
+ odpsTableInfo.getTable()
|
||||
+ ", partition="
|
||||
+ partition,
|
||||
e);
|
||||
lastError = e;
|
||||
}
|
||||
|
||||
waitMs(ODPS_WAIT_MS);
|
||||
}
|
||||
throw new OdpsException(
|
||||
"create_download_session_failed, time_out, table="
|
||||
+ odpsTableInfo.getProject()
|
||||
+ "."
|
||||
+ odpsTableInfo.getTable()
|
||||
+ ", partition="
|
||||
+ partition,
|
||||
lastError);
|
||||
}
|
||||
|
||||
/** open record reader */
|
||||
public static TunnelRecordReader tryOpenRecordReader(
|
||||
TableTunnel.DownloadSession downloadSession, long start, long count) {
|
||||
TunnelRecordReader recordReader;
|
||||
int maxTryTimes = MAX_TRY_TIMES;
|
||||
while (--maxTryTimes >= 0) {
|
||||
try {
|
||||
recordReader = downloadSession.openRecordReader(start, count);
|
||||
return recordReader;
|
||||
} catch (Exception e) {
|
||||
if (e instanceof com.aliyun.odps.OdpsException) {
|
||||
com.aliyun.odps.OdpsException oe = (com.aliyun.odps.OdpsException) e;
|
||||
if ("FlowExceeded".equals(oe.getErrorCode())) {
|
||||
log.warn("open_record_reader, flow exceeded");
|
||||
--maxTryTimes;
|
||||
waitMs(ODPS_FLOW_EXCEEDED_WAIT_MS);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
log.error("open_record_reader_error", e);
|
||||
}
|
||||
waitMs(ODPS_WAIT_MS);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
/** must call on driver */
|
||||
public static UploadSession createUploadSession(OdpsTableInfo odpsTableInfo) {
|
||||
// check odps table is exist
|
||||
TableSchema schema = getTableSchema(odpsTableInfo);
|
||||
if (null == schema) {
|
||||
if (odpsTableInfo.getProject().endsWith("_dev")) {
|
||||
createTable(odpsTableInfo);
|
||||
schema = getTableSchema(odpsTableInfo);
|
||||
if (null == schema) {
|
||||
throw new OdpsException("create table error", null);
|
||||
}
|
||||
} else {
|
||||
// table not exist
|
||||
throw new OdpsException(
|
||||
"table not exist, project="
|
||||
+ odpsTableInfo.getProject()
|
||||
+ ",table="
|
||||
+ odpsTableInfo.getTable(),
|
||||
null);
|
||||
}
|
||||
}
|
||||
|
||||
// check table schema is match
|
||||
if (!isSchemaMatch(odpsTableInfo, schema)) {
|
||||
throw new OdpsException(
|
||||
"table "
|
||||
+ odpsTableInfo.getProject()
|
||||
+ "."
|
||||
+ odpsTableInfo.getTable()
|
||||
+ ",schema not match",
|
||||
null);
|
||||
}
|
||||
|
||||
// create partition and upload session
|
||||
createPartition(odpsTableInfo);
|
||||
return tryCreateUploadSession(odpsTableInfo);
|
||||
}
|
||||
|
||||
/** delete odps table */
|
||||
public static void dropOdpsTable(OdpsTableInfo odpsTableInfo) throws Exception {
|
||||
Odps odps = getODPSInstance(odpsTableInfo);
|
||||
odps.tables().delete(odpsTableInfo.getProject(), odpsTableInfo.getTable());
|
||||
log.info("dropOdpsTable," + odpsTableInfo.getTableInfoKeyString());
|
||||
}
|
||||
|
||||
/** get read range */
|
||||
public static Map<OdpsTableInfo, ReadRange> getReadRange(
|
||||
int parallel, int index, int allRound, int nowRound, Map<OdpsTableInfo, Long> tableCountMap) {
|
||||
|
||||
List<Pair<OdpsTableInfo, Long>> tableCountList = new ArrayList<>();
|
||||
long allCount = 0;
|
||||
for (OdpsTableInfo tableInfo : tableCountMap.keySet()) {
|
||||
tableCountList.add(new ImmutablePair<>(tableInfo, tableCountMap.get(tableInfo)));
|
||||
allCount += tableCountMap.get(tableInfo);
|
||||
}
|
||||
tableCountList.sort(
|
||||
Comparator.comparingLong((Pair<OdpsTableInfo, Long> o) -> o.getRight())
|
||||
.thenComparing(Pair::getLeft));
|
||||
|
||||
Map<OdpsTableInfo, ReadRange> result = new HashMap<>();
|
||||
|
||||
ReadRange loadRange = getReadRange(parallel, index, allRound, nowRound, allCount, 1);
|
||||
long offset1 = 0;
|
||||
long offset2 = 0;
|
||||
long loadedCount = 0;
|
||||
for (Pair<OdpsTableInfo, Long> partitionInfo : tableCountList) {
|
||||
offset1 = offset2;
|
||||
offset2 += partitionInfo.getRight();
|
||||
|
||||
long start = 0;
|
||||
if (loadedCount <= 0) {
|
||||
if (loadRange.getStart() >= offset1 && loadRange.getStart() < offset2) {
|
||||
start = loadRange.getStart() - offset1;
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if (loadRange.getEnd() <= offset2) {
|
||||
long end = start + (loadRange.getCount() - loadedCount);
|
||||
if (end == start) {
|
||||
continue;
|
||||
}
|
||||
result.put(partitionInfo.getLeft(), new ReadRange(start, end));
|
||||
loadedCount += end - start;
|
||||
break;
|
||||
} else {
|
||||
long end = offset2 - offset1;
|
||||
if (end == start) {
|
||||
continue;
|
||||
}
|
||||
result.put(partitionInfo.getLeft(), new ReadRange(start, end));
|
||||
loadedCount += end - start;
|
||||
}
|
||||
}
|
||||
|
||||
result =
|
||||
result.entrySet().stream()
|
||||
.filter(tableReadRange -> tableReadRange.getValue().getCount() > 0)
|
||||
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
private static ReadRange getReadRange(
|
||||
int parallel, int index, int allRound, int nowRound, long count, long minLoadSize) {
|
||||
long modSize = count % parallel;
|
||||
long share = count / parallel;
|
||||
long start;
|
||||
if (index < modSize) {
|
||||
share += 1;
|
||||
start = share * index;
|
||||
} else {
|
||||
start = modSize + share * index;
|
||||
}
|
||||
long end = start + share;
|
||||
|
||||
if (share > 0 && share < minLoadSize) {
|
||||
// allocate to last node
|
||||
start = 0;
|
||||
if (index == parallel - 1) {
|
||||
end = count;
|
||||
} else {
|
||||
end = 0;
|
||||
}
|
||||
}
|
||||
|
||||
if (allRound > 1) {
|
||||
long roundCount = end - start;
|
||||
long roundShare = roundCount / allRound;
|
||||
if (roundShare > 0 && roundShare < minLoadSize) {
|
||||
// allocate to last node
|
||||
if (nowRound == allRound - 1) {
|
||||
end = start + roundCount;
|
||||
} else {
|
||||
end = start;
|
||||
}
|
||||
} else {
|
||||
long roundModSize = roundCount % allRound;
|
||||
if (nowRound < roundModSize) {
|
||||
roundShare += 1;
|
||||
start += nowRound * roundShare;
|
||||
} else {
|
||||
start += roundModSize + roundShare * nowRound;
|
||||
}
|
||||
end = start + roundShare;
|
||||
}
|
||||
}
|
||||
|
||||
return new ReadRange(start, end);
|
||||
}
|
||||
|
||||
/** wait ms */
|
||||
public static void waitMs(long ms) {
|
||||
try {
|
||||
Thread.sleep(ms);
|
||||
} catch (InterruptedException e) {
|
||||
log.warn("sleep_error", e);
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,151 @@
|
||||
/*
|
||||
* Copyright 2023 OpenSPG Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||
* in compliance with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
* is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
* or implied.
|
||||
*/
|
||||
|
||||
package com.antgroup.openspg.reasoner.io.odps;
|
||||
|
||||
import com.aliyun.odps.data.Record;
|
||||
import com.aliyun.odps.tunnel.TableTunnel.UploadSession;
|
||||
import com.aliyun.odps.tunnel.io.TunnelBufferedWriter;
|
||||
import com.antgroup.openspg.reasoner.common.exception.OdpsException;
|
||||
import com.antgroup.openspg.reasoner.io.ITableWriter;
|
||||
import com.antgroup.openspg.reasoner.io.model.AbstractTableInfo;
|
||||
import com.antgroup.openspg.reasoner.io.model.OdpsTableInfo;
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
@Slf4j
|
||||
public class OdpsWriter implements ITableWriter {
|
||||
private int taskIndex;
|
||||
|
||||
private OdpsTableInfo odpsTableInfo;
|
||||
|
||||
private transient UploadSession uploadSession = null;
|
||||
private transient TunnelBufferedWriter recordWriter = null;
|
||||
|
||||
private long writeCount = 0L;
|
||||
|
||||
private static final int MAX_TRY_WRITE_TIMES = 5;
|
||||
|
||||
/** write buffer size */
|
||||
private static final long WRITER_BUFFER_SIZE = 32 * 1024 * 1024;
|
||||
/** reset writer when count to 10M */
|
||||
private static final long WRITER_RESET_COUNT = 1000 * 10000;
|
||||
|
||||
/**
|
||||
* init odps writer
|
||||
*
|
||||
* <p>The odps writer will not commit the result himself, You must ensure data commit by yourself.
|
||||
*
|
||||
* <p>for example:
|
||||
*
|
||||
* <p>// create upload session on driver UploadSession session =
|
||||
* OdpsUtils.getUploadSession(tableInfo);
|
||||
*
|
||||
* <p>// set session id, makesure that the writer on each worker is under the same session id
|
||||
* tableInfo.setUploadSessionId(session.getId());
|
||||
*
|
||||
* <p>// on worker, get writer and write data ...(code on worker)
|
||||
*
|
||||
* <p>// on driver, commit session session.commit();
|
||||
*/
|
||||
public void open(int taskIndex, int parallel, AbstractTableInfo tableInfo) {
|
||||
this.taskIndex = taskIndex;
|
||||
this.odpsTableInfo = (OdpsTableInfo) tableInfo;
|
||||
log.info("openOdpsWriter,index=" + this.taskIndex + ",odpsTableInfo=" + this.odpsTableInfo);
|
||||
this.uploadSession =
|
||||
OdpsUtils.tryGetUploadSession(
|
||||
this.odpsTableInfo, this.odpsTableInfo.getUploadSessionId(), taskIndex, parallel);
|
||||
resetWriter();
|
||||
}
|
||||
|
||||
/** write record */
|
||||
@Override
|
||||
public void write(Object[] data) {
|
||||
long c = this.writeCount++;
|
||||
|
||||
if (1 == c % 10000) {
|
||||
log.info(
|
||||
"index="
|
||||
+ this.taskIndex
|
||||
+ ",write_odps_record["
|
||||
+ Arrays.toString(data)
|
||||
+ "], write_count="
|
||||
+ c);
|
||||
}
|
||||
|
||||
Record record = uploadSession.newRecord();
|
||||
record.set(data);
|
||||
|
||||
// try five times at most
|
||||
int maxTryTimes = MAX_TRY_WRITE_TIMES;
|
||||
while (maxTryTimes-- > 0) {
|
||||
try {
|
||||
synchronized (this) {
|
||||
recordWriter.write(record);
|
||||
}
|
||||
break;
|
||||
} catch (IOException e) {
|
||||
if (e.getLocalizedMessage().contains("MalformedDataStream")) {
|
||||
log.error("write_odps_get_io_exception", e);
|
||||
// io exception, reset
|
||||
resetWriter();
|
||||
continue;
|
||||
}
|
||||
throw new OdpsException("write_odps_record_error", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** close writer */
|
||||
@Override
|
||||
public void close() {
|
||||
closeWriter();
|
||||
}
|
||||
|
||||
@Override
|
||||
public long writeCount() {
|
||||
return this.writeCount;
|
||||
}
|
||||
|
||||
private void resetWriter() {
|
||||
closeWriter();
|
||||
recordWriter = OdpsUtils.tryCreateBufferRecordWriter(this.uploadSession);
|
||||
recordWriter.setBufferSize(WRITER_BUFFER_SIZE);
|
||||
}
|
||||
|
||||
private void closeWriter() {
|
||||
if (null != recordWriter) {
|
||||
try {
|
||||
log.info(
|
||||
"odps_writer_close, index="
|
||||
+ this.taskIndex
|
||||
+ ", info="
|
||||
+ odpsTableInfo
|
||||
+ ", odps_write_count="
|
||||
+ writeCount);
|
||||
recordWriter.close();
|
||||
} catch (IOException e) {
|
||||
if (e.getLocalizedMessage().contains("MalformedDataStream")) {
|
||||
log.error("close_writer_MalformedDataStream", e);
|
||||
return;
|
||||
}
|
||||
log.error("close_writer_error", e);
|
||||
throw new OdpsException("close_writer_error", e);
|
||||
} finally {
|
||||
recordWriter = null;
|
||||
writeCount = 0L;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user