feat(ingest): allow custom SF API version (#11145)

This commit is contained in:
skrydal 2024-08-16 10:46:42 +02:00 committed by GitHub
parent 12b3da3d71
commit 11890e5445
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 121 additions and 24 deletions

View File

@ -3,7 +3,7 @@ import logging
import time
from datetime import datetime
from enum import Enum
from typing import Dict, Iterable, List, Optional
from typing import Any, Dict, Iterable, List, Optional
import requests
from pydantic import Field, validator
@ -124,6 +124,9 @@ class SalesforceConfig(DatasetSourceConfigMixin):
default=dict(),
description='Regex patterns for tables/schemas to describe domain_key domain key (domain_key can be any string like "sales".) There can be multiple domain keys specified.',
)
api_version: Optional[str] = Field(
description="If specified, overrides default version used by the Salesforce package. Example value: '59.0'"
)
profiling: SalesforceProfilingConfig = SalesforceProfilingConfig()
@ -222,6 +225,12 @@ class SalesforceSource(Source):
self.session = requests.Session()
self.platform: str = "salesforce"
self.fieldCounts = {}
common_args: Dict[str, Any] = {
"domain": "test" if self.config.is_sandbox else None,
"session": self.session,
}
if self.config.api_version:
common_args["version"] = self.config.api_version
try:
if self.config.auth is SalesforceAuthType.DIRECT_ACCESS_TOKEN:
@ -236,8 +245,7 @@ class SalesforceSource(Source):
self.sf = Salesforce(
instance_url=self.config.instance_url,
session_id=self.config.access_token,
session=self.session,
domain="test" if self.config.is_sandbox else None,
**common_args,
)
elif self.config.auth is SalesforceAuthType.USERNAME_PASSWORD:
logger.debug("Username/Password Provided in Config")
@ -255,8 +263,7 @@ class SalesforceSource(Source):
username=self.config.username,
password=self.config.password,
security_token=self.config.security_token,
session=self.session,
domain="test" if self.config.is_sandbox else None,
**common_args,
)
elif self.config.auth is SalesforceAuthType.JSON_WEB_TOKEN:
@ -275,14 +282,13 @@ class SalesforceSource(Source):
username=self.config.username,
consumer_key=self.config.consumer_key,
privatekey=self.config.private_key,
session=self.session,
domain="test" if self.config.is_sandbox else None,
**common_args,
)
except Exception as e:
logger.error(e)
raise ConfigurationError("Salesforce login failed") from e
else:
if not self.config.api_version:
# List all REST API versions and use latest one
versions_url = "https://{instance}/services/data/".format(
instance=self.sf.sf_instance,
@ -290,17 +296,22 @@ class SalesforceSource(Source):
versions_response = self.sf._call_salesforce("GET", versions_url).json()
latest_version = versions_response[-1]
version = latest_version["version"]
# we could avoid setting the version like below (after the Salesforce object has been already initiated
# above), since, according to the docs:
# https://developer.salesforce.com/docs/atlas.en-us.api_rest.meta/api_rest/dome_versions.htm
# we don't need to be authenticated to list the versions (so we could perform this call before even
# authenticating)
self.sf.sf_version = version
self.base_url = "https://{instance}/services/data/v{sf_version}/".format(
instance=self.sf.sf_instance, sf_version=version
)
self.base_url = "https://{instance}/services/data/v{sf_version}/".format(
instance=self.sf.sf_instance, sf_version=self.sf.sf_version
)
logger.debug(
"Using Salesforce REST API with {label} version: {version}".format(
label=latest_version["label"], version=latest_version["version"]
)
logger.debug(
"Using Salesforce REST API version: {version}".format(
version=self.sf.sf_version
)
)
def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:
sObjects = self.get_salesforce_objects()

View File

@ -1,10 +1,12 @@
import json
import pathlib
from unittest import mock
from unittest.mock import Mock
from freezegun import freeze_time
from datahub.ingestion.run.pipeline import Pipeline
from datahub.ingestion.source.salesforce import SalesforceConfig, SalesforceSource
from tests.test_helpers import mce_helpers
FROZEN_TIME = "2022-05-12 11:00:00"
@ -19,15 +21,16 @@ def _read_response(file_name: str) -> dict:
return data
class MockResponse:
def __init__(self, json_data, status_code):
self.json_data = json_data
self.status_code = status_code
def json(self):
return self.json_data
def side_effect_call_salesforce(type, url):
class MockResponse:
def __init__(self, json_data, status_code):
self.json_data = json_data
self.status_code = status_code
def json(self):
return self.json_data
if url.endswith("/services/data/"):
return MockResponse(_read_response("versions_response.json"), 200)
if url.endswith("FROM EntityDefinition WHERE IsCustomizable = true"):
@ -55,9 +58,92 @@ def side_effect_call_salesforce(type, url):
return MockResponse({}, 404)
@mock.patch("datahub.ingestion.source.salesforce.Salesforce")
def test_latest_version(mock_sdk):
mock_sf = mock.Mock()
mocked_call = mock.Mock()
mocked_call.side_effect = side_effect_call_salesforce
mock_sf._call_salesforce = mocked_call
mock_sdk.return_value = mock_sf
config = SalesforceConfig.parse_obj(
{
"auth": "DIRECT_ACCESS_TOKEN",
"instance_url": "https://mydomain.my.salesforce.com/",
"access_token": "access_token`",
"ingest_tags": True,
"object_pattern": {
"allow": [
"^Account$",
"^Property__c$",
],
},
"domain": {"sales": {"allow": {"^Property__c$"}}},
"profiling": {"enabled": True},
"profile_pattern": {
"allow": [
"^Property__c$",
]
},
}
)
SalesforceSource(config=config, ctx=Mock())
calls = mock_sf._call_salesforce.mock_calls
assert (
len(calls) == 1
), "We didn't specify version but source didn't call SF API to get the latest one"
assert calls[0].ends_with(
"/services/data"
), "Source didn't call proper SF API endpoint to get all versions"
assert (
mock_sf.sf_version == "54.0"
), "API version was not correctly set (see versions_responses.json)"
@mock.patch("datahub.ingestion.source.salesforce.Salesforce")
def test_custom_version(mock_sdk):
mock_sf = mock.Mock()
mocked_call = mock.Mock()
mocked_call.side_effect = side_effect_call_salesforce
mock_sf._call_salesforce = mocked_call
mock_sdk.return_value = mock_sf
config = SalesforceConfig.parse_obj(
{
"auth": "DIRECT_ACCESS_TOKEN",
"api_version": "46.0",
"instance_url": "https://mydomain.my.salesforce.com/",
"access_token": "access_token`",
"ingest_tags": True,
"object_pattern": {
"allow": [
"^Account$",
"^Property__c$",
],
},
"domain": {"sales": {"allow": {"^Property__c$"}}},
"profiling": {"enabled": True},
"profile_pattern": {
"allow": [
"^Property__c$",
]
},
}
)
SalesforceSource(config=config, ctx=Mock())
calls = mock_sf._call_salesforce.mock_calls
assert (
len(calls) == 0
), "Source called API to get all versions even though we specified proper version"
assert (
mock_sdk.call_args.kwargs["version"] == "46.0"
), "API client object was not correctly initialized with the custom version"
@freeze_time(FROZEN_TIME)
def test_salesforce_ingest(pytestconfig, tmp_path):
with mock.patch("simple_salesforce.Salesforce") as mock_sdk:
with mock.patch("datahub.ingestion.source.salesforce.Salesforce") as mock_sdk:
mock_sf = mock.Mock()
mocked_call = mock.Mock()
mocked_call.side_effect = side_effect_call_salesforce