diff --git a/python/packages/agbench/src/agbench/run_cmd.py b/python/packages/agbench/src/agbench/run_cmd.py index a8a8161c5..a8a781229 100644 --- a/python/packages/agbench/src/agbench/run_cmd.py +++ b/python/packages/agbench/src/agbench/run_cmd.py @@ -284,14 +284,8 @@ def get_scenario_env(token_provider: Optional[Callable[[], str]] = None, env_fil ## Support Azure auth tokens azure_openai_ad_token = os.environ.get("AZURE_OPENAI_AD_TOKEN") - if not azure_openai_ad_token and token_provider: + if azure_openai_ad_token is None and token_provider is not None: azure_openai_ad_token = token_provider() - if not azure_openai_ad_token: - azure_token_provider = get_azure_token_provider() - if azure_token_provider: - azure_openai_ad_token = azure_token_provider() - else: - logging.warning("No Azure AD token provider found. Azure AD token not set.") if azure_openai_ad_token is not None and len(azure_openai_ad_token.strip()) > 0: env["AZURE_OPENAI_AD_TOKEN"] = azure_openai_ad_token @@ -888,6 +882,12 @@ def run_cli(args: Sequence[str]) -> None: help="The number of parallel processes to run (default: 1).", default=1, ) + parser.add_argument( + "-a", + "--azure", + action="store_true", + help="Use Azure identity to pass an AZURE_OPENAI_AD_TOKEN to the task environment. This is necessary when using Azure-hosted OpenAI models rather than those hosted by OpenAI.", + ) parser.add_argument( "-e", "--env", @@ -972,7 +972,9 @@ def run_cli(args: Sequence[str]) -> None: ) # Get the Azure bearer token generator if a token wasn't provided and there's any evidence of using Azure - azure_token_provider = get_azure_token_provider() + azure_token_provider = None + if parsed_args.azure: + azure_token_provider = get_azure_token_provider() # Run the scenario if parsed_args.parallel > 1: