mirror of
				https://github.com/PaddlePaddle/PaddleOCR.git
				synced 2025-11-04 03:39:22 +00:00 
			
		
		
		
	Update rec_nrtr_optim_head.py
This commit is contained in:
		
							parent
							
								
									c635925895
								
							
						
					
					
						commit
						c8094e6575
					
				@ -216,7 +216,7 @@ class TransformerOptim(nn.Layer):
 | 
				
			|||||||
            new_shape = (n_curr_active_inst * n_bm, *d_hs)
 | 
					            new_shape = (n_curr_active_inst * n_bm, *d_hs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            beamed_tensor = beamed_tensor.reshape(
 | 
					            beamed_tensor = beamed_tensor.reshape(
 | 
				
			||||||
                [n_prev_active_inst, -1])  #contiguous()
 | 
					                [n_prev_active_inst, -1])
 | 
				
			||||||
            beamed_tensor = beamed_tensor.index_select(
 | 
					            beamed_tensor = beamed_tensor.index_select(
 | 
				
			||||||
                paddle.to_tensor(curr_active_inst_idx), axis=0)
 | 
					                paddle.to_tensor(curr_active_inst_idx), axis=0)
 | 
				
			||||||
            beamed_tensor = beamed_tensor.reshape([*new_shape])
 | 
					            beamed_tensor = beamed_tensor.reshape([*new_shape])
 | 
				
			||||||
@ -337,7 +337,7 @@ class TransformerOptim(nn.Layer):
 | 
				
			|||||||
            n_inst, len_s, d_h = src_enc.shape
 | 
					            n_inst, len_s, d_h = src_enc.shape
 | 
				
			||||||
            src_enc = paddle.concat([src_enc for i in range(n_bm)], axis=1)
 | 
					            src_enc = paddle.concat([src_enc for i in range(n_bm)], axis=1)
 | 
				
			||||||
            src_enc = src_enc.reshape([n_inst * n_bm, len_s, d_h]).transpose(
 | 
					            src_enc = src_enc.reshape([n_inst * n_bm, len_s, d_h]).transpose(
 | 
				
			||||||
                [1, 0, 2])  #repeat(1, n_bm, 1)
 | 
					                [1, 0, 2])
 | 
				
			||||||
            #-- Prepare beams
 | 
					            #-- Prepare beams
 | 
				
			||||||
            inst_dec_beams = [Beam(n_bm) for _ in range(n_inst)]
 | 
					            inst_dec_beams = [Beam(n_bm) for _ in range(n_inst)]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user