package auth.sso.oidc;
import static auth.AuthUtils.*;
import static com.linkedin.metadata.Constants.CORP_USER_ENTITY_NAME;
import static com.linkedin.metadata.Constants.GROUP_MEMBERSHIP_ASPECT_NAME;
import static org.pac4j.play.store.PlayCookieSessionStore.*;
import static play.mvc.Results.internalServerError;
import auth.CookieConfigs;
import auth.sso.SsoManager;
import client.AuthServiceClient;
import com.datahub.authentication.Authentication;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.linkedin.common.AuditStamp;
import com.linkedin.common.CorpGroupUrnArray;
import com.linkedin.common.CorpuserUrnArray;
import com.linkedin.common.UrnArray;
import com.linkedin.common.url.Url;
import com.linkedin.common.urn.CorpGroupUrn;
import com.linkedin.common.urn.CorpuserUrn;
import com.linkedin.common.urn.Urn;
import com.linkedin.data.template.SetMode;
import com.linkedin.entity.Entity;
import com.linkedin.entity.client.SystemEntityClient;
import com.linkedin.events.metadata.ChangeType;
import com.linkedin.identity.CorpGroupInfo;
import com.linkedin.identity.CorpUserEditableInfo;
import com.linkedin.identity.CorpUserInfo;
import com.linkedin.identity.CorpUserStatus;
import com.linkedin.identity.GroupMembership;
import com.linkedin.metadata.Constants;
import com.linkedin.metadata.aspect.CorpGroupAspect;
import com.linkedin.metadata.aspect.CorpGroupAspectArray;
import com.linkedin.metadata.aspect.CorpUserAspect;
import com.linkedin.metadata.aspect.CorpUserAspectArray;
import com.linkedin.metadata.snapshot.CorpGroupSnapshot;
import com.linkedin.metadata.snapshot.CorpUserSnapshot;
import com.linkedin.metadata.snapshot.Snapshot;
import com.linkedin.metadata.utils.GenericRecordUtils;
import com.linkedin.mxe.MetadataChangeProposal;
import com.linkedin.r2.RemoteInvocationException;
import java.io.UnsupportedEncodingException;
import java.net.MalformedURLException;
import java.net.URI;
import java.net.URLEncoder;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Base64;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.pac4j.core.config.Config;
import org.pac4j.core.context.Cookie;
import org.pac4j.core.engine.DefaultCallbackLogic;
import org.pac4j.core.http.adapter.HttpActionAdapter;
import org.pac4j.core.profile.CommonProfile;
import org.pac4j.core.profile.ProfileManager;
import org.pac4j.core.profile.UserProfile;
import org.pac4j.core.util.Pac4jConstants;
import org.pac4j.play.PlayWebContext;
import play.mvc.Result;
/**
 * This class contains the logic that is executed when an OpenID Connect Identity Provider redirects
 * back to D DataHub after an authentication attempt.
 *
 * 
On receiving a user profile from the IdP (using /userInfo endpoint), we attempt to extract
 * basic information about the user including their name, email, groups, & more. If just-in-time
 * provisioning is enabled, we also attempt to create a DataHub User ({@link CorpUserSnapshot}) for
 * the user, along with any Groups ({@link CorpGroupSnapshot}) that can be extracted, only doing so
 * if the user does not already exist.
 */
