autogen/flaml/nlp/result_analysis/wandb_utils.py

104 lines
4.4 KiB
Python
Raw Normal View History

import os
import subprocess
import hashlib
from time import time
import json
class WandbUtils:
# Documentation on the wandb setting:
# There are two ways to initialize wandb in tune.run:
# (1) using WandbLoggerCallback, by adding the following argument to tune.run:
# callbacks=[WandbLoggerCallback(
# project="hpo",
# api_key = os.environ["WANDB_API_KEY"],
# group = os.environ["WANDB_RUN_GROUP"],
# log_config=True)]
# (2) using wandb_mixin decorator (the current implementation)
# The current implementation uses (2) because (1) has the following bug.
# In Ray 1.2, when using WandbLoggerCallback + setting time limit using the time_budget_s argument,
# A bug exists which is the previous run will not clear the cache after tune.run returns. After the
# later run has already starts, some zombie trials in the previous run remain in the memory and never stop.
# This bug can be reproduced by switching to (1) by adding the above callbacks argument
# and removing the wandb_mixin decorator
# https://docs.ray.io/en/master/tune/tutorials/tune-wandb.html
def __init__(self,
is_wandb_on=False,
wandb_key_path=None,
jobid_config=None):
if is_wandb_on:
wandb_key = WandbUtils.get_wandb_key(wandb_key_path)
if wandb_key != "":
subprocess.run(["wandb", "login", "--relogin", wandb_key])
os.environ["WANDB_API_KEY"] = wandb_key
os.environ["WANDB_MODE"] = "online"
else:
2021-09-04 20:28:37 -07:00
# TODO coverage
os.environ["WANDB_MODE"] = "disabled"
self.jobid_config = jobid_config
@staticmethod
def get_wandb_key(key_path):
try:
try:
key_json = json.load(open(os.path.join(key_path, "key.json"), "r"))
wandb_key = key_json["wandb_key"]
return wandb_key
except FileNotFoundError:
print("Cannot use wandb module because key.json is not found under key_path")
return ""
except KeyError:
print("Cannot use wandb module because wandb key is not specified")
return ""
def set_wandb_per_trial(self):
2021-09-04 20:28:37 -07:00
# TODO coverage
print("before wandb.init\n\n\n")
try:
import wandb
try:
if os.environ["WANDB_MODE"] == "online":
os.environ["WANDB_SILENT"] = "false"
return wandb.init(project=self.jobid_config.get_jobid_full_data_name(),
group=self.wandb_group_name,
name=str(WandbUtils._get_next_trial_ids()),
settings=wandb.Settings(
_disable_stats=True),
reinit=False)
else:
return None
except wandb.errors.UsageError as err:
print(err)
return None
except ImportError:
print("Cannot use wandb module because wandb is not installed, run pip install wandb==0.10.26")
@staticmethod
def _get_next_trial_ids():
hash = hashlib.sha1()
hash.update(str(time()).encode('utf-8'))
return "trial_" + hash.hexdigest()[:3]
def set_wandb_per_run(self):
try:
import wandb
os.environ["WANDB_RUN_GROUP"] = self.jobid_config.to_wandb_string() + wandb.util.generate_id()
self.wandb_group_name = os.environ["WANDB_RUN_GROUP"]
try:
if os.environ["WANDB_MODE"] == "online":
os.environ["WANDB_SILENT"] = "false"
return wandb.init(project=self.jobid_config.get_jobid_full_data_name(),
group=os.environ["WANDB_RUN_GROUP"],
settings=wandb.Settings(
_disable_stats=True),
reinit=False)
else:
return None
except wandb.errors.UsageError as err:
print(err)
return None
except ImportError:
print("Cannot use wandb module because wandb is not installed, run pip install wandb==0.10.26")