mirror of
				https://github.com/PaddlePaddle/PaddleOCR.git
				synced 2025-10-31 09:49:30 +00:00 
			
		
		
		
	Update rec_nrtr_optim_head.py
This commit is contained in:
		
							parent
							
								
									c635925895
								
							
						
					
					
						commit
						c8094e6575
					
				| @ -216,7 +216,7 @@ class TransformerOptim(nn.Layer): | |||||||
|             new_shape = (n_curr_active_inst * n_bm, *d_hs) |             new_shape = (n_curr_active_inst * n_bm, *d_hs) | ||||||
| 
 | 
 | ||||||
|             beamed_tensor = beamed_tensor.reshape( |             beamed_tensor = beamed_tensor.reshape( | ||||||
|                 [n_prev_active_inst, -1])  #contiguous() |                 [n_prev_active_inst, -1]) | ||||||
|             beamed_tensor = beamed_tensor.index_select( |             beamed_tensor = beamed_tensor.index_select( | ||||||
|                 paddle.to_tensor(curr_active_inst_idx), axis=0) |                 paddle.to_tensor(curr_active_inst_idx), axis=0) | ||||||
|             beamed_tensor = beamed_tensor.reshape([*new_shape]) |             beamed_tensor = beamed_tensor.reshape([*new_shape]) | ||||||
| @ -337,7 +337,7 @@ class TransformerOptim(nn.Layer): | |||||||
|             n_inst, len_s, d_h = src_enc.shape |             n_inst, len_s, d_h = src_enc.shape | ||||||
|             src_enc = paddle.concat([src_enc for i in range(n_bm)], axis=1) |             src_enc = paddle.concat([src_enc for i in range(n_bm)], axis=1) | ||||||
|             src_enc = src_enc.reshape([n_inst * n_bm, len_s, d_h]).transpose( |             src_enc = src_enc.reshape([n_inst * n_bm, len_s, d_h]).transpose( | ||||||
|                 [1, 0, 2])  #repeat(1, n_bm, 1) |                 [1, 0, 2]) | ||||||
|             #-- Prepare beams |             #-- Prepare beams | ||||||
|             inst_dec_beams = [Beam(n_bm) for _ in range(n_inst)] |             inst_dec_beams = [Beam(n_bm) for _ in range(n_inst)] | ||||||
| 
 | 
 | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 topduke
						topduke