| 
									
										
										
										
											2025-08-21 14:08:32 -04:00
										 |  |  | import importlib | 
					
						
							|  |  |  | import logging | 
					
						
							|  |  |  | import sys | 
					
						
							|  |  |  | import types | 
					
						
							|  |  |  | from typing import Any, cast | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import pytest | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | MODULE_UNDER_TEST = "datahub_actions.utils.kafka_msk_iam" | 
					
						
							| 
									
										
										
										
											2025-10-07 11:10:05 +02:00
										 |  |  | VENDOR_MODULE = "aws_msk_iam_sasl_signer" | 
					
						
							| 
									
										
										
										
											2025-08-21 14:08:32 -04:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-10-07 11:10:05 +02:00
										 |  |  | def ensure_fake_vendor(monkeypatch: Any) -> Any: | 
					
						
							| 
									
										
										
										
											2025-08-21 14:08:32 -04:00
										 |  |  |     """
 | 
					
						
							|  |  |  |     Ensure a fake MSKAuthTokenProvider is available at import path | 
					
						
							| 
									
										
										
										
											2025-10-07 11:10:05 +02:00
										 |  |  |     aws_msk_iam_sasl_signer for environments where the vendor package is not installed. | 
					
						
							| 
									
										
										
										
											2025-08-21 14:08:32 -04:00
										 |  |  |     Returns the fake module so tests can monkeypatch its behavior. | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  |     # If already present (package installed), just return the real module | 
					
						
							|  |  |  |     if VENDOR_MODULE in sys.modules: | 
					
						
							|  |  |  |         return sys.modules[VENDOR_MODULE] | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-10-07 11:10:05 +02:00
										 |  |  |     # Create a minimal fake module matching the direct import path | 
					
						
							| 
									
										
										
										
											2025-08-21 14:08:32 -04:00
										 |  |  |     fake_mod: Any = types.ModuleType(VENDOR_MODULE) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     class MSKAuthTokenProvider: | 
					
						
							|  |  |  |         @staticmethod | 
					
						
							| 
									
										
										
										
											2025-10-07 11:10:05 +02:00
										 |  |  |         def generate_auth_token( | 
					
						
							|  |  |  |             region: str | None = None, | 
					
						
							|  |  |  |         ) -> None:  # will be monkeypatched per test | 
					
						
							| 
									
										
										
										
											2025-08-21 14:08:32 -04:00
										 |  |  |             raise NotImplementedError | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     fake_mod.MSKAuthTokenProvider = MSKAuthTokenProvider | 
					
						
							|  |  |  |     monkeypatch.setitem(sys.modules, VENDOR_MODULE, fake_mod) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return fake_mod | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-10-07 11:10:05 +02:00
										 |  |  | def import_sut(monkeypatch: Any) -> Any: | 
					
						
							| 
									
										
										
										
											2025-08-21 14:08:32 -04:00
										 |  |  |     """Import or reload the module under test after ensuring the vendor symbol exists.""" | 
					
						
							|  |  |  |     ensure_fake_vendor(monkeypatch) | 
					
						
							|  |  |  |     if MODULE_UNDER_TEST in sys.modules: | 
					
						
							|  |  |  |         return importlib.reload(sys.modules[MODULE_UNDER_TEST]) | 
					
						
							|  |  |  |     return importlib.import_module(MODULE_UNDER_TEST) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-10-07 11:10:05 +02:00
										 |  |  | def test_oauth_cb_success_converts_ms_to_seconds(monkeypatch: Any) -> None: | 
					
						
							| 
									
										
										
										
											2025-08-21 14:08:32 -04:00
										 |  |  |     sut = import_sut(monkeypatch) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Monkeypatch the provider to return a known token and expiry in ms | 
					
						
							|  |  |  |     provider = cast(Any, sut).MSKAuthTokenProvider | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-10-07 11:10:05 +02:00
										 |  |  |     def fake_generate(region: str | None = None) -> tuple[str, int]: | 
					
						
							| 
									
										
										
										
											2025-08-21 14:08:32 -04:00
										 |  |  |         return "my-token", 12_345  # ms | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     monkeypatch.setattr(provider, "generate_auth_token", staticmethod(fake_generate)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     token, expiry_seconds = sut.oauth_cb({"any": "config"}) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     assert token == "my-token" | 
					
						
							|  |  |  |     assert expiry_seconds == 12.345  # ms to seconds via division | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-10-07 11:10:05 +02:00
										 |  |  | def test_oauth_cb_raises_and_logs_on_error(monkeypatch: Any, caplog: Any) -> None: | 
					
						
							| 
									
										
										
										
											2025-08-21 14:08:32 -04:00
										 |  |  |     sut = import_sut(monkeypatch) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-10-07 11:10:05 +02:00
										 |  |  |     def boom(region: str | None = None) -> None: | 
					
						
							| 
									
										
										
										
											2025-08-21 14:08:32 -04:00
										 |  |  |         raise RuntimeError("signer blew up") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     provider = cast(Any, sut).MSKAuthTokenProvider | 
					
						
							|  |  |  |     monkeypatch.setattr(provider, "generate_auth_token", staticmethod(boom)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     caplog.set_level(logging.ERROR) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     with pytest.raises(RuntimeError, match="signer blew up"): | 
					
						
							|  |  |  |         sut.oauth_cb({}) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Verify the error log is present and descriptive | 
					
						
							|  |  |  |     assert any( | 
					
						
							|  |  |  |         rec.levelno == logging.ERROR | 
					
						
							|  |  |  |         and "Error generating AWS MSK IAM authentication token" in rec.getMessage() | 
					
						
							|  |  |  |         for rec in caplog.records | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-10-07 11:10:05 +02:00
										 |  |  | def test_oauth_cb_returns_tuple_types(monkeypatch: Any) -> None: | 
					
						
							| 
									
										
										
										
											2025-08-21 14:08:32 -04:00
										 |  |  |     sut = import_sut(monkeypatch) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     provider = cast(Any, sut).MSKAuthTokenProvider | 
					
						
							|  |  |  |     monkeypatch.setattr( | 
					
						
							|  |  |  |         provider, | 
					
						
							|  |  |  |         "generate_auth_token", | 
					
						
							| 
									
										
										
										
											2025-10-07 11:10:05 +02:00
										 |  |  |         staticmethod(lambda region=None: ("tkn", 1_000)),  # 1000 ms | 
					
						
							| 
									
										
										
										
											2025-08-21 14:08:32 -04:00
										 |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     result = sut.oauth_cb(None) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     assert isinstance(result, tuple) | 
					
						
							|  |  |  |     token, expiry = result | 
					
						
							|  |  |  |     assert token == "tkn" | 
					
						
							|  |  |  |     assert isinstance(expiry, float) | 
					
						
							|  |  |  |     assert expiry == 1.0 |