feat: added oauth and azure ad for unity catalog

This commit is contained in:
Keshav Mohta 2025-09-25 16:04:18 +05:30
parent a7859c5906
commit 88fc3c65fb
No known key found for this signature in database
GPG Key ID: 9481AB99C36FAE9C
23 changed files with 232 additions and 45917 deletions

View File

@ -0,0 +1,14 @@
-- Migration script to restructure Databricks connection configuration
-- Move 'token' field from connection.config.token to connection.config.authType.token
UPDATE dbservice_entity
SET
json = JSON_SET (
JSON_REMOVE (json, '$.connection.config.token'),
'$.connection.config.authType',
JSON_OBJECT (
'token',
JSON_EXTRACT (json, '$.connection.config.token')
)
)
WHERE
serviceType in ('Databricks', 'UnityCatalog');

View File

@ -0,0 +1,11 @@
-- Migration script to restructure Databricks connection configuration
-- Move 'token' field from connection.config.token to connection.config.authType.token
UPDATE dbservice_entity
SET json = jsonb_set(
json #- '{connection,config,token}',
'{connection,config,authType}',
jsonb_build_object('token', json #> '{connection,config,token}'),
true
)
WHERE serviceType in ('Databricks', 'UnityCatalog');

View File

@ -5,18 +5,3 @@
UPDATE profiler_data_time_series
SET json = JSON_SET(json, '$.profileData', json->'$.profileData.profileData')
WHERE json->>'$.profileData.profileData' IS NOT NULL;
-- Migration script to restructure Databricks connection configuration
-- Move 'token' field from connection.config.token to connection.config.authType.token
UPDATE dbservice_entity
SET
json = JSON_SET (
JSON_REMOVE (json, '$.connection.config.token'),
'$.connection.config.authType',
JSON_OBJECT (
'token',
JSON_EXTRACT (json, '$.connection.config.token')
)
)
WHERE
serviceType = 'Databricks';

View File

@ -5,15 +5,3 @@
UPDATE profiler_data_time_series
SET json = jsonb_set(json::jsonb, '{profileData}', json::jsonb->'profileData'->'profileData')::json
WHERE json->'profileData'->>'profileData' IS NOT NULL;
-- Migration script to restructure Databricks connection configuration
-- Move 'token' field from connection.config.token to connection.config.authType.token
UPDATE dbservice_entity
SET json = jsonb_set(
json #- '{connection,config,token}',
'{connection,config,authType}',
jsonb_build_object('token', json #> '{connection,config,token}'),
true
)
WHERE serviceType = 'Databricks';

View File

@ -61,14 +61,13 @@ class DatabricksClient:
):
self.config = config
base_url, *_ = self.config.hostPort.split(":")
auth_token = self.config.token.get_secret_value()
self.base_url = f"https://{base_url}{API_VERSION}"
self.base_query_url = f"{self.base_url}{QUERIES_PATH}"
self.base_job_url = f"https://{base_url}{JOB_API_VERSION}/jobs"
self.jobs_list_url = f"{self.base_job_url}/list"
self.jobs_run_list_url = f"{self.base_job_url}/runs/list"
self.headers = {
"Authorization": f"Bearer {auth_token}",
**self._get_auth_header(),
"Content-Type": "application/json",
}
self.api_timeout = self.config.connectionTimeout or 120
@ -79,6 +78,12 @@ class DatabricksClient:
self.engine = engine
self.client = requests
def _get_auth_header(self) -> str:
"""
Method to get auth header
"""
return {"Authorization": f"Bearer {self.config.token.get_secret_value()}"}
def test_query_api_access(self) -> None:
res = self.client.get(
self.base_query_url, headers=self.headers, timeout=self.api_timeout

View File

@ -16,10 +16,24 @@ import traceback
from requests import HTTPError
from metadata.generated.schema.entity.services.connections.database.databricks.azureAdSetup import (
AzureAdSetup,
)
from metadata.generated.schema.entity.services.connections.database.databricks.databricksOAuth import (
DatabricksOauth,
)
from metadata.generated.schema.entity.services.connections.database.databricks.personalAccessToken import (
PersonalAccessToken,
)
from metadata.ingestion.source.database.databricks.client import (
API_TIMEOUT,
DatabricksClient,
)
from metadata.ingestion.source.database.unitycatalog.connection import (
get_azure_ad_auth,
get_databricks_oauth_auth,
get_personal_access_token_auth,
)
from metadata.ingestion.source.database.unitycatalog.models import (
LineageColumnStreams,
LineageTableStreams,
@ -37,6 +51,26 @@ class UnityCatalogClient(DatabricksClient):
UnityCatalogClient creates a Databricks connection based on DatabricksCredentials.
"""
def _get_auth_header(self) -> str:
"""
Method to get auth header
"""
auth_method = {
PersonalAccessToken: get_personal_access_token_auth,
DatabricksOauth: get_databricks_oauth_auth,
AzureAdSetup: get_azure_ad_auth,
}.get(type(self.config.authType))
if not auth_method:
raise ValueError(
f"Unsupported authentication type: {type(self.config.authType)}"
)
auth_args = auth_method(self.config)
if auth_args.get("access_token"):
return {"Authorization": f"Bearer {auth_args['access_token']}"}
return {"Authorization": auth_args["credentials_provider"]()}
def get_table_lineage(self, table_name: str) -> LineageTableStreams:
"""
Method returns table lineage details

View File

@ -12,16 +12,27 @@
"""
Source connection handler
"""
from copy import deepcopy
from functools import partial
from typing import Optional
from databricks.sdk import WorkspaceClient
from databricks.sdk.core import Config, azure_service_principal, oauth_service_principal
from sqlalchemy.engine import Engine
from sqlalchemy.exc import DatabaseError
from metadata.generated.schema.entity.automations.workflow import (
Workflow as AutomationWorkflow,
)
from metadata.generated.schema.entity.services.connections.database.databricks.azureAdSetup import (
AzureAdSetup,
)
from metadata.generated.schema.entity.services.connections.database.databricks.databricksOAuth import (
DatabricksOauth,
)
from metadata.generated.schema.entity.services.connections.database.databricks.personalAccessToken import (
PersonalAccessToken,
)
from metadata.generated.schema.entity.services.connections.database.unityCatalogConnection import (
UnityCatalogConnection,
)
@ -50,8 +61,50 @@ from metadata.utils.logger import ingestion_logger
logger = ingestion_logger()
def get_personal_access_token_auth(connection: UnityCatalogConnection) -> dict:
"""
Configure Personal Access Token authentication
"""
return {"access_token": connection.authType.token.get_secret_value()}
def get_databricks_oauth_auth(connection: UnityCatalogConnection):
"""
Create Databricks OAuth2 M2M credentials provider for Service Principal authentication
"""
def credential_provider():
hostname = connection.hostPort.split(":")[0]
config = Config(
host=f"https://{hostname}",
client_id=connection.authType.clientId,
client_secret=connection.authType.clientSecret.get_secret_value(),
)
return oauth_service_principal(config)
return {"credentials_provider": credential_provider}
def get_azure_ad_auth(connection: UnityCatalogConnection):
"""
Create Azure AD credentials provider for Azure Service Principal authentication
"""
def credential_provider():
hostname = connection.hostPort.split(":")[0]
config = Config(
host=f"https://{hostname}",
azure_client_secret=connection.authType.azureClientSecret.get_secret_value(),
azure_client_id=connection.authType.azureClientId,
azure_tenant_id=connection.authType.azureTenantId,
)
return azure_service_principal(config)
return {"credentials_provider": credential_provider}
def get_connection_url(connection: UnityCatalogConnection) -> str:
url = f"{connection.scheme.value}://token:{connection.token.get_secret_value()}@{connection.hostPort}"
url = f"{connection.scheme.value}://{connection.hostPort}"
return url
@ -59,10 +112,23 @@ def get_connection(connection: UnityCatalogConnection) -> WorkspaceClient:
"""
Create connection
"""
client_params = {}
if isinstance(connection.authType, PersonalAccessToken):
client_params["token"] = connection.authType.token.get_secret_value()
elif isinstance(connection.authType, DatabricksOauth):
client_params["client_id"] = connection.authType.clientId
client_params[
"client_secret"
] = connection.authType.clientSecret.get_secret_value()
elif isinstance(connection.authType, AzureAdSetup):
client_params["azure_client_id"] = connection.authType.azureClientId
client_params[
"azure_client_secret"
] = connection.authType.azureClientSecret.get_secret_value()
client_params["azure_tenant_id"] = connection.authType.azureTenantId
return WorkspaceClient(
host=get_host_from_host_port(connection.hostPort),
token=connection.token.get_secret_value(),
host=get_host_from_host_port(connection.hostPort), **client_params
)
@ -76,6 +142,23 @@ def get_sqlalchemy_connection(connection: UnityCatalogConnection) -> Engine:
connection.connectionArguments = init_empty_connection_arguments()
connection.connectionArguments.root["http_path"] = connection.httpPath
auth_method = {
PersonalAccessToken: get_personal_access_token_auth,
DatabricksOauth: get_databricks_oauth_auth,
AzureAdSetup: get_azure_ad_auth,
}.get(type(connection.authType))
if not auth_method:
raise ValueError(
f"Unsupported authentication type: {type(connection.authType)}"
)
auth_args = auth_method(connection)
original_connection_arguments = connection.connectionArguments
connection.connectionArguments = deepcopy(original_connection_arguments)
connection.connectionArguments.root.update(auth_args)
return create_generic_db_connection(
connection=connection,
get_connection_url_fn=get_connection_url,

View File

@ -15,9 +15,6 @@ supporting sqlalchemy abstraction layer
"""
from metadata.generated.schema.entity.services.connections.database.databricksConnection import (
DatabricksConnection,
)
from metadata.sampler.sqlalchemy.databricks.sampler import DatabricksSamplerInterface
@ -27,22 +24,4 @@ class UnityCatalogSamplerInterface(DatabricksSamplerInterface):
"""
def __init__(self, *args, **kwargs):
# Convert Unity Catalog connection to Databricks and move token to authType.
kwargs["service_connection_config"] = DatabricksConnection.model_validate(
{
**(
(
t := (
cfg := kwargs["service_connection_config"].model_dump(
mode="json"
)
).pop("token")
)
and cfg
),
"type": "Databricks",
"authType": {"token": t},
}
)
super().__init__(*args, **kwargs)

View File

@ -42,6 +42,7 @@ import org.openmetadata.schema.services.connections.database.RedshiftConnection;
import org.openmetadata.schema.services.connections.database.SalesforceConnection;
import org.openmetadata.schema.services.connections.database.SapHanaConnection;
import org.openmetadata.schema.services.connections.database.TrinoConnection;
import org.openmetadata.schema.services.connections.database.UnityCatalogConnection;
import org.openmetadata.schema.services.connections.database.datalake.GCSConfig;
import org.openmetadata.schema.services.connections.database.deltalake.StorageConfig;
import org.openmetadata.schema.services.connections.database.iceberg.IcebergFileSystem;
@ -106,6 +107,7 @@ public final class ClassConverterFactory {
Map.entry(VertexAIConnection.class, new VertexAIConnectionClassConverter()),
Map.entry(RangerConnection.class, new RangerConnectionClassConverter()),
Map.entry(DatabricksConnection.class, new DatabricksConnectionClassConverter()),
Map.entry(UnityCatalogConnection.class, new UnityCatalogConnectionClassConverter()),
Map.entry(CassandraConnection.class, new CassandraConnectionClassConverter()),
Map.entry(SSISConnection.class, new SsisConnectionClassConverter()),
Map.entry(WherescapeConnection.class, new WherescapeConnectionClassConverter()));

View File

@ -0,0 +1,43 @@
/*
* Copyright 2021 Collate
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.openmetadata.service.secrets.converter;
import java.util.List;
import org.openmetadata.schema.services.connections.database.UnityCatalogConnection;
import org.openmetadata.schema.services.connections.database.databricks.AzureADSetup;
import org.openmetadata.schema.services.connections.database.databricks.DatabricksOAuth;
import org.openmetadata.schema.services.connections.database.databricks.PersonalAccessToken;
import org.openmetadata.schema.utils.JsonUtils;
/** Converter class to get a `UnityCatalogConnection` object. */
public class UnityCatalogConnectionClassConverter extends ClassConverter {
private static final List<Class<?>> CONFIG_SOURCE_CLASSES =
List.of(PersonalAccessToken.class, DatabricksOAuth.class, AzureADSetup.class);
public UnityCatalogConnectionClassConverter() {
super(UnityCatalogConnection.class);
}
@Override
public Object convert(Object object) {
UnityCatalogConnection unityCatalogConnection =
(UnityCatalogConnection) JsonUtils.convertValue(object, this.clazz);
tryToConvert(unityCatalogConnection.getAuthType(), CONFIG_SOURCE_CLASSES)
.ifPresent(unityCatalogConnection::setAuthType);
return unityCatalogConnection;
}
}

View File

@ -9,13 +9,17 @@
"databricksType": {
"description": "Service type.",
"type": "string",
"enum": ["UnityCatalog"],
"enum": [
"UnityCatalog"
],
"default": "UnityCatalog"
},
"databricksScheme": {
"description": "SQLAlchemy driver scheme options.",
"type": "string",
"enum": ["databricks+connector"],
"enum": [
"databricks+connector"
],
"default": "databricks+connector"
}
},
@ -37,11 +41,23 @@
"description": "Host and port of the Databricks service.",
"type": "string"
},
"token": {
"title": "Token",
"description": "Generated Token to connect to Databricks.",
"type": "string",
"format": "password"
"authType": {
"title": "Authentication Type",
"description": "Choose between different authentication types for Databricks.",
"oneOf": [
{
"title": "Personal Access Token",
"$ref": "./databricks/personalAccessToken.json"
},
{
"title": "Databricks OAuth",
"$ref": "./databricks/databricksOAuth.json"
},
{
"title": "Azure AD Setup",
"$ref": "./databricks/azureAdSetup.json"
}
]
},
"httpPath": {
"title": "Http Path",
@ -78,7 +94,9 @@
"$ref": "../../../../type/filterPattern.json#/definitions/filterPattern",
"default": {
"includes": [],
"excludes": ["^information_schema$"]
"excludes": [
"^information_schema$"
]
}
},
"tableFilterPattern": {
@ -92,7 +110,9 @@
"$ref": "../../../../type/filterPattern.json#/definitions/filterPattern",
"default": {
"includes": [],
"excludes": ["^system$"]
"excludes": [
"^system$"
]
}
},
"supportsUsageExtraction": {
@ -126,5 +146,8 @@
}
},
"additionalProperties": false,
"required": ["hostPort", "token"]
}
"required": [
"hostPort",
"token"
]
}

