refactor(smoke test): centralise env variables (#15100)

This commit is contained in:
Aseem Bansal 2025-10-24 19:40:57 +05:30 committed by GitHub
parent 9bfb90e188
commit 61f9dd92ab
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 296 additions and 53 deletions

View File

@ -7,6 +7,7 @@ import requests
from datahub.ingestion.graph.client import DatahubClientConfig, DataHubGraph, get_default_graph
from tests.test_result_msg import send_message
from tests.utilities import env_vars
from tests.utils import (
TestSessionWrapper,
get_frontend_session,
@ -149,11 +150,9 @@ def bin_pack_tasks(tasks, n_buckets):
return buckets
def get_batch_start_end(num_tests: int) -> Tuple[int, int]:
batch_count_env = os.getenv("BATCH_COUNT", 1)
batch_count = int(batch_count_env)
batch_count = env_vars.get_batch_count()
batch_number_env = os.getenv("BATCH_NUMBER", 0)
batch_number = int(batch_number_env)
batch_number = env_vars.get_batch_number()
if batch_count == 0 or batch_count > num_tests:
raise ValueError(
@ -182,7 +181,7 @@ def get_batch_start_end(num_tests: int) -> Tuple[int, int]:
def pytest_collection_modifyitems(
session: pytest.Session, config: pytest.Config, items: List[Item]
) -> None:
if os.getenv("TEST_STRATEGY") == "cypress":
if env_vars.get_test_strategy() == "cypress":
return # We launch cypress via pytests, but needs a different batching mechanism at cypress level.
# If BATCH_COUNT and BATCH_ENV vars are set, splits the pytests to batches and runs filters only the BATCH_NUMBER

View File

@ -21,6 +21,7 @@ from datahub.ingestion.api.sink import NoopWriteCallback
from datahub.ingestion.graph.client import DatahubClientConfig, DataHubGraph
from datahub.ingestion.sink.file import FileSink, FileSinkConfig
from datahub.utilities.urns.urn import Urn
from tests.utilities import env_vars
from tests.utils import (
delete_urns_from_file,
get_gms_url,
@ -30,7 +31,7 @@ from tests.utils import (
logger = logging.getLogger(__name__)
DELETE_AFTER_TEST = os.getenv("DELETE_AFTER_TEST", "false").lower() == "true"
DELETE_AFTER_TEST = env_vars.get_delete_after_test()
class FileEmitter:

View File

@ -1,13 +1,14 @@
import logging
import os
import subprocess
import time
USE_STATIC_SLEEP: bool = bool(os.getenv("USE_STATIC_SLEEP", False))
ELASTICSEARCH_REFRESH_INTERVAL_SECONDS: int = int(
os.getenv("ELASTICSEARCH_REFRESH_INTERVAL_SECONDS", 3)
from tests.utilities import env_vars
USE_STATIC_SLEEP: bool = env_vars.get_use_static_sleep()
ELASTICSEARCH_REFRESH_INTERVAL_SECONDS: int = (
env_vars.get_elasticsearch_refresh_interval_seconds()
)
KAFKA_BOOTSTRAP_SERVER: str = str(os.getenv("KAFKA_BOOTSTRAP_SERVER", "broker:29092"))
KAFKA_BOOTSTRAP_SERVER: str = env_vars.get_kafka_bootstrap_server()
logger = logging.getLogger(__name__)
@ -35,7 +36,7 @@ def wait_for_writes_to_sync(
time.sleep(ELASTICSEARCH_REFRESH_INTERVAL_SECONDS)
return
KAFKA_BROKER_CONTAINER: str = str(
os.getenv("KAFKA_BROKER_CONTAINER", infer_kafka_broker_container())
env_vars.get_kafka_broker_container() or infer_kafka_broker_container()
)
start_time = time.time()
# get offsets

View File

@ -13,6 +13,7 @@ from tests.setup.lineage.ingest_time_lineage import (
get_time_lineage_urns,
ingest_time_lineage,
)
from tests.utilities import env_vars
from tests.utils import (
create_datahub_step_state_aspects,
delete_urns,
@ -234,15 +235,15 @@ def _get_cypress_tests_batch():
else:
tests_with_weights.append(test)
test_batches = bin_pack_tasks(tests_with_weights, int(os.getenv("BATCH_COUNT", 1)))
return test_batches[int(os.getenv("BATCH_NUMBER", 0))]
test_batches = bin_pack_tasks(tests_with_weights, env_vars.get_batch_count())
return test_batches[env_vars.get_batch_number()]
def test_run_cypress(auth_session):
# Run with --record option only if CYPRESS_RECORD_KEY is non-empty
record_key = os.getenv("CYPRESS_RECORD_KEY")
record_key = env_vars.get_cypress_record_key()
tag_arg = ""
test_strategy = os.getenv("TEST_STRATEGY", None)
test_strategy = env_vars.get_test_strategy()
if record_key:
record_arg = " --record "
batch_number = os.getenv("BATCH_NUMBER")

View File

@ -15,7 +15,6 @@ Note: Some tests may be skipped if required configuration (e.g., Mixpanel API se
"""
import json
import os
import time
from datetime import datetime, timezone
@ -24,6 +23,7 @@ import requests
from confluent_kafka import Consumer, KafkaError
from dotenv import load_dotenv
from tests.utilities import env_vars
from tests.utils import get_kafka_broker_url
# Load environment variables from .env file if it exists
@ -37,7 +37,7 @@ def test_tracking_api_mixpanel(auth_session, graph_client):
"""Test that we can post events to the tracking endpoint and verify they are sent to Mixpanel."""
# Check if Mixpanel API secret is available in environment variables
api_secret = os.environ.get("MIXPANEL_API_SECRET")
api_secret = env_vars.get_mixpanel_api_secret()
if not api_secret:
pytest.skip("MIXPANEL_API_SECRET environment variable not set, skipping test")
@ -80,9 +80,7 @@ def test_tracking_api_mixpanel(auth_session, graph_client):
# Query Mixpanel's JQL API to retrieve our test event
# Note: This requires a service account with access to JQL
project_id = os.environ.get(
"MIXPANEL_PROJECT_ID", "3653440"
) # Allow project ID to be configurable too
project_id = env_vars.get_mixpanel_project_id()
# log the unique_id
print(f"\nLooking for test event with customField: {unique_id}")
@ -357,8 +355,8 @@ def test_tracking_api_elasticsearch(auth_session):
# Query Elasticsearch to retrieve our test event
# This requires the Elasticsearch URL and credentials
es_url = os.environ.get("ELASTICSEARCH_URL", "http://localhost:9200")
es_index = os.environ.get("ELASTICSEARCH_INDEX", "datahub_usage_event")
es_url = env_vars.get_elasticsearch_url()
es_index = env_vars.get_elasticsearch_index()
# Create a query to find our test event by the unique browserId and customField
es_query = {

View File

@ -1,10 +1,11 @@
import os
import re
import pytest
from tests.utilities import env_vars
# Kept separate so that it does not cause failures in PRs
DATAHUB_VERSION = os.getenv("TEST_DATAHUB_VERSION")
DATAHUB_VERSION = env_vars.get_test_datahub_version()
def looks_like_a_short_sha(sha: str) -> bool:

View File

@ -1,7 +1,7 @@
import os
from slack_sdk import WebClient
from tests.utilities import env_vars
datahub_stats = {}
@ -10,10 +10,10 @@ def add_datahub_stats(stat_name, stat_val):
def send_to_slack(passed: str):
slack_api_token = os.getenv("SLACK_API_TOKEN")
slack_channel = os.getenv("SLACK_CHANNEL")
slack_thread_ts = os.getenv("SLACK_THREAD_TS")
test_identifier = os.getenv("TEST_IDENTIFIER", "LOCAL_TEST")
slack_api_token = env_vars.get_slack_api_token()
slack_channel = env_vars.get_slack_channel()
slack_thread_ts = env_vars.get_slack_thread_ts()
test_identifier = env_vars.get_test_identifier()
if slack_api_token is None or slack_channel is None:
return
client = WebClient(token=slack_api_token)

View File

@ -0,0 +1,244 @@
# ABOUTME: Central registry for all environment variables used in smoke-test.
# ABOUTME: All environment variable reads should go through this module for discoverability and maintainability.
import os
from typing import Optional
# ============================================================================
# Core DataHub Configuration
# ============================================================================
def get_telemetry_enabled() -> str:
"""Enable/disable telemetry (true/false)."""
return os.getenv("DATAHUB_TELEMETRY_ENABLED", "false")
def get_suppress_logging_manager() -> Optional[str]:
"""Suppress DataHub logging manager initialization."""
return os.getenv("DATAHUB_SUPPRESS_LOGGING_MANAGER")
def get_gms_url() -> Optional[str]:
"""GMS URL."""
return os.getenv("DATAHUB_GMS_URL")
def get_base_path() -> str:
"""Base path for DataHub frontend."""
return os.getenv("DATAHUB_BASE_PATH", "")
def get_gms_base_path() -> str:
"""Base path for GMS API endpoints."""
return os.getenv("DATAHUB_GMS_BASE_PATH", "")
def get_frontend_url() -> Optional[str]:
"""DataHub frontend URL."""
return os.getenv("DATAHUB_FRONTEND_URL")
def get_kafka_url() -> Optional[str]:
"""Kafka broker URL."""
return os.getenv("DATAHUB_KAFKA_URL")
def get_kafka_schema_registry_url() -> Optional[str]:
"""Kafka schema registry URL."""
return os.getenv("DATAHUB_KAFKA_SCHEMA_REGISTRY_URL")
# ============================================================================
# Admin Credentials
# ============================================================================
def get_admin_username() -> str:
"""Admin username for smoke tests."""
return os.getenv("ADMIN_USERNAME", "datahub")
def get_admin_password() -> str:
"""Admin password for smoke tests."""
return os.getenv("ADMIN_PASSWORD", "datahub")
# ============================================================================
# Database Configuration
# ============================================================================
def get_db_type() -> Optional[str]:
"""Database type (mysql/postgres)."""
return os.getenv("DB_TYPE")
def get_profile_name() -> Optional[str]:
"""Profile name for inferring database type."""
return os.getenv("PROFILE_NAME")
def get_mysql_url() -> str:
"""MySQL database URL."""
return os.getenv("DATAHUB_MYSQL_URL", "localhost:3306")
def get_mysql_username() -> str:
"""MySQL username."""
return os.getenv("DATAHUB_MYSQL_USERNAME", "datahub")
def get_mysql_password() -> str:
"""MySQL password."""
return os.getenv("DATAHUB_MYSQL_PASSWORD", "datahub")
def get_postgres_url() -> str:
"""PostgreSQL database URL."""
return os.getenv("DATAHUB_POSTGRES_URL", "localhost:5432")
def get_postgres_username() -> str:
"""PostgreSQL username."""
return os.getenv("DATAHUB_POSTGRES_USERNAME", "datahub")
def get_postgres_password() -> str:
"""PostgreSQL password."""
return os.getenv("DATAHUB_POSTGRES_PASSWORD", "datahub")
# ============================================================================
# Testing Configuration
# ============================================================================
def get_batch_count() -> int:
"""Number of test batches for parallel execution."""
return int(os.getenv("BATCH_COUNT", "1"))
def get_batch_number() -> int:
"""Current batch number (zero-indexed)."""
return int(os.getenv("BATCH_NUMBER", "0"))
def get_test_strategy() -> Optional[str]:
"""Test execution strategy (e.g., 'cypress')."""
return os.getenv("TEST_STRATEGY")
def get_test_sleep_between() -> int:
"""Sleep duration in seconds between test retries."""
return int(os.getenv("DATAHUB_TEST_SLEEP_BETWEEN", "20"))
def get_test_sleep_times() -> int:
"""Number of retry attempts for tests."""
return int(os.getenv("DATAHUB_TEST_SLEEP_TIMES", "3"))
def get_k8s_cluster_enabled() -> bool:
"""Whether Kubernetes cluster is enabled."""
return os.getenv("K8S_CLUSTER_ENABLED", "false").lower() in ["true", "yes"]
def get_test_datahub_version() -> Optional[str]:
"""DataHub version being tested."""
return os.getenv("TEST_DATAHUB_VERSION")
# ============================================================================
# Consistency Testing
# ============================================================================
def get_use_static_sleep() -> bool:
"""Use static sleep instead of dynamic wait for consistency."""
return bool(os.getenv("USE_STATIC_SLEEP", False))
def get_elasticsearch_refresh_interval_seconds() -> int:
"""Elasticsearch refresh interval in seconds."""
return int(os.getenv("ELASTICSEARCH_REFRESH_INTERVAL_SECONDS", "3"))
def get_kafka_bootstrap_server() -> str:
"""Kafka bootstrap server for smoke tests."""
return str(os.getenv("KAFKA_BOOTSTRAP_SERVER", "broker:29092"))
def get_kafka_broker_container() -> Optional[str]:
"""Kafka broker container name."""
return os.getenv("KAFKA_BROKER_CONTAINER")
# ============================================================================
# Cypress Testing
# ============================================================================
def get_cypress_record_key() -> Optional[str]:
"""Cypress Cloud recording key."""
return os.getenv("CYPRESS_RECORD_KEY")
# ============================================================================
# Cleanup Configuration
# ============================================================================
def get_delete_after_test() -> bool:
"""Delete test data after test completion."""
return os.getenv("DELETE_AFTER_TEST", "false").lower() == "true"
# ============================================================================
# Integration Testing
# ============================================================================
def get_mixpanel_api_secret() -> Optional[str]:
"""Mixpanel API secret for tracking tests."""
return os.getenv("MIXPANEL_API_SECRET")
def get_mixpanel_project_id() -> str:
"""Mixpanel project ID."""
return os.getenv("MIXPANEL_PROJECT_ID", "3653440")
def get_elasticsearch_url() -> str:
"""Elasticsearch URL for integration tests."""
return os.getenv("ELASTICSEARCH_URL", "http://localhost:9200")
def get_elasticsearch_index() -> str:
"""Elasticsearch index name for usage events."""
return os.getenv("ELASTICSEARCH_INDEX", "datahub_usage_event")
# ============================================================================
# Slack Notifications
# ============================================================================
def get_slack_api_token() -> Optional[str]:
"""Slack API token for test notifications."""
return os.getenv("SLACK_API_TOKEN")
def get_slack_channel() -> Optional[str]:
"""Slack channel for test notifications."""
return os.getenv("SLACK_CHANNEL")
def get_slack_thread_ts() -> Optional[str]:
"""Slack thread timestamp for threaded notifications."""
return os.getenv("SLACK_THREAD_TS")
def get_test_identifier() -> str:
"""Test run identifier for notifications."""
return os.getenv("TEST_IDENTIFIER", "LOCAL_TEST")

View File

@ -1,6 +1,5 @@
import json
import logging
import os
from collections.abc import Callable
from datetime import datetime, timedelta, timezone
from typing import Any, Dict, List, Optional, Tuple
@ -17,6 +16,7 @@ from datahub.cli import cli_utils, env_utils
from datahub.entrypoints import datahub
from datahub.ingestion.run.pipeline import Pipeline
from tests.consistency_utils import wait_for_writes_to_sync
from tests.utilities import env_vars
TIME: int = 1581407189000
logger = logging.getLogger(__name__)
@ -43,18 +43,18 @@ def get_admin_username() -> str:
def get_admin_credentials():
return (
os.getenv("ADMIN_USERNAME", "datahub"),
os.getenv("ADMIN_PASSWORD", "datahub"),
env_vars.get_admin_username(),
env_vars.get_admin_password(),
)
def get_base_path():
base_path = os.getenv("DATAHUB_BASE_PATH", "")
base_path = env_vars.get_base_path()
return "" if base_path == "/" else base_path
def get_gms_base_path():
base_gms_path = os.getenv("DATAHUB_GMS_BASE_PATH", "")
base_gms_path = env_vars.get_gms_base_path()
return "" if base_gms_path == "/" else base_gms_path
@ -63,34 +63,32 @@ def get_root_urn():
def get_gms_url():
return os.getenv("DATAHUB_GMS_URL") or f"http://localhost:8080{get_gms_base_path()}"
return env_vars.get_gms_url() or f"http://localhost:8080{get_gms_base_path()}"
def get_frontend_url():
return (
os.getenv("DATAHUB_FRONTEND_URL") or f"http://localhost:9002{get_base_path()}"
)
return env_vars.get_frontend_url() or f"http://localhost:9002{get_base_path()}"
def get_kafka_broker_url():
return os.getenv("DATAHUB_KAFKA_URL") or "localhost:9092"
return env_vars.get_kafka_url() or "localhost:9092"
def get_kafka_schema_registry():
# internal registry "http://localhost:8080/schema-registry/api/"
return (
os.getenv("DATAHUB_KAFKA_SCHEMA_REGISTRY_URL")
env_vars.get_kafka_schema_registry_url()
or f"http://localhost:8080{get_gms_base_path()}/schema-registry/api"
)
def get_db_type():
db_type = os.getenv("DB_TYPE")
db_type = env_vars.get_db_type()
if db_type:
return db_type
else:
# infer from profile
profile_name = os.getenv("PROFILE_NAME")
profile_name = env_vars.get_profile_name()
if profile_name and "postgres" in profile_name:
return "postgres"
else:
@ -99,29 +97,29 @@ def get_db_type():
def get_db_url():
if get_db_type() == "mysql":
return os.getenv("DATAHUB_MYSQL_URL") or "localhost:3306"
return env_vars.get_mysql_url()
else:
return os.getenv("DATAHUB_POSTGRES_URL") or "localhost:5432"
return env_vars.get_postgres_url()
def get_db_username():
if get_db_type() == "mysql":
return os.getenv("DATAHUB_MYSQL_USERNAME") or "datahub"
return env_vars.get_mysql_username()
else:
return os.getenv("DATAHUB_POSTGRES_USERNAME") or "datahub"
return env_vars.get_postgres_username()
def get_db_password():
if get_db_type() == "mysql":
return os.getenv("DATAHUB_MYSQL_PASSWORD") or "datahub"
return env_vars.get_mysql_password()
else:
return os.getenv("DATAHUB_POSTGRES_PASSWORD") or "datahub"
return env_vars.get_postgres_password()
def get_sleep_info() -> Tuple[int, int]:
return (
int(os.getenv("DATAHUB_TEST_SLEEP_BETWEEN", 20)),
int(os.getenv("DATAHUB_TEST_SLEEP_TIMES", 3)),
env_vars.get_test_sleep_between(),
env_vars.get_test_sleep_times(),
)
@ -158,7 +156,7 @@ def with_test_retry(
def is_k8s_enabled():
return os.getenv("K8S_CLUSTER_ENABLED", "false").lower() in ["true", "yes"]
return env_vars.get_k8s_cluster_enabled()
def wait_for_healthcheck_util(auth_session):