mirror of
https://github.com/PaddlePaddle/PaddleOCR.git
synced 2025-12-28 15:38:18 +00:00
fix attenton loss for ce
This commit is contained in:
parent
1bbf6e6a92
commit
f26846ccd1
@ -45,6 +45,7 @@ class AttentionHead(nn.Layer):
|
||||
output_hiddens = []
|
||||
|
||||
if targets is not None:
|
||||
print("target is not None")
|
||||
for i in range(num_steps):
|
||||
char_onehots = self._char_to_onehot(
|
||||
targets[:, i], onehot_dim=self.num_classes)
|
||||
@ -53,8 +54,8 @@ class AttentionHead(nn.Layer):
|
||||
output_hiddens.append(paddle.unsqueeze(outputs, axis=1))
|
||||
output = paddle.concat(output_hiddens, axis=1)
|
||||
probs = self.generator(output)
|
||||
|
||||
else:
|
||||
print("target is None")
|
||||
targets = paddle.zeros(shape=[batch_size], dtype="int32")
|
||||
probs = None
|
||||
char_onehots = None
|
||||
@ -75,6 +76,7 @@ class AttentionHead(nn.Layer):
|
||||
probs_step, axis=1)], axis=1)
|
||||
next_input = probs_step.argmax(axis=1)
|
||||
targets = next_input
|
||||
if not self.training:
|
||||
probs = paddle.nn.functional.softmax(probs, axis=2)
|
||||
return probs
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user