feat(ingest): support user group filtering for Azure AD (#3312)

This commit is contained in:
Vincenzo Lavorini 2021-10-06 08:03:30 +02:00 committed by GitHub
parent b4c0e20c68
commit 9cb71d974e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 170 additions and 120 deletions

View File

@ -90,7 +90,12 @@ source:
graph_url: "https://graph.microsoft.com/v1.0" graph_url: "https://graph.microsoft.com/v1.0"
ingest_users: True ingest_users: True
ingest_groups: True ingest_groups: True
ingest_group_membership: True groups_pattern:
allow:
- ".*"
users_pattern:
allow:
- ".*"
sink: sink:
# sink configs # sink configs
@ -117,16 +122,13 @@ Note that a `.` is used to denote nested fields in the YAML configuration block.
| `ingest_group_membership` | bool | | `True` | Whether group membership should be ingested into DataHub. ingest_groups must be True if this is True. | | `ingest_group_membership` | bool | | `True` | Whether group membership should be ingested into DataHub. ingest_groups must be True if this is True. |
| `azure_ad_response_to_username_attr` | string | | `"login"` | Which Azure AD User Response attribute to use as input to DataHub username mapping. | | `azure_ad_response_to_username_attr` | string | | `"login"` | Which Azure AD User Response attribute to use as input to DataHub username mapping. |
| `azure_ad_response_to_username_regex` | string | | `"([^@]+)"` | A regex used to parse the DataHub username from the attribute specified in `azure_ad_response_to_username_attr`. | | `azure_ad_response_to_username_regex` | string | | `"([^@]+)"` | A regex used to parse the DataHub username from the attribute specified in `azure_ad_response_to_username_attr`. |
| `users_pattern.allow` | list of strings | | | List of regex patterns for users to include in ingestion. The name against which compare the regexp is the DataHub user name, i.e. the one resulting from the action of `azure_ad_response_to_username_attr` and `azure_ad_response_to_username_regex` |
| `users_pattern.deny` | list of strings | | | As above, but for excluding users from ingestion. |
| `azure_ad_response_to_groupname_attr` | string | | `"name"` | Which Azure AD Group Response attribute to use as input to DataHub group name mapping. | | `azure_ad_response_to_groupname_attr` | string | | `"name"` | Which Azure AD Group Response attribute to use as input to DataHub group name mapping. |
| `azure_ad_response_to_groupname_regex` | string | | `"(.*)"` | A regex used to parse the DataHub group name from the attribute specified in `azure_ad_response_to_groupname_attr`. | | `azure_ad_response_to_groupname_regex` | string | | `"(.*)"` | A regex used to parse the DataHub group name from the attribute specified in `azure_ad_response_to_groupname_attr`. |
| `groups_pattern.allow` | list of strings | | | List of regex patterns for groups to include in ingestion. The name against which compare the regexp is the DataHub group name, i.e. the one resulting from the action of `azure_ad_response_to_groupname_attr` and `azure_ad_response_to_groupname_regex` |
| `groups_pattern.deny` | list of strings | | | As above, but for exculing groups from ingestion. |
## Compatibility | `ingest_groups_users` | bool | | `True` | This option is useful only when `ingest_users` is set to False and `ingest_group_membership` to True. As effect, only the users which belongs to the selected groups will be ingested. |
Validated against load:
- User Count: `1000`
- Group Count: `100`
- Group Membership Edges: `1000` (1 per User)
## Questions ## Questions

View File

@ -3,12 +3,13 @@ import logging
import re import re
import urllib import urllib
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Dict, Iterable, List, Union from typing import Any, Dict, Generator, Iterable, List
import click import click
import requests import requests
from datahub.configuration import ConfigModel from datahub.configuration import ConfigModel
from datahub.configuration.common import AllowDenyPattern
from datahub.ingestion.api.common import PipelineContext from datahub.ingestion.api.common import PipelineContext
from datahub.ingestion.api.source import Source, SourceReport from datahub.ingestion.api.source import Source, SourceReport
from datahub.ingestion.api.workunit import MetadataWorkUnit from datahub.ingestion.api.workunit import MetadataWorkUnit
@ -53,6 +54,10 @@ class AzureADConfig(ConfigModel):
ingest_groups: bool = True ingest_groups: bool = True
ingest_group_membership: bool = True ingest_group_membership: bool = True
ingest_groups_users: bool = True
users_pattern: AllowDenyPattern = AllowDenyPattern.allow_all()
groups_pattern: AllowDenyPattern = AllowDenyPattern.allow_all()
@dataclass @dataclass
class AzureADSourceReport(SourceReport): class AzureADSourceReport(SourceReport):
@ -63,11 +68,6 @@ class AzureADSourceReport(SourceReport):
# Source that extracts Azure AD users, groups and group memberships using Microsoft Graph REST API # Source that extracts Azure AD users, groups and group memberships using Microsoft Graph REST API
#
# Validated against load:
# - user count: 1000
# - group count: 100
# - group membership edges: 1000 (1 per user)
class AzureADSource(Source): class AzureADSource(Source):
@ -98,27 +98,48 @@ class AzureADSource(Source):
token = token_response.json().get("access_token") token = token_response.json().get("access_token")
return token return token
else: else:
error_str = f"Token response status code: {str(token_response.status_code)}. Token response content: {str(token_response.content)}" error_str = (
f"Token response status code: {str(token_response.status_code)}. "
f"Token response content: {str(token_response.content)}"
)
logger.error(error_str) logger.error(error_str)
self.report.report_failure("get_token", error_str) self.report.report_failure("get_token", error_str)
click.echo("Error: Token response invalid") click.echo("Error: Token response invalid")
exit() exit()
selected_azure_ad_groups: list = []
azure_ad_groups_users: list = []
def get_workunits(self) -> Iterable[MetadataWorkUnit]: def get_workunits(self) -> Iterable[MetadataWorkUnit]:
# for future developers: The actual logic of this ingestion wants to be executed, in order:
# 1) the groups
# 2) the groups' memberships
# 3) the users
# Create MetadataWorkUnits for CorpGroups # Create MetadataWorkUnits for CorpGroups
if self.config.ingest_groups: if self.config.ingest_groups:
azure_ad_groups = next(self._get_azure_ad_groups()) # 1) the groups
datahub_corp_group_snapshots = self._map_azure_ad_groups(azure_ad_groups) for azure_ad_groups in self._get_azure_ad_groups():
logger.info("Processing another groups batch...")
datahub_corp_group_snapshots = self._map_azure_ad_groups(
azure_ad_groups
)
for datahub_corp_group_snapshot in datahub_corp_group_snapshots: for datahub_corp_group_snapshot in datahub_corp_group_snapshots:
mce = MetadataChangeEvent(proposedSnapshot=datahub_corp_group_snapshot) mce = MetadataChangeEvent(
proposedSnapshot=datahub_corp_group_snapshot
)
wu = MetadataWorkUnit(id=datahub_corp_group_snapshot.urn, mce=mce) wu = MetadataWorkUnit(id=datahub_corp_group_snapshot.urn, mce=mce)
self.report.report_workunit(wu) self.report.report_workunit(wu)
yield wu yield wu
# Populate GroupMembership Aspects for CorpUsers # Populate GroupMembership Aspects for CorpUsers
datahub_corp_user_urn_to_group_membership: Dict[str, GroupMembershipClass] = {} datahub_corp_user_urn_to_group_membership: Dict[str, GroupMembershipClass] = {}
if self.config.ingest_group_membership and azure_ad_groups: if (
# Fetch membership for each group self.config.ingest_group_membership
for azure_ad_group in azure_ad_groups: and len(self.selected_azure_ad_groups) > 0
):
# 2) the groups' membership
for azure_ad_group in self.selected_azure_ad_groups:
datahub_corp_group_urn = self._map_azure_ad_group_to_urn(azure_ad_group) datahub_corp_group_urn = self._map_azure_ad_group_to_urn(azure_ad_group)
if not datahub_corp_group_urn: if not datahub_corp_group_urn:
error_str = "Failed to extract DataHub Group Name from Azure AD Group named {}. Skipping...".format( error_str = "Failed to extract DataHub Group Name from Azure AD Group named {}. Skipping...".format(
@ -127,9 +148,9 @@ class AzureADSource(Source):
self.report.report_failure("azure_ad_group_mapping", error_str) self.report.report_failure("azure_ad_group_mapping", error_str)
continue continue
# Extract and map users for each group # Extract and map users for each group
azure_ad_group_users = next( for azure_ad_group_users in self._get_azure_ad_group_users(
self._get_azure_ad_group_users(azure_ad_group) azure_ad_group
) ):
# if group doesn't have any members, continue # if group doesn't have any members, continue
if not azure_ad_group_users: if not azure_ad_group_users:
continue continue
@ -141,9 +162,11 @@ class AzureADSource(Source):
error_str = "Failed to extract DataHub Username from Azure ADUser {}. Skipping...".format( error_str = "Failed to extract DataHub Username from Azure ADUser {}. Skipping...".format(
azure_ad_user.get("displayName") azure_ad_user.get("displayName")
) )
self.report.report_failure("azure_ad_user_mapping", error_str) self.report.report_failure(
"azure_ad_user_mapping", error_str
)
continue continue
self.azure_ad_groups_users.append(azure_ad_user)
# update/create the GroupMembership aspect for this group member. # update/create the GroupMembership aspect for this group member.
if ( if (
datahub_corp_user_urn datahub_corp_user_urn
@ -157,15 +180,41 @@ class AzureADSource(Source):
datahub_corp_user_urn datahub_corp_user_urn
] = GroupMembershipClass(groups=[datahub_corp_group_urn]) ] = GroupMembershipClass(groups=[datahub_corp_group_urn])
if (
self.config.ingest_groups_users
and self.config.ingest_group_membership
and not self.config.ingest_users
):
# 3) the users
# getting infos about the users belonging to the found groups
datahub_corp_user_snapshots = self._map_azure_ad_users(
self.azure_ad_groups_users
)
yield from self.ingest_ad_users(
datahub_corp_user_snapshots, datahub_corp_user_urn_to_group_membership
)
# Create MetadatWorkUnits for CorpUsers # Create MetadatWorkUnits for CorpUsers
if self.config.ingest_users: if self.config.ingest_users:
azure_ad_users = next(self._get_azure_ad_users()) # 3) the users
for azure_ad_users in self._get_azure_ad_users():
# azure_ad_users = next(self._get_azure_ad_users())
datahub_corp_user_snapshots = self._map_azure_ad_users(azure_ad_users) datahub_corp_user_snapshots = self._map_azure_ad_users(azure_ad_users)
yield from self.ingest_ad_users(
datahub_corp_user_snapshots,
datahub_corp_user_urn_to_group_membership,
)
def ingest_ad_users(
self,
datahub_corp_user_snapshots: Generator[CorpUserSnapshot, Any, None],
datahub_corp_user_urn_to_group_membership: dict,
) -> Generator[MetadataWorkUnit, Any, None]:
for datahub_corp_user_snapshot in datahub_corp_user_snapshots: for datahub_corp_user_snapshot in datahub_corp_user_snapshots:
# Add GroupMembership if applicable # Add GroupMembership if applicable
if ( if (
datahub_corp_user_snapshot.urn datahub_corp_user_snapshot.urn
in datahub_corp_user_urn_to_group_membership in datahub_corp_user_urn_to_group_membership.keys()
): ):
datahub_group_membership = ( datahub_group_membership = (
datahub_corp_user_urn_to_group_membership.get( datahub_corp_user_urn_to_group_membership.get(
@ -185,69 +234,62 @@ class AzureADSource(Source):
def close(self) -> None: def close(self) -> None:
pass pass
def _get_azure_ad_groups(self): def _get_azure_ad_groups(self) -> Iterable[List]:
yield from self._get_azure_ad_data(kind="/groups")
def _get_azure_ad_users(self) -> Iterable[List]:
yield from self._get_azure_ad_data(kind="/users")
def _get_azure_ad_group_users(self, azure_ad_group: dict) -> Iterable[List]:
group_id = azure_ad_group.get("id")
kind = f"/groups/{group_id}/members"
yield from self._get_azure_ad_data(kind=kind)
def _get_azure_ad_data(self, kind: str) -> Iterable[List]:
headers = {"Authorization": "Bearer {}".format(self.token)} headers = {"Authorization": "Bearer {}".format(self.token)}
url = self.config.graph_url + "/groups" # 'ConsistencyLevel': 'eventual'}
url = self.config.graph_url + kind
while True: while True:
if not url: if not url:
break break
response = requests.get(url, headers=headers) response = requests.get(url, headers=headers)
if response.status_code == 200: if response.status_code == 200:
json_data = json.loads(response.text) json_data = json.loads(response.text)
url = json_data.get("@odata.nextLink", None) try:
url = json_data["@odata.nextLink"]
except KeyError:
# no more data will follow
url = False # type: ignore
yield json_data["value"] yield json_data["value"]
else: else:
error_str = f"Response status code: {str(response.status_code)}. Response content: {str(response.content)}" error_str = (
logger.error(error_str) f"Response status code: {str(response.status_code)}. "
self.report.report_failure("_get_azure_ad_groups", error_str) f"Response content: {str(response.content)}"
continue
def _get_azure_ad_users(self):
headers = {"Authorization": "Bearer {}".format(self.token)}
url = self.config.graph_url + "/users"
while True:
if not url:
break
response = requests.get(url, headers=headers)
if response.status_code == 200:
json_data = json.loads(response.text)
url = json_data.get("@odata.nextLink", None)
yield json_data["value"]
else:
error_str = f"Response status code: {str(response.status_code)}. Response content: {str(response.content)}"
logger.error(error_str)
self.report.report_failure("_get_azure_ad_groups", error_str)
continue
def _get_azure_ad_group_users(self, azure_ad_group):
headers = {"Authorization": "Bearer {}".format(self.token)}
url = "{0}/groups/{1}/members".format(
self.config.graph_url, azure_ad_group.get("id")
) )
while True:
if not url:
break
response = requests.get(url, headers=headers)
if response.status_code == 200:
json_data = json.loads(response.text)
url = json_data.get("@odata.nextLink", None)
yield json_data["value"]
else:
error_str = f"Response status code: {str(response.status_code)}. Response content: {str(response.content)}"
logger.error(error_str) logger.error(error_str)
self.report.report_failure("_get_azure_ad_groups", error_str) self.report.report_failure("_get_azure_ad_data_", error_str)
continue continue
def _map_azure_ad_groups(self, azure_ad_groups): def _map_azure_ad_groups(self, azure_ad_groups):
for azure_ad_group in azure_ad_groups: for azure_ad_group in azure_ad_groups:
corp_group_urn = self._map_azure_ad_group_to_urn(azure_ad_group) corp_group_urn = self._map_azure_ad_group_to_urn(azure_ad_group)
if not corp_group_urn: if not corp_group_urn:
error_str = "Failed to extract DataHub Group Name from Azure Group for group named {}. Skipping...".format( error_str = (
azure_ad_group.get("displayName") "Failed to extract DataHub Group Name from Azure Group for group named {}. "
"Skipping...".format(azure_ad_group.get("displayName"))
) )
logger.error(error_str) logger.error(error_str)
self.report.report_failure("azure_ad_group_mapping", error_str) self.report.report_failure("azure_ad_group_mapping", error_str)
continue continue
group_name = self._extract_regex_match_from_dict_value(
azure_ad_group,
self.config.azure_ad_response_to_groupname_attr,
self.config.azure_ad_response_to_groupname_regex,
)
if not self.config.groups_pattern.allowed(group_name):
self.report.report_filtered(f"{corp_group_urn}")
continue
self.selected_azure_ad_groups.append(azure_ad_group)
corp_group_snapshot = CorpGroupSnapshot( corp_group_snapshot = CorpGroupSnapshot(
urn=corp_group_urn, urn=corp_group_urn,
aspects=[], aspects=[],
@ -272,7 +314,7 @@ class AzureADSource(Source):
group_name = self._map_azure_ad_group_to_group_name(azure_ad_group) group_name = self._map_azure_ad_group_to_group_name(azure_ad_group)
if not group_name: if not group_name:
return None return None
# URL encode the group name to deal with potential spaces # decode the group name to deal with URL encoding, and replace spaces with '_'
url_encoded_group_name = urllib.parse.quote(group_name) url_encoded_group_name = urllib.parse.quote(group_name)
return self._make_corp_group_urn(url_encoded_group_name) return self._make_corp_group_urn(url_encoded_group_name)
@ -293,6 +335,9 @@ class AzureADSource(Source):
logger.error(error_str) logger.error(error_str)
self.report.report_failure("azure_ad_user_mapping", error_str) self.report.report_failure("azure_ad_user_mapping", error_str)
continue continue
if not self.config.users_pattern.allowed(corp_user_urn):
self.report.report_filtered(f"{corp_user_urn}.*")
continue
corp_user_snapshot = CorpUserSnapshot( corp_user_snapshot = CorpUserSnapshot(
urn=corp_user_urn, urn=corp_user_urn,
aspects=[], aspects=[],
@ -340,11 +385,13 @@ class AzureADSource(Source):
def _extract_regex_match_from_dict_value( def _extract_regex_match_from_dict_value(
self, str_dict: Dict[str, str], key: str, pattern: str self, str_dict: Dict[str, str], key: str, pattern: str
) -> Union[str, None]: ) -> str:
raw_value = str_dict.get(key) raw_value = str_dict.get(key)
if raw_value is None: if raw_value is None:
return None raise ValueError(f"Unable to find the key {key} in Group. Is it wrong?")
match = re.search(pattern, raw_value) match = re.search(pattern, raw_value)
if match is None: if match is None:
return None raise ValueError(
f"Unable to extract a name from {raw_value} with the pattern {pattern}"
)
return match.group() return match.group()

View File

@ -190,7 +190,8 @@ def mocked_functions(
# For simplicity, each user is placed in ALL groups. # For simplicity, each user is placed in ALL groups.
# Create a separate response mock for each group in our sample data. # Create a separate response mock for each group in our sample data.
r = [] mock_groups_users.return_value = [users]
for _ in groups: # r = []
r.append(users) # for _ in groups:
mock_groups_users.return_value = iter(r) # r.append(users)
# mock_groups_users.return_value = iter(r)