mirror of
https://github.com/microsoft/autogen.git
synced 2025-11-08 05:44:08 +00:00
104 lines
4.4 KiB
Python
104 lines
4.4 KiB
Python
import os
|
|
import subprocess
|
|
import hashlib
|
|
from time import time
|
|
import json
|
|
|
|
|
|
class WandbUtils:
|
|
|
|
# Documentation on the wandb setting:
|
|
# There are two ways to initialize wandb in tune.run:
|
|
# (1) using WandbLoggerCallback, by adding the following argument to tune.run:
|
|
# callbacks=[WandbLoggerCallback(
|
|
# project="hpo",
|
|
# api_key = os.environ["WANDB_API_KEY"],
|
|
# group = os.environ["WANDB_RUN_GROUP"],
|
|
# log_config=True)]
|
|
# (2) using wandb_mixin decorator (the current implementation)
|
|
# The current implementation uses (2) because (1) has the following bug.
|
|
# In Ray 1.2, when using WandbLoggerCallback + setting time limit using the time_budget_s argument,
|
|
# A bug exists which is the previous run will not clear the cache after tune.run returns. After the
|
|
# later run has already starts, some zombie trials in the previous run remain in the memory and never stop.
|
|
# This bug can be reproduced by switching to (1) by adding the above callbacks argument
|
|
# and removing the wandb_mixin decorator
|
|
# https://docs.ray.io/en/master/tune/tutorials/tune-wandb.html
|
|
|
|
def __init__(self,
|
|
is_wandb_on=False,
|
|
wandb_key_path=None,
|
|
jobid_config=None):
|
|
if is_wandb_on:
|
|
wandb_key = WandbUtils.get_wandb_key(wandb_key_path)
|
|
if wandb_key != "":
|
|
subprocess.run(["wandb", "login", "--relogin", wandb_key])
|
|
os.environ["WANDB_API_KEY"] = wandb_key
|
|
os.environ["WANDB_MODE"] = "online"
|
|
else:
|
|
# TODO coverage
|
|
os.environ["WANDB_MODE"] = "disabled"
|
|
self.jobid_config = jobid_config
|
|
|
|
@staticmethod
|
|
def get_wandb_key(key_path):
|
|
try:
|
|
try:
|
|
key_json = json.load(open(os.path.join(key_path, "key.json"), "r"))
|
|
wandb_key = key_json["wandb_key"]
|
|
return wandb_key
|
|
except FileNotFoundError:
|
|
print("Cannot use wandb module because key.json is not found under key_path")
|
|
return ""
|
|
except KeyError:
|
|
print("Cannot use wandb module because wandb key is not specified")
|
|
return ""
|
|
|
|
def set_wandb_per_trial(self):
|
|
# TODO coverage
|
|
print("before wandb.init\n\n\n")
|
|
try:
|
|
import wandb
|
|
try:
|
|
if os.environ["WANDB_MODE"] == "online":
|
|
os.environ["WANDB_SILENT"] = "false"
|
|
return wandb.init(project=self.jobid_config.get_jobid_full_data_name(),
|
|
group=self.wandb_group_name,
|
|
name=str(WandbUtils._get_next_trial_ids()),
|
|
settings=wandb.Settings(
|
|
_disable_stats=True),
|
|
reinit=False)
|
|
else:
|
|
return None
|
|
except wandb.errors.UsageError as err:
|
|
print(err)
|
|
return None
|
|
except ImportError:
|
|
print("Cannot use wandb module because wandb is not installed, run pip install wandb==0.10.26")
|
|
|
|
@staticmethod
|
|
def _get_next_trial_ids():
|
|
hash = hashlib.sha1()
|
|
hash.update(str(time()).encode('utf-8'))
|
|
return "trial_" + hash.hexdigest()[:3]
|
|
|
|
def set_wandb_per_run(self):
|
|
try:
|
|
import wandb
|
|
os.environ["WANDB_RUN_GROUP"] = self.jobid_config.to_wandb_string() + wandb.util.generate_id()
|
|
self.wandb_group_name = os.environ["WANDB_RUN_GROUP"]
|
|
try:
|
|
if os.environ["WANDB_MODE"] == "online":
|
|
os.environ["WANDB_SILENT"] = "false"
|
|
return wandb.init(project=self.jobid_config.get_jobid_full_data_name(),
|
|
group=os.environ["WANDB_RUN_GROUP"],
|
|
settings=wandb.Settings(
|
|
_disable_stats=True),
|
|
reinit=False)
|
|
else:
|
|
return None
|
|
except wandb.errors.UsageError as err:
|
|
print(err)
|
|
return None
|
|
except ImportError:
|
|
print("Cannot use wandb module because wandb is not installed, run pip install wandb==0.10.26")
|