mirror of
				https://github.com/PaddlePaddle/PaddleOCR.git
				synced 2025-10-31 09:49:30 +00:00 
			
		
		
		
	Merge pull request #5396 from littletomatodonkey/dyg/fix_ips_calc
fix ips info and reduce interval of metric calc
This commit is contained in:
		
						commit
						84a42eb585
					
				| @ -31,7 +31,8 @@ class CTCLoss(nn.Layer): | ||||
|             predicts = predicts[-1] | ||||
|         predicts = predicts.transpose((1, 0, 2)) | ||||
|         N, B, _ = predicts.shape | ||||
|         preds_lengths = paddle.to_tensor([N] * B, dtype='int64') | ||||
|         preds_lengths = paddle.to_tensor( | ||||
|             [N] * B, dtype='int64', place=paddle.CPUPlace()) | ||||
|         labels = batch[1].astype("int32") | ||||
|         label_lengths = batch[2].astype('int64') | ||||
|         loss = self.loss_func(predicts, labels, preds_lengths, label_lengths) | ||||
|  | ||||
| @ -146,6 +146,7 @@ def train(config, | ||||
|           scaler=None): | ||||
|     cal_metric_during_train = config['Global'].get('cal_metric_during_train', | ||||
|                                                    False) | ||||
|     calc_epoch_interval = config['Global'].get('calc_epoch_interval', 1) | ||||
|     log_smooth_window = config['Global']['log_smooth_window'] | ||||
|     epoch_num = config['Global']['epoch_num'] | ||||
|     print_batch_step = config['Global']['print_batch_step'] | ||||
| @ -244,6 +245,16 @@ def train(config, | ||||
|                 optimizer.step() | ||||
|             optimizer.clear_grad() | ||||
| 
 | ||||
|             if cal_metric_during_train and epoch % calc_epoch_interval == 0:  # only rec and cls need | ||||
|                 batch = [item.numpy() for item in batch] | ||||
|                 if model_type in ['table', 'kie']: | ||||
|                     eval_class(preds, batch) | ||||
|                 else: | ||||
|                     post_result = post_process_class(preds, batch[1]) | ||||
|                     eval_class(post_result, batch) | ||||
|                 metric = eval_class.get_metric() | ||||
|                 train_stats.update(metric) | ||||
| 
 | ||||
|             train_batch_time = time.time() - reader_start | ||||
|             train_batch_cost += train_batch_time | ||||
|             eta_meter.update(train_batch_time) | ||||
| @ -258,16 +269,6 @@ def train(config, | ||||
|             stats['lr'] = lr | ||||
|             train_stats.update(stats) | ||||
| 
 | ||||
|             if cal_metric_during_train:  # only rec and cls need | ||||
|                 batch = [item.numpy() for item in batch] | ||||
|                 if model_type in ['table', 'kie']: | ||||
|                     eval_class(preds, batch) | ||||
|                 else: | ||||
|                     post_result = post_process_class(preds, batch[1]) | ||||
|                     eval_class(post_result, batch) | ||||
|                 metric = eval_class.get_metric() | ||||
|                 train_stats.update(metric) | ||||
| 
 | ||||
|             if vdl_writer is not None and dist.get_rank() == 0: | ||||
|                 for k, v in train_stats.get().items(): | ||||
|                     vdl_writer.add_scalar('TRAIN/{}'.format(k), v, global_step) | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 xiaoting
						xiaoting