mirror of
https://github.com/datahub-project/datahub.git
synced 2025-12-28 10:28:22 +00:00
fix(ingest): enable mypy disallow_incomplete_defs and disallow_untyped_decorators (#2393)
This commit is contained in:
parent
c57d5a3731
commit
2af4603e49
@ -7,9 +7,10 @@ produces a new JSON file called demo_data.json.
|
||||
import csv
|
||||
import dataclasses
|
||||
import json
|
||||
import os
|
||||
import pathlib
|
||||
import time
|
||||
from typing import List
|
||||
from typing import Dict, List
|
||||
|
||||
from datahub.metadata.schema_classes import (
|
||||
AuditStampClass,
|
||||
@ -41,20 +42,20 @@ class Directive:
|
||||
depends_on: List[str]
|
||||
|
||||
|
||||
def read_mces(path) -> List[MetadataChangeEventClass]:
|
||||
def read_mces(path: os.PathLike) -> List[MetadataChangeEventClass]:
|
||||
with open(path) as f:
|
||||
objs = json.load(f)
|
||||
mces = [MetadataChangeEventClass.from_obj(obj) for obj in objs]
|
||||
return mces
|
||||
|
||||
|
||||
def write_mces(path, mces: List[MetadataChangeEventClass]) -> None:
|
||||
def write_mces(path: os.PathLike, mces: List[MetadataChangeEventClass]) -> None:
|
||||
objs = [mce.to_obj() for mce in mces]
|
||||
with open(path, "w") as f:
|
||||
json.dump(objs, f, indent=4)
|
||||
|
||||
|
||||
def parse_directive(row) -> Directive:
|
||||
def parse_directive(row: Dict) -> Directive:
|
||||
return Directive(
|
||||
table=row["table"],
|
||||
drop=bool(row["drop"]),
|
||||
|
||||
@ -20,7 +20,7 @@ def suppress_checks_in_file(filepath: str) -> None:
|
||||
@click.command()
|
||||
@click.argument("schema_file", type=click.Path(exists=True))
|
||||
@click.argument("outdir", type=click.Path())
|
||||
def generate(schema_file: str, outdir: str):
|
||||
def generate(schema_file: str, outdir: str) -> None:
|
||||
# print(f'using {schema_file}')
|
||||
with open(schema_file) as f:
|
||||
raw_schema_text = f.read()
|
||||
|
||||
@ -24,9 +24,10 @@ exclude = \.git|venv|build|dist
|
||||
ignore_missing_imports = yes
|
||||
strict_optional = yes
|
||||
check_untyped_defs = yes
|
||||
disallow_incomplete_defs = yes
|
||||
disallow_untyped_decorators = yes
|
||||
# eventually we'd like to enable these
|
||||
disallow_untyped_defs = no
|
||||
disallow_incomplete_defs = no
|
||||
|
||||
[isort]
|
||||
profile = black
|
||||
|
||||
@ -45,7 +45,7 @@ def local_docker() -> None:
|
||||
default=False,
|
||||
help="Include extra information for each plugin",
|
||||
)
|
||||
def plugins(verbose) -> None:
|
||||
def plugins(verbose: bool) -> None:
|
||||
"""Check the enabled ingestion plugins"""
|
||||
|
||||
click.secho("Sources:", bold=True)
|
||||
|
||||
@ -45,7 +45,7 @@ def get_client_with_error():
|
||||
docker_cli.close()
|
||||
|
||||
|
||||
def memory_in_gb(mem_bytes: int):
|
||||
def memory_in_gb(mem_bytes: int) -> float:
|
||||
return mem_bytes / (1024 * 1024 * 1000)
|
||||
|
||||
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import IO
|
||||
from typing import IO, cast
|
||||
|
||||
import toml
|
||||
|
||||
@ -8,6 +8,6 @@ from .common import ConfigurationMechanism
|
||||
class TomlConfigurationMechanism(ConfigurationMechanism):
|
||||
"""Ability to load configuration from toml files"""
|
||||
|
||||
def load_config(self, config_fp: IO):
|
||||
def load_config(self, config_fp: IO) -> dict:
|
||||
config = toml.load(config_fp)
|
||||
return config
|
||||
return cast(dict, config) # converts MutableMapping -> dict
|
||||
|
||||
@ -8,6 +8,6 @@ from datahub.configuration import ConfigurationMechanism
|
||||
class YamlConfigurationMechanism(ConfigurationMechanism):
|
||||
"""Ability to load configuration from yaml files"""
|
||||
|
||||
def load_config(self, config_fp: IO):
|
||||
def load_config(self, config_fp: IO) -> dict:
|
||||
config = yaml.safe_load(config_fp)
|
||||
return config
|
||||
|
||||
@ -3,7 +3,7 @@ from typing import Callable
|
||||
from confluent_kafka import SerializingProducer
|
||||
from confluent_kafka.schema_registry import SchemaRegistryClient
|
||||
from confluent_kafka.schema_registry.avro import AvroSerializer
|
||||
from confluent_kafka.serialization import StringSerializer
|
||||
from confluent_kafka.serialization import SerializationContext, StringSerializer
|
||||
from pydantic import Field
|
||||
|
||||
from datahub.configuration.common import ConfigModel
|
||||
@ -31,7 +31,9 @@ class DatahubKafkaEmitter:
|
||||
}
|
||||
schema_registry_client = SchemaRegistryClient(schema_registry_conf)
|
||||
|
||||
def convert_mce_to_dict(mce: MetadataChangeEvent, ctx):
|
||||
def convert_mce_to_dict(
|
||||
mce: MetadataChangeEvent, ctx: SerializationContext
|
||||
) -> dict:
|
||||
tuple_encoding = mce.to_obj(tuples=True)
|
||||
return tuple_encoding
|
||||
|
||||
@ -54,7 +56,7 @@ class DatahubKafkaEmitter:
|
||||
self,
|
||||
mce: MetadataChangeEvent,
|
||||
callback: Callable[[Exception, str], None],
|
||||
):
|
||||
) -> None:
|
||||
# Call poll to trigger any callbacks on success / failure of previous writes
|
||||
self.producer.poll(0)
|
||||
self.producer.produce(
|
||||
|
||||
@ -20,7 +20,7 @@ class Registry(Generic[T]):
|
||||
tp = typing_inspect.get_args(cls)[0]
|
||||
return tp
|
||||
|
||||
def _check_cls(self, cls: Type[T]):
|
||||
def _check_cls(self, cls: Type[T]) -> None:
|
||||
if inspect.isabstract(cls):
|
||||
raise ValueError(
|
||||
f"cannot register an abstract type in the registry; got {cls}"
|
||||
|
||||
@ -9,11 +9,11 @@ from datahub.ingestion.api.report import Report
|
||||
|
||||
@dataclass
|
||||
class SinkReport(Report):
|
||||
records_written = 0
|
||||
records_written: int = 0
|
||||
warnings: List[Any] = field(default_factory=list)
|
||||
failures: List[Any] = field(default_factory=list)
|
||||
|
||||
def report_record_written(self, record_envelope: RecordEnvelope):
|
||||
def report_record_written(self, record_envelope: RecordEnvelope) -> None:
|
||||
self.records_written += 1
|
||||
|
||||
def report_warning(self, info: Any) -> None:
|
||||
@ -25,7 +25,9 @@ class SinkReport(Report):
|
||||
|
||||
class WriteCallback(metaclass=ABCMeta):
|
||||
@abstractmethod
|
||||
def on_success(self, record_envelope: RecordEnvelope, success_metadata: dict):
|
||||
def on_success(
|
||||
self, record_envelope: RecordEnvelope, success_metadata: dict
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@ -34,7 +36,7 @@ class WriteCallback(metaclass=ABCMeta):
|
||||
record_envelope: RecordEnvelope,
|
||||
failure_exception: Exception,
|
||||
failure_metadata: dict,
|
||||
):
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
|
||||
|
||||
@ -9,7 +9,7 @@ from .report import Report
|
||||
|
||||
@dataclass
|
||||
class SourceReport(Report):
|
||||
workunits_produced = 0
|
||||
workunits_produced: int = 0
|
||||
workunit_ids: List[str] = field(default_factory=list)
|
||||
|
||||
warnings: Dict[str, List[str]] = field(default_factory=dict)
|
||||
@ -35,7 +35,7 @@ WorkUnitType = TypeVar("WorkUnitType", bound=WorkUnit)
|
||||
|
||||
class Extractor(Generic[WorkUnitType], Closeable, metaclass=ABCMeta):
|
||||
@abstractmethod
|
||||
def configure(self, config_dict: dict, ctx: PipelineContext):
|
||||
def configure(self, config_dict: dict, ctx: PipelineContext) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@ -2,7 +2,7 @@ from typing import Iterable, cast
|
||||
|
||||
from datahub.ingestion.api import RecordEnvelope
|
||||
from datahub.ingestion.api.common import PipelineContext
|
||||
from datahub.ingestion.api.source import Extractor
|
||||
from datahub.ingestion.api.source import Extractor, WorkUnit
|
||||
from datahub.ingestion.source.metadata_common import MetadataWorkUnit
|
||||
from datahub.metadata.com.linkedin.pegasus2avro.mxe import MetadataChangeEvent
|
||||
|
||||
@ -10,10 +10,12 @@ from datahub.metadata.com.linkedin.pegasus2avro.mxe import MetadataChangeEvent
|
||||
class WorkUnitMCEExtractor(Extractor):
|
||||
"""An extractor that simply returns MCE-s inside workunits back as records"""
|
||||
|
||||
def configure(self, config_dict: dict, ctx: PipelineContext):
|
||||
def configure(self, config_dict: dict, ctx: PipelineContext) -> None:
|
||||
pass
|
||||
|
||||
def get_records(self, workunit) -> Iterable[RecordEnvelope[MetadataChangeEvent]]:
|
||||
def get_records(
|
||||
self, workunit: WorkUnit
|
||||
) -> Iterable[RecordEnvelope[MetadataChangeEvent]]:
|
||||
workunit = cast(MetadataWorkUnit, workunit)
|
||||
if len(workunit.mce.proposedSnapshot.aspects) == 0:
|
||||
raise AttributeError("every mce must have at least one aspect")
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import logging
|
||||
from typing import Any, List
|
||||
from typing import Any, List, Union
|
||||
|
||||
import avro.schema
|
||||
|
||||
@ -42,10 +42,10 @@ _field_type_mapping = {
|
||||
}
|
||||
|
||||
|
||||
def _get_column_type(field_type) -> SchemaFieldDataType:
|
||||
def _get_column_type(field_type: Union[str, dict]) -> SchemaFieldDataType:
|
||||
tp = field_type
|
||||
if hasattr(tp, "type"):
|
||||
tp = tp.type
|
||||
tp = tp.type # type: ignore
|
||||
tp = str(tp)
|
||||
TypeClass: Any = _field_type_mapping.get(tp)
|
||||
# Note: we could populate the nestedTypes field for unions and similar fields
|
||||
@ -55,7 +55,7 @@ def _get_column_type(field_type) -> SchemaFieldDataType:
|
||||
return dt
|
||||
|
||||
|
||||
def _is_nullable(schema: avro.schema.Schema):
|
||||
def _is_nullable(schema: avro.schema.Schema) -> bool:
|
||||
if isinstance(schema, avro.schema.UnionSchema):
|
||||
return any(_is_nullable(sub_schema) for sub_schema in schema.schemas)
|
||||
elif isinstance(schema, avro.schema.PrimitiveSchema):
|
||||
|
||||
@ -34,13 +34,20 @@ class PipelineConfig(ConfigModel):
|
||||
|
||||
|
||||
class LoggingCallback(WriteCallback):
|
||||
def on_success(self, record_envelope: RecordEnvelope, success_meta):
|
||||
def on_success(
|
||||
self, record_envelope: RecordEnvelope, success_metadata: dict
|
||||
) -> None:
|
||||
logger.info(f"sink wrote workunit {record_envelope.metadata['workunit_id']}")
|
||||
|
||||
def on_failure(self, record_envelope: RecordEnvelope, exception, failure_meta):
|
||||
def on_failure(
|
||||
self,
|
||||
record_envelope: RecordEnvelope,
|
||||
failure_exception: Exception,
|
||||
failure_metadata: dict,
|
||||
) -> None:
|
||||
logger.error(
|
||||
f"failed to write record with workunit {record_envelope.metadata['workunit_id']}"
|
||||
f" with {exception} and info {failure_meta}"
|
||||
f" with {failure_exception} and info {failure_metadata}"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -23,7 +23,7 @@ class ConsoleSink(Sink):
|
||||
|
||||
def write_record_async(
|
||||
self, record_envelope: RecordEnvelope, write_callback: WriteCallback
|
||||
):
|
||||
) -> None:
|
||||
print(f"{record_envelope}")
|
||||
if write_callback:
|
||||
self.report.report_record_written(record_envelope)
|
||||
|
||||
@ -36,14 +36,14 @@ class DatahubKafkaSink(Sink):
|
||||
report: SinkReport
|
||||
emitter: DatahubKafkaEmitter
|
||||
|
||||
def __init__(self, config: KafkaSinkConfig, ctx):
|
||||
def __init__(self, config: KafkaSinkConfig, ctx: PipelineContext):
|
||||
super().__init__(ctx)
|
||||
self.config = config
|
||||
self.report = SinkReport()
|
||||
self.emitter = DatahubKafkaEmitter(self.config)
|
||||
|
||||
@classmethod
|
||||
def create(cls, config_dict, ctx: PipelineContext):
|
||||
def create(cls, config_dict: dict, ctx: PipelineContext) -> "DatahubKafkaSink":
|
||||
config = KafkaSinkConfig.parse_obj(config_dict)
|
||||
return cls(config, ctx)
|
||||
|
||||
@ -57,7 +57,7 @@ class DatahubKafkaSink(Sink):
|
||||
self,
|
||||
record_envelope: RecordEnvelope[MetadataChangeEvent],
|
||||
write_callback: WriteCallback,
|
||||
):
|
||||
) -> None:
|
||||
mce = record_envelope.record
|
||||
self.emitter.emit_mce_async(
|
||||
mce,
|
||||
|
||||
@ -29,7 +29,7 @@ class DatahubRestSink(Sink):
|
||||
self.emitter = DatahubRestEmitter(self.config.server)
|
||||
|
||||
@classmethod
|
||||
def create(cls, config_dict: dict, ctx: PipelineContext):
|
||||
def create(cls, config_dict: dict, ctx: PipelineContext) -> "DatahubRestSink":
|
||||
config = DatahubRestSinkConfig.parse_obj(config_dict)
|
||||
return cls(ctx, config)
|
||||
|
||||
@ -43,7 +43,7 @@ class DatahubRestSink(Sink):
|
||||
self,
|
||||
record_envelope: RecordEnvelope[MetadataChangeEvent],
|
||||
write_callback: WriteCallback,
|
||||
):
|
||||
) -> None:
|
||||
mce = record_envelope.record
|
||||
|
||||
try:
|
||||
|
||||
@ -29,7 +29,7 @@ class FileSink(Sink):
|
||||
self.wrote_something = False
|
||||
|
||||
@classmethod
|
||||
def create(cls, config_dict: dict, ctx: PipelineContext):
|
||||
def create(cls, config_dict: dict, ctx: PipelineContext) -> "FileSink":
|
||||
config = FileSinkConfig.parse_obj(config_dict)
|
||||
return cls(ctx, config)
|
||||
|
||||
|
||||
@ -65,7 +65,7 @@ class DBTNode:
|
||||
return self.__class__.__name__ + str(tuple(sorted(fields))).replace("'", "")
|
||||
|
||||
|
||||
def get_columns(catalog_node) -> List[DBTColumn]:
|
||||
def get_columns(catalog_node: dict) -> List[DBTColumn]:
|
||||
columns = []
|
||||
|
||||
raw_columns = catalog_node["columns"]
|
||||
@ -83,7 +83,7 @@ def get_columns(catalog_node) -> List[DBTColumn]:
|
||||
|
||||
|
||||
def extract_dbt_entities(
|
||||
nodes, catalog, platform: str, environment: str
|
||||
nodes: Dict[str, dict], catalog: Dict[str, dict], platform: str, environment: str
|
||||
) -> List[DBTNode]:
|
||||
dbt_entities = []
|
||||
|
||||
@ -169,7 +169,7 @@ def get_custom_properties(node: DBTNode) -> Dict[str, str]:
|
||||
|
||||
|
||||
def get_upstreams(
|
||||
upstreams: List[str], all_nodes, platform: str, environment: str
|
||||
upstreams: List[str], all_nodes: Dict[str, dict], platform: str, environment: str
|
||||
) -> List[str]:
|
||||
upstream_urns = []
|
||||
|
||||
|
||||
@ -101,8 +101,10 @@ class GlueSource(Source):
|
||||
return cls(config, ctx)
|
||||
|
||||
def get_workunits(self) -> Iterable[MetadataWorkUnit]:
|
||||
def get_all_tables():
|
||||
def get_tables_from_database(database_name: str, tables: List):
|
||||
def get_all_tables() -> List[dict]:
|
||||
def get_tables_from_database(
|
||||
database_name: str, tables: List
|
||||
) -> List[dict]:
|
||||
kwargs = {"DatabaseName": database_name}
|
||||
while True:
|
||||
data = self.glue_client.get_tables(**kwargs)
|
||||
@ -113,7 +115,7 @@ class GlueSource(Source):
|
||||
break
|
||||
return tables
|
||||
|
||||
def get_tables_from_all_databases():
|
||||
def get_tables_from_all_databases() -> List[dict]:
|
||||
tables = []
|
||||
kwargs: Dict = {}
|
||||
while True:
|
||||
@ -126,7 +128,7 @@ class GlueSource(Source):
|
||||
return tables
|
||||
|
||||
if self.source_config.database_pattern.is_fully_specified_allow_list():
|
||||
all_tables: List = []
|
||||
all_tables: List[dict] = []
|
||||
database_names = self.source_config.database_pattern.get_allowed_list()
|
||||
for database in database_names:
|
||||
all_tables += get_tables_from_database(database, all_tables)
|
||||
@ -153,7 +155,7 @@ class GlueSource(Source):
|
||||
yield workunit
|
||||
|
||||
def _extract_record(self, table: Dict, table_name: str) -> MetadataChangeEvent:
|
||||
def get_owner(time) -> OwnershipClass:
|
||||
def get_owner(time: int) -> OwnershipClass:
|
||||
owner = table.get("Owner")
|
||||
if owner:
|
||||
owners = [
|
||||
@ -187,7 +189,7 @@ class GlueSource(Source):
|
||||
tags=[],
|
||||
)
|
||||
|
||||
def get_schema_metadata(glue_source: GlueSource):
|
||||
def get_schema_metadata(glue_source: GlueSource) -> SchemaMetadata:
|
||||
schema = table["StorageDescriptor"]["Columns"]
|
||||
fields: List[SchemaField] = []
|
||||
for field in schema:
|
||||
|
||||
@ -33,7 +33,7 @@ class KafkaSourceConfig(ConfigModel):
|
||||
|
||||
@dataclass
|
||||
class KafkaSourceReport(SourceReport):
|
||||
topics_scanned = 0
|
||||
topics_scanned: int = 0
|
||||
filtered: List[str] = field(default_factory=list)
|
||||
|
||||
def report_topic_scanned(self, topic: str) -> None:
|
||||
|
||||
@ -12,7 +12,7 @@ from datahub.metadata.com.linkedin.pegasus2avro.mxe import MetadataChangeEvent
|
||||
from datahub.metadata.schema_classes import CorpUserInfoClass, CorpUserSnapshotClass
|
||||
|
||||
|
||||
def create_controls(pagesize):
|
||||
def create_controls(pagesize: int) -> SimplePagedResultsControl:
|
||||
"""
|
||||
Create an LDAP control with a page size of "pagesize".
|
||||
"""
|
||||
@ -37,7 +37,7 @@ def set_cookie(lc_object, pctrls, pagesize):
|
||||
return cookie
|
||||
|
||||
|
||||
def guess_person_ldap(dn, attrs) -> Optional[str]:
|
||||
def guess_person_ldap(dn: str, attrs: dict) -> Optional[str]:
|
||||
"""Determine the user's LDAP based on the DN and attributes."""
|
||||
if "sAMAccountName" in attrs:
|
||||
return attrs["sAMAccountName"][0].decode()
|
||||
@ -124,7 +124,7 @@ class LDAPSource(Source):
|
||||
|
||||
cookie = set_cookie(self.lc, pctrls, self.config.page_size)
|
||||
|
||||
def handle_user(self, dn, attrs) -> Iterable[MetadataWorkUnit]:
|
||||
def handle_user(self, dn: str, attrs: dict) -> Iterable[MetadataWorkUnit]:
|
||||
"""
|
||||
Handle a DN and attributes by adding manager info and constructing a
|
||||
work unit based on the information.
|
||||
@ -154,7 +154,7 @@ class LDAPSource(Source):
|
||||
yield from []
|
||||
|
||||
def build_corp_user_mce(
|
||||
self, dn, attrs, manager_ldap
|
||||
self, dn: str, attrs: dict, manager_ldap: Optional[str]
|
||||
) -> Optional[MetadataChangeEvent]:
|
||||
"""
|
||||
Create the MetadataChangeEvent via DN and attributes.
|
||||
|
||||
@ -2,6 +2,7 @@ from dataclasses import dataclass, field
|
||||
from typing import Iterable, List, Optional
|
||||
|
||||
import pymongo
|
||||
from pymongo.mongo_client import MongoClient
|
||||
|
||||
from datahub.configuration.common import AllowDenyPattern, ConfigModel
|
||||
from datahub.ingestion.api.common import PipelineContext
|
||||
@ -43,6 +44,7 @@ class MongoDBSourceReport(SourceReport):
|
||||
class MongoDBSource(Source):
|
||||
config: MongoDBConfig
|
||||
report: MongoDBSourceReport
|
||||
mongo_client: MongoClient
|
||||
|
||||
def __init__(self, ctx: PipelineContext, config: MongoDBConfig):
|
||||
super().__init__(ctx)
|
||||
@ -68,7 +70,7 @@ class MongoDBSource(Source):
|
||||
self.mongo_client.admin.command("ismaster")
|
||||
|
||||
@classmethod
|
||||
def create(cls, config_dict: dict, ctx: PipelineContext):
|
||||
def create(cls, config_dict: dict, ctx: PipelineContext) -> "MongoDBSource":
|
||||
config = MongoDBConfig.parse_obj(config_dict)
|
||||
return cls(ctx, config)
|
||||
|
||||
|
||||
@ -9,7 +9,7 @@ class SQLServerConfig(BasicSQLAlchemyConfig):
|
||||
host_port = "localhost:1433"
|
||||
scheme = "mssql+pytds"
|
||||
|
||||
def get_identifier(self, schema: str, table: str):
|
||||
def get_identifier(self, schema: str, table: str) -> str:
|
||||
regular = f"{schema}.{table}"
|
||||
if self.database:
|
||||
return f"{self.database}.{regular}"
|
||||
|
||||
@ -32,12 +32,12 @@ from datahub.metadata.com.linkedin.pegasus2avro.schema import (
|
||||
)
|
||||
from datahub.metadata.schema_classes import DatasetPropertiesClass
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger: logging.Logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SQLSourceReport(SourceReport):
|
||||
tables_scanned = 0
|
||||
tables_scanned: int = 0
|
||||
filtered: List[str] = field(default_factory=list)
|
||||
|
||||
def report_table_scanned(self, table_name: str) -> None:
|
||||
@ -150,7 +150,7 @@ def get_column_type(
|
||||
|
||||
|
||||
def get_schema_metadata(
|
||||
sql_report: SQLSourceReport, dataset_name: str, platform: str, columns
|
||||
sql_report: SQLSourceReport, dataset_name: str, platform: str, columns: List[dict]
|
||||
) -> SchemaMetadata:
|
||||
canonical_schema: List[SchemaField] = []
|
||||
for column in columns:
|
||||
|
||||
@ -18,7 +18,7 @@ if TYPE_CHECKING:
|
||||
from airflow.configuration import conf
|
||||
|
||||
|
||||
def _entities_to_urn_list(iolets: List):
|
||||
def _entities_to_urn_list(iolets: List) -> List[str]:
|
||||
return [let.urn for let in iolets]
|
||||
|
||||
|
||||
@ -39,7 +39,7 @@ class DatahubAirflowLineageBackend(LineageBackend):
|
||||
inlets: Optional[List] = None,
|
||||
outlets: Optional[List] = None,
|
||||
context: Dict = None,
|
||||
):
|
||||
) -> None:
|
||||
context = context or {} # ensure not None to satisfy mypy
|
||||
|
||||
dag: "DAG" = context["dag"]
|
||||
|
||||
@ -16,8 +16,11 @@ class DatahubBaseOperator(BaseOperator):
|
||||
|
||||
hook: Union[DatahubRestHook, DatahubKafkaHook]
|
||||
|
||||
@apply_defaults
|
||||
def __init__(
|
||||
# mypy is not a fan of this. Newer versions of Airflow support proper typing for the decorator
|
||||
# using PEP 612. However, there is not yet a good way to inherit the types of the kwargs from
|
||||
# the superclass.
|
||||
@apply_defaults # type: ignore[misc]
|
||||
def __init__( # type: ignore[no-untyped-def]
|
||||
self,
|
||||
*,
|
||||
datahub_conn_id: str,
|
||||
@ -30,8 +33,9 @@ class DatahubBaseOperator(BaseOperator):
|
||||
|
||||
|
||||
class DatahubEmitterOperator(DatahubBaseOperator):
|
||||
@apply_defaults
|
||||
def __init__(
|
||||
# See above for why these mypy type issues are ignored here.
|
||||
@apply_defaults # type: ignore[misc]
|
||||
def __init__( # type: ignore[no-untyped-def]
|
||||
self,
|
||||
mces: List[MetadataChangeEvent],
|
||||
datahub_conn_id: str,
|
||||
|
||||
@ -17,8 +17,8 @@ def wait_for_port(
|
||||
docker_services: pytest_docker.plugin.Services,
|
||||
container_name: str,
|
||||
container_port: int,
|
||||
timeout=15.0,
|
||||
):
|
||||
timeout: float = 15.0,
|
||||
) -> None:
|
||||
# port = docker_services.port_for(container_name, container_port)
|
||||
docker_services.wait_until_responsive(
|
||||
timeout=timeout,
|
||||
|
||||
@ -9,7 +9,7 @@ def load_json_file(filename: str) -> object:
|
||||
return a
|
||||
|
||||
|
||||
def assert_mces_equal(output, golden) -> None:
|
||||
def assert_mces_equal(output: dict, golden: dict) -> None:
|
||||
# This method assumes we're given a list of MCE json objects.
|
||||
|
||||
ignore_paths = {
|
||||
|
||||
@ -1,8 +1,10 @@
|
||||
import io
|
||||
import json
|
||||
import pathlib
|
||||
|
||||
import fastavro
|
||||
import pytest
|
||||
from _pytest.config import Config as PytestConfig
|
||||
from click.testing import CliRunner
|
||||
|
||||
from datahub.entrypoints import datahub
|
||||
@ -11,6 +13,12 @@ from datahub.ingestion.source.mce_file import iterate_mce_file
|
||||
from datahub.metadata.schema_classes import SCHEMA_JSON_STR, MetadataChangeEventClass
|
||||
from tests.test_helpers import mce_helpers
|
||||
|
||||
# The current PytestConfig solution is somewhat ugly and not ideal.
|
||||
# However, it is currently the best solution available, as the type itself is not
|
||||
# exported: https://docs.pytest.org/en/stable/reference.html#config.
|
||||
# As pytest's type support improves, this will likely change.
|
||||
# TODO: revisit pytestconfig as https://github.com/pytest-dev/pytest/issues/7469 progresses.
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"json_filename",
|
||||
@ -21,7 +29,9 @@ from tests.test_helpers import mce_helpers
|
||||
"tests/unit/serde/test_serde_chart_snapshot.json",
|
||||
],
|
||||
)
|
||||
def test_serde_to_json(pytestconfig, tmp_path, json_filename):
|
||||
def test_serde_to_json(
|
||||
pytestconfig: PytestConfig, tmp_path: pathlib.Path, json_filename: str
|
||||
) -> None:
|
||||
golden_file = pytestconfig.rootpath / json_filename
|
||||
|
||||
output_filename = "output.json"
|
||||
@ -48,7 +58,7 @@ def test_serde_to_json(pytestconfig, tmp_path, json_filename):
|
||||
"tests/unit/serde/test_serde_chart_snapshot.json",
|
||||
],
|
||||
)
|
||||
def test_serde_to_avro(pytestconfig, json_filename):
|
||||
def test_serde_to_avro(pytestconfig: PytestConfig, json_filename: str) -> None:
|
||||
# In this test, we want to read in from JSON -> MCE object.
|
||||
# Next we serialize from MCE to Avro and then deserialize back to MCE.
|
||||
# Finally, we want to compare the two MCE objects.
|
||||
@ -88,7 +98,7 @@ def test_serde_to_avro(pytestconfig, json_filename):
|
||||
"examples/mce_files/bootstrap_mce.json",
|
||||
],
|
||||
)
|
||||
def test_check_mce_schema(pytestconfig, json_filename):
|
||||
def test_check_mce_schema(pytestconfig: PytestConfig, json_filename: str) -> None:
|
||||
json_file_path = pytestconfig.rootpath / json_filename
|
||||
|
||||
runner = CliRunner()
|
||||
|
||||
@ -1,22 +1,22 @@
|
||||
from datahub.configuration.common import AllowDenyPattern
|
||||
|
||||
|
||||
def test_allow_all():
|
||||
def test_allow_all() -> None:
|
||||
pattern = AllowDenyPattern.allow_all()
|
||||
assert pattern.allowed("foo.table")
|
||||
|
||||
|
||||
def test_deny_all():
|
||||
def test_deny_all() -> None:
|
||||
pattern = AllowDenyPattern(allow=[], deny=[".*"])
|
||||
assert not pattern.allowed("foo.table")
|
||||
|
||||
|
||||
def test_single_table():
|
||||
def test_single_table() -> None:
|
||||
pattern = AllowDenyPattern(allow=["foo.mytable"])
|
||||
assert pattern.allowed("foo.mytable")
|
||||
|
||||
|
||||
def test_default_deny():
|
||||
def test_default_deny() -> None:
|
||||
pattern = AllowDenyPattern(allow=["foo.mytable"])
|
||||
assert not pattern.allowed("foo.bar")
|
||||
|
||||
|
||||
@ -4,11 +4,11 @@ from datahub.ingestion.sink.sink_registry import sink_registry
|
||||
from datahub.ingestion.source.source_registry import source_registry
|
||||
|
||||
|
||||
def test_sources_not_abstract():
|
||||
def test_sources_not_abstract() -> None:
|
||||
for cls in source_registry.mapping.values():
|
||||
assert not inspect.isabstract(cls)
|
||||
|
||||
|
||||
def test_sinks_not_abstract():
|
||||
def test_sinks_not_abstract() -> None:
|
||||
for cls in sink_registry.mapping.values():
|
||||
assert not inspect.isabstract(cls)
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import MagicMock, patch, sentinel
|
||||
|
||||
from datahub.ingestion.api.common import RecordEnvelope
|
||||
from datahub.ingestion.api.sink import SinkReport, WriteCallback
|
||||
@ -32,7 +32,7 @@ class KafkaSinkTest(unittest.TestCase):
|
||||
kafka_sink = DatahubKafkaSink.create(
|
||||
{"connection": {"bootstrap": "foobar:9092"}}, mock_context
|
||||
)
|
||||
re = RecordEnvelope(record="test", metadata={})
|
||||
re = RecordEnvelope(record=sentinel, metadata={})
|
||||
kafka_sink.write_record_async(re, callback)
|
||||
assert mock_producer_instance.poll.call_count == 1 # poll() called once
|
||||
self.validate_kafka_callback(
|
||||
|
||||
@ -148,15 +148,15 @@ basic_3 = json.loads(
|
||||
)
|
||||
|
||||
|
||||
def test_basic_diff_same():
|
||||
def test_basic_diff_same() -> None:
|
||||
mce_helpers.assert_mces_equal(basic_1, basic_2)
|
||||
|
||||
|
||||
def test_basic_diff_only_owner_change():
|
||||
def test_basic_diff_only_owner_change() -> None:
|
||||
with pytest.raises(AssertionError):
|
||||
mce_helpers.assert_mces_equal(basic_2, basic_3)
|
||||
|
||||
|
||||
def test_basic_diff_owner_change():
|
||||
def test_basic_diff_owner_change() -> None:
|
||||
with pytest.raises(AssertionError):
|
||||
mce_helpers.assert_mces_equal(basic_1, basic_3)
|
||||
|
||||
@ -23,7 +23,7 @@ def test_registry_nonempty(registry):
|
||||
assert len(registry.mapping) > 0
|
||||
|
||||
|
||||
def test_list_all():
|
||||
def test_list_all() -> None:
|
||||
# This just verifies that it runs without error.
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(datahub, ["check", "plugins", "--verbose"])
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user