mirror of
				https://github.com/PaddlePaddle/PaddleOCR.git
				synced 2025-10-31 09:49:30 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			78 lines
		
	
	
		
			2.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			78 lines
		
	
	
		
			2.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import os
 | |
| from .base_logger import BaseLogger
 | |
| 
 | |
| class WandbLogger(BaseLogger):
 | |
|     def __init__(self, 
 | |
|         project=None, 
 | |
|         name=None, 
 | |
|         id=None, 
 | |
|         entity=None, 
 | |
|         save_dir=None, 
 | |
|         config=None,
 | |
|         **kwargs):
 | |
|         try:
 | |
|             import wandb
 | |
|             self.wandb = wandb
 | |
|         except ModuleNotFoundError:
 | |
|             raise ModuleNotFoundError(
 | |
|                 "Please install wandb using `pip install wandb`"
 | |
|                 )
 | |
| 
 | |
|         self.project = project
 | |
|         self.name = name
 | |
|         self.id = id
 | |
|         self.save_dir = save_dir
 | |
|         self.config = config
 | |
|         self.kwargs = kwargs
 | |
|         self.entity = entity
 | |
|         self._run = None
 | |
|         self._wandb_init = dict(
 | |
|             project=self.project,
 | |
|             name=self.name,
 | |
|             id=self.id,
 | |
|             entity=self.entity,
 | |
|             dir=self.save_dir,
 | |
|             resume="allow"
 | |
|         )
 | |
|         self._wandb_init.update(**kwargs)
 | |
| 
 | |
|         _ = self.run
 | |
| 
 | |
|         if self.config:
 | |
|             self.run.config.update(self.config)
 | |
| 
 | |
|     @property
 | |
|     def run(self):
 | |
|         if self._run is None:
 | |
|             if self.wandb.run is not None:
 | |
|                 logger.info(
 | |
|                     "There is a wandb run already in progress "
 | |
|                     "and newly created instances of `WandbLogger` will reuse"
 | |
|                     " this run. If this is not desired, call `wandb.finish()`"
 | |
|                     "before instantiating `WandbLogger`."
 | |
|                 )
 | |
|                 self._run = self.wandb.run
 | |
|             else:
 | |
|                 self._run = self.wandb.init(**self._wandb_init)
 | |
|         return self._run
 | |
| 
 | |
|     def log_metrics(self, metrics, prefix=None, step=None):
 | |
|         if not prefix:
 | |
|             prefix = ""
 | |
|         updated_metrics = {prefix.lower() + "/" + k: v for k, v in metrics.items()}
 | |
|         
 | |
|         self.run.log(updated_metrics, step=step)
 | |
| 
 | |
|     def log_model(self, is_best, prefix, metadata=None):
 | |
|         model_path = os.path.join(self.save_dir, prefix + '.pdparams')
 | |
|         artifact = self.wandb.Artifact('model-{}'.format(self.run.id), type='model', metadata=metadata)
 | |
|         artifact.add_file(model_path, name="model_ckpt.pdparams")
 | |
| 
 | |
|         aliases = [prefix]
 | |
|         if is_best:
 | |
|             aliases.append("best")
 | |
| 
 | |
|         self.run.log_artifact(artifact, aliases=aliases)
 | |
| 
 | |
|     def close(self):
 | |
|         self.run.finish() | 
