Tag PATCH + Cleanup of helpers methods (#8150)

* cleanup

* lint

* Add tag patch

* Fix rename

* Dont kill tests
This commit is contained in:
Pere Miquel Brull 2022-10-15 14:56:30 +02:00 committed by GitHub
parent a8970b289d
commit d48fd468d4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 377 additions and 248 deletions

View File

@ -97,8 +97,8 @@ run_python_tests: ## Run all Python tests with coverage
.PHONY: coverage
coverage: ## Run all Python tests and generate the coverage XML report
$(MAKE) run_python_tests
coverage xml --rcfile ingestion/.coveragerc -o ingestion/coverage.xml
sed -e 's/$(shell python -c "import site; import os; from pathlib import Path; print(os.path.relpath(site.getsitepackages()[0], str(Path.cwd())).replace('/','\/'))")/src/g' ingestion/coverage.xml >> ingestion/ci-coverage.xml
coverage xml --rcfile ingestion/.coveragerc -o ingestion/coverage.xml || true
sed -e 's/$(shell python -c "import site; import os; from pathlib import Path; print(os.path.relpath(site.getsitepackages()[0], str(Path.cwd())).replace('/','\/'))")/src/g' ingestion/coverage.xml >> ingestion/ci-coverage.xml || true
.PHONY: sonar_ingestion
sonar_ingestion: ## Run the Sonar analysis based on the tests results and push it to SonarCloud

View File

@ -19,7 +19,7 @@ from typing import Dict, List, Optional, Tuple
from sqlparse.sql import Comparison, Identifier, Statement
from metadata.generated.schema.type.tableUsageCount import TableColumn, TableColumnJoin
from metadata.utils.helpers import find_in_list, get_formatted_entity_name
from metadata.utils.helpers import find_in_iter, get_formatted_entity_name
from metadata.utils.logger import ingestion_logger
# Prevent sqllineage from modifying the logger config
@ -96,15 +96,15 @@ def get_table_name_from_list(
:param tables: Contains all involved tables
:return: table name from parser info
"""
table = find_in_list(element=table_name, container=tables)
table = find_in_iter(element=table_name, container=tables)
if table:
return table
schema_table = find_in_list(element=f"{schema_name}.{table_name}", container=tables)
schema_table = find_in_iter(element=f"{schema_name}.{table_name}", container=tables)
if schema_table:
return schema_table
db_schema_table = find_in_list(
db_schema_table = find_in_iter(
element=f"{database_name}.{schema_name}.{table_name}", container=tables
)
if db_schema_table:
@ -173,7 +173,7 @@ def stateful_add_table_joins(
table_columns = [
join_info.tableColumn for join_info in statement_joins[source.table]
]
existing_table_column = find_in_list(element=source, container=table_columns)
existing_table_column = find_in_iter(element=source, container=table_columns)
if existing_table_column:
existing_join_info = [
join_info

View File

@ -0,0 +1,68 @@
# Copyright 2021 Collate
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
OMeta client create helpers
"""
import traceback
from typing import List
from metadata.generated.schema.entity.data.chart import Chart
from metadata.generated.schema.entity.services.connections.metadata.openMetadataConnection import (
OpenMetadataConnection,
)
from metadata.generated.schema.type.entityReference import EntityReference
from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.utils import fqn
from metadata.utils.logger import ometa_logger
logger = ometa_logger()
def create_ometa_client(
metadata_config: OpenMetadataConnection,
) -> OpenMetadata:
"""Create an OpenMetadata client
Args:
metadata_config (OpenMetadataConnection): OM connection config
Returns:
OpenMetadata: an OM client
"""
try:
metadata = OpenMetadata(metadata_config)
metadata.health_check()
return metadata
except Exception as exc:
logger.debug(traceback.format_exc())
logger.warning(f"Wild error initialising the OMeta Client {exc}")
raise ValueError(exc)
def get_chart_entities_from_id(
chart_ids: List[str], metadata: OpenMetadata, service_name: str
) -> List[EntityReference]:
"""
Method to get the chart entity using get_by_name api
"""
entities = []
for chart_id in chart_ids:
chart: Chart = metadata.get_by_name(
entity=Chart,
fqn=fqn.build(
metadata, Chart, chart_name=str(chart_id), service_name=service_name
),
)
if chart:
entity = EntityReference(id=chart.id, type="chart")
entities.append(entity)
return entities

View File

@ -21,8 +21,10 @@ from pydantic import BaseModel
from metadata.generated.schema.entity.data.table import Table
from metadata.generated.schema.type import basic
from metadata.generated.schema.type.tagLabel import LabelType, State, TagSource
from metadata.ingestion.ometa.client import REST
from metadata.ingestion.ometa.utils import model_str, ometa_logger
from metadata.utils.helpers import find_column_in_table_with_index
logger = ometa_logger()
@ -40,6 +42,9 @@ REPLACE = "replace"
ENTITY_DESCRIPTION = "/description"
COL_DESCRIPTION = "/columns/{index}/description"
ENTITY_TAG = "/tags/{tag_index}"
COL_TAG = "/columns/{index}/tags/{tag_index}"
class OMetaPatchMixin(Generic[T]):
"""
@ -67,7 +72,7 @@ class OMetaPatchMixin(Generic[T]):
instance to update
"""
instance = self.get_by_id(entity=entity, entity_id=entity_id)
instance = self.get_by_id(entity=entity, entity_id=entity_id, fields=["*"])
if not instance:
logger.warning(
@ -139,8 +144,7 @@ class OMetaPatchMixin(Generic[T]):
description: str,
force: bool = False,
) -> Optional[T]:
"""
Given an Entity type and ID, JSON PATCH the description of the column
"""Given an Entity ID, JSON PATCH the description of the column
Args
entity_id: ID
@ -152,18 +156,14 @@ class OMetaPatchMixin(Generic[T]):
Updated Entity
"""
table: Table = self._validate_instance_description(
entity=Table, entity_id=entity_id
entity=Table,
entity_id=entity_id,
)
if not table:
return None
col_index, col = next(
(
(col_index, col)
for col_index, col in enumerate(table.columns)
if str(col.name.__root__).lower() == column_name.lower()
),
None,
col_index, col = find_column_in_table_with_index(
column_name=column_name, table=table
)
if col_index is None:
@ -199,3 +199,123 @@ class OMetaPatchMixin(Generic[T]):
)
return None
def patch_tag(
self,
entity: Type[T],
entity_id: Union[str, basic.Uuid],
tag_fqn: str,
from_glossary: bool = False,
) -> Optional[T]:
"""
Given an Entity type and ID, JSON PATCH the tag.
Args
entity (T): Entity Type
entity_id: ID
description: new description to add
force: if True, we will patch any existing description. Otherwise, we will maintain
the existing data.
Returns
Updated Entity
"""
instance = self._validate_instance_description(
entity=entity, entity_id=entity_id
)
if not instance:
return None
tag_index = len(instance.tags) if instance.tags else 0
try:
res = self.client.patch(
path=f"{self.get_suffix(Table)}/{model_str(entity_id)}",
data=json.dumps(
[
{
OPERATION: ADD,
PATH: ENTITY_TAG.format(tag_index=tag_index),
VALUE: {
"labelType": LabelType.Automated.value,
"source": TagSource.Tag.value
if not from_glossary
else TagSource.Glossary.value,
"state": State.Confirmed.value,
"tagFQN": tag_fqn,
},
}
]
),
)
return entity(**res)
except Exception as exc:
logger.debug(traceback.format_exc())
logger.error(
f"Error trying to PATCH description for {entity.__class__.__name__} [{entity_id}]: {exc}"
)
return None
def patch_column_tag(
self,
entity_id: Union[str, basic.Uuid],
column_name: str,
tag_fqn: str,
from_glossary: bool = False,
) -> Optional[T]:
"""Given an Entity ID, JSON PATCH the tag of the column
Args
entity_id: ID
tag_fqn: new tag to add
column_name: column to update
from_glossary: the tag comes from a glossary
Returns
Updated Entity
"""
table: Table = self._validate_instance_description(
entity=Table, entity_id=entity_id
)
if not table:
return None
col_index, _ = find_column_in_table_with_index(
column_name=column_name, table=table
)
if col_index is None:
logger.warning(f"Cannot find column {column_name} in Table.")
return None
tag_index = len(table.tags) if table.tags else 0
try:
res = self.client.patch(
path=f"{self.get_suffix(Table)}/{model_str(entity_id)}",
data=json.dumps(
[
{
OPERATION: ADD,
PATH: COL_TAG.format(index=col_index, tag_index=tag_index),
VALUE: {
"labelType": LabelType.Automated.value,
"source": TagSource.Tag.value
if not from_glossary
else TagSource.Glossary.value,
"state": State.Confirmed.value,
"tagFQN": tag_fqn,
},
}
]
),
)
return Table(**res)
except Exception as exc:
logger.debug(traceback.format_exc())
logger.warning(
f"Error trying to PATCH description for Table Column: {entity_id}, {column_name}: {exc}"
)
return None

View File

@ -34,6 +34,9 @@ from metadata.generated.schema.api.data.createTableProfile import (
)
from metadata.generated.schema.api.data.createTopic import CreateTopicRequest
from metadata.generated.schema.api.lineage.addLineage import AddLineageRequest
from metadata.generated.schema.api.services.createStorageService import (
CreateStorageServiceRequest,
)
from metadata.generated.schema.api.teams.createRole import CreateRoleRequest
from metadata.generated.schema.api.teams.createTeam import CreateTeamRequest
from metadata.generated.schema.api.teams.createUser import CreateUserRequest
@ -67,6 +70,7 @@ from metadata.generated.schema.entity.services.databaseService import DatabaseSe
from metadata.generated.schema.entity.services.messagingService import MessagingService
from metadata.generated.schema.entity.services.mlmodelService import MlModelService
from metadata.generated.schema.entity.services.pipelineService import PipelineService
from metadata.generated.schema.entity.services.storageService import StorageService
from metadata.generated.schema.entity.teams.team import Team
from metadata.generated.schema.entity.teams.user import User
from metadata.generated.schema.metadataIngestion.workflow import (
@ -88,14 +92,11 @@ from metadata.ingestion.models.tests_data import (
OMetaTestSuiteSample,
)
from metadata.ingestion.models.user import OMetaUserProfile
from metadata.ingestion.ometa.client_utils import get_chart_entities_from_id
from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.ingestion.source.database.database_service import TableLocationLink
from metadata.utils import fqn
from metadata.utils.helpers import (
get_chart_entities_from_id,
get_standard_chart_type,
get_storage_service_or_create,
)
from metadata.utils.helpers import get_standard_chart_type
from metadata.utils.logger import ingestion_logger
logger = ingestion_logger()
@ -107,6 +108,25 @@ COL_DESCRIPTION = "Description"
TableKey = namedtuple("TableKey", ["schema", "table_name"])
def get_storage_service_or_create(service_json, metadata_config) -> StorageService:
"""
Get an existing storage service or create a new one based on the config provided
To be refactored after cleaning Storage Services
"""
metadata = OpenMetadata(metadata_config)
service: StorageService = metadata.get_by_name(
entity=StorageService, fqn=service_json["name"]
)
if service is not None:
return service
created_service = metadata.create_or_update(
CreateStorageServiceRequest(**service_json)
)
return created_service
class InvalidSampleDataException(Exception):
"""
Sample data is not valid to be ingested

View File

@ -60,10 +60,11 @@ from metadata.ingestion.api.common import Entity
from metadata.ingestion.api.source import InvalidSourceException, Source, SourceStatus
from metadata.ingestion.models.ometa_tag_category import OMetaTagAndCategory
from metadata.ingestion.models.user import OMetaUserProfile
from metadata.ingestion.ometa.client_utils import get_chart_entities_from_id
from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.ingestion.source.database.column_type_parser import ColumnTypeParser
from metadata.utils import fqn
from metadata.utils.helpers import get_chart_entities_from_id, get_standard_chart_type
from metadata.utils.helpers import get_standard_chart_type
from metadata.utils.logger import ingestion_logger
from metadata.utils.sql_queries import (
NEO4J_AMUNDSEN_DASHBOARD_QUERY,

View File

@ -48,6 +48,7 @@ from metadata.generated.schema.metadataIngestion.workflow import (
from metadata.ingestion.api.parser import parse_workflow_config_gracefully
from metadata.ingestion.api.processor import ProcessorStatus
from metadata.ingestion.api.sink import Sink
from metadata.ingestion.ometa.client_utils import create_ometa_client
from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.ingestion.source.database.common_db_source import SQLSourceStatus
from metadata.interfaces.profiler_protocol import ProfilerProtocol
@ -67,7 +68,6 @@ from metadata.utils.class_helper import (
get_service_type_from_source_type,
)
from metadata.utils.filters import filter_by_database, filter_by_schema, filter_by_table
from metadata.utils.helpers import create_ometa_client
from metadata.utils.logger import profiler_logger
from metadata.utils.workflow_output_handler import print_profiler_status

View File

@ -46,13 +46,13 @@ from metadata.generated.schema.tests.testDefinition import TestDefinition
from metadata.generated.schema.tests.testSuite import TestSuite
from metadata.ingestion.api.parser import parse_workflow_config_gracefully
from metadata.ingestion.api.processor import ProcessorStatus
from metadata.ingestion.ometa.client_utils import create_ometa_client
from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.interfaces.sqalchemy.sqa_test_suite_interface import SQATestSuiteInterface
from metadata.orm_profiler.api.models import TablePartitionConfig
from metadata.test_suite.api.models import TestCaseDefinition, TestSuiteProcessorConfig
from metadata.test_suite.runner.core import DataTestsRunner
from metadata.utils import entity_link
from metadata.utils.helpers import create_ometa_client
from metadata.utils.logger import test_suite_logger
from metadata.utils.workflow_output_handler import print_test_suite_status

View File

@ -14,42 +14,13 @@ Helpers module for ingestion related methods
"""
import re
import traceback
from datetime import datetime, timedelta
from functools import wraps
from time import perf_counter
from typing import Any, Dict, Iterable, List, Optional
from typing import Any, Dict, Iterable, List, Optional, Tuple
from metadata.generated.schema.api.services.createDashboardService import (
CreateDashboardServiceRequest,
)
from metadata.generated.schema.api.services.createDatabaseService import (
CreateDatabaseServiceRequest,
)
from metadata.generated.schema.api.services.createMessagingService import (
CreateMessagingServiceRequest,
)
from metadata.generated.schema.api.services.createStorageService import (
CreateStorageServiceRequest,
)
from metadata.generated.schema.entity.data.chart import Chart, ChartType
from metadata.generated.schema.entity.data.chart import ChartType
from metadata.generated.schema.entity.data.table import Column, Table
from metadata.generated.schema.entity.services.connections.metadata.openMetadataConnection import (
OpenMetadataConnection,
)
from metadata.generated.schema.entity.services.dashboardService import DashboardService
from metadata.generated.schema.entity.services.databaseService import DatabaseService
from metadata.generated.schema.entity.services.messagingService import MessagingService
from metadata.generated.schema.entity.services.storageService import StorageService
from metadata.generated.schema.metadataIngestion.workflow import (
Source as WorkflowSource,
)
from metadata.generated.schema.type.entityReference import (
EntityReference,
EntityReferenceList,
)
from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.utils import fqn
from metadata.utils.logger import utils_logger
logger = utils_logger()
@ -150,139 +121,6 @@ def snake_to_camel(snake_str):
return "".join(split_str)
def get_database_service_or_create(
config: WorkflowSource, metadata_config, service_name=None
) -> DatabaseService:
"""
Get an existing database service or create a new one based on the config provided
"""
metadata = OpenMetadata(metadata_config)
if not service_name:
service_name = config.serviceName
service: DatabaseService = metadata.get_by_name(
entity=DatabaseService, fqn=service_name
)
if not service:
config_dict = config.dict()
service_connection_config = config_dict.get("serviceConnection").get("config")
password = (
service_connection_config.get("password").get_secret_value()
if service_connection_config and service_connection_config.get("password")
else None
)
# Use a JSON to dynamically parse the pydantic model
# based on the serviceType
# TODO revisit me
service_json = {
"connection": {
"config": {
"hostPort": service_connection_config.get("hostPort")
if service_connection_config
else None,
"username": service_connection_config.get("username")
if service_connection_config
else None,
"password": password,
"database": service_connection_config.get("database")
if service_connection_config
else None,
"connectionOptions": service_connection_config.get(
"connectionOptions"
)
if service_connection_config
else None,
"connectionArguments": service_connection_config.get(
"connectionArguments"
)
if service_connection_config
else None,
}
},
"name": service_name,
"description": "",
"serviceType": service_connection_config.get("type").value
if service_connection_config
else None,
}
created_service: DatabaseService = metadata.create_or_update(
CreateDatabaseServiceRequest(**service_json)
)
logger.info(f"Creating DatabaseService instance for {service_name}")
return created_service
return service
def get_messaging_service_or_create(
service_name: str,
message_service_type: str,
config: dict,
metadata_config,
) -> MessagingService:
"""
Get an existing messaging service or create a new one based on the config provided
"""
metadata = OpenMetadata(metadata_config)
service: MessagingService = metadata.get_by_name(
entity=MessagingService, fqn=service_name
)
if service is not None:
return service
created_service = metadata.create_or_update(
CreateMessagingServiceRequest(
name=service_name, serviceType=message_service_type, connection=config
)
)
return created_service
def get_dashboard_service_or_create(
service_name: str,
dashboard_service_type: str,
config: dict,
metadata_config,
) -> DashboardService:
"""
Get an existing dashboard service or create a new one based on the config provided
"""
metadata = OpenMetadata(metadata_config)
service: DashboardService = metadata.get_by_name(
entity=DashboardService, fqn=service_name
)
if service is not None:
return service
dashboard_config = {"config": config}
created_service = metadata.create_or_update(
CreateDashboardServiceRequest(
name=service_name,
serviceType=dashboard_service_type,
connection=dashboard_config,
)
)
return created_service
def get_storage_service_or_create(service_json, metadata_config) -> StorageService:
"""
Get an existing storage service or create a new one based on the config provided
"""
metadata = OpenMetadata(metadata_config)
service: StorageService = metadata.get_by_name(
entity=StorageService, fqn=service_json["name"]
)
if service is not None:
return service
created_service = metadata.create_or_update(
CreateStorageServiceRequest(**service_json)
)
return created_service
def datetime_to_ts(date: Optional[datetime]) -> Optional[int]:
"""
Convert a given date to a timestamp as an Int in milliseconds
@ -302,16 +140,6 @@ def get_formatted_entity_name(name: str) -> Optional[str]:
)
def get_raw_extract_iter(alchemy_helper) -> Iterable[Dict[str, Any]]:
"""
Provides iterator of result row from SQLAlchemy helper
:return:
"""
rows = alchemy_helper.execute_query()
for row in rows:
yield row
def replace_special_with(raw: str, replacement: str) -> str:
"""
Replace special characters in a string by a hyphen
@ -331,28 +159,7 @@ def get_standard_chart_type(raw_chart_type: str) -> str:
return om_chart_type_dict.get(raw_chart_type.lower(), ChartType.Other)
def get_chart_entities_from_id(
chart_ids: List[str], metadata: OpenMetadata, service_name: str
) -> List[EntityReferenceList]:
"""
Method to get the chart entity using get_by_name api
"""
entities = []
for chart_id in chart_ids:
chart: Chart = metadata.get_by_name(
entity=Chart,
fqn=fqn.build(
metadata, Chart, chart_name=str(chart_id), service_name=service_name
),
)
if chart:
entity = EntityReference(id=chart.id, type="chart")
entities.append(entity)
return entities
def find_in_list(element: Any, container: Iterable[Any]) -> Optional[Any]:
def find_in_iter(element: Any, container: Iterable[Any]) -> Optional[Any]:
"""
If the element is in the container, return it.
Otherwise, return None
@ -360,7 +167,7 @@ def find_in_list(element: Any, container: Iterable[Any]) -> Optional[Any]:
:param container: container with element
:return: element or None
"""
return next(iter([elem for elem in container if elem == element]), None)
return next((elem for elem in container if elem == element), None)
def find_column_in_table(column_name: str, table: Table) -> Optional[Column]:
@ -372,6 +179,30 @@ def find_column_in_table(column_name: str, table: Table) -> Optional[Column]:
)
def find_column_in_table_with_index(
column_name: str, table: Table
) -> Optional[Tuple[int, Column]]:
"""Return a column and its index in a Table Entity
Args:
column_name (str): column to find
table (Table): Table Entity
Return:
A tuple of Index, Column if the column is found
"""
col_index, col = next(
(
(col_index, col)
for col_index, col in enumerate(table.columns)
if str(col.name.__root__).lower() == column_name.lower()
),
(None, None),
)
return col_index, col
def list_to_dict(original: Optional[List[str]], sep: str = "=") -> Dict[str, str]:
"""
Given a list with strings that have a separator,
@ -386,30 +217,6 @@ def list_to_dict(original: Optional[List[str]], sep: str = "=") -> Dict[str, str
return dict(split_original)
def create_ometa_client(
metadata_config: OpenMetadataConnection,
) -> OpenMetadata:
"""Create an OpenMetadata client
Args:
metadata_config (OpenMetadataConnection): OM connection config
Returns:
OpenMetadata: an OM client
"""
try:
metadata = OpenMetadata(metadata_config)
metadata.health_check()
return metadata
except Exception as exc:
logger.debug(traceback.format_exc())
logger.warning(
f"No OpenMetadata server configuration found. "
f"Setting client to `None`. You won't be able to access the server from the client: {exc}"
)
raise ValueError(exc)
def clean_up_starting_ending_double_quotes_in_string(string: str) -> str:
"""Remove start and ending double quotes in a string

View File

@ -14,7 +14,6 @@ OpenMetadata high-level API Table test
"""
from unittest import TestCase
from ingestion.src.metadata.utils.helpers import find_column_in_table
from metadata.generated.schema.api.data.createDatabase import CreateDatabaseRequest
from metadata.generated.schema.api.data.createDatabaseSchema import (
CreateDatabaseSchemaRequest,
@ -40,6 +39,7 @@ from metadata.generated.schema.security.client.openMetadataJWTClientConfig impor
)
from metadata.generated.schema.type.entityReference import EntityReference
from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.utils.helpers import find_column_in_table
class OMetaTableTest(TestCase):
@ -191,3 +191,46 @@ class OMetaTableTest(TestCase):
updated_col = find_column_in_table(column_name="another", table=force_updated)
assert updated_col.description.__root__ == "Forced new"
def test_patch_tag(self):
"""
Update table tags
"""
updated: Table = self.metadata.patch_tag(
entity=Table,
entity_id=self.entity_id,
tag_fqn="PII.Sensitive", # Shipped by default
)
assert updated.tags[0].tagFQN.__root__ == "PII.Sensitive"
updated: Table = self.metadata.patch_tag(
entity=Table,
entity_id=self.entity_id,
tag_fqn="Tier.Tier2", # Shipped by default
)
assert updated.tags[0].tagFQN.__root__ == "PII.Sensitive"
assert updated.tags[1].tagFQN.__root__ == "Tier.Tier2"
def test_patch_column_tags(self):
"""
Update column tags
"""
updated: Table = self.metadata.patch_column_tag(
entity_id=self.entity_id,
tag_fqn="PII.Sensitive", # Shipped by default
column_name="id",
)
updated_col = find_column_in_table(column_name="id", table=updated)
assert updated_col.tags[0].tagFQN.__root__ == "PII.Sensitive"
updated_again: Table = self.metadata.patch_column_tag(
entity_id=self.entity_id,
tag_fqn="Tier.Tier2", # Shipped by default
column_name="id",
)
updated_again_col = find_column_in_table(column_name="id", table=updated_again)
assert updated_again_col.tags[0].tagFQN.__root__ == "PII.Sensitive"
assert updated_again_col.tags[1].tagFQN.__root__ == "Tier.Tier2"

View File

@ -0,0 +1,70 @@
# Copyright 2021 Collate
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Test helpers
"""
import uuid
from unittest import TestCase
from metadata.generated.schema.entity.data.table import Column, DataType, Table
from metadata.generated.schema.type.entityReference import EntityReference
from metadata.utils.helpers import (
find_column_in_table,
find_column_in_table_with_index,
find_in_iter,
)
class HelpersTest(TestCase):
def test_find_in_iter(self):
"""We can find elements within a list"""
iter_ = ("A", "B", "C")
found = find_in_iter(element="B", container=iter_)
self.assertEqual("B", found)
not_found = find_in_iter(element="random", container=iter_)
self.assertIsNone(not_found)
def test_find_column_in_table(self):
"""Check we can find a column inside a table"""
table = Table(
id=uuid.uuid4(),
name="test",
databaseSchema=EntityReference(
id=uuid.uuid4(),
type="databaseSchema",
),
fullyQualifiedName="test-service-table.test-db.test-schema.test",
columns=[
Column(name="id", dataType=DataType.BIGINT),
Column(name="hello", dataType=DataType.BIGINT),
Column(name="foo", dataType=DataType.BIGINT),
Column(name="bar", dataType=DataType.BIGINT),
],
)
col = find_column_in_table(column_name="foo", table=table)
self.assertEqual(col, Column(name="foo", dataType=DataType.BIGINT))
not_found = find_column_in_table(column_name="random", table=table)
self.assertIsNone(not_found)
idx, col = find_column_in_table_with_index(column_name="foo", table=table)
self.assertEqual(col, Column(name="foo", dataType=DataType.BIGINT))
self.assertEqual(idx, 2)
not_found_col, not_found_idx = find_column_in_table_with_index(
column_name="random", table=table
)
self.assertIsNone(not_found)
self.assertIsNone(not_found_idx)