Refactored Datalake and Deltalake for Topology (#6034)

* rebasing with main

* refactored deltalake for topology

* using requests instead of urllib

* formatting fixes

Co-authored-by: Onkar Ravgan <onkarravgan@Onkars-MacBook-Pro.local>
This commit is contained in:
Onkar Ravgan 2022-07-19 10:07:27 +05:30 committed by GitHub
parent e7dca141ed
commit 8c9dc91ccf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 390 additions and 261 deletions

View File

@ -25,6 +25,7 @@ except ImportError:
from typing_compat import get_args
from pydantic import BaseModel
from requests.utils import quote
from metadata.generated.schema.api.lineage.addLineage import AddLineageRequest
from metadata.generated.schema.entity.data.chart import Chart
@ -458,7 +459,11 @@ class OpenMetadata(
Return entity by name or None
"""
return self._get(entity=entity, path=f"name/{model_str(fqn)}", fields=fields)
return self._get(
entity=entity,
path=f"name/{quote(model_str(fqn), safe='')}",
fields=fields,
)
def get_by_id(
self,

View File

@ -13,15 +13,17 @@
DataLake connector to fetch metadata from a files stored s3, gcs and Hdfs
"""
import traceback
import uuid
from typing import Iterable, Optional
from typing import Iterable, Optional, Tuple
from metadata.generated.schema.entity.data.database import Database
from metadata.generated.schema.entity.data.databaseSchema import DatabaseSchema
from metadata.generated.schema.api.data.createDatabase import CreateDatabaseRequest
from metadata.generated.schema.api.data.createDatabaseSchema import (
CreateDatabaseSchemaRequest,
)
from metadata.generated.schema.api.data.createTable import CreateTableRequest
from metadata.generated.schema.api.lineage.addLineage import AddLineageRequest
from metadata.generated.schema.entity.data.table import (
Column,
DataType,
Table,
TableData,
TableType,
)
@ -42,10 +44,13 @@ from metadata.generated.schema.metadataIngestion.workflow import (
)
from metadata.generated.schema.type.entityReference import EntityReference
from metadata.ingestion.api.common import Entity
from metadata.ingestion.api.source import InvalidSourceException, Source, SourceStatus
from metadata.ingestion.models.ometa_table_db import OMetaDatabaseAndTable
from metadata.ingestion.api.source import InvalidSourceException, SourceStatus
from metadata.ingestion.models.ometa_tag_category import OMetaTagAndCategory
from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.ingestion.source.database.common_db_source import SQLSourceStatus
from metadata.ingestion.source.database.database_service import (
DatabaseServiceSource,
SQLSourceStatus,
)
from metadata.utils.connections import get_connection, test_connection
from metadata.utils.filters import filter_by_table
from metadata.utils.gcs_utils import (
@ -66,10 +71,11 @@ logger = ingestion_logger()
DATALAKE_INT_TYPES = {"int64", "INT"}
DATALAKE_SUPPORTED_FILE_TYPES = (".csv", ".tsv", ".json", ".parquet")
class DatalakeSource(Source[Entity]):
class DatalakeSource(DatabaseServiceSource):
def __init__(self, config: WorkflowSource, metadata_config: OpenMetadataConnection):
super().__init__()
self.status = SQLSourceStatus()
self.config = config
@ -84,8 +90,10 @@ class DatalakeSource(Source[Entity]):
)
self.connection = get_connection(self.service_connection)
self.client = self.connection.client
self.table_constraints = None
self.database_source_state = set()
super().__init__()
@classmethod
def create(cls, config_dict, metadata_config: OpenMetadataConnection):
@ -97,146 +105,171 @@ class DatalakeSource(Source[Entity]):
)
return cls(config, metadata_config)
def prepare(self):
pass
def get_database_names(self) -> Iterable[str]:
"""
Default case with a single database.
def next_record(self) -> Iterable[Entity]:
try:
It might come informed - or not - from the source.
bucket_name = self.service_connection.bucketName
prefix = self.service_connection.prefix
Sources with multiple databases should overwrite this and
apply the necessary filters.
"""
database_name = "default"
yield database_name
def yield_database(self, database_name: str) -> Iterable[CreateDatabaseRequest]:
"""
From topology.
Prepare a database request and pass it to the sink
"""
yield CreateDatabaseRequest(
name=database_name,
service=EntityReference(
id=self.context.database_service.id,
type="databaseService",
),
)
def get_database_schema_names(self) -> Iterable[str]:
"""
return schema names
"""
bucket_name = self.service_connection.bucketName
if isinstance(self.service_connection.configSource, GCSConfig):
if bucket_name:
yield bucket_name
else:
for bucket in self.client.list_buckets():
yield bucket.name
if isinstance(self.service_connection.configSource, S3Config):
if bucket_name:
yield bucket_name
else:
for bucket in self.client.list_buckets()["Buckets"]:
yield bucket["Name"]
def yield_database_schema(
self, schema_name: str
) -> Iterable[CreateDatabaseSchemaRequest]:
"""
From topology.
Prepare a database schema request and pass it to the sink
"""
yield CreateDatabaseSchemaRequest(
name=schema_name,
database=EntityReference(id=self.context.database.id, type="database"),
)
def get_tables_name_and_type(self) -> Optional[Iterable[Tuple[str, str]]]:
"""
Handle table and views.
Fetches them up using the context information and
the inspector set when preparing the db.
:return: tables or views, depending on config
"""
bucket_name = self.context.database_schema.name.__root__
prefix = self.service_connection.prefix
if self.source_config.includeTables:
if isinstance(self.service_connection.configSource, GCSConfig):
if bucket_name:
yield from self.get_gcs_files(bucket_name, prefix)
else:
for bucket in self.client.list_buckets():
yield from self.get_gcs_files(bucket.name, prefix)
bucket = self.client.get_bucket(bucket_name)
for key in bucket.list_blobs(prefix=prefix):
if filter_by_table(
self.config.sourceConfig.config.tableFilterPattern, key.name
) or not self.check_valid_file_type(key.name):
self.status.filter(
"{}".format(key["Key"]),
"Table pattern not allowed",
)
continue
table_name = self.standardize_table_name(bucket_name, key.name)
yield table_name, TableType.Regular
if isinstance(self.service_connection.configSource, S3Config):
if bucket_name:
yield from self.get_s3_files(bucket_name, prefix)
else:
for bucket in self.client.list_buckets()["Buckets"]:
yield from self.get_s3_files(bucket["Name"], prefix)
kwargs = {"Bucket": bucket_name}
if prefix:
kwargs["Prefix"] = prefix if prefix.endswith("/") else f"{prefix}/"
for key in self.client.list_objects(**kwargs)["Contents"]:
if filter_by_table(
self.config.sourceConfig.config.tableFilterPattern, key["Key"]
) or not self.check_valid_file_type(key["Key"]):
self.status.filter(
"{}".format(key["Key"]),
"Table pattern not allowed",
)
continue
table_name = self.standardize_table_name(bucket_name, key["Key"])
yield table_name, TableType.Regular
def yield_table(
self, table_name_and_type: Tuple[str, str]
) -> Iterable[Optional[CreateTableRequest]]:
"""
From topology.
Prepare a table request and pass it to the sink
"""
table_name, table_type = table_name_and_type
schema_name = self.context.database_schema.name.__root__
try:
table_constraints = None
if isinstance(self.service_connection.configSource, GCSConfig):
df = self.get_gcs_files(key=table_name, bucket_name=schema_name)
if isinstance(self.service_connection.configSource, S3Config):
df = self.get_s3_files(key=table_name, bucket_name=schema_name)
columns = self.get_columns(df)
table_request = CreateTableRequest(
name=table_name,
tableType=table_type,
description="",
columns=columns,
tableConstraints=table_constraints if table_constraints else None,
databaseSchema=EntityReference(
id=self.context.database_schema.id,
type="databaseSchema",
),
)
yield table_request
self.register_record(table_request=table_request)
except Exception as err:
logger.error(traceback.format_exc())
logger.debug(traceback.format_exc())
logger.error(err)
self.status.failures.append(
"{}.{}".format(self.config.serviceName, table_name)
)
def get_gcs_files(self, key, bucket_name):
try:
if key.endswith(".csv"):
return read_csv_from_gcs(key, bucket_name)
if key.endswith(".tsv"):
return read_tsv_from_gcs(key, bucket_name)
if key.endswith(".json"):
return read_json_from_gcs(self.client, key, bucket_name)
if key.endswith(".parquet"):
return read_parquet_from_gcs(key, bucket_name)
except Exception as err:
logger.debug(traceback.format_exc())
logger.error(err)
def get_gcs_files(self, bucket_name, prefix):
bucket = self.client.get_bucket(bucket_name)
for key in bucket.list_blobs(prefix=prefix):
try:
if filter_by_table(
self.config.sourceConfig.config.tableFilterPattern, key.name
):
self.status.filter(
"{}".format(key["Key"]),
"Table pattern not allowed",
)
continue
if key.name.endswith(".csv"):
df = read_csv_from_gcs(key, bucket_name)
yield from self.ingest_tables(key.name, df, bucket_name)
if key.name.endswith(".tsv"):
df = read_tsv_from_gcs(key, bucket_name)
yield from self.ingest_tables(key.name, df, bucket_name)
if key.name.endswith(".json"):
df = read_json_from_gcs(key)
yield from self.ingest_tables(key.name, df, bucket_name)
if key.name.endswith(".parquet"):
df = read_parquet_from_gcs(key, bucket_name)
yield from self.ingest_tables(key.name, df, bucket_name)
except Exception as err:
logger.debug(traceback.format_exc())
logger.error(err)
def get_s3_files(self, bucket_name, prefix):
kwargs = {"Bucket": bucket_name}
if prefix:
kwargs["Prefix"] = prefix if prefix.endswith("/") else f"{prefix}/"
for key in self.client.list_objects(**kwargs)["Contents"]:
try:
if filter_by_table(
self.config.sourceConfig.config.tableFilterPattern, key["Key"]
):
self.status.filter(
"{}".format(key["Key"]),
"Table pattern not allowed",
)
continue
if key["Key"].endswith(".csv"):
df = read_csv_from_s3(self.client, key, bucket_name)
yield from self.ingest_tables(key["Key"], df, bucket_name)
if key["Key"].endswith(".tsv"):
df = read_tsv_from_s3(self.client, key, bucket_name)
yield from self.ingest_tables(key["Key"], df, bucket_name)
if key["Key"].endswith(".json"):
df = read_json_from_s3(self.client, key, bucket_name)
yield from self.ingest_tables(key["Key"], df, bucket_name)
if key["Key"].endswith(".parquet"):
df = read_parquet_from_s3(self.client, key, bucket_name)
yield from self.ingest_tables(key["Key"], df, bucket_name)
except Exception as err:
logger.debug(traceback.format_exc())
logger.error(err)
def ingest_tables(self, key, df, bucket_name) -> Iterable[OMetaDatabaseAndTable]:
def get_s3_files(self, key, bucket_name):
try:
table_columns = self.get_columns(df)
database_entity = Database(
id=uuid.uuid4(),
name="default",
service=EntityReference(id=self.service.id, type="databaseService"),
)
table_entity = Table(
id=uuid.uuid4(),
name=key,
description="",
columns=table_columns,
tableType=TableType.External,
)
schema_entity = DatabaseSchema(
id=uuid.uuid4(),
name=bucket_name,
database=EntityReference(id=database_entity.id, type="database"),
service=EntityReference(id=self.service.id, type="databaseService"),
)
table_and_db = OMetaDatabaseAndTable(
table=table_entity,
database=database_entity,
database_schema=schema_entity,
)
if key.endswith(".csv"):
return read_csv_from_s3(self.client, key, bucket_name)
yield table_and_db
if key.endswith(".tsv"):
return read_tsv_from_s3(self.client, key, bucket_name)
if key.endswith(".json"):
return read_json_from_s3(self.client, key, bucket_name)
if key.endswith(".parquet"):
return read_parquet_from_s3(self.client, key, bucket_name)
except Exception as err:
logger.debug(traceback.format_exc())
@ -259,27 +292,43 @@ class DatalakeSource(Source[Entity]):
return None
def get_columns(self, df):
df_columns = list(df.columns)
for column in df_columns:
try:
if hasattr(df, "columns"):
df_columns = list(df.columns)
for column in df_columns:
try:
if (
hasattr(df[column], "dtypes")
and df[column].dtypes.name in DATALAKE_INT_TYPES
):
if df[column].dtypes.name == "int64":
data_type = DataType.INT.value
else:
data_type = DataType.STRING.value
parsed_string = {}
parsed_string["dataTypeDisplay"] = data_type
parsed_string["dataType"] = data_type
parsed_string["name"] = column[:64]
parsed_string["dataLength"] = parsed_string.get("dataLength", 1)
yield Column(**parsed_string)
except Exception as err:
logger.debug(traceback.format_exc())
logger.error(err)
if (
hasattr(df[column], "dtypes")
and df[column].dtypes.name in DATALAKE_INT_TYPES
):
if df[column].dtypes.name == "int64":
data_type = DataType.INT.value
else:
data_type = DataType.STRING.value
parsed_string = {}
parsed_string["dataTypeDisplay"] = data_type
parsed_string["dataType"] = data_type
parsed_string["name"] = column[:64]
parsed_string["dataLength"] = parsed_string.get("dataLength", 1)
yield Column(**parsed_string)
except Exception as err:
logger.debug(traceback.format_exc())
logger.error(err)
def yield_view_lineage(
self, table_name_and_type: Tuple[str, str]
) -> Optional[Iterable[AddLineageRequest]]:
pass
def yield_tag(self, schema_name: str) -> Iterable[OMetaTagAndCategory]:
pass
def standardize_table_name(self, schema: str, table: str) -> str:
return table
def check_valid_file_type(self, key_name):
if key_name.endswith(DATALAKE_SUPPORTED_FILE_TYPES):
return True
return False
def close(self):
pass

View File

@ -1,15 +1,18 @@
import logging
import re
import uuid
from typing import Any, Dict, Iterable, List, Optional
import traceback
from typing import Any, Dict, Iterable, List, Optional, Tuple
from pyspark.sql import SparkSession
from pyspark.sql.types import ArrayType, MapType, StructType
from pyspark.sql.utils import AnalysisException, ParseException
from metadata.generated.schema.entity.data.database import Database
from metadata.generated.schema.entity.data.databaseSchema import DatabaseSchema
from metadata.generated.schema.entity.data.table import Column, Table, TableType
from metadata.generated.schema.api.data.createDatabase import CreateDatabaseRequest
from metadata.generated.schema.api.data.createDatabaseSchema import (
CreateDatabaseSchemaRequest,
)
from metadata.generated.schema.api.data.createTable import CreateTableRequest
from metadata.generated.schema.api.lineage.addLineage import AddLineageRequest
from metadata.generated.schema.entity.data.table import Column, TableType
from metadata.generated.schema.entity.services.connections.database.deltaLakeConnection import (
DeltaLakeConnection,
)
@ -17,15 +20,20 @@ from metadata.generated.schema.entity.services.connections.metadata.openMetadata
OpenMetadataConnection,
)
from metadata.generated.schema.entity.services.databaseService import DatabaseService
from metadata.generated.schema.metadataIngestion.databaseServiceMetadataPipeline import (
DatabaseServiceMetadataPipeline,
)
from metadata.generated.schema.metadataIngestion.workflow import (
Source as WorkflowSource,
)
from metadata.generated.schema.type.entityReference import EntityReference
from metadata.ingestion.api.common import Entity
from metadata.ingestion.api.source import InvalidSourceException, Source
from metadata.ingestion.models.ometa_table_db import OMetaDatabaseAndTable
from metadata.ingestion.api.source import InvalidSourceException
from metadata.ingestion.models.ometa_tag_category import OMetaTagAndCategory
from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.ingestion.source.database.common_db_source import SQLSourceStatus
from metadata.ingestion.source.database.database_service import (
DatabaseServiceSource,
SQLSourceStatus,
)
from metadata.utils.column_type_parser import ColumnTypeParser
from metadata.utils.connections import get_connection
from metadata.utils.filters import filter_by_schema, filter_by_table
@ -42,7 +50,7 @@ class MetaStoreNotFoundException(Exception):
"""
class DeltalakeSource(Source[Entity]):
class DeltalakeSource(DatabaseServiceSource):
spark: SparkSession = None
def __init__(
@ -50,8 +58,11 @@ class DeltalakeSource(Source[Entity]):
config: WorkflowSource,
metadata_config: OpenMetadataConnection,
):
super().__init__()
self.config = config
self.source_config: DatabaseServiceMetadataPipeline = (
self.config.sourceConfig.config
)
self.metadata_config = metadata_config
self.metadata = OpenMetadata(metadata_config)
self.service_connection = self.config.serviceConnection.__root__.config
@ -72,6 +83,9 @@ class DeltalakeSource(Source[Entity]):
self.array_datatype_replace_map = {"(": "<", ")": ">", "=": ":", "<>": ""}
self.ARRAY_CHILD_START_INDEX = 6
self.ARRAY_CHILD_END_INDEX = -1
self.table_constraints = None
self.database_source_state = set()
super().__init__()
@classmethod
def create(cls, config_dict, metadata_config: OpenMetadataConnection):
@ -87,7 +101,35 @@ class DeltalakeSource(Source[Entity]):
)
return cls(config, metadata_config)
def next_record(self) -> Iterable[OMetaDatabaseAndTable]:
def get_database_names(self) -> Iterable[str]:
"""
Default case with a single database.
It might come informed - or not - from the source.
Sources with multiple databases should overwrite this and
apply the necessary filters.
"""
yield DEFAULT_DATABASE
def yield_database(self, database_name: str) -> Iterable[CreateDatabaseRequest]:
"""
From topology.
Prepare a database request and pass it to the sink
"""
yield CreateDatabaseRequest(
name=database_name,
service=EntityReference(
id=self.context.database_service.id,
type="databaseService",
),
)
def get_database_schema_names(self) -> Iterable[str]:
"""
return schema names
"""
schemas = self.spark.catalog.listDatabases()
for schema in schemas:
if filter_by_schema(
@ -95,88 +137,110 @@ class DeltalakeSource(Source[Entity]):
):
self.status.filter(schema.name, "Schema pattern not allowed")
continue
yield from self.fetch_tables(schema.name)
yield schema.name
def get_status(self):
return self.status
def yield_database_schema(
self, schema_name: str
) -> Iterable[CreateDatabaseSchemaRequest]:
"""
From topology.
Prepare a database schema request and pass it to the sink
"""
yield CreateDatabaseSchemaRequest(
name=schema_name,
database=EntityReference(id=self.context.database.id, type="database"),
)
def prepare(self):
pass
def get_tables_name_and_type(self) -> Optional[Iterable[Tuple[str, str]]]:
"""
Handle table and views.
def _get_table_type(self, table_type: str):
return self.table_type_map.get(table_type.lower(), TableType.Regular.value)
Fetches them up using the context information and
the inspector set when preparing the db.
def fetch_tables(self, schema: str) -> Iterable[OMetaDatabaseAndTable]:
for table in self.spark.catalog.listTables(schema):
:return: tables or views, depending on config
"""
schema_name = self.context.database_schema.name.__root__
for table in self.spark.catalog.listTables(schema_name):
try:
table_name = table.name
if filter_by_table(
self.config.sourceConfig.config.tableFilterPattern, table_name
self.source_config.tableFilterPattern, table_name=table_name
):
self.status.filter(
"{}.{}".format(self.config.serviceName, table_name),
f"{table_name}",
"Table pattern not allowed",
)
continue
self.status.scanned("{}.{}".format(self.config.serviceName, table_name))
table_columns = self._fetch_columns(schema, table_name)
if table.tableType and table.tableType.lower() != "view":
table_entity = Table(
id=uuid.uuid4(),
name=table_name,
tableType=self._get_table_type(table.tableType),
description=table.description,
columns=table_columns,
)
else:
view_definition = self._fetch_view_schema(table_name)
table_entity = Table(
id=uuid.uuid4(),
name=table_name,
tableType=self._get_table_type(table.tableType),
description=table.description,
columns=table_columns,
viewDefinition=view_definition,
)
if (
self.source_config.includeTables
and table.tableType
and table.tableType.lower() != "view"
):
table_name = self.standardize_table_name(schema_name, table_name)
self.context.table_description = table.description
yield table_name, TableType.Regular
if (
self.source_config.includeViews
and table.tableType
and table.tableType.lower() == "view"
):
view_name = self.standardize_table_name(schema_name, table_name)
self.context.table_description = table.description
yield view_name, TableType.View
database = self.get_database_entity()
table_and_db = OMetaDatabaseAndTable(
table=table_entity,
database=database,
database_schema=self._get_database_schema(database, schema),
)
yield table_and_db
except Exception as err:
logger.error(err)
self.status.warnings.append(
"{}.{}".format(self.config.serviceName, table.name)
)
def get_database_entity(self) -> Database:
return Database(
id=uuid.uuid4(),
name=DEFAULT_DATABASE,
service=EntityReference(
id=self.service.id, type=self.service_connection.type.value
),
)
def _get_database_schema(self, database: Database, schema: str) -> DatabaseSchema:
return DatabaseSchema(
name=schema,
service=EntityReference(
id=self.service.id, type=self.service_connection.type.value
),
database=EntityReference(id=database.id, type="database"),
)
def _fetch_table_description(self, table_name: str) -> Optional[Dict]:
def yield_table(
self, table_name_and_type: Tuple[str, str]
) -> Iterable[Optional[CreateTableRequest]]:
"""
From topology.
Prepare a table request and pass it to the sink
"""
table_name, table_type = table_name_and_type
schema_name = self.context.database_schema.name.__root__
try:
table_details_df = self.spark.sql(f"describe detail {table_name}")
table_detail = table_details_df.collect()[0]
return table_detail.asDict()
except Exception as e:
logging.error(e)
columns = self.get_columns(schema_name, table_name)
view_definition = (
self._fetch_view_schema(table_name)
if table_type == TableType.View
else None
)
table_request = CreateTableRequest(
name=table_name,
tableType=table_type,
description=self.context.table_description,
columns=columns,
tableConstraints=None,
databaseSchema=EntityReference(
id=self.context.database_schema.id,
type="databaseSchema",
),
viewDefinition=view_definition,
)
yield table_request
self.register_record(table_request=table_request)
except Exception as err:
logger.debug(traceback.format_exc())
logger.error(err)
self.status.failures.append(
"{}.{}".format(self.config.serviceName, table_name)
)
def get_status(self):
return self.status
def prepare(self):
pass
def _fetch_view_schema(self, view_name: str) -> Optional[Dict]:
describe_output = []
@ -251,7 +315,7 @@ class DeltalakeSource(Source[Entity]):
)
return column
def _fetch_columns(self, schema: str, table: str) -> List[Column]:
def get_columns(self, schema: str, table: str) -> List[Column]:
raw_columns = []
field_dict: Dict[str, Any] = {}
table_name = f"{schema}.{table}"
@ -275,12 +339,19 @@ class DeltalakeSource(Source[Entity]):
return parsed_columns
def _is_complex_delta_type(self, delta_type: Any) -> bool:
return (
isinstance(delta_type, StructType)
or isinstance(delta_type, ArrayType)
or isinstance(delta_type, MapType)
)
def yield_view_lineage(
self, table_name_and_type: Tuple[str, str]
) -> Optional[Iterable[AddLineageRequest]]:
pass
def yield_tag(self, schema_name: str) -> Iterable[OMetaTagAndCategory]:
pass
def close(self):
pass
def standardize_table_name(self, schema: str, table: str) -> str:
return table
def test_connection(self) -> None:
pass

View File

@ -280,6 +280,10 @@ def _(connection: DeltaLakeConnection, verbose: bool = False) -> DeltaLakeClient
elif connection.metastoreFilePath:
builder.config("spark.sql.warehouse.dir", f"{connection.metastoreFilePath}")
if connection.connectionArguments:
for key, value in connection.connectionArguments:
builder.config(key, value)
deltalake_connection = DeltaLakeClient(
configure_spark_with_delta_pip(builder).getOrCreate()
)

View File

@ -16,7 +16,6 @@ import dask.dataframe as dd
import gcsfs
import pandas as pd
import pyarrow.parquet as pq
from google.cloud.storage.blob import Blob
from pandas import DataFrame
from metadata.utils.logger import utils_logger
@ -24,19 +23,21 @@ from metadata.utils.logger import utils_logger
logger = utils_logger()
def read_csv_from_gcs(key: Blob, bucket_name: str) -> DataFrame:
df = dd.read_csv(f"gs://{bucket_name}/{key.name}")
def read_csv_from_gcs(key: str, bucket_name: str) -> DataFrame:
df = dd.read_csv(f"gs://{bucket_name}/{key}")
return df
def read_tsv_from_gcs(key: Blob, bucket_name: str) -> DataFrame:
df = dd.read_csv(f"gs://{bucket_name}/{key.name}", sep="\t")
def read_tsv_from_gcs(key: str, bucket_name: str) -> DataFrame:
df = dd.read_csv(f"gs://{bucket_name}/{key}", sep="\t")
return df
def read_json_from_gcs(key: Blob) -> DataFrame:
def read_json_from_gcs(client, key: str, bucket_name: str) -> DataFrame:
try:
data = key.download_as_string().decode()
bucket = client.get_bucket(bucket_name)
blob = bucket.get_blob(key)
data = blob.download_as_string().decode()
data = json.loads(data)
if isinstance(data, list):
df = pd.DataFrame.from_dict(data)
@ -51,8 +52,8 @@ def read_json_from_gcs(key: Blob) -> DataFrame:
logger.error(verr)
def read_parquet_from_gcs(key: Blob, bucket_name: str) -> DataFrame:
def read_parquet_from_gcs(key: str, bucket_name: str) -> DataFrame:
gs = gcsfs.GCSFileSystem()
arrow_df = pq.ParquetDataset(f"gs://{bucket_name}/{key.name}", filesystem=gs)
arrow_df = pq.ParquetDataset(f"gs://{bucket_name}/{key}", filesystem=gs)
df = arrow_df.read_pandas().to_pandas()
return df

View File

@ -16,25 +16,24 @@ import pandas as pd
from pandas import DataFrame
def read_csv_from_s3(client: Any, key: dict, bucket_name: str) -> DataFrame:
csv_obj = client.get_object(Bucket=bucket_name, Key=key["Key"])
def read_csv_from_s3(client: Any, key: str, bucket_name: str) -> DataFrame:
csv_obj = client.get_object(Bucket=bucket_name, Key=key)
body = csv_obj["Body"]
csv_string = body.read().decode("utf-8")
df = pd.read_csv(StringIO(csv_string))
return df
def read_tsv_from_s3(client: Any, key: dict, bucket_name: str) -> DataFrame:
tsv_obj = client.get_object(Bucket=bucket_name, Key=key["Key"])
def read_tsv_from_s3(client: Any, key: str, bucket_name: str) -> DataFrame:
tsv_obj = client.get_object(Bucket=bucket_name, Key=key)
body = tsv_obj["Body"]
tsv_string = body.read().decode("utf-8")
df = pd.read_csv(StringIO(tsv_string), sep="\t")
return df
def read_json_from_s3(client: Any, key: dict, bucket_name: str) -> DataFrame:
obj = client.get_object(Bucket=bucket_name, Key=key["Key"])
def read_json_from_s3(client: Any, key: str, bucket_name: str) -> DataFrame:
obj = client.get_object(Bucket=bucket_name, Key=key)
json_text = obj["Body"].read().decode("utf-8")
data = json.loads(json_text)
if isinstance(data, list):
@ -44,7 +43,7 @@ def read_json_from_s3(client: Any, key: dict, bucket_name: str) -> DataFrame:
return df
def read_parquet_from_s3(client: Any, key: dict, bucket_name: str) -> DataFrame:
obj = client.get_object(Bucket=bucket_name, Key=key["Key"])
def read_parquet_from_s3(client: Any, key: str, bucket_name: str) -> DataFrame:
obj = client.get_object(Bucket=bucket_name, Key=key)
df = pd.read_parquet(BytesIO(obj["Body"].read()))
return df