mirror of
https://github.com/PaddlePaddle/PaddleOCR.git
synced 2025-11-14 17:12:53 +00:00
commit
f8fb8a7b1b
@ -191,7 +191,6 @@ Eval:
|
|||||||
channel_first: False
|
channel_first: False
|
||||||
- DetLabelEncode: # Class handling label
|
- DetLabelEncode: # Class handling label
|
||||||
- DetResizeForTest:
|
- DetResizeForTest:
|
||||||
# image_shape: [736, 1280]
|
|
||||||
- NormalizeImage:
|
- NormalizeImage:
|
||||||
scale: 1./255.
|
scale: 1./255.
|
||||||
mean: [0.485, 0.456, 0.406]
|
mean: [0.485, 0.456, 0.406]
|
||||||
|
|||||||
@ -24,6 +24,7 @@ Architecture:
|
|||||||
model_type: det
|
model_type: det
|
||||||
Models:
|
Models:
|
||||||
Student:
|
Student:
|
||||||
|
pretrained:
|
||||||
model_type: det
|
model_type: det
|
||||||
algorithm: DB
|
algorithm: DB
|
||||||
Transform: null
|
Transform: null
|
||||||
@ -40,6 +41,7 @@ Architecture:
|
|||||||
name: DBHead
|
name: DBHead
|
||||||
k: 50
|
k: 50
|
||||||
Student2:
|
Student2:
|
||||||
|
pretrained:
|
||||||
model_type: det
|
model_type: det
|
||||||
algorithm: DB
|
algorithm: DB
|
||||||
Transform: null
|
Transform: null
|
||||||
@ -91,14 +93,11 @@ Loss:
|
|||||||
- ["Student", "Student2"]
|
- ["Student", "Student2"]
|
||||||
maps_name: "thrink_maps"
|
maps_name: "thrink_maps"
|
||||||
weight: 1.0
|
weight: 1.0
|
||||||
# act: None
|
|
||||||
model_name_pairs: ["Student", "Student2"]
|
model_name_pairs: ["Student", "Student2"]
|
||||||
key: maps
|
key: maps
|
||||||
- DistillationDBLoss:
|
- DistillationDBLoss:
|
||||||
weight: 1.0
|
weight: 1.0
|
||||||
model_name_list: ["Student", "Student2"]
|
model_name_list: ["Student", "Student2"]
|
||||||
# key: maps
|
|
||||||
# name: DBLoss
|
|
||||||
balance_loss: true
|
balance_loss: true
|
||||||
main_loss_type: DiceLoss
|
main_loss_type: DiceLoss
|
||||||
alpha: 5
|
alpha: 5
|
||||||
@ -197,6 +196,7 @@ Train:
|
|||||||
drop_last: false
|
drop_last: false
|
||||||
batch_size_per_card: 8
|
batch_size_per_card: 8
|
||||||
num_workers: 4
|
num_workers: 4
|
||||||
|
|
||||||
Eval:
|
Eval:
|
||||||
dataset:
|
dataset:
|
||||||
name: SimpleDataSet
|
name: SimpleDataSet
|
||||||
@ -204,31 +204,21 @@ Eval:
|
|||||||
label_file_list:
|
label_file_list:
|
||||||
- ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
|
- ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
|
||||||
transforms:
|
transforms:
|
||||||
- DecodeImage:
|
- DecodeImage: # load image
|
||||||
img_mode: BGR
|
img_mode: BGR
|
||||||
channel_first: false
|
channel_first: False
|
||||||
- DetLabelEncode: null
|
- DetLabelEncode: # Class handling label
|
||||||
- DetResizeForTest: null
|
- DetResizeForTest:
|
||||||
- NormalizeImage:
|
- NormalizeImage:
|
||||||
scale: 1./255.
|
scale: 1./255.
|
||||||
mean:
|
mean: [0.485, 0.456, 0.406]
|
||||||
- 0.485
|
std: [0.229, 0.224, 0.225]
|
||||||
- 0.456
|
order: 'hwc'
|
||||||
- 0.406
|
- ToCHWImage:
|
||||||
std:
|
- KeepKeys:
|
||||||
- 0.229
|
keep_keys: ['image', 'shape', 'polys', 'ignore_tags']
|
||||||
- 0.224
|
|
||||||
- 0.225
|
|
||||||
order: hwc
|
|
||||||
- ToCHWImage: null
|
|
||||||
- KeepKeys:
|
|
||||||
keep_keys:
|
|
||||||
- image
|
|
||||||
- shape
|
|
||||||
- polys
|
|
||||||
- ignore_tags
|
|
||||||
loader:
|
loader:
|
||||||
shuffle: false
|
shuffle: False
|
||||||
drop_last: false
|
drop_last: False
|
||||||
batch_size_per_card: 1
|
batch_size_per_card: 1 # must be 1
|
||||||
num_workers: 2
|
num_workers: 2
|
||||||
@ -60,19 +60,19 @@ class KLJSLoss(object):
|
|||||||
], "mode can only be one of ['kl', 'KL', 'js', 'JS']"
|
], "mode can only be one of ['kl', 'KL', 'js', 'JS']"
|
||||||
self.mode = mode
|
self.mode = mode
|
||||||
|
|
||||||
def __call__(self, p1, p2, reduction="mean"):
|
def __call__(self, p1, p2, reduction="mean", eps=1e-5):
|
||||||
|
|
||||||
if self.mode.lower() == 'kl':
|
if self.mode.lower() == 'kl':
|
||||||
loss = paddle.multiply(p2,
|
loss = paddle.multiply(p2,
|
||||||
paddle.log((p2 + 1e-5) / (p1 + 1e-5) + 1e-5))
|
paddle.log((p2 + eps) / (p1 + eps) + eps))
|
||||||
loss += paddle.multiply(
|
loss += paddle.multiply(p1,
|
||||||
p1, paddle.log((p1 + 1e-5) / (p2 + 1e-5) + 1e-5))
|
paddle.log((p1 + eps) / (p2 + eps) + eps))
|
||||||
loss *= 0.5
|
loss *= 0.5
|
||||||
elif self.mode.lower() == "js":
|
elif self.mode.lower() == "js":
|
||||||
loss = paddle.multiply(
|
loss = paddle.multiply(
|
||||||
p2, paddle.log((2 * p2 + 1e-5) / (p1 + p2 + 1e-5) + 1e-5))
|
p2, paddle.log((2 * p2 + eps) / (p1 + p2 + eps) + eps))
|
||||||
loss += paddle.multiply(
|
loss += paddle.multiply(
|
||||||
p1, paddle.log((2 * p1 + 1e-5) / (p1 + p2 + 1e-5) + 1e-5))
|
p1, paddle.log((2 * p1 + eps) / (p1 + p2 + eps) + eps))
|
||||||
loss *= 0.5
|
loss *= 0.5
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user