diff --git a/ingestion/src/metadata/ingestion/bulksink/metadata_usage.py b/ingestion/src/metadata/ingestion/bulksink/metadata_usage.py index e7bd9af46c0..174a875c78e 100644 --- a/ingestion/src/metadata/ingestion/bulksink/metadata_usage.py +++ b/ingestion/src/metadata/ingestion/bulksink/metadata_usage.py @@ -30,10 +30,7 @@ from metadata.ingestion.models.table_queries import ( ) from metadata.ingestion.ometa.client import APIError from metadata.ingestion.ometa.ometa_api import OpenMetadata -from metadata.ingestion.ometa.openmetadata_rest import ( - MetadataServerConfig, - OpenMetadataAPIClient, -) +from metadata.ingestion.ometa.openmetadata_rest import MetadataServerConfig logger = logging.getLogger(__name__) diff --git a/ingestion/src/metadata/ingestion/ometa/ometa_api.py b/ingestion/src/metadata/ingestion/ometa/ometa_api.py index e2d6a6c9a6a..5e459de8926 100644 --- a/ingestion/src/metadata/ingestion/ometa/ometa_api.py +++ b/ingestion/src/metadata/ingestion/ometa/ometa_api.py @@ -1,5 +1,5 @@ import logging -from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union, get_args +from typing import Generic, List, Optional, Type, TypeVar, Union, get_args from pydantic import BaseModel @@ -12,12 +12,12 @@ from metadata.generated.schema.entity.data.model import Model from metadata.generated.schema.entity.data.pipeline import Pipeline from metadata.generated.schema.entity.data.report import Report from metadata.generated.schema.entity.data.table import Table -from metadata.generated.schema.entity.data.task import Task from metadata.generated.schema.entity.data.topic import Topic from metadata.generated.schema.entity.services.dashboardService import DashboardService from metadata.generated.schema.entity.services.databaseService import DatabaseService from metadata.generated.schema.entity.services.messagingService import MessagingService from metadata.generated.schema.entity.services.pipelineService import PipelineService +from metadata.generated.schema.entity.tags.tagCategory import Tag from metadata.generated.schema.entity.teams.user import User from metadata.ingestion.ometa.auth_provider import AuthenticationProvider from metadata.ingestion.ometa.client import REST, APIError, ClientConfig @@ -158,6 +158,9 @@ class OpenMetadata(OMetaLineageMixin, OMetaTableMixin, Generic[T, C]): if issubclass(entity, Report): return "/reports" + if issubclass(entity, Tag): + return "/tags" + if issubclass(entity, get_args(Union[User, self.get_create_entity_type(User)])): return "/users" @@ -335,7 +338,11 @@ class OpenMetadata(OMetaLineageMixin, OMetaTableMixin, Generic[T, C]): return None def list_entities( - self, entity: Type[T], fields: str = None, after: str = None, limit: int = 1000 + self, + entity: Type[T], + fields: Optional[List[str]] = None, + after: str = None, + limit: int = 1000, ) -> EntityList[T]: """ Helps us paginate over the collection @@ -344,7 +351,7 @@ class OpenMetadata(OMetaLineageMixin, OMetaTableMixin, Generic[T, C]): suffix = self.get_suffix(entity) url_limit = f"?limit={limit}" url_after = f"&after={after}" if after else "" - url_fields = f"&fields={fields}" if fields else "" + url_fields = f"&fields={','.join(fields)}" if fields else "" resp = self.client.get(f"{suffix}{url_limit}{url_after}{url_fields}") @@ -378,6 +385,13 @@ class OpenMetadata(OMetaLineageMixin, OMetaTableMixin, Generic[T, C]): resp = self.client.post(f"/usage/compute.percentile/{entity_name}/{date}") logger.debug("published compute percentile {}".format(resp)) + def list_tags_by_category(self, category: str) -> List[Tag]: + """ + List all tags + """ + resp = self.client.get(f"{self.get_suffix(Tag)}/{category}") + return [Tag(**d) for d in resp["children"]] + def health_check(self) -> bool: """ Run endpoint health-check. Return `true` if OK diff --git a/ingestion/src/metadata/ingestion/ometa/openmetadata_rest.py b/ingestion/src/metadata/ingestion/ometa/openmetadata_rest.py index 8cae59a985c..4263bf85070 100644 --- a/ingestion/src/metadata/ingestion/ometa/openmetadata_rest.py +++ b/ingestion/src/metadata/ingestion/ometa/openmetadata_rest.py @@ -18,8 +18,7 @@ import json import logging import time import uuid -from typing import List, Optional -from urllib.error import HTTPError +from typing import List import google.auth import google.auth.transport.requests @@ -28,53 +27,15 @@ from jose import jwt from pydantic import BaseModel from metadata.config.common import ConfigModel -from metadata.generated.schema.api.data.createChart import CreateChartEntityRequest -from metadata.generated.schema.api.data.createDashboard import ( - CreateDashboardEntityRequest, -) -from metadata.generated.schema.api.data.createDatabase import ( - CreateDatabaseEntityRequest, -) -from metadata.generated.schema.api.data.createPipeline import ( - CreatePipelineEntityRequest, -) -from metadata.generated.schema.api.data.createTable import CreateTableEntityRequest -from metadata.generated.schema.api.data.createTask import CreateTaskEntityRequest -from metadata.generated.schema.api.data.createTopic import CreateTopicEntityRequest -from metadata.generated.schema.api.lineage.addLineage import AddLineage -from metadata.generated.schema.api.services.createDashboardService import ( - CreateDashboardServiceEntityRequest, -) -from metadata.generated.schema.api.services.createDatabaseService import ( - CreateDatabaseServiceEntityRequest, -) -from metadata.generated.schema.api.services.createMessagingService import ( - CreateMessagingServiceEntityRequest, -) -from metadata.generated.schema.api.services.createPipelineService import ( - CreatePipelineServiceEntityRequest, -) -from metadata.generated.schema.entity.data.chart import Chart from metadata.generated.schema.entity.data.dashboard import Dashboard from metadata.generated.schema.entity.data.database import Database -from metadata.generated.schema.entity.data.model import Model from metadata.generated.schema.entity.data.pipeline import Pipeline -from metadata.generated.schema.entity.data.table import ( - Table, - TableData, - TableJoins, - TableProfile, -) +from metadata.generated.schema.entity.data.table import Table, TableProfile from metadata.generated.schema.entity.data.task import Task from metadata.generated.schema.entity.data.topic import Topic -from metadata.generated.schema.entity.services.dashboardService import DashboardService from metadata.generated.schema.entity.services.databaseService import DatabaseService -from metadata.generated.schema.entity.services.messagingService import MessagingService -from metadata.generated.schema.entity.services.pipelineService import PipelineService from metadata.generated.schema.entity.tags.tagCategory import Tag -from metadata.ingestion.models.table_queries import TableUsageRequest from metadata.ingestion.ometa.auth_provider import AuthenticationProvider -from metadata.ingestion.ometa.client import REST, APIError, ClientConfig logger = logging.getLogger(__name__) @@ -197,409 +158,3 @@ class Auth0AuthenticationProvider(AuthenticationProvider): data = res.read() token = json.loads(data.decode("utf-8")) return token["access_token"] - - -class OpenMetadataAPIClient(object): - client: REST - _auth_provider: AuthenticationProvider - - def __init__(self, config: MetadataServerConfig, raw_data: bool = False): - self.config = config - if self.config.auth_provider_type == "google": - self._auth_provider: AuthenticationProvider = ( - GoogleAuthenticationProvider.create(self.config) - ) - elif self.config.auth_provider_type == "okta": - self._auth_provider: AuthenticationProvider = ( - OktaAuthenticationProvider.create(self.config) - ) - elif self.config.auth_provider_type == "auth0": - self._auth_provider: AuthenticationProvider = ( - Auth0AuthenticationProvider.create(self.config) - ) - else: - self._auth_provider: AuthenticationProvider = ( - NoOpAuthenticationProvider.create(self.config) - ) - client_config: ClientConfig = ClientConfig( - base_url=self.config.api_endpoint, - api_version=self.config.api_version, - auth_header="X-Catalog-Source", - auth_token=self._auth_provider.auth_token(), - ) - self.client = REST(client_config) - self._use_raw_data = raw_data - - def get_database_service(self, service_name: str) -> Optional[DatabaseService]: - """Get the Database service""" - try: - resp = self.client.get( - "/services/databaseServices/name/{}".format(service_name) - ) - return DatabaseService(**resp) - except APIError as err: - logger.error(f"Error trying to GET the database service {service_name}") - return None - - def get_database_service_by_id(self, service_id: str) -> DatabaseService: - """Get the Database Service by ID""" - resp = self.client.get("/services/databaseServices/{}".format(service_id)) - return DatabaseService(**resp) - - def list_database_services(self) -> DatabaseServiceEntities: - """Get a list of mysql services""" - resp = self.client.get("/services/databaseServices") - if self._use_raw_data: - return resp - else: - return [DatabaseService(**p) for p in resp["data"]] - - def create_database_service( - self, database_service: CreateDatabaseServiceEntityRequest - ) -> DatabaseService: - """Create a new Database Service""" - resp = self.client.post( - "/services/databaseServices", data=database_service.json() - ) - return DatabaseService(**resp) - - def delete_database_service(self, service_id: str) -> None: - """Delete a Database service""" - self.client.delete("/services/databaseServices/{}".format(service_id)) - - def get_database_by_name( - self, database_name: str, fields: [] = ["service"] - ) -> Database: - """Get the Database""" - params = {"fields": ",".join(fields)} - resp = self.client.get("/databases/name/{}".format(database_name), data=params) - return Database(**resp) - - def list_databases(self, fields: [] = ["service"]) -> DatabaseEntities: - """List all databases""" - params = {"fields": ",".join(fields)} - resp = self.client.get("/databases", data=params) - if self._use_raw_data: - return resp - else: - return [Database(**d) for d in resp["data"]] - - def get_database_by_id( - self, database_id: str, fields: [] = ["owner,service,tables,usageSummary"] - ) -> Database: - """Get Database By ID""" - params = {"fields": ",".join(fields)} - resp = self.client.get("/databases/{}".format(database_id), data=params) - return Database(**resp) - - def create_database( - self, create_database_request: CreateDatabaseEntityRequest - ) -> Database: - """Create a Database""" - resp = self.client.put("/databases", data=create_database_request.json()) - return Database(**resp) - - def delete_database(self, database_id: str): - """Delete Database using ID""" - self.client.delete("/databases/{}".format(database_id)) - - def list_tables( - self, fields: str = None, after: str = None, limit: int = 1000 - ) -> TableEntities: - """List all tables""" - - if fields is None: - resp = self.client.get("/tables") - else: - if after is not None: - resp = self.client.get( - "/tables?fields={}&after={}&limit={}".format(fields, after, limit) - ) - else: - resp = self.client.get( - "/tables?fields={}&limit={}".format(fields, limit) - ) - - if self._use_raw_data: - return resp - else: - tables = [Table(**t) for t in resp["data"]] - total = resp["paging"]["total"] - after = resp["paging"]["after"] if "after" in resp["paging"] else None - return TableEntities(tables=tables, total=total, after=after) - - def get_table_by_id(self, table_id: str, fields: [] = ["columns"]) -> Table: - """Get Table By ID""" - params = {"fields": ",".join(fields)} - resp = self.client.get("/tables/{}".format(table_id), data=params) - return Table(**resp) - - def create_or_update_table( - self, create_table_request: CreateTableEntityRequest - ) -> Table: - """Create or Update a Table""" - resp = self.client.put("/tables", data=create_table_request.json()) - resp.pop("database", None) - return Table(**resp) - - def get_table_by_name(self, table_name: str, fields: [] = ["columns"]) -> Table: - """Get Table By Name""" - params = {"fields": ",".join(fields)} - resp = self.client.get("/tables/name/{}".format(table_name), data=params) - return Table(**resp) - - def list_tags_by_category(self, category: str) -> {}: - """List all tags""" - resp = self.client.get("/tags/{}".format(category)) - return [Tag(**d) for d in resp["children"]] - - def get_messaging_service(self, service_name: str) -> Optional[MessagingService]: - """Get the Messaging service""" - try: - resp = self.client.get( - "/services/messagingServices/name/{}".format(service_name) - ) - return MessagingService(**resp) - except APIError as err: - logger.error(f"Error trying to GET the messaging service {service_name}") - return None - - def get_messaging_service_by_id(self, service_id: str) -> MessagingService: - """Get the Messaging Service by ID""" - resp = self.client.get("/services/messagingServices/{}".format(service_id)) - return MessagingService(**resp) - - def create_messaging_service( - self, messaging_service: CreateMessagingServiceEntityRequest - ) -> MessagingService: - """Create a new Database Service""" - resp = self.client.post( - "/services/messagingServices", data=messaging_service.json() - ) - return MessagingService(**resp) - - def create_or_update_topic( - self, create_topic_request: CreateTopicEntityRequest - ) -> Topic: - """Create or Update a Table""" - resp = self.client.put("/topics", data=create_topic_request.json()) - return Topic(**resp) - - def list_topics( - self, fields: str = None, after: str = None, limit: int = 1000 - ) -> TopicEntities: - """List all topics""" - if fields is None: - resp = self.client.get("/tables") - else: - if after is not None: - resp = self.client.get( - "/topics?fields={}&after={}&limit={}".format(fields, after, limit) - ) - else: - resp = self.client.get( - "/topics?fields={}&limit={}".format(fields, limit) - ) - - if self._use_raw_data: - return resp - else: - topics = [Topic(**t) for t in resp["data"]] - total = resp["paging"]["total"] - after = resp["paging"]["after"] if "after" in resp["paging"] else None - return TopicEntities(topics=topics, total=total, after=after) - - def get_dashboard_service(self, service_name: str) -> Optional[DashboardService]: - """Get the Dashboard service""" - try: - resp = self.client.get( - "/services/dashboardServices/name/{}".format(service_name) - ) - return DashboardService(**resp) - except APIError as err: - logger.error(f"Error trying to GET the dashboard service {service_name}") - return None - - def get_dashboard_service_by_id(self, service_id: str) -> DashboardService: - """Get the Dashboard Service by ID""" - resp = self.client.get("/services/dashboardServices/{}".format(service_id)) - return DashboardService(**resp) - - def create_dashboard_service( - self, dashboard_service: CreateDashboardServiceEntityRequest - ) -> Optional[DashboardService]: - """Create a new Database Service""" - try: - resp = self.client.post( - "/services/dashboardServices", data=dashboard_service.json() - ) - return DashboardService(**resp) - except APIError as err: - logger.error( - f"Error trying to POST the dashboard service {dashboard_service}" - ) - return None - - def create_or_update_chart( - self, create_chart_request: CreateChartEntityRequest - ) -> Chart: - """Create or Update a Chart""" - resp = self.client.put("/charts", data=create_chart_request.json()) - return Chart(**resp) - - def get_chart_by_id(self, chart_id: str, fields: [] = ["tags,service"]) -> Chart: - """Get Chart By ID""" - params = {"fields": ",".join(fields)} - resp = self.client.get("/charts/{}".format(chart_id), data=params) - return Chart(**resp) - - def create_or_update_dashboard( - self, create_dashboard_request: CreateDashboardEntityRequest - ) -> Dashboard: - """Create or Update a Dashboard""" - resp = self.client.put("/dashboards", data=create_dashboard_request.json()) - return Dashboard(**resp) - - def get_dashboard_by_name( - self, dashboard_name: str, fields: [] = ["charts", "service"] - ) -> Dashboard: - """Get Dashboard By Name""" - params = {"fields": ",".join(fields)} - resp = self.client.get( - "/dashboards/name/{}".format(dashboard_name), data=params - ) - return Dashboard(**resp) - - def list_dashboards( - self, fields: str = None, after: str = None, limit: int = 1000 - ) -> DashboardEntities: - """List all dashboards""" - - if fields is None: - resp = self.client.get("/dashboards") - else: - if after is not None: - resp = self.client.get( - "/dashboards?fields={}&after={}&limit={}".format( - fields, after, limit - ) - ) - else: - resp = self.client.get( - "/dashboards?fields={}&limit={}".format(fields, limit) - ) - - if self._use_raw_data: - return resp - else: - dashboards = [Dashboard(**t) for t in resp["data"]] - total = resp["paging"]["total"] - after = resp["paging"]["after"] if "after" in resp["paging"] else None - return DashboardEntities(dashboards=dashboards, total=total, after=after) - - def get_pipeline_service(self, service_name: str) -> Optional[PipelineService]: - """Get the Pipeline service""" - try: - resp = self.client.get( - "/services/pipelineServices/name/{}".format(service_name) - ) - return PipelineService(**resp) - except APIError as err: - logger.error(f"Error trying to GET the pipeline service {service_name}") - return None - - def get_pipeline_service_by_id(self, service_id: str) -> PipelineService: - """Get the Pipeline Service by ID""" - resp = self.client.get("/services/pipelineServices/{}".format(service_id)) - return PipelineService(**resp) - - def create_pipeline_service( - self, pipeline_service: CreatePipelineServiceEntityRequest - ) -> Optional[PipelineService]: - """Create a new Pipeline Service""" - try: - resp = self.client.post( - "/services/pipelineServices", data=pipeline_service.json() - ) - return PipelineService(**resp) - except APIError as err: - logger.error( - f"Error trying to POST the pipeline service {pipeline_service}" - ) - return None - - def create_or_update_task( - self, create_task_request: CreateTaskEntityRequest - ) -> Task: - """Create or Update a Task""" - resp = self.client.put("/tasks", data=create_task_request.json()) - return Task(**resp) - - def get_task_by_id(self, task_id: str, fields: [] = ["tags, service"]) -> Task: - """Get Task By ID""" - params = {"fields": ",".join(fields)} - resp = self.client.get("/tasks/{}".format(task_id), data=params) - return Task(**resp) - - def list_tasks( - self, fields: str = None, offset: int = 0, limit: int = 1000 - ) -> Tasks: - """List all tasks""" - if fields is None: - resp = self.client.get("/tasks?offset={}&limit={}".format(offset, limit)) - else: - resp = self.client.get( - "/tasks?fields={}&offset={}&limit={}".format(fields, offset, limit) - ) - if self._use_raw_data: - return resp - else: - return [Task(**t) for t in resp["data"]] - - def create_or_update_pipeline( - self, create_pipeline_request: CreatePipelineEntityRequest - ) -> Pipeline: - """Create or Update a Pipeline""" - resp = self.client.put("/pipelines", data=create_pipeline_request.json()) - return Pipeline(**resp) - - def list_pipelines( - self, fields: str = None, after: str = None, limit: int = 1000 - ) -> PipelineEntities: - """List all pipelines""" - if fields is None: - resp = self.client.get("/pipelines") - else: - if after is not None: - resp = self.client.get( - "/pipelines?fields={}&after={}&limit={}".format( - fields, after, limit - ) - ) - else: - resp = self.client.get( - "/pipelines?fields={}&limit={}".format(fields, limit) - ) - - if self._use_raw_data: - return resp - else: - pipelines = [Pipeline(**t) for t in resp["data"]] - total = resp["paging"]["total"] - after = resp["paging"]["after"] if "after" in resp["paging"] else None - return PipelineEntities(pipelines=pipelines, total=total, after=after) - - def get_pipeline_by_name( - self, pipeline_name: str, fields: [] = ["tasks", "service"] - ) -> Pipeline: - """Get Pipeline By Name""" - params = {"fields": ",".join(fields)} - resp = self.client.get("/pipelines/name/{}".format(pipeline_name), data=params) - return Pipeline(**resp) - - def create_or_update_model(self, model: Model): - resp = self.client.put("/models", data=model.json()) - return Model(**resp) - - def close(self): - self.client.close() diff --git a/ingestion/src/metadata/ingestion/processor/pii.py b/ingestion/src/metadata/ingestion/processor/pii.py index a1d7a78e17c..42529f8ca98 100644 --- a/ingestion/src/metadata/ingestion/processor/pii.py +++ b/ingestion/src/metadata/ingestion/processor/pii.py @@ -28,10 +28,8 @@ from metadata.generated.schema.type.tagLabel import TagLabel from metadata.ingestion.api.common import Record, WorkflowContext from metadata.ingestion.api.processor import Processor, ProcessorStatus from metadata.ingestion.models.ometa_table_db import OMetaDatabaseAndTable -from metadata.ingestion.ometa.openmetadata_rest import ( - MetadataServerConfig, - OpenMetadataAPIClient, -) +from metadata.ingestion.ometa.ometa_api import OpenMetadata +from metadata.ingestion.ometa.openmetadata_rest import MetadataServerConfig from metadata.utils.helpers import snake_to_camel @@ -173,7 +171,7 @@ class PiiProcessor(Processor): config: PiiProcessorConfig metadata_config: MetadataServerConfig status: ProcessorStatus - client: OpenMetadataAPIClient + metadata: OpenMetadata def __init__( self, @@ -185,7 +183,7 @@ class PiiProcessor(Processor): self.config = config self.metadata_config = metadata_config self.status = ProcessorStatus() - self.client = OpenMetadataAPIClient(self.metadata_config) + self.metadata = OpenMetadata(self.metadata_config) self.tags = self.__get_tags() self.column_scanner = ColumnNameScanner() self.ner_scanner = NERScanner() @@ -199,7 +197,7 @@ class PiiProcessor(Processor): return cls(ctx, config, metadata_config) def __get_tags(self) -> {}: - user_tags = self.client.list_tags_by_category("user") + user_tags = self.metadata.list_tags_by_category("user") tags_dict = {} for tag in user_tags: tags_dict[tag.name.__root__] = tag diff --git a/ingestion/src/metadata/ingestion/sink/elasticsearch.py b/ingestion/src/metadata/ingestion/sink/elasticsearch.py index b02194016e8..50ced40e067 100644 --- a/ingestion/src/metadata/ingestion/sink/elasticsearch.py +++ b/ingestion/src/metadata/ingestion/sink/elasticsearch.py @@ -22,10 +22,15 @@ from elasticsearch import Elasticsearch from metadata.config.common import ConfigModel from metadata.generated.schema.entity.data.chart import Chart from metadata.generated.schema.entity.data.dashboard import Dashboard +from metadata.generated.schema.entity.data.database import Database from metadata.generated.schema.entity.data.pipeline import Pipeline from metadata.generated.schema.entity.data.table import Table from metadata.generated.schema.entity.data.task import Task from metadata.generated.schema.entity.data.topic import Topic +from metadata.generated.schema.entity.services.dashboardService import DashboardService +from metadata.generated.schema.entity.services.databaseService import DatabaseService +from metadata.generated.schema.entity.services.messagingService import MessagingService +from metadata.generated.schema.entity.services.pipelineService import PipelineService from metadata.generated.schema.type import entityReference from metadata.ingestion.api.common import Record, WorkflowContext from metadata.ingestion.api.sink import Sink, SinkStatus @@ -35,10 +40,8 @@ from metadata.ingestion.models.table_metadata import ( TableESDocument, TopicESDocument, ) -from metadata.ingestion.ometa.openmetadata_rest import ( - MetadataServerConfig, - OpenMetadataAPIClient, -) +from metadata.ingestion.ometa.ometa_api import OpenMetadata +from metadata.ingestion.ometa.openmetadata_rest import MetadataServerConfig from metadata.ingestion.sink.elasticsearch_constants import ( DASHBOARD_ELASTICSEARCH_INDEX_MAPPING, PIPELINE_ELASTICSEARCH_INDEX_MAPPING, @@ -88,7 +91,7 @@ class ElasticsearchSink(Sink): self.metadata_config = metadata_config self.ctx = ctx self.status = SinkStatus() - self.rest = OpenMetadataAPIClient(self.metadata_config) + self.metadata = OpenMetadata(self.metadata_config) self.elasticsearch_doc_type = "_doc" http_auth = None if self.config.es_username: @@ -205,9 +208,11 @@ class ElasticsearchSink(Sink): for col_tag in column.tags: tags.add(col_tag.tagFQN) - database_entity = self.rest.get_database_by_id(table.database.id.__root__) - service_entity = self.rest.get_database_service_by_id( - database_entity.service.id.__root__ + database_entity = self.metadata.get_by_id( + entity=Database, entity_id=str(table.database.id.__root__) + ) + service_entity = self.metadata.get_by_id( + entity=DatabaseService, entity_id=str(database_entity.service.id.__root__) ) table_owner = str(table.owner.id.__root__) if table.owner is not None else "" table_followers = [] @@ -250,8 +255,8 @@ class ElasticsearchSink(Sink): ] tags = set() timestamp = time.time() - service_entity = self.rest.get_messaging_service_by_id( - str(topic.service.id.__root__) + service_entity = self.metadata.get_by_id( + entity=MessagingService, entity_id=str(topic.service.id.__root__) ) topic_owner = str(topic.owner.id.__root__) if topic.owner is not None else "" topic_followers = [] @@ -287,8 +292,8 @@ class ElasticsearchSink(Sink): suggest = [{"input": [dashboard.displayName], "weight": 10}] tags = set() timestamp = time.time() - service_entity = self.rest.get_dashboard_service_by_id( - str(dashboard.service.id.__root__) + service_entity = self.metadata.get_by_id( + entity=DashboardService, entity_id=str(dashboard.service.id.__root__) ) dashboard_owner = ( str(dashboard.owner.id.__root__) if dashboard.owner is not None else "" @@ -343,8 +348,8 @@ class ElasticsearchSink(Sink): suggest = [{"input": [pipeline.displayName], "weight": 10}] tags = set() timestamp = time.time() - service_entity = self.rest.get_pipeline_service_by_id( - str(pipeline.service.id.__root__) + service_entity = self.metadata.get_by_id( + entity=PipelineService, entity_id=str(pipeline.service.id.__root__) ) pipeline_owner = ( str(pipeline.owner.id.__root__) if pipeline.owner is not None else "" @@ -391,20 +396,14 @@ class ElasticsearchSink(Sink): def _get_charts(self, chart_refs: Optional[List[entityReference.EntityReference]]): charts = [] - if chart_refs is not None: + if chart_refs: for chart_ref in chart_refs: - chart = self.rest.get_chart_by_id(str(chart_ref.id.__root__)) + chart = self.metadata.get_by_id( + entity=Chart, entity_id=str(chart_ref.id.__root__), fields=["tags"] + ) charts.append(chart) return charts - def _get_tasks(self, task_refs: Optional[List[entityReference.EntityReference]]): - tasks = [] - if task_refs is not None: - for task_ref in task_refs: - task = self.rest.get_task_by_id(str(task_ref.id.__root__)) - tasks.append(task) - return tasks - def get_status(self): return self.status diff --git a/ingestion/src/metadata/ingestion/sink/ldap_rest_users.py b/ingestion/src/metadata/ingestion/sink/ldap_rest_users.py index 91bc2da91b5..04ca2d8e8fe 100644 --- a/ingestion/src/metadata/ingestion/sink/ldap_rest_users.py +++ b/ingestion/src/metadata/ingestion/sink/ldap_rest_users.py @@ -19,10 +19,8 @@ from metadata.config.common import ConfigModel from metadata.ingestion.api.common import Record, WorkflowContext from metadata.ingestion.api.sink import Sink, SinkStatus from metadata.ingestion.models.user import MetadataUser -from metadata.ingestion.ometa.openmetadata_rest import ( - MetadataServerConfig, - OpenMetadataAPIClient, -) +from metadata.ingestion.ometa.ometa_api import OpenMetadata +from metadata.ingestion.ometa.openmetadata_rest import MetadataServerConfig logger = logging.getLogger(__name__) @@ -46,7 +44,7 @@ class LdapRestUsersSink(Sink): self.metadata_config = metadata_config self.status = SinkStatus() self.api_users = "/users" - self.rest = OpenMetadataAPIClient(metadata_config).client + self.rest = OpenMetadata(metadata_config).client @classmethod def create( diff --git a/ingestion/src/metadata/ingestion/sink/metadata_rest.py b/ingestion/src/metadata/ingestion/sink/metadata_rest.py index c21208253c6..0ffdbb449d3 100644 --- a/ingestion/src/metadata/ingestion/sink/metadata_rest.py +++ b/ingestion/src/metadata/ingestion/sink/metadata_rest.py @@ -26,6 +26,7 @@ from metadata.generated.schema.api.data.createDashboard import ( from metadata.generated.schema.api.data.createDatabase import ( CreateDatabaseEntityRequest, ) +from metadata.generated.schema.api.data.createModel import CreateModelEntityRequest from metadata.generated.schema.api.data.createPipeline import ( CreatePipelineEntityRequest, ) @@ -45,10 +46,7 @@ from metadata.ingestion.models.table_metadata import Chart, Dashboard from metadata.ingestion.models.user import MetadataTeam, MetadataUser, User from metadata.ingestion.ometa.client import APIError from metadata.ingestion.ometa.ometa_api import OpenMetadata -from metadata.ingestion.ometa.openmetadata_rest import ( - MetadataServerConfig, - OpenMetadataAPIClient, -) +from metadata.ingestion.ometa.openmetadata_rest import MetadataServerConfig logger = logging.getLogger(__name__) @@ -87,10 +85,8 @@ class MetadataRestSink(Sink): self.status = SinkStatus() self.wrote_something = False self.charts_dict = {} - self.client = OpenMetadataAPIClient(self.metadata_config) - # Let's migrate usages from OpenMetadataAPIClient to OpenMetadata self.metadata = OpenMetadata(self.metadata_config) - self.api_client = self.client.client + self.api_client = self.metadata.client self.api_team = "/teams" self.api_users = "/users" self.team_entities = {} @@ -190,7 +186,7 @@ class MetadataRestSink(Sink): def write_topics(self, topic: CreateTopicEntityRequest) -> None: try: - created_topic = self.client.create_or_update_topic(topic) + created_topic = self.metadata.create_or_update(topic) logger.info(f"Successfully ingested topic {created_topic.name.__root__}") self.status.records_written(f"Topic: {created_topic.name.__root__}") except (APIError, ValidationError) as err: @@ -215,7 +211,7 @@ class MetadataRestSink(Sink): chartUrl=chart.url, service=chart.service, ) - created_chart = self.client.create_or_update_chart(chart_request) + created_chart = self.metadata.create_or_update(chart_request) self.charts_dict[chart.name] = EntityReference( id=created_chart.id, type="chart" ) @@ -238,9 +234,7 @@ class MetadataRestSink(Sink): charts=charts, service=dashboard.service, ) - created_dashboard = self.client.create_or_update_dashboard( - dashboard_request - ) + created_dashboard = self.metadata.create_or_update(dashboard_request) logger.info( f"Successfully ingested dashboard {created_dashboard.displayName}" ) @@ -267,7 +261,7 @@ class MetadataRestSink(Sink): downstreamTasks=task.downstreamTasks, service=task.service, ) - created_task = self.client.create_or_update_task(task_request) + created_task = self.metadata.create_or_update(task_request) logger.info(f"Successfully ingested Task {created_task.displayName}") self.status.records_written(f"Task: {created_task.displayName}") except (APIError, ValidationError) as err: @@ -285,7 +279,7 @@ class MetadataRestSink(Sink): tasks=pipeline.tasks, service=pipeline.service, ) - created_pipeline = self.client.create_or_update_pipeline(pipeline_request) + created_pipeline = self.metadata.create_or_update(pipeline_request) logger.info( f"Successfully ingested Pipeline {created_pipeline.displayName}" ) @@ -309,7 +303,14 @@ class MetadataRestSink(Sink): def write_model(self, model: Model): try: logger.info(model) - created_model = self.client.create_or_update_model(model) + model_request = CreateModelEntityRequest( + name=model.name, + displayName=model.displayName, + description=model.description, + algorithm=model.algorithm, + dashboard=model.dashboard, + ) + created_model = self.metadata.create_or_update(model_request) logger.info(f"Successfully added Model {created_model.displayName}") self.status.records_written(f"Model: {created_model.displayName}") except (APIError, ValidationError) as err: diff --git a/ingestion/src/metadata/ingestion/source/metadata.py b/ingestion/src/metadata/ingestion/source/metadata.py index 05ad1369c84..3fbc77bb3a8 100644 --- a/ingestion/src/metadata/ingestion/source/metadata.py +++ b/ingestion/src/metadata/ingestion/source/metadata.py @@ -18,9 +18,10 @@ from dataclasses import dataclass, field from typing import Iterable, List, Optional from metadata.config.common import ConfigModel +from metadata.generated.schema.entity.data.pipeline import Pipeline from metadata.ingestion.api.common import Record, WorkflowContext from metadata.ingestion.api.source import Source, SourceStatus -from metadata.ingestion.ometa.openmetadata_rest import OpenMetadataAPIClient +from metadata.ingestion.ometa.ometa_api import OpenMetadata from ...generated.schema.entity.data.dashboard import Dashboard from ...generated.schema.entity.data.table import Table @@ -79,7 +80,7 @@ class MetadataSource(Source): self.metadata_config = metadata_config self.status = MetadataSourceStatus() self.wrote_something = False - self.client = OpenMetadataAPIClient(self.metadata_config) + self.metadata = OpenMetadata(self.metadata_config) self.tables = None self.topics = None @@ -104,12 +105,21 @@ class MetadataSource(Source): if self.config.include_tables: after = None while True: - table_entities = self.client.list_tables( - fields="columns,tableConstraints,usageSummary,owner,database,tags,followers", + table_entities = self.metadata.list_entities( + entity=Table, + fields=[ + "columns", + "tableConstraints", + "usageSummary", + "owner", + "database", + "tags", + "followers", + ], after=after, limit=self.config.limit_records, ) - for table in table_entities.tables: + for table in table_entities.entities: self.status.scanned_table(table.name.__root__) yield table if table_entities.after is None: @@ -120,12 +130,13 @@ class MetadataSource(Source): if self.config.include_topics: after = None while True: - topic_entities = self.client.list_topics( - fields="owner,service,tags,followers", + topic_entities = self.metadata.list_entities( + entity=Topic, + fields=["owner", "service", "tags", "followers"], after=after, limit=self.config.limit_records, ) - for topic in topic_entities.topics: + for topic in topic_entities.entities: self.status.scanned_topic(topic.name.__root__) yield topic if topic_entities.after is None: @@ -136,28 +147,37 @@ class MetadataSource(Source): if self.config.include_dashboards: after = None while True: - dashboard_entities = self.client.list_dashboards( - fields="owner,service,tags,followers,charts,usageSummary", + dashboard_entities = self.metadata.list_entities( + entity=Dashboard, + fields=[ + "owner", + "service", + "tags", + "followers", + "charts", + "usageSummary", + ], after=after, limit=self.config.limit_records, ) - for dashboard in dashboard_entities.dashboards: + for dashboard in dashboard_entities.entities: self.status.scanned_dashboard(dashboard.name) yield dashboard if dashboard_entities.after is None: break after = dashboard_entities.after - def fetch_pipeline(self) -> Dashboard: + def fetch_pipeline(self) -> Pipeline: if self.config.include_pipelines: after = None while True: - pipeline_entities = self.client.list_pipelines( - fields="owner,service,tags,followers,tasks", + pipeline_entities = self.metadata.list_entities( + entity=Pipeline, + fields=["owner", "service", "tags", "followers", "tasks"], after=after, limit=self.config.limit_records, ) - for pipeline in pipeline_entities.pipelines: + for pipeline in pipeline_entities.entities: self.status.scanned_dashboard(pipeline.name) yield pipeline if pipeline_entities.after is None: diff --git a/ingestion/src/metadata/ingestion/source/sample_data.py b/ingestion/src/metadata/ingestion/source/sample_data.py index 768f9f99ebd..097bbc8e1b6 100644 --- a/ingestion/src/metadata/ingestion/source/sample_data.py +++ b/ingestion/src/metadata/ingestion/source/sample_data.py @@ -57,10 +57,7 @@ from metadata.ingestion.models.ometa_table_db import OMetaDatabaseAndTable from metadata.ingestion.models.table_metadata import Chart, Dashboard from metadata.ingestion.models.user import User from metadata.ingestion.ometa.ometa_api import OpenMetadata -from metadata.ingestion.ometa.openmetadata_rest import ( - MetadataServerConfig, - OpenMetadataAPIClient, -) +from metadata.ingestion.ometa.openmetadata_rest import MetadataServerConfig from metadata.utils.helpers import get_database_service_or_create logger: logging.Logger = logging.getLogger(__name__) @@ -79,64 +76,64 @@ class InvalidSampleDataException(Exception): def get_database_service_or_create(service_json, metadata_config) -> DatabaseService: - client = OpenMetadataAPIClient(metadata_config) - service = client.get_database_service(service_json["name"]) + metadata = OpenMetadata(metadata_config) + service = metadata.get_by_name(entity=DatabaseService, fqdn=service_json["name"]) if service is not None: return service else: - created_service = client.create_database_service( + created_service = metadata.create_or_update( CreateDatabaseServiceEntityRequest(**service_json) ) return created_service def get_messaging_service_or_create(service_json, metadata_config) -> MessagingService: - client = OpenMetadataAPIClient(metadata_config) - service = client.get_messaging_service(service_json["name"]) + metadata = OpenMetadata(metadata_config) + service = metadata.get_by_name(entity=MessagingService, fqdn=service_json["name"]) if service is not None: return service else: - created_service = client.create_messaging_service( + created_service = metadata.create_or_update( CreateMessagingServiceEntityRequest(**service_json) ) return created_service def get_dashboard_service_or_create(service_json, metadata_config) -> DashboardService: - client = OpenMetadataAPIClient(metadata_config) - service = client.get_dashboard_service(service_json["name"]) + metadata = OpenMetadata(metadata_config) + service = metadata.get_by_name(entity=DashboardService, fqdn=service_json["name"]) if service is not None: return service else: - created_service = client.create_dashboard_service( + created_service = metadata.create_or_update( CreateDashboardServiceEntityRequest(**service_json) ) return created_service def get_pipeline_service_or_create(service_json, metadata_config) -> PipelineService: - client = OpenMetadataAPIClient(metadata_config) - service = client.get_pipeline_service(service_json["name"]) + metadata = OpenMetadata(metadata_config) + service = metadata.get_by_name(entity=PipelineService, fqdn=service_json["name"]) if service is not None: return service else: - created_service = client.create_pipeline_service( + created_service = metadata.create_or_update( CreatePipelineServiceEntityRequest(**service_json) ) return created_service def get_lineage_entity_ref(edge, metadata_config) -> EntityReference: - client = OpenMetadataAPIClient(metadata_config) + metadata = OpenMetadata(metadata_config) fqn = edge["fqn"] if edge["type"] == "table": - table = client.get_table_by_name(fqn) + table = metadata.get_by_name(entity=Table, fqdn=fqn) return EntityReference(id=table.id, type="table") elif edge["type"] == "pipeline": - pipeline = client.get_pipeline_by_name(edge["fqn"]) + pipeline = metadata.get_by_name(entity=Pipeline, fqdn=fqn) return EntityReference(id=pipeline.id, type="pipeline") elif edge["type"] == "dashboard": - dashboard = client.get_dashboard_by_name(fqn) + dashboard = metadata.get_by_name(entity=Dashboard, fqdn=fqn) return EntityReference(id=dashboard.id, type="dashboard") @@ -280,7 +277,6 @@ class SampleDataSource(Source): self.status = SampleDataSourceStatus() self.config = config self.metadata_config = metadata_config - self.client = OpenMetadataAPIClient(metadata_config) self.metadata = OpenMetadata(metadata_config) self.database_service_json = json.load( open(self.config.sample_data_folder + "/datasets/service.json", "r") diff --git a/ingestion/src/metadata/ingestion/source/sample_entity.py b/ingestion/src/metadata/ingestion/source/sample_entity.py index 9631a61ec9d..a513754870b 100644 --- a/ingestion/src/metadata/ingestion/source/sample_entity.py +++ b/ingestion/src/metadata/ingestion/source/sample_entity.py @@ -27,10 +27,8 @@ from metadata.ingestion.api.source import Source, SourceStatus from metadata.ingestion.models.ometa_table_db import OMetaDatabaseAndTable from metadata.ingestion.models.table_metadata import Chart, Dashboard from metadata.ingestion.ometa.client import APIError -from metadata.ingestion.ometa.openmetadata_rest import ( - MetadataServerConfig, - OpenMetadataAPIClient, -) +from metadata.ingestion.ometa.ometa_api import OpenMetadata +from metadata.ingestion.ometa.openmetadata_rest import MetadataServerConfig from metadata.ingestion.processor.pii import ColumnNameScanner from metadata.ingestion.source.sample_data import get_database_service_or_create from metadata.ingestion.source.sql_source import SQLConnectionConfig @@ -82,7 +80,7 @@ class SampleEntitySource(Source): self.status = SampleEntitySourceStatus() self.config = config self.metadata_config = metadata_config - self.client = OpenMetadataAPIClient(metadata_config) + self.metadata = OpenMetadata(metadata_config) self.column_scanner = ColumnNameScanner() self.service_name = lambda: self.faker.word() self.service_type = lambda: random.choice( @@ -118,7 +116,7 @@ class SampleEntitySource(Source): pass def __get_tags(self) -> {}: - return self.client.list_tags_by_category("user") + return self.metadata.list_tags_by_category("user") def scan(self, text): types = set() @@ -148,7 +146,7 @@ class SampleEntitySource(Source): create_service = None while True: try: - create_service = self.client.create_database_service( + create_service = self.metadata.create_or_update( CreateDatabaseServiceEntityRequest(**service) ) break @@ -230,7 +228,7 @@ class SampleEntitySource(Source): "password": "admin", "serviceType": "Superset", } - create_service = self.client.create_dashboard_service( + create_service = self.metadata.create_or_update( CreateDashboardServiceEntityRequest(**service) ) break @@ -293,7 +291,7 @@ class SampleEntitySource(Source): "schemaRegistry": "http://localhost:8081", "serviceType": "Kafka", } - create_service = self.client.create_messaging_service( + create_service = self.metadata.create_or_update( CreateMessagingServiceEntityRequest(**service) ) break diff --git a/ingestion/src/metadata/ingestion/source/sample_usage.py b/ingestion/src/metadata/ingestion/source/sample_usage.py index 318ccef1ae0..1191820cbf9 100644 --- a/ingestion/src/metadata/ingestion/source/sample_usage.py +++ b/ingestion/src/metadata/ingestion/source/sample_usage.py @@ -6,7 +6,7 @@ from typing import Iterable from metadata.ingestion.api.source import Source from metadata.ingestion.models.table_queries import TableQuery -from ..ometa.openmetadata_rest import MetadataServerConfig, OpenMetadataAPIClient +from ..ometa.openmetadata_rest import MetadataServerConfig from .sample_data import ( SampleDataSourceConfig, SampleDataSourceStatus, @@ -25,7 +25,6 @@ class SampleUsageSource(Source): self.status = SampleDataSourceStatus() self.config = config self.metadata_config = metadata_config - self.client = OpenMetadataAPIClient(metadata_config) self.service_json = json.load( open(config.sample_data_folder + "/datasets/service.json", "r") ) diff --git a/ingestion/src/metadata/utils/helpers.py b/ingestion/src/metadata/utils/helpers.py index 244d3d09372..0b9fbf99c62 100644 --- a/ingestion/src/metadata/utils/helpers.py +++ b/ingestion/src/metadata/utils/helpers.py @@ -27,7 +27,7 @@ from metadata.generated.schema.api.services.createMessagingService import ( from metadata.generated.schema.entity.services.dashboardService import DashboardService from metadata.generated.schema.entity.services.databaseService import DatabaseService from metadata.generated.schema.entity.services.messagingService import MessagingService -from metadata.ingestion.ometa.openmetadata_rest import OpenMetadataAPIClient +from metadata.ingestion.ometa.ometa_api import OpenMetadata def get_start_and_end(duration): @@ -48,8 +48,8 @@ def snake_to_camel(s): def get_database_service_or_create(config, metadata_config) -> DatabaseService: - client = OpenMetadataAPIClient(metadata_config) - service = client.get_database_service(config.service_name) + metadata = OpenMetadata(metadata_config) + service = metadata.get_by_name(entity=DatabaseService, fqdn=config.service_name) if service is not None: return service else: @@ -62,7 +62,7 @@ def get_database_service_or_create(config, metadata_config) -> DatabaseService: "description": "", "serviceType": config.get_service_type(), } - created_service = client.create_database_service( + created_service = metadata.create_or_update( CreateDatabaseServiceEntityRequest(**service) ) return created_service @@ -75,19 +75,18 @@ def get_messaging_service_or_create( brokers: List[str], metadata_config, ) -> MessagingService: - client = OpenMetadataAPIClient(metadata_config) - service = client.get_messaging_service(service_name) + metadata = OpenMetadata(metadata_config) + service = metadata.get_by_name(entity=MessagingService, fqdn=service_name) if service is not None: return service else: - create_messaging_service_request = CreateMessagingServiceEntityRequest( - name=service_name, - serviceType=message_service_type, - brokers=brokers, - schemaRegistry=schema_registry_url, - ) - created_service = client.create_messaging_service( - create_messaging_service_request + created_service = metadata.create_or_update( + CreateMessagingServiceEntityRequest( + name=service_name, + serviceType=message_service_type, + brokers=brokers, + schemaRegistry=schema_registry_url, + ) ) return created_service @@ -100,20 +99,19 @@ def get_dashboard_service_or_create( dashboard_url: str, metadata_config, ) -> DashboardService: - client = OpenMetadataAPIClient(metadata_config) - service = client.get_dashboard_service(service_name) + metadata = OpenMetadata(metadata_config) + service = metadata.get_by_name(entity=DashboardService, fqdn=service_name) if service is not None: return service else: - create_dashboard_service_request = CreateDashboardServiceEntityRequest( - name=service_name, - serviceType=dashboard_service_type, - username=username, - password=password, - dashboardUrl=dashboard_url, - ) - created_service = client.create_dashboard_service( - create_dashboard_service_request + created_service = metadata.create_or_update( + CreateDashboardServiceEntityRequest( + name=service_name, + serviceType=dashboard_service_type, + username=username, + password=password, + dashboardUrl=dashboard_url, + ) ) return created_service diff --git a/ingestion/tests/integration/hive/test_hive_crud.py b/ingestion/tests/integration/hive/test_hive_crud.py index 4a12c61de2e..0236dad4d56 100644 --- a/ingestion/tests/integration/hive/test_hive_crud.py +++ b/ingestion/tests/integration/hive/test_hive_crud.py @@ -13,12 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import socket import time +from typing import List +from urllib.parse import urlparse import pytest import requests -import socket -from metadata.ingestion.ometa.openmetadata_rest import MetadataServerConfig, OpenMetadataAPIClient from sqlalchemy.engine import create_engine from sqlalchemy.inspection import inspect @@ -29,9 +30,13 @@ from metadata.generated.schema.api.data.createTable import CreateTableEntityRequ from metadata.generated.schema.api.services.createDatabaseService import ( CreateDatabaseServiceEntityRequest, ) -from metadata.generated.schema.entity.data.table import Column +from metadata.generated.schema.entity.data.database import Database +from metadata.generated.schema.entity.data.table import Column, Table +from metadata.generated.schema.entity.services.databaseService import DatabaseService from metadata.generated.schema.type.entityReference import EntityReference -from urllib.parse import urlparse +from metadata.ingestion.ometa.ometa_api import OpenMetadata +from metadata.ingestion.ometa.openmetadata_rest import MetadataServerConfig + def is_responsive(url): try: @@ -41,6 +46,7 @@ def is_responsive(url): except ConnectionError: return False + def is_port_open(url): url_parts = urlparse(url) hostname = url_parts.hostname @@ -54,6 +60,7 @@ def is_port_open(url): finally: s.close() + def sleep(timeout_s): print(f"sleeping for {timeout_s} seconds") n = len(str(timeout_s)) @@ -70,8 +77,7 @@ def status(r): return 0 -def create_delete_table(client): - databases = client.list_databases() +def create_delete_table(client: OpenMetadata, databases: List[Database]): columns = [ Column(name="id", dataType="INT", dataLength=1), Column(name="name", dataType="VARCHAR", dataLength=1), @@ -79,20 +85,16 @@ def create_delete_table(client): table = CreateTableEntityRequest( name="test1", columns=columns, database=databases[0].id ) - created_table = client.create_or_update_table(table) + created_table = client.create_or_update(table) if table.name.__root__ == created_table.name.__root__: - requests.delete( - "http://localhost:8585/api/v1/tables/{}".format(created_table.id.__root__) - ) + client.delete(entity=Table, entity_id=str(created_table.id.__root__)) return 1 else: - requests.delete( - "http://localhost:8585/api/v1/tables/{}".format(created_table.id.__root__) - ) + client.delete(entity=Table, entity_id=str(created_table.id.__root__)) return 0 -def create_delete_database(client): +def create_delete_database(client: OpenMetadata, databases: List[Database]): data = { "jdbc": {"connectionUrl": "hive://localhost/default", "driverClass": "jdbc"}, "name": "temp_local_hive", @@ -100,15 +102,15 @@ def create_delete_database(client): "description": "local hive env", } create_hive_service = CreateDatabaseServiceEntityRequest(**data) - hive_service = client.create_database_service(create_hive_service) + hive_service = client.create_or_update(create_hive_service) create_database_request = CreateDatabaseEntityRequest( name="dwh", service=EntityReference(id=hive_service.id, type="databaseService") ) - created_database = client.create_database(create_database_request) - resp = create_delete_table(client) + created_database = client.create_or_update(create_database_request) + resp = create_delete_table(client, databases) print(resp) - client.delete_database(created_database.id.__root__) - client.delete_database_service(hive_service.id.__root__) + client.delete(entity=Database, entity_id=str(created_database.id.__root__)) + client.delete(entity=DatabaseService, entity_id=str(hive_service.id.__root__)) return resp @@ -127,6 +129,7 @@ def hive_service(docker_ip, docker_services): inspector = inspect(engine) return inspector + def test_check_schema(hive_service): inspector = hive_service schemas = [] @@ -161,9 +164,9 @@ def test_check_table(): metadata_config = MetadataServerConfig.parse_obj( {"api_endpoint": "http://localhost:8585/api", "auth_provider_type": "no-auth"} ) - client = OpenMetadataAPIClient(metadata_config) - databases = client.list_databases() + client = OpenMetadata(metadata_config) + databases = client.list_entities(entity=Database).entities if len(databases) > 0: - assert create_delete_table(client) + assert create_delete_table(client, databases) else: - assert create_delete_database(client) + assert create_delete_database(client, databases) diff --git a/ingestion/tests/integration/mysql/test_mysql_crud.py b/ingestion/tests/integration/mysql/test_mysql_crud.py index d773d4843f7..b69fe4d9470 100644 --- a/ingestion/tests/integration/mysql/test_mysql_crud.py +++ b/ingestion/tests/integration/mysql/test_mysql_crud.py @@ -28,12 +28,12 @@ from metadata.generated.schema.api.data.createTable import CreateTableEntityRequ from metadata.generated.schema.api.services.createDatabaseService import ( CreateDatabaseServiceEntityRequest, ) +from metadata.generated.schema.entity.data.database import Database from metadata.generated.schema.entity.data.table import Column +from metadata.generated.schema.entity.services.databaseService import DatabaseService from metadata.generated.schema.type.entityReference import EntityReference -from metadata.ingestion.ometa.openmetadata_rest import ( - MetadataServerConfig, - OpenMetadataAPIClient, -) +from metadata.ingestion.ometa.ometa_api import OpenMetadata +from metadata.ingestion.ometa.openmetadata_rest import MetadataServerConfig def is_responsive(url): @@ -45,8 +45,8 @@ def is_responsive(url): return False -def create_delete_table(client): - databases = client.list_databases() +def create_delete_table(client: OpenMetadata): + databases = client.list_entities(entity=Database).entities columns = [ Column(name="id", dataType="INT", dataLength=1), Column(name="name", dataType="VARCHAR", dataLength=1), @@ -54,7 +54,7 @@ def create_delete_table(client): table = CreateTableEntityRequest( name="test1", columns=columns, database=databases[0].id ) - created_table = client.create_or_update_table(table) + created_table = client.create_or_update(table) if table.name.__root__ == created_table.name.__root__: requests.delete( "http://localhost:8585/api/v1/tables/{}".format(created_table.id.__root__) @@ -67,7 +67,7 @@ def create_delete_table(client): return 0 -def create_delete_database(client): +def create_delete_database(client: OpenMetadata): data = { "jdbc": { "connectionUrl": "mysql://localhost/catalog_db", @@ -78,15 +78,15 @@ def create_delete_database(client): "description": "local mysql env", } create_mysql_service = CreateDatabaseServiceEntityRequest(**data) - mysql_service = client.create_database_service(create_mysql_service) + mysql_service = client.create_or_update(create_mysql_service) create_database_request = CreateDatabaseEntityRequest( name="dwh", service=EntityReference(id=mysql_service.id, type="databaseService") ) - created_database = client.create_database(create_database_request) + created_database = client.create_or_update(create_database_request) resp = create_delete_table(client) print(resp) - client.delete_database(created_database.id.__root__) - client.delete_database_service(mysql_service.id.__root__) + client.delete(entity=Database, entity_id=str(created_database.id.__root__)) + client.delete(entity=DatabaseService, entity_id=str(mysql_service.id.__root__)) return resp @@ -107,8 +107,7 @@ def test_check_tables(catalog_service): metadata_config = MetadataServerConfig.parse_obj( {"api_endpoint": catalog_service + "/api", "auth_provider_type": "no-auth"} ) - client = OpenMetadataAPIClient(metadata_config) - databases = client.list_databases() + client = OpenMetadata(metadata_config) assert create_delete_database(client) diff --git a/ingestion/tests/unit/helpers_test.py b/ingestion/tests/unit/helpers_test.py deleted file mode 100644 index 546c5eaee9f..00000000000 --- a/ingestion/tests/unit/helpers_test.py +++ /dev/null @@ -1,175 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import json -from unittest import TestCase - -from metadata.generated.schema.api.data.createDatabase import ( - CreateDatabaseEntityRequest, -) -from metadata.generated.schema.api.data.createTable import CreateTableEntityRequest -from metadata.generated.schema.api.services.createDashboardService import ( - CreateDashboardServiceEntityRequest, -) -from metadata.generated.schema.api.services.createDatabaseService import ( - CreateDatabaseServiceEntityRequest, -) -from metadata.generated.schema.api.services.createMessagingService import ( - CreateMessagingServiceEntityRequest, -) -from metadata.generated.schema.entity.data.table import Column -from metadata.generated.schema.type.entityReference import EntityReference -from metadata.ingestion.ometa.client import APIError -from metadata.ingestion.ometa.openmetadata_rest import ( - MetadataServerConfig, - OpenMetadataAPIClient, -) - - -class RestTest(TestCase): - file_path = "tests/unit/mysql_test.json" - with open(file_path) as ingestionFile: - ingestionData = ingestionFile.read() - client_config = json.loads(ingestionData).get("metadata_server") - config = client_config.get("config", {}) - metadata_config = MetadataServerConfig.parse_obj(config) - openmetadata_client = OpenMetadataAPIClient(metadata_config) - client = OpenMetadataAPIClient(metadata_config).client - - def test_1_create_service(self): - data = { - "jdbc": { - "connectionUrl": "mysql://localhost/openmetadata_db", - "driverClass": "jdbc", - }, - "name": "local_mysql_test", - "serviceType": "MySQL", - "description": "local mysql env", - } - create_mysql_service = CreateDatabaseServiceEntityRequest(**data) - mysql_service = self.openmetadata_client.create_database_service( - create_mysql_service - ) - self.assertEqual(mysql_service.name, create_mysql_service.name) - - def test_2_get_service(self): - mysql_service = self.openmetadata_client.get_database_service( - "local_mysql_test" - ) - self.assertEqual(mysql_service.name, "local_mysql_test") - - def test_3_get_service_by_id(self): - mysql_service = self.openmetadata_client.get_database_service( - "local_mysql_test" - ) - mysql_service_get_id = self.openmetadata_client.get_database_service_by_id( - mysql_service.id.__root__ - ) - self.assertEqual(mysql_service.id, mysql_service_get_id.id) - - def test_4_create_update_databases(self): - mysql_service = self.openmetadata_client.get_database_service( - "local_mysql_test" - ) - service_reference = EntityReference( - id=mysql_service.id.__root__, type="databaseService" - ) - create_database_request = CreateDatabaseEntityRequest( - name="dwh", service=service_reference - ) - created_database = self.openmetadata_client.create_database( - create_database_request - ) - created_database.description = "hello world" - update_database_request = CreateDatabaseEntityRequest( - name=created_database.name, - description=created_database.description, - service=service_reference, - ) - updated_database = self.openmetadata_client.create_database( - update_database_request - ) - self.assertEqual(updated_database.description, created_database.description) - - def test_5_create_table(self): - databases = self.openmetadata_client.list_databases() - columns = [ - Column(name="id", columnDataType="INT"), - Column(name="name", columnDataType="VARCHAR"), - ] - table = CreateTableEntityRequest( - name="test1", columns=columns, database=databases[0].id.__root__ - ) - created_table = self.openmetadata_client.create_or_update_table(table) - self.client.delete(f"/tables/{created_table.id.__root__}") - self.client.delete(f"/databases/{databases[0].id.__root__}") - self.assertEqual(table.name, created_table.name) - - def test_6_delete_service(self): - mysql_service = self.openmetadata_client.get_database_service( - "local_mysql_test" - ) - self.openmetadata_client.delete_database_service(mysql_service.id.__root__) - self.assertRaises( - APIError, - self.openmetadata_client.get_database_service_by_id, - mysql_service.id.__root__, - ) - - def test_7_create_messaging_service(self): - create_messaging_service = CreateMessagingServiceEntityRequest( - name="sample_kafka_test", - serviceType="Kafka", - brokers=["localhost:9092"], - schemaRegistry="http://localhost:8081", - ) - messaging_service = self.openmetadata_client.create_messaging_service( - create_messaging_service - ) - self.assertEqual(create_messaging_service.name, messaging_service.name) - - def test_8_get_messaging_service(self): - messaging_service = self.openmetadata_client.get_messaging_service( - "sample_kafka_test" - ) - self.client.delete( - f"/services/messagingServices/{messaging_service.id.__root__}" - ) - self.assertEqual(messaging_service.name, "sample_kafka_test") - - def test_9_create_dashboard_service(self): - create_dashboard_service = CreateDashboardServiceEntityRequest( - name="sample_superset_test", - serviceType="Superset", - username="admin", - password="admin", - dashboardUrl="http://localhost:8088", - ) - dashboard_service = None - try: - dashboard_service = self.openmetadata_client.create_dashboard_service( - create_dashboard_service - ) - except APIError: - print(APIError) - self.assertEqual(create_dashboard_service.name, dashboard_service.name) - - def test_10_get_dashboard_service(self): - dashboard_service = self.openmetadata_client.get_dashboard_service( - "sample_superset_test" - ) - self.client.delete( - f"/services/dashboardServices/{dashboard_service.id.__root__}" - ) - self.assertEqual(dashboard_service.name, "sample_superset_test") diff --git a/ingestion/tests/unit/test_ometa_endpoints.py b/ingestion/tests/unit/test_ometa_endpoints.py index 1adb29f5eab..dedbc269196 100644 --- a/ingestion/tests/unit/test_ometa_endpoints.py +++ b/ingestion/tests/unit/test_ometa_endpoints.py @@ -57,7 +57,6 @@ class OMetaEndpointTest(TestCase): # Pipelines self.assertEqual(self.metadata.get_suffix(Pipeline), "/pipelines") - self.assertEqual(self.metadata.get_suffix(Task), "/tasks") # Topic self.assertEqual(self.metadata.get_suffix(Topic), "/topics") diff --git a/ingestion/tests/unit/workflow_test.py b/ingestion/tests/unit/workflow_test.py index 0c543b39c8f..18345133661 100644 --- a/ingestion/tests/unit/workflow_test.py +++ b/ingestion/tests/unit/workflow_test.py @@ -4,10 +4,8 @@ from unittest import TestCase from metadata.config.common import load_config_file from metadata.ingestion.api.workflow import Workflow -from metadata.ingestion.ometa.openmetadata_rest import ( - MetadataServerConfig, - OpenMetadataAPIClient, -) +from metadata.ingestion.ometa.ometa_api import OpenMetadata +from metadata.ingestion.ometa.openmetadata_rest import MetadataServerConfig class WorkflowTest(TestCase): @@ -59,7 +57,7 @@ class WorkflowTest(TestCase): config = MetadataServerConfig.parse_obj( workflow_config.get("metadata_server").get("config") ) - client = OpenMetadataAPIClient(config).client + client = OpenMetadata(config).client client.delete( f"/services/databaseServices/"