@Slf4j
public class OidcCallbackLogic extends DefaultCallbackLogic {
  private final SsoManager _ssoManager;
  private final SystemEntityClient _entityClient;
  private final Authentication _systemAuthentication;
  private final AuthServiceClient _authClient;
  private final CookieConfigs _cookieConfigs;
  public OidcCallbackLogic(
      final SsoManager ssoManager,
      final Authentication systemAuthentication,
      final SystemEntityClient entityClient,
      final AuthServiceClient authClient,
      final CookieConfigs cookieConfigs) {
    _ssoManager = ssoManager;
    _systemAuthentication = systemAuthentication;
    _entityClient = entityClient;
    _authClient = authClient;
    _cookieConfigs = cookieConfigs;
  }
  @Override
  public Result perform(
      PlayWebContext context,
      Config config,
      HttpActionAdapter httpActionAdapter,
      String defaultUrl,
      Boolean saveInSession,
      Boolean multiProfile,
      Boolean renewSession,
      String defaultClient) {
    setContextRedirectUrl(context);
    final Result result =
        super.perform(
            context,
            config,
            httpActionAdapter,
            defaultUrl,
            saveInSession,
            multiProfile,
            renewSession,
            defaultClient);
    // Handle OIDC authentication errors.
    if (OidcResponseErrorHandler.isError(context)) {
      return OidcResponseErrorHandler.handleError(context);
    }
    // By this point, we know that OIDC is the enabled provider.
    final OidcConfigs oidcConfigs = (OidcConfigs) _ssoManager.getSsoProvider().configs();
    return handleOidcCallback(oidcConfigs, result, context, getProfileManager(context));
  }
  @SuppressWarnings("unchecked")
  private void setContextRedirectUrl(PlayWebContext context) {
    Optional redirectUrl =
        context.getRequestCookies().stream()
            .filter(cookie -> REDIRECT_URL_COOKIE_NAME.equals(cookie.getName()))
            .findFirst();
    redirectUrl.ifPresent(
        cookie ->
            context
                .getSessionStore()
                .set(
                    context,
                    Pac4jConstants.REQUESTED_URL,
                    JAVA_SER_HELPER.deserializeFromBytes(
                        uncompressBytes(Base64.getDecoder().decode(cookie.getValue())))));
  }
  private Result handleOidcCallback(
      final OidcConfigs oidcConfigs,
      final Result result,
      final PlayWebContext context,
      final ProfileManager profileManager) {
    log.debug("Beginning OIDC Callback Handling...");
    if (profileManager.isAuthenticated()) {
      // If authenticated, the user should have a profile.
      final CommonProfile profile = (CommonProfile) profileManager.get(true).get();
      log.debug(
          String.format(
              "Found authenticated user with profile %s", profile.getAttributes().toString()));
      // Extract the User name required to log into DataHub.
      final String userName = extractUserNameOrThrow(oidcConfigs, profile);
      final CorpuserUrn corpUserUrn = new CorpuserUrn(userName);
      try {
        // If just-in-time User Provisioning is enabled, try to create the DataHub user if it does
        // not exist.
        if (oidcConfigs.isJitProvisioningEnabled()) {
          log.debug("Just-in-time provisioning is enabled. Beginning provisioning process...");
          CorpUserSnapshot extractedUser = extractUser(corpUserUrn, profile);
          tryProvisionUser(extractedUser);
          if (oidcConfigs.isExtractGroupsEnabled()) {
            // Extract groups & provision them.
            List extractedGroups = extractGroups(profile);
            tryProvisionGroups(extractedGroups);
            // Add users to groups on DataHub. Note that this clears existing group membership for a
            // user if it already exists.
            updateGroupMembership(corpUserUrn, createGroupMembership(extractedGroups));
          }
        } else if (oidcConfigs.isPreProvisioningRequired()) {
          // We should only allow logins for user accounts that have been pre-provisioned
          log.debug("Pre Provisioning is required. Beginning validation of extracted user...");
          verifyPreProvisionedUser(corpUserUrn);
        }
        // Update user status to active on login.
        // If we want to prevent certain users from logging in, here's where we'll want to do it.
        setUserStatus(
            corpUserUrn,
            new CorpUserStatus()
                .setStatus(Constants.CORP_USER_STATUS_ACTIVE)
                .setLastModified(
                    new AuditStamp()
                        .setActor(Urn.createFromString(Constants.SYSTEM_ACTOR))
                        .setTime(System.currentTimeMillis())));
      } catch (Exception e) {
        log.error("Failed to perform post authentication steps. Redirecting to error page.", e);
        return internalServerError(
            String.format(
                "Failed to perform post authentication steps. Error message: %s", e.getMessage()));
      }
      // Successfully logged in - Generate GMS login token
      final String accessToken = _authClient.generateSessionTokenForUser(corpUserUrn.getId());
      return result
          .withSession(createSessionMap(corpUserUrn.toString(), accessToken))
          .withCookies(
              createActorCookie(
                  corpUserUrn.toString(),
                  _cookieConfigs.getTtlInHours(),
                  _cookieConfigs.getAuthCookieSameSite(),
                  _cookieConfigs.getAuthCookieSecure()));
    }
    return internalServerError(
        "Failed to authenticate current user. Cannot find valid identity provider profile in session.");
  }
  private String extractUserNameOrThrow(
      final OidcConfigs oidcConfigs, final CommonProfile profile) {
    // Ensure that the attribute exists (was returned by IdP)
    if (!profile.containsAttribute(oidcConfigs.getUserNameClaim())) {
      throw new RuntimeException(
          String.format(
              "Failed to resolve user name claim from profile provided by Identity Provider. Missing attribute. Attribute: '%s', Regex: '%s', Profile: %s",
              oidcConfigs.getUserNameClaim(),
              oidcConfigs.getUserNameClaimRegex(),
              profile.getAttributes().toString()));
    }
    final String userNameClaim = (String) profile.getAttribute(oidcConfigs.getUserNameClaim());
    final Optional mappedUserName =
        extractRegexGroup(oidcConfigs.getUserNameClaimRegex(), userNameClaim);
    return mappedUserName.orElseThrow(
        () ->
            new RuntimeException(
                String.format(
                    "Failed to extract DataHub username from username claim %s using regex %s. Profile: %s",
                    userNameClaim,
                    oidcConfigs.getUserNameClaimRegex(),
                    profile.getAttributes().toString())));
  }
  /** Attempts to map to an OIDC {@link CommonProfile} (userInfo) to a {@link CorpUserSnapshot}. */
  private CorpUserSnapshot extractUser(CorpuserUrn urn, CommonProfile profile) {
    log.debug(
        String.format(
            "Attempting to extract user from OIDC profile %s", profile.getAttributes().toString()));
    // Extracts these based on the default set of OIDC claims, described here:
    // https://developer.okta.com/blog/2017/07/25/oidc-primer-part-1
    String firstName = profile.getFirstName();
    String lastName = profile.getFamilyName();
    String email = profile.getEmail();
    URI picture = profile.getPictureUrl();
    String displayName = profile.getDisplayName();
    String fullName =
        (String)
            profile.getAttribute("name"); // Name claim is sometimes provided, including by Google.
    if (fullName == null && firstName != null && lastName != null) {
      fullName = String.format("%s %s", firstName, lastName);
    }
    // TODO: Support custom claims mapping. (e.g. department, title, etc)
    final CorpUserInfo userInfo = new CorpUserInfo();
    userInfo.setActive(true);
    userInfo.setFirstName(firstName, SetMode.IGNORE_NULL);
    userInfo.setLastName(lastName, SetMode.IGNORE_NULL);
    userInfo.setFullName(fullName, SetMode.IGNORE_NULL);
    userInfo.setEmail(email, SetMode.IGNORE_NULL);
    // If there is a display name, use it. Otherwise fall back to full name.
    userInfo.setDisplayName(
        displayName == null ? userInfo.getFullName() : displayName, SetMode.IGNORE_NULL);
    final CorpUserEditableInfo editableInfo = new CorpUserEditableInfo();
    try {
      if (picture != null) {
        editableInfo.setPictureLink(new Url(picture.toURL().toString()));
      }
    } catch (MalformedURLException e) {
      log.error("Failed to extract User Profile URL skipping.", e);
    }
    final CorpUserSnapshot corpUserSnapshot = new CorpUserSnapshot();
    corpUserSnapshot.setUrn(urn);
    final CorpUserAspectArray aspects = new CorpUserAspectArray();
    aspects.add(CorpUserAspect.create(userInfo));
    aspects.add(CorpUserAspect.create(editableInfo));
    corpUserSnapshot.setAspects(aspects);
    return corpUserSnapshot;
  }
  public static Collection getGroupNames(CommonProfile profile, Object groupAttribute, String groupsClaimName) {
      Collection groupNames = Collections.emptyList();
      try {
        if (groupAttribute instanceof Collection) {
          // List of group names
          groupNames = (Collection) profile.getAttribute(groupsClaimName, Collection.class);
        } else if (groupAttribute instanceof String) {
          String groupString = (String) groupAttribute;
          ObjectMapper objectMapper = new ObjectMapper();
          try {
            // Json list of group names
            groupNames = objectMapper.readValue(groupString, new TypeReference>(){});
          } catch (Exception e) {
            groupNames = Arrays.asList(groupString.split(","));
          }
        }
      } catch (Exception e) {
        log.error(String.format(
                "Failed to parse group names: Expected to find a list of strings for attribute with name %s, found %s",
                groupsClaimName, profile.getAttribute(groupsClaimName).getClass()));
      }
      return groupNames;
  }
  private List extractGroups(CommonProfile profile) {
    log.debug(
        String.format(
            "Attempting to extract groups from OIDC profile %s",
            profile.getAttributes().toString()));
    final OidcConfigs configs = (OidcConfigs) _ssoManager.getSsoProvider().configs();
    // First, attempt to extract a list of groups from the profile, using the group name attribute
    // config.
    final List extractedGroups = new ArrayList<>();
    final List groupsClaimNames =
        new ArrayList(Arrays.asList(configs.getGroupsClaimName().split(",")))
            .stream().map(String::trim).collect(Collectors.toList());
    for (final String groupsClaimName : groupsClaimNames) {
      if (profile.containsAttribute(groupsClaimName)) {
        try {
          final List groupSnapshots = new ArrayList<>();
          Collection groupNames = getGroupNames(profile, profile.getAttribute(groupsClaimName), groupsClaimName);
          for (String groupName : groupNames) {
            // Create a basic CorpGroupSnapshot from the information.
            try {
              final CorpGroupInfo corpGroupInfo = new CorpGroupInfo();
              corpGroupInfo.setAdmins(new CorpuserUrnArray());
              corpGroupInfo.setGroups(new CorpGroupUrnArray());
              corpGroupInfo.setMembers(new CorpuserUrnArray());
              corpGroupInfo.setEmail("");
              corpGroupInfo.setDisplayName(groupName);
              // To deal with the possibility of spaces, we url encode the URN group name.
              final String urlEncodedGroupName =
                  URLEncoder.encode(groupName, StandardCharsets.UTF_8.toString());
              final CorpGroupUrn groupUrn = new CorpGroupUrn(urlEncodedGroupName);
              final CorpGroupSnapshot corpGroupSnapshot = new CorpGroupSnapshot();
              corpGroupSnapshot.setUrn(groupUrn);
              final CorpGroupAspectArray aspects = new CorpGroupAspectArray();
              aspects.add(CorpGroupAspect.create(corpGroupInfo));
              corpGroupSnapshot.setAspects(aspects);
              groupSnapshots.add(corpGroupSnapshot);
            } catch (UnsupportedEncodingException ex) {
              log.error(
                  String.format(
                      "Failed to URL encoded extracted group name %s. Skipping", groupName));
            }
          }
          if (groupSnapshots.isEmpty()) {
            log.warn(
                String.format(
                    "Failed to extract groups: No OIDC claim with name %s found", groupsClaimName));
          } else {
            extractedGroups.addAll(groupSnapshots);
          }
        } catch (Exception e) {
          log.error(
              String.format(
                  "Failed to extract groups: Expected to find a list of strings for attribute with name %s, found %s",
                  groupsClaimName, profile.getAttribute(groupsClaimName).getClass()));
        }
      }
    }
    return extractedGroups;
  }
  private GroupMembership createGroupMembership(final List extractedGroups) {
    final GroupMembership groupMembershipAspect = new GroupMembership();
    groupMembershipAspect.setGroups(
        new UrnArray(
            extractedGroups.stream().map(CorpGroupSnapshot::getUrn).collect(Collectors.toList())));
    return groupMembershipAspect;
  }
  private void tryProvisionUser(CorpUserSnapshot corpUserSnapshot) {
    log.debug(String.format("Attempting to provision user with urn %s", corpUserSnapshot.getUrn()));
    // 1. Check if this user already exists.
    try {
      final Entity corpUser = _entityClient.get(corpUserSnapshot.getUrn(), _systemAuthentication);
      final CorpUserSnapshot existingCorpUserSnapshot = corpUser.getValue().getCorpUserSnapshot();
      log.debug(String.format("Fetched GMS user with urn %s", corpUserSnapshot.getUrn()));
      // If we find more than the key aspect, then the entity "exists".
      if (existingCorpUserSnapshot.getAspects().size() <= 1) {
        log.debug(
            String.format(
                "Extracted user that does not yet exist %s. Provisioning...",
                corpUserSnapshot.getUrn()));
        // 2. The user does not exist. Provision them.
        final Entity newEntity = new Entity();
        newEntity.setValue(Snapshot.create(corpUserSnapshot));
        _entityClient.update(newEntity, _systemAuthentication);
        log.debug(String.format("Successfully provisioned user %s", corpUserSnapshot.getUrn()));
      }
      log.debug(
          String.format(
              "User %s already exists. Skipping provisioning", corpUserSnapshot.getUrn()));
      // Otherwise, the user exists. Skip provisioning.
    } catch (RemoteInvocationException e) {
      // Failing provisioning is something worth throwing about.
      throw new RuntimeException(
          String.format("Failed to provision user with urn %s.", corpUserSnapshot.getUrn()), e);
    }
  }
  private void tryProvisionGroups(List corpGroups) {
    log.debug(
        String.format(
            "Attempting to provision groups with urns %s",
            corpGroups.stream().map(CorpGroupSnapshot::getUrn).collect(Collectors.toList())));
    // 1. Check if this user already exists.
    try {
      final Set urnsToFetch =
          corpGroups.stream().map(CorpGroupSnapshot::getUrn).collect(Collectors.toSet());
      final Map existingGroups =
          _entityClient.batchGet(urnsToFetch, _systemAuthentication);
      log.debug(String.format("Fetched GMS groups with urns %s", existingGroups.keySet()));
      final List groupsToCreate = new ArrayList<>();
      for (CorpGroupSnapshot extractedGroup : corpGroups) {
        if (existingGroups.containsKey(extractedGroup.getUrn())) {
          final Entity groupEntity = existingGroups.get(extractedGroup.getUrn());
          final CorpGroupSnapshot corpGroupSnapshot = groupEntity.getValue().getCorpGroupSnapshot();
          // If more than the key aspect exists, then the group already "exists".
          if (corpGroupSnapshot.getAspects().size() <= 1) {
            log.debug(
                String.format(
                    "Extracted group that does not yet exist %s. Provisioning...",
                    corpGroupSnapshot.getUrn()));
            groupsToCreate.add(extractedGroup);
          }
          log.debug(
              String.format(
                  "Group %s already exists. Skipping provisioning", corpGroupSnapshot.getUrn()));
        } else {
          // Should not occur until we stop returning default Key aspects for unrecognized entities.
          log.debug(
              String.format(
                  "Extracted group that does not yet exist %s. Provisioning...",
                  extractedGroup.getUrn()));
          groupsToCreate.add(extractedGroup);
        }
      }
      List groupsToCreateUrns =
          groupsToCreate.stream().map(CorpGroupSnapshot::getUrn).collect(Collectors.toList());
      log.debug(String.format("Provisioning groups with urns %s", groupsToCreateUrns));
      // Now batch create all entities identified to create.
      _entityClient.batchUpdate(
          groupsToCreate.stream()
              .map(groupSnapshot -> new Entity().setValue(Snapshot.create(groupSnapshot)))
              .collect(Collectors.toSet()),
          _systemAuthentication);
      log.debug(String.format("Successfully provisioned groups with urns %s", groupsToCreateUrns));
    } catch (RemoteInvocationException e) {
      // Failing provisioning is something worth throwing about.
      throw new RuntimeException(
          String.format(
              "Failed to provision groups with urns %s.",
              corpGroups.stream().map(CorpGroupSnapshot::getUrn).collect(Collectors.toList())),
          e);
    }
  }
  private void updateGroupMembership(Urn urn, GroupMembership groupMembership) {
    log.debug(String.format("Updating group membership for user %s", urn));
    final MetadataChangeProposal proposal = new MetadataChangeProposal();
    proposal.setEntityUrn(urn);
    proposal.setEntityType(CORP_USER_ENTITY_NAME);
    proposal.setAspectName(GROUP_MEMBERSHIP_ASPECT_NAME);
    proposal.setAspect(GenericRecordUtils.serializeAspect(groupMembership));
    proposal.setChangeType(ChangeType.UPSERT);
    try {
      _entityClient.ingestProposal(proposal, _systemAuthentication);
    } catch (RemoteInvocationException e) {
      throw new RuntimeException(
          String.format("Failed to update group membership for user with urn %s", urn), e);
    }
  }
  private void verifyPreProvisionedUser(CorpuserUrn urn) {
    // Validate that the user exists in the system (there is more than just a key aspect for them,
    // as of today).
    try {
      final Entity corpUser = _entityClient.get(urn, _systemAuthentication);
      log.debug(String.format("Fetched GMS user with urn %s", urn));
      // If we find more than the key aspect, then the entity "exists".
      if (corpUser.getValue().getCorpUserSnapshot().getAspects().size() <= 1) {
        log.debug(
            String.format(
                "Found user that does not yet exist %s. Invalid login attempt. Throwing...", urn));
        throw new RuntimeException(
            String.format(
                "User with urn %s has not yet been provisioned in DataHub. "
                    + "Please contact your DataHub admin to provision an account.",
                urn));
      }
      // Otherwise, the user exists.
    } catch (RemoteInvocationException e) {
      // Failing validation is something worth throwing about.
      throw new RuntimeException(String.format("Failed to validate user with urn %s.", urn), e);
    }
  }
  private void setUserStatus(final Urn urn, final CorpUserStatus newStatus) throws Exception {
    // Update status aspect to be active.
    final MetadataChangeProposal proposal = new MetadataChangeProposal();
    proposal.setEntityUrn(urn);
    proposal.setEntityType(Constants.CORP_USER_ENTITY_NAME);
    proposal.setAspectName(Constants.CORP_USER_STATUS_ASPECT_NAME);
    proposal.setAspect(GenericRecordUtils.serializeAspect(newStatus));
    proposal.setChangeType(ChangeType.UPSERT);
    _entityClient.ingestProposal(proposal, _systemAuthentication);
  }
  private Optional extractRegexGroup(final String patternStr, final String target) {
    final Pattern pattern = Pattern.compile(patternStr);
    final Matcher matcher = pattern.matcher(target);
    if (matcher.find()) {
      final String extractedValue = matcher.group();
      return Optional.of(extractedValue);
    }
    return Optional.empty();
  }
}