View File

@ -409,6 +409,7 @@ export enum EntityStatus {
Draft = "Draft",
InReview = "In Review",
Rejected = "Rejected",
Unprocessed = "Unprocessed",
}
/**

View File

@ -1,191 +0,0 @@
/*
* Copyright 2025 Collate.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* UnityCatalog Connection Config
*/
export interface UnityCatalogConnection {
/**
* Catalog of the data source(Example: hive_metastore). This is optional parameter, if you
* would like to restrict the metadata reading to a single catalog. When left blank,
* OpenMetadata Ingestion attempts to scan all the catalog.
*/
catalog?: string;
connectionArguments?: { [key: string]: any };
connectionOptions?: { [key: string]: string };
/**
* The maximum amount of time (in seconds) to wait for a successful connection to the data
* source. If the connection attempt takes longer than this timeout period, an error will be
* returned.
*/
connectionTimeout?: number;
/**
* Regex to only include/exclude databases that matches the pattern.
*/
databaseFilterPattern?: FilterPattern;
/**
* Database Schema of the data source. This is optional parameter, if you would like to
* restrict the metadata reading to a single schema. When left blank, OpenMetadata Ingestion
* attempts to scan all the schemas.
*/
databaseSchema?: string;
/**
* Host and port of the Databricks service.
*/
hostPort: string;
/**
* Databricks compute resources URL.
*/
httpPath?: string;
sampleDataStorageConfig?: SampleDataStorageConfig;
/**
* Regex to only include/exclude schemas that matches the pattern.
*/
schemaFilterPattern?: FilterPattern;
/**
* SQLAlchemy driver scheme options.
*/
scheme?: DatabricksScheme;
supportsDatabase?: boolean;
supportsDBTExtraction?: boolean;
supportsLineageExtraction?: boolean;
supportsMetadataExtraction?: boolean;
supportsProfiler?: boolean;
supportsQueryComment?: boolean;
supportsUsageExtraction?: boolean;
/**
* Regex to only include/exclude tables that matches the pattern.
*/
tableFilterPattern?: FilterPattern;
/**
* Generated Token to connect to Databricks.
*/
token: string;
/**
* Service Type
*/
type?: DatabricksType;
}
/**
* Regex to only include/exclude databases that matches the pattern.
*
* Regex to only fetch entities that matches the pattern.
*
* Regex to only include/exclude schemas that matches the pattern.
*
* Regex to only include/exclude tables that matches the pattern.
*/
export interface FilterPattern {
/**
* List of strings/regex patterns to match and exclude only database entities that match.
*/
excludes?: string[];
/**
* List of strings/regex patterns to match and include only database entities that match.
*/
includes?: string[];
}
/**
* Storage config to store sample data
*/
export interface SampleDataStorageConfig {
config?: DataStorageConfig;
}
/**
* Storage config to store sample data
*/
export interface DataStorageConfig {
/**
* Bucket Name
*/
bucketName?: string;
/**
* Provide the pattern of the path where the generated sample data file needs to be stored.
*/
filePathPattern?: string;
/**
* When this field enabled a single parquet file will be created to store sample data,
* otherwise we will create a new file per day
*/
overwriteData?: boolean;
/**
* Prefix of the data source.
*/
prefix?: string;
storageConfig?: AwsCredentials;
[property: string]: any;
}
/**
* AWS credentials configs.
*/
export interface AwsCredentials {
/**
* The Amazon Resource Name (ARN) of the role to assume. Required Field in case of Assume
* Role
*/
assumeRoleArn?: string;
/**
* An identifier for the assumed role session. Use the role session name to uniquely
* identify a session when the same role is assumed by different principals or for different
* reasons. Required Field in case of Assume Role
*/
assumeRoleSessionName?: string;
/**
* The Amazon Resource Name (ARN) of the role to assume. Optional Field in case of Assume
* Role
*/
assumeRoleSourceIdentity?: string;
/**
* AWS Access key ID.
*/
awsAccessKeyId?: string;
/**
* AWS Region
*/
awsRegion?: string;
/**
* AWS Secret Access Key.
*/
awsSecretAccessKey?: string;
/**
* AWS Session Token.
*/
awsSessionToken?: string;
/**
* EndPoint URL for the AWS
*/
endPointURL?: string;
/**
* The name of a profile to use with the boto session.
*/
profileName?: string;
}
/**
* SQLAlchemy driver scheme options.
*/
export enum DatabricksScheme {
DatabricksConnector = "databricks+connector",
}
/**
* Service Type
*
* Service type.
*/
export enum DatabricksType {
UnityCatalog = "UnityCatalog",
}