mirror of
https://github.com/PaddlePaddle/PaddleOCR.git
synced 2025-12-06 20:06:48 +00:00
add style_text_rec
This commit is contained in:
parent
b1623d69a5
commit
f2d98c5e76
255
tools/style_text_rec/arch/base_module.py
Normal file
255
tools/style_text_rec/arch/base_module.py
Normal file
@ -0,0 +1,255 @@
|
|||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import functools
|
||||||
|
import paddle
|
||||||
|
import paddle.nn as nn
|
||||||
|
from arch.spectral_norm import spectral_norm
|
||||||
|
|
||||||
|
|
||||||
|
class CBN(nn.Layer):
|
||||||
|
def __init__(self,
|
||||||
|
name,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
dilation=1,
|
||||||
|
groups=1,
|
||||||
|
use_bias=False,
|
||||||
|
norm_layer=None,
|
||||||
|
act=None,
|
||||||
|
act_attr=None):
|
||||||
|
super(CBN, self).__init__()
|
||||||
|
if use_bias:
|
||||||
|
bias_attr = paddle.ParamAttr(name=name + "_bias")
|
||||||
|
else:
|
||||||
|
bias_attr = None
|
||||||
|
self._conv = paddle.nn.Conv2D(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
dilation=dilation,
|
||||||
|
groups=groups,
|
||||||
|
weight_attr=paddle.ParamAttr(name=name + "_weights"),
|
||||||
|
bias_attr=bias_attr)
|
||||||
|
if norm_layer:
|
||||||
|
self._norm_layer = getattr(paddle.nn, norm_layer)(
|
||||||
|
num_features=out_channels, name=name + "_bn")
|
||||||
|
else:
|
||||||
|
self._norm_layer = None
|
||||||
|
if act:
|
||||||
|
if act_attr:
|
||||||
|
self._act = getattr(paddle.nn, act)(**act_attr,
|
||||||
|
name=name + "_" + act)
|
||||||
|
else:
|
||||||
|
self._act = getattr(paddle.nn, act)(name=name + "_" + act)
|
||||||
|
else:
|
||||||
|
self._act = None
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out = self._conv(x)
|
||||||
|
if self._norm_layer:
|
||||||
|
out = self._norm_layer(out)
|
||||||
|
if self._act:
|
||||||
|
out = self._act(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class SNConv(nn.Layer):
|
||||||
|
def __init__(self,
|
||||||
|
name,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
dilation=1,
|
||||||
|
groups=1,
|
||||||
|
use_bias=False,
|
||||||
|
norm_layer=None,
|
||||||
|
act=None,
|
||||||
|
act_attr=None):
|
||||||
|
super(SNConv, self).__init__()
|
||||||
|
if use_bias:
|
||||||
|
bias_attr = paddle.ParamAttr(name=name + "_bias")
|
||||||
|
else:
|
||||||
|
bias_attr = None
|
||||||
|
self._sn_conv = spectral_norm(
|
||||||
|
paddle.nn.Conv2D(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
dilation=dilation,
|
||||||
|
groups=groups,
|
||||||
|
weight_attr=paddle.ParamAttr(name=name + "_weights"),
|
||||||
|
bias_attr=bias_attr))
|
||||||
|
if norm_layer:
|
||||||
|
self._norm_layer = getattr(paddle.nn, norm_layer)(
|
||||||
|
num_features=out_channels, name=name + "_bn")
|
||||||
|
else:
|
||||||
|
self._norm_layer = None
|
||||||
|
if act:
|
||||||
|
if act_attr:
|
||||||
|
self._act = getattr(paddle.nn, act)(**act_attr,
|
||||||
|
name=name + "_" + act)
|
||||||
|
else:
|
||||||
|
self._act = getattr(paddle.nn, act)(name=name + "_" + act)
|
||||||
|
else:
|
||||||
|
self._act = None
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out = self._sn_conv(x)
|
||||||
|
if self._norm_layer:
|
||||||
|
out = self._norm_layer(out)
|
||||||
|
if self._act:
|
||||||
|
out = self._act(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class SNConvTranspose(nn.Layer):
|
||||||
|
def __init__(self,
|
||||||
|
name,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
output_padding=0,
|
||||||
|
dilation=1,
|
||||||
|
groups=1,
|
||||||
|
use_bias=False,
|
||||||
|
norm_layer=None,
|
||||||
|
act=None,
|
||||||
|
act_attr=None):
|
||||||
|
super(SNConvTranspose, self).__init__()
|
||||||
|
if use_bias:
|
||||||
|
bias_attr = paddle.ParamAttr(name=name + "_bias")
|
||||||
|
else:
|
||||||
|
bias_attr = None
|
||||||
|
self._sn_conv_transpose = spectral_norm(
|
||||||
|
paddle.nn.Conv2DTranspose(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
output_padding=output_padding,
|
||||||
|
dilation=dilation,
|
||||||
|
groups=groups,
|
||||||
|
weight_attr=paddle.ParamAttr(name=name + "_weights"),
|
||||||
|
bias_attr=bias_attr))
|
||||||
|
if norm_layer:
|
||||||
|
self._norm_layer = getattr(paddle.nn, norm_layer)(
|
||||||
|
num_features=out_channels, name=name + "_bn")
|
||||||
|
else:
|
||||||
|
self._norm_layer = None
|
||||||
|
if act:
|
||||||
|
if act_attr:
|
||||||
|
self._act = getattr(paddle.nn, act)(**act_attr,
|
||||||
|
name=name + "_" + act)
|
||||||
|
else:
|
||||||
|
self._act = getattr(paddle.nn, act)(name=name + "_" + act)
|
||||||
|
else:
|
||||||
|
self._act = None
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out = self._sn_conv_transpose(x)
|
||||||
|
if self._norm_layer:
|
||||||
|
out = self._norm_layer(out)
|
||||||
|
if self._act:
|
||||||
|
out = self._act(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class MiddleNet(nn.Layer):
|
||||||
|
def __init__(self, name, in_channels, mid_channels, out_channels,
|
||||||
|
use_bias):
|
||||||
|
super(MiddleNet, self).__init__()
|
||||||
|
self._sn_conv1 = SNConv(
|
||||||
|
name=name + "_sn_conv1",
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=mid_channels,
|
||||||
|
kernel_size=1,
|
||||||
|
use_bias=use_bias,
|
||||||
|
norm_layer=None,
|
||||||
|
act=None)
|
||||||
|
self._pad2d = nn.Pad2D(padding=[1, 1, 1, 1], mode="replicate")
|
||||||
|
self._sn_conv2 = SNConv(
|
||||||
|
name=name + "_sn_conv2",
|
||||||
|
in_channels=mid_channels,
|
||||||
|
out_channels=mid_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
use_bias=use_bias)
|
||||||
|
self._sn_conv3 = SNConv(
|
||||||
|
name=name + "_sn_conv3",
|
||||||
|
in_channels=mid_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=1,
|
||||||
|
use_bias=use_bias)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
|
||||||
|
sn_conv1 = self._sn_conv1.forward(x)
|
||||||
|
pad_2d = self._pad2d.forward(sn_conv1)
|
||||||
|
sn_conv2 = self._sn_conv2.forward(pad_2d)
|
||||||
|
sn_conv3 = self._sn_conv3.forward(sn_conv2)
|
||||||
|
return sn_conv3
|
||||||
|
|
||||||
|
|
||||||
|
class ResBlock(nn.Layer):
|
||||||
|
def __init__(self, name, channels, norm_layer, use_dropout, use_dilation,
|
||||||
|
use_bias):
|
||||||
|
super(ResBlock, self).__init__()
|
||||||
|
if use_dilation:
|
||||||
|
padding_mat = [1, 1, 1, 1]
|
||||||
|
else:
|
||||||
|
padding_mat = [0, 0, 0, 0]
|
||||||
|
self._pad1 = nn.Pad2D(padding_mat, mode="replicate")
|
||||||
|
|
||||||
|
self._sn_conv1 = SNConv(
|
||||||
|
name=name + "_sn_conv1",
|
||||||
|
in_channels=channels,
|
||||||
|
out_channels=channels,
|
||||||
|
kernel_size=3,
|
||||||
|
padding=0,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
use_bias=use_bias,
|
||||||
|
act="ReLU",
|
||||||
|
act_attr=None)
|
||||||
|
if use_dropout:
|
||||||
|
self._dropout = nn.Dropout(0.5)
|
||||||
|
else:
|
||||||
|
self._dropout = None
|
||||||
|
self._pad2 = nn.Pad2D([1, 1, 1, 1], mode="replicate")
|
||||||
|
self._sn_conv2 = SNConv(
|
||||||
|
name=name + "_sn_conv2",
|
||||||
|
in_channels=channels,
|
||||||
|
out_channels=channels,
|
||||||
|
kernel_size=3,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
use_bias=use_bias,
|
||||||
|
act="ReLU",
|
||||||
|
act_attr=None)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
pad1 = self._pad1.forward(x)
|
||||||
|
sn_conv1 = self._sn_conv1.forward(pad1)
|
||||||
|
pad2 = self._pad2.forward(sn_conv1)
|
||||||
|
sn_conv2 = self._sn_conv2.forward(pad2)
|
||||||
|
return sn_conv2 + x
|
||||||
250
tools/style_text_rec/arch/decoder.py
Normal file
250
tools/style_text_rec/arch/decoder.py
Normal file
@ -0,0 +1,250 @@
|
|||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import paddle
|
||||||
|
import paddle.nn as nn
|
||||||
|
from arch.base_module import SNConv, SNConvTranspose, ResBlock
|
||||||
|
|
||||||
|
|
||||||
|
class Decoder(nn.Layer):
|
||||||
|
def __init__(self, name, encode_dim, out_channels, use_bias, norm_layer,
|
||||||
|
act, act_attr, conv_block_dropout, conv_block_num,
|
||||||
|
conv_block_dilation, out_conv_act, out_conv_act_attr):
|
||||||
|
super(Decoder, self).__init__()
|
||||||
|
conv_blocks = []
|
||||||
|
for i in range(conv_block_num):
|
||||||
|
conv_blocks.append(
|
||||||
|
ResBlock(
|
||||||
|
name="{}_conv_block_{}".format(name, i),
|
||||||
|
channels=encode_dim * 8,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
use_dropout=conv_block_dropout,
|
||||||
|
use_dilation=conv_block_dilation,
|
||||||
|
use_bias=use_bias))
|
||||||
|
self.conv_blocks = nn.Sequential(*conv_blocks)
|
||||||
|
self._up1 = SNConvTranspose(
|
||||||
|
name=name + "_up1",
|
||||||
|
in_channels=encode_dim * 8,
|
||||||
|
out_channels=encode_dim * 4,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2,
|
||||||
|
padding=1,
|
||||||
|
output_padding=1,
|
||||||
|
use_bias=use_bias,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
act=act,
|
||||||
|
act_attr=act_attr)
|
||||||
|
self._up2 = SNConvTranspose(
|
||||||
|
name=name + "_up2",
|
||||||
|
in_channels=encode_dim * 4,
|
||||||
|
out_channels=encode_dim * 2,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2,
|
||||||
|
padding=1,
|
||||||
|
output_padding=1,
|
||||||
|
use_bias=use_bias,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
act=act,
|
||||||
|
act_attr=act_attr)
|
||||||
|
self._up3 = SNConvTranspose(
|
||||||
|
name=name + "_up3",
|
||||||
|
in_channels=encode_dim * 2,
|
||||||
|
out_channels=encode_dim,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2,
|
||||||
|
padding=1,
|
||||||
|
output_padding=1,
|
||||||
|
use_bias=use_bias,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
act=act,
|
||||||
|
act_attr=act_attr)
|
||||||
|
self._pad2d = paddle.nn.Pad2D([1, 1, 1, 1], mode="replicate")
|
||||||
|
self._out_conv = SNConv(
|
||||||
|
name=name + "_out_conv",
|
||||||
|
in_channels=encode_dim,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
use_bias=use_bias,
|
||||||
|
norm_layer=None,
|
||||||
|
act=out_conv_act,
|
||||||
|
act_attr=out_conv_act_attr)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if isinstance(x, (list, tuple)):
|
||||||
|
x = paddle.concat(x, axis=1)
|
||||||
|
output_dict = dict()
|
||||||
|
output_dict["conv_blocks"] = self.conv_blocks.forward(x)
|
||||||
|
output_dict["up1"] = self._up1.forward(output_dict["conv_blocks"])
|
||||||
|
output_dict["up2"] = self._up2.forward(output_dict["up1"])
|
||||||
|
output_dict["up3"] = self._up3.forward(output_dict["up2"])
|
||||||
|
output_dict["pad2d"] = self._pad2d.forward(output_dict["up3"])
|
||||||
|
output_dict["out_conv"] = self._out_conv.forward(output_dict["pad2d"])
|
||||||
|
return output_dict
|
||||||
|
|
||||||
|
|
||||||
|
class DecoderUnet(nn.Layer):
|
||||||
|
def __init__(self, name, encode_dim, out_channels, use_bias, norm_layer,
|
||||||
|
act, act_attr, conv_block_dropout, conv_block_num,
|
||||||
|
conv_block_dilation, out_conv_act, out_conv_act_attr):
|
||||||
|
super(DecoderUnet, self).__init__()
|
||||||
|
conv_blocks = []
|
||||||
|
for i in range(conv_block_num):
|
||||||
|
conv_blocks.append(
|
||||||
|
ResBlock(
|
||||||
|
name="{}_conv_block_{}".format(name, i),
|
||||||
|
channels=encode_dim * 8,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
use_dropout=conv_block_dropout,
|
||||||
|
use_dilation=conv_block_dilation,
|
||||||
|
use_bias=use_bias))
|
||||||
|
self._conv_blocks = nn.Sequential(*conv_blocks)
|
||||||
|
self._up1 = SNConvTranspose(
|
||||||
|
name=name + "_up1",
|
||||||
|
in_channels=encode_dim * 8,
|
||||||
|
out_channels=encode_dim * 4,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2,
|
||||||
|
padding=1,
|
||||||
|
output_padding=1,
|
||||||
|
use_bias=use_bias,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
act=act,
|
||||||
|
act_attr=act_attr)
|
||||||
|
self._up2 = SNConvTranspose(
|
||||||
|
name=name + "_up2",
|
||||||
|
in_channels=encode_dim * 8,
|
||||||
|
out_channels=encode_dim * 2,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2,
|
||||||
|
padding=1,
|
||||||
|
output_padding=1,
|
||||||
|
use_bias=use_bias,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
act=act,
|
||||||
|
act_attr=act_attr)
|
||||||
|
self._up3 = SNConvTranspose(
|
||||||
|
name=name + "_up3",
|
||||||
|
in_channels=encode_dim * 4,
|
||||||
|
out_channels=encode_dim,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2,
|
||||||
|
padding=1,
|
||||||
|
output_padding=1,
|
||||||
|
use_bias=use_bias,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
act=act,
|
||||||
|
act_attr=act_attr)
|
||||||
|
self._pad2d = paddle.nn.Pad2D([1, 1, 1, 1], mode="replicate")
|
||||||
|
self._out_conv = SNConv(
|
||||||
|
name=name + "_out_conv",
|
||||||
|
in_channels=encode_dim,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
use_bias=use_bias,
|
||||||
|
norm_layer=None,
|
||||||
|
act=out_conv_act,
|
||||||
|
act_attr=out_conv_act_attr)
|
||||||
|
|
||||||
|
def forward(self, x, y, feature2, feature1):
|
||||||
|
output_dict = dict()
|
||||||
|
output_dict["conv_blocks"] = self._conv_blocks(
|
||||||
|
paddle.concat(
|
||||||
|
(x, y), axis=1))
|
||||||
|
output_dict["up1"] = self._up1.forward(output_dict["conv_blocks"])
|
||||||
|
output_dict["up2"] = self._up2.forward(
|
||||||
|
paddle.concat(
|
||||||
|
(output_dict["up1"], feature2), axis=1))
|
||||||
|
output_dict["up3"] = self._up3.forward(
|
||||||
|
paddle.concat(
|
||||||
|
(output_dict["up2"], feature1), axis=1))
|
||||||
|
output_dict["pad2d"] = self._pad2d.forward(output_dict["up3"])
|
||||||
|
output_dict["out_conv"] = self._out_conv.forward(output_dict["pad2d"])
|
||||||
|
return output_dict
|
||||||
|
|
||||||
|
|
||||||
|
class SingleDecoder(nn.Layer):
|
||||||
|
def __init__(self, name, encode_dim, out_channels, use_bias, norm_layer,
|
||||||
|
act, act_attr, conv_block_dropout, conv_block_num,
|
||||||
|
conv_block_dilation, out_conv_act, out_conv_act_attr):
|
||||||
|
super(SingleDecoder, self).__init__()
|
||||||
|
conv_blocks = []
|
||||||
|
for i in range(conv_block_num):
|
||||||
|
conv_blocks.append(
|
||||||
|
ResBlock(
|
||||||
|
name="{}_conv_block_{}".format(name, i),
|
||||||
|
channels=encode_dim * 4,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
use_dropout=conv_block_dropout,
|
||||||
|
use_dilation=conv_block_dilation,
|
||||||
|
use_bias=use_bias))
|
||||||
|
self._conv_blocks = nn.Sequential(*conv_blocks)
|
||||||
|
self._up1 = SNConvTranspose(
|
||||||
|
name=name + "_up1",
|
||||||
|
in_channels=encode_dim * 4,
|
||||||
|
out_channels=encode_dim * 4,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2,
|
||||||
|
padding=1,
|
||||||
|
output_padding=1,
|
||||||
|
use_bias=use_bias,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
act=act,
|
||||||
|
act_attr=act_attr)
|
||||||
|
self._up2 = SNConvTranspose(
|
||||||
|
name=name + "_up2",
|
||||||
|
in_channels=encode_dim * 8,
|
||||||
|
out_channels=encode_dim * 2,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2,
|
||||||
|
padding=1,
|
||||||
|
output_padding=1,
|
||||||
|
use_bias=use_bias,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
act=act,
|
||||||
|
act_attr=act_attr)
|
||||||
|
self._up3 = SNConvTranspose(
|
||||||
|
name=name + "_up3",
|
||||||
|
in_channels=encode_dim * 4,
|
||||||
|
out_channels=encode_dim,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2,
|
||||||
|
padding=1,
|
||||||
|
output_padding=1,
|
||||||
|
use_bias=use_bias,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
act=act,
|
||||||
|
act_attr=act_attr)
|
||||||
|
self._pad2d = paddle.nn.Pad2D([1, 1, 1, 1], mode="replicate")
|
||||||
|
self._out_conv = SNConv(
|
||||||
|
name=name + "_out_conv",
|
||||||
|
in_channels=encode_dim,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
use_bias=use_bias,
|
||||||
|
norm_layer=None,
|
||||||
|
act=out_conv_act,
|
||||||
|
act_attr=out_conv_act_attr)
|
||||||
|
|
||||||
|
def forward(self, x, feature2, feature1):
|
||||||
|
output_dict = dict()
|
||||||
|
output_dict["conv_blocks"] = self._conv_blocks.forward(x)
|
||||||
|
output_dict["up1"] = self._up1.forward(output_dict["conv_blocks"])
|
||||||
|
output_dict["up2"] = self._up2.forward(
|
||||||
|
paddle.concat(
|
||||||
|
(output_dict["up1"], feature2), axis=1))
|
||||||
|
output_dict["up3"] = self._up3.forward(
|
||||||
|
paddle.concat(
|
||||||
|
(output_dict["up2"], feature1), axis=1))
|
||||||
|
output_dict["pad2d"] = self._pad2d.forward(output_dict["up3"])
|
||||||
|
output_dict["out_conv"] = self._out_conv.forward(output_dict["pad2d"])
|
||||||
|
return output_dict
|
||||||
185
tools/style_text_rec/arch/encoder.py
Normal file
185
tools/style_text_rec/arch/encoder.py
Normal file
@ -0,0 +1,185 @@
|
|||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import paddle
|
||||||
|
import paddle.nn as nn
|
||||||
|
from arch.base_module import SNConv, SNConvTranspose, ResBlock
|
||||||
|
|
||||||
|
|
||||||
|
class Encoder(nn.Layer):
|
||||||
|
def __init__(self, name, in_channels, encode_dim, use_bias, norm_layer,
|
||||||
|
act, act_attr, conv_block_dropout, conv_block_num,
|
||||||
|
conv_block_dilation):
|
||||||
|
super(Encoder, self).__init__()
|
||||||
|
self._pad2d = paddle.nn.Pad2D([3, 3, 3, 3], mode="replicate")
|
||||||
|
self._in_conv = SNConv(
|
||||||
|
name=name + "_in_conv",
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=encode_dim,
|
||||||
|
kernel_size=7,
|
||||||
|
use_bias=use_bias,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
act=act,
|
||||||
|
act_attr=act_attr)
|
||||||
|
self._down1 = SNConv(
|
||||||
|
name=name + "_down1",
|
||||||
|
in_channels=encode_dim,
|
||||||
|
out_channels=encode_dim * 2,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2,
|
||||||
|
padding=1,
|
||||||
|
use_bias=use_bias,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
act=act,
|
||||||
|
act_attr=act_attr)
|
||||||
|
self._down2 = SNConv(
|
||||||
|
name=name + "_down2",
|
||||||
|
in_channels=encode_dim * 2,
|
||||||
|
out_channels=encode_dim * 4,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2,
|
||||||
|
padding=1,
|
||||||
|
use_bias=use_bias,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
act=act,
|
||||||
|
act_attr=act_attr)
|
||||||
|
self._down3 = SNConv(
|
||||||
|
name=name + "_down3",
|
||||||
|
in_channels=encode_dim * 4,
|
||||||
|
out_channels=encode_dim * 4,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2,
|
||||||
|
padding=1,
|
||||||
|
use_bias=use_bias,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
act=act,
|
||||||
|
act_attr=act_attr)
|
||||||
|
conv_blocks = []
|
||||||
|
for i in range(conv_block_num):
|
||||||
|
conv_blocks.append(
|
||||||
|
ResBlock(
|
||||||
|
name="{}_conv_block_{}".format(name, i),
|
||||||
|
channels=encode_dim * 4,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
use_dropout=conv_block_dropout,
|
||||||
|
use_dilation=conv_block_dilation,
|
||||||
|
use_bias=use_bias))
|
||||||
|
self._conv_blocks = nn.Sequential(*conv_blocks)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out_dict = dict()
|
||||||
|
x = self._pad2d(x)
|
||||||
|
out_dict["in_conv"] = self._in_conv.forward(x)
|
||||||
|
out_dict["down1"] = self._down1.forward(out_dict["in_conv"])
|
||||||
|
out_dict["down2"] = self._down2.forward(out_dict["down1"])
|
||||||
|
out_dict["down3"] = self._down3.forward(out_dict["down2"])
|
||||||
|
out_dict["res_blocks"] = self._conv_blocks.forward(out_dict["down3"])
|
||||||
|
return out_dict
|
||||||
|
|
||||||
|
|
||||||
|
class EncoderUnet(nn.Layer):
|
||||||
|
def __init__(self, name, in_channels, encode_dim, use_bias, norm_layer,
|
||||||
|
act, act_attr):
|
||||||
|
super(EncoderUnet, self).__init__()
|
||||||
|
self._pad2d = paddle.nn.Pad2D([3, 3, 3, 3], mode="replicate")
|
||||||
|
self._in_conv = SNConv(
|
||||||
|
name=name + "_in_conv",
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=encode_dim,
|
||||||
|
kernel_size=7,
|
||||||
|
use_bias=use_bias,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
act=act,
|
||||||
|
act_attr=act_attr)
|
||||||
|
self._down1 = SNConv(
|
||||||
|
name=name + "_down1",
|
||||||
|
in_channels=encode_dim,
|
||||||
|
out_channels=encode_dim * 2,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2,
|
||||||
|
padding=1,
|
||||||
|
use_bias=use_bias,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
act=act,
|
||||||
|
act_attr=act_attr)
|
||||||
|
self._down2 = SNConv(
|
||||||
|
name=name + "_down2",
|
||||||
|
in_channels=encode_dim * 2,
|
||||||
|
out_channels=encode_dim * 2,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2,
|
||||||
|
padding=1,
|
||||||
|
use_bias=use_bias,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
act=act,
|
||||||
|
act_attr=act_attr)
|
||||||
|
self._down3 = SNConv(
|
||||||
|
name=name + "_down3",
|
||||||
|
in_channels=encode_dim * 2,
|
||||||
|
out_channels=encode_dim * 2,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2,
|
||||||
|
padding=1,
|
||||||
|
use_bias=use_bias,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
act=act,
|
||||||
|
act_attr=act_attr)
|
||||||
|
self._down4 = SNConv(
|
||||||
|
name=name + "_down4",
|
||||||
|
in_channels=encode_dim * 2,
|
||||||
|
out_channels=encode_dim * 2,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2,
|
||||||
|
padding=1,
|
||||||
|
use_bias=use_bias,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
act=act,
|
||||||
|
act_attr=act_attr)
|
||||||
|
self._up1 = SNConvTranspose(
|
||||||
|
name=name + "_up1",
|
||||||
|
in_channels=encode_dim * 2,
|
||||||
|
out_channels=encode_dim * 2,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2,
|
||||||
|
padding=1,
|
||||||
|
use_bias=use_bias,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
act=act,
|
||||||
|
act_attr=act_attr)
|
||||||
|
self._up2 = SNConvTranspose(
|
||||||
|
name=name + "_up2",
|
||||||
|
in_channels=encode_dim * 4,
|
||||||
|
out_channels=encode_dim * 4,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2,
|
||||||
|
padding=1,
|
||||||
|
use_bias=use_bias,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
act=act,
|
||||||
|
act_attr=act_attr)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
output_dict = dict()
|
||||||
|
x = self._pad2d(x)
|
||||||
|
output_dict['in_conv'] = self._in_conv.forward(x)
|
||||||
|
output_dict['down1'] = self._down1.forward(output_dict['in_conv'])
|
||||||
|
output_dict['down2'] = self._down2.forward(output_dict['down1'])
|
||||||
|
output_dict['down3'] = self._down3.forward(output_dict['down2'])
|
||||||
|
output_dict['down4'] = self._down4.forward(output_dict['down3'])
|
||||||
|
output_dict['up1'] = self._up1.forward(output_dict['down4'])
|
||||||
|
output_dict['up2'] = self._up2.forward(
|
||||||
|
paddle.concat(
|
||||||
|
(output_dict['down3'], output_dict['up1']), axis=1))
|
||||||
|
output_dict['concat'] = paddle.concat(
|
||||||
|
(output_dict['down2'], output_dict['up2']), axis=1)
|
||||||
|
return output_dict
|
||||||
154
tools/style_text_rec/arch/spectral_norm.py
Normal file
154
tools/style_text_rec/arch/spectral_norm.py
Normal file
@ -0,0 +1,154 @@
|
|||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import math
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
import paddle.nn as nn
|
||||||
|
import paddle.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
def normal_(x, mean=0., std=1.):
|
||||||
|
temp_value = paddle.normal(mean, std, shape=x.shape)
|
||||||
|
x.set_value(temp_value)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SpectralNorm(object):
|
||||||
|
def __init__(self, name='weight', n_power_iterations=1, dim=0, eps=1e-12):
|
||||||
|
self.name = name
|
||||||
|
self.dim = dim
|
||||||
|
if n_power_iterations <= 0:
|
||||||
|
raise ValueError('Expected n_power_iterations to be positive, but '
|
||||||
|
'got n_power_iterations={}'.format(
|
||||||
|
n_power_iterations))
|
||||||
|
self.n_power_iterations = n_power_iterations
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
def reshape_weight_to_matrix(self, weight):
|
||||||
|
weight_mat = weight
|
||||||
|
if self.dim != 0:
|
||||||
|
# transpose dim to front
|
||||||
|
weight_mat = weight_mat.transpose([
|
||||||
|
self.dim,
|
||||||
|
* [d for d in range(weight_mat.dim()) if d != self.dim]
|
||||||
|
])
|
||||||
|
|
||||||
|
height = weight_mat.shape[0]
|
||||||
|
|
||||||
|
return weight_mat.reshape([height, -1])
|
||||||
|
|
||||||
|
def compute_weight(self, module, do_power_iteration):
|
||||||
|
weight = getattr(module, self.name + '_orig')
|
||||||
|
u = getattr(module, self.name + '_u')
|
||||||
|
v = getattr(module, self.name + '_v')
|
||||||
|
weight_mat = self.reshape_weight_to_matrix(weight)
|
||||||
|
|
||||||
|
if do_power_iteration:
|
||||||
|
with paddle.no_grad():
|
||||||
|
for _ in range(self.n_power_iterations):
|
||||||
|
v.set_value(
|
||||||
|
F.normalize(
|
||||||
|
paddle.matmul(
|
||||||
|
weight_mat,
|
||||||
|
u,
|
||||||
|
transpose_x=True,
|
||||||
|
transpose_y=False),
|
||||||
|
axis=0,
|
||||||
|
epsilon=self.eps, ))
|
||||||
|
|
||||||
|
u.set_value(
|
||||||
|
F.normalize(
|
||||||
|
paddle.matmul(weight_mat, v),
|
||||||
|
axis=0,
|
||||||
|
epsilon=self.eps, ))
|
||||||
|
if self.n_power_iterations > 0:
|
||||||
|
u = u.clone()
|
||||||
|
v = v.clone()
|
||||||
|
|
||||||
|
sigma = paddle.dot(u, paddle.mv(weight_mat, v))
|
||||||
|
weight = weight / sigma
|
||||||
|
return weight
|
||||||
|
|
||||||
|
def remove(self, module):
|
||||||
|
with paddle.no_grad():
|
||||||
|
weight = self.compute_weight(module, do_power_iteration=False)
|
||||||
|
delattr(module, self.name)
|
||||||
|
delattr(module, self.name + '_u')
|
||||||
|
delattr(module, self.name + '_v')
|
||||||
|
delattr(module, self.name + '_orig')
|
||||||
|
|
||||||
|
module.add_parameter(self.name, weight.detach())
|
||||||
|
|
||||||
|
def __call__(self, module, inputs):
|
||||||
|
setattr(
|
||||||
|
module,
|
||||||
|
self.name,
|
||||||
|
self.compute_weight(
|
||||||
|
module, do_power_iteration=module.training))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def apply(module, name, n_power_iterations, dim, eps):
|
||||||
|
for k, hook in module._forward_pre_hooks.items():
|
||||||
|
if isinstance(hook, SpectralNorm) and hook.name == name:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Cannot register two spectral_norm hooks on "
|
||||||
|
"the same parameter {}".format(name))
|
||||||
|
|
||||||
|
fn = SpectralNorm(name, n_power_iterations, dim, eps)
|
||||||
|
weight = module._parameters[name]
|
||||||
|
|
||||||
|
with paddle.no_grad():
|
||||||
|
weight_mat = fn.reshape_weight_to_matrix(weight)
|
||||||
|
h, w = weight_mat.shape
|
||||||
|
|
||||||
|
# randomly initialize u and v
|
||||||
|
u = module.create_parameter([h])
|
||||||
|
u = normal_(u, 0., 1.)
|
||||||
|
v = module.create_parameter([w])
|
||||||
|
v = normal_(v, 0., 1.)
|
||||||
|
u = F.normalize(u, axis=0, epsilon=fn.eps)
|
||||||
|
v = F.normalize(v, axis=0, epsilon=fn.eps)
|
||||||
|
|
||||||
|
# delete fn.name form parameters, otherwise you can not set attribute
|
||||||
|
del module._parameters[fn.name]
|
||||||
|
module.add_parameter(fn.name + "_orig", weight)
|
||||||
|
# still need to assign weight back as fn.name because all sorts of
|
||||||
|
# things may assume that it exists, e.g., when initializing weights.
|
||||||
|
# However, we can't directly assign as it could be an Parameter and
|
||||||
|
# gets added as a parameter. Instead, we register weight * 1.0 as a plain
|
||||||
|
# attribute.
|
||||||
|
setattr(module, fn.name, weight * 1.0)
|
||||||
|
module.register_buffer(fn.name + "_u", u)
|
||||||
|
module.register_buffer(fn.name + "_v", v)
|
||||||
|
|
||||||
|
module.register_forward_pre_hook(fn)
|
||||||
|
return fn
|
||||||
|
|
||||||
|
|
||||||
|
def spectral_norm(module,
|
||||||
|
name='weight',
|
||||||
|
n_power_iterations=1,
|
||||||
|
eps=1e-12,
|
||||||
|
dim=None):
|
||||||
|
|
||||||
|
if dim is None:
|
||||||
|
if isinstance(module, (nn.Conv1DTranspose, nn.Conv2DTranspose,
|
||||||
|
nn.Conv3DTranspose, nn.Linear)):
|
||||||
|
dim = 1
|
||||||
|
else:
|
||||||
|
dim = 0
|
||||||
|
SpectralNorm.apply(module, name, n_power_iterations, dim, eps)
|
||||||
|
return module
|
||||||
288
tools/style_text_rec/arch/style_text_rec.py
Normal file
288
tools/style_text_rec/arch/style_text_rec.py
Normal file
@ -0,0 +1,288 @@
|
|||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
import math
|
||||||
|
import paddle
|
||||||
|
import paddle.nn as nn
|
||||||
|
|
||||||
|
from arch.base_module import MiddleNet, ResBlock
|
||||||
|
from arch.encoder import Encoder
|
||||||
|
from arch.decoder import Decoder, DecoderUnet, SingleDecoder
|
||||||
|
from utils.load_params import load_dygraph_pretrain
|
||||||
|
from utils.logging import get_logger
|
||||||
|
|
||||||
|
|
||||||
|
class StyleTextRec(nn.Layer):
|
||||||
|
def __init__(self, config):
|
||||||
|
super(StyleTextRec, self).__init__()
|
||||||
|
self.logger = get_logger()
|
||||||
|
self.text_generator = TextGenerator(config["Predictor"][
|
||||||
|
"text_generator"])
|
||||||
|
self.bg_generator = BgGeneratorWithMask(config["Predictor"][
|
||||||
|
"bg_generator"])
|
||||||
|
self.fusion_generator = FusionGeneratorSimple(config["Predictor"][
|
||||||
|
"fusion_generator"])
|
||||||
|
bg_generator_pretrain = config["Predictor"]["bg_generator"]["pretrain"]
|
||||||
|
text_generator_pretrain = config["Predictor"]["text_generator"][
|
||||||
|
"pretrain"]
|
||||||
|
fusion_generator_pretrain = config["Predictor"]["fusion_generator"][
|
||||||
|
"pretrain"]
|
||||||
|
load_dygraph_pretrain(
|
||||||
|
self.bg_generator,
|
||||||
|
self.logger,
|
||||||
|
path=bg_generator_pretrain,
|
||||||
|
load_static_weights=False)
|
||||||
|
load_dygraph_pretrain(
|
||||||
|
self.text_generator,
|
||||||
|
self.logger,
|
||||||
|
path=text_generator_pretrain,
|
||||||
|
load_static_weights=False)
|
||||||
|
load_dygraph_pretrain(
|
||||||
|
self.fusion_generator,
|
||||||
|
self.logger,
|
||||||
|
path=fusion_generator_pretrain,
|
||||||
|
load_static_weights=False)
|
||||||
|
|
||||||
|
def forward(self, style_input, text_input):
|
||||||
|
text_gen_output = self.text_generator.forward(style_input, text_input)
|
||||||
|
fake_text = text_gen_output["fake_text"]
|
||||||
|
fake_sk = text_gen_output["fake_sk"]
|
||||||
|
bg_gen_output = self.bg_generator.forward(style_input)
|
||||||
|
bg_encode_feature = bg_gen_output["bg_encode_feature"]
|
||||||
|
bg_decode_feature1 = bg_gen_output["bg_decode_feature1"]
|
||||||
|
bg_decode_feature2 = bg_gen_output["bg_decode_feature2"]
|
||||||
|
fake_bg = bg_gen_output["fake_bg"]
|
||||||
|
|
||||||
|
fusion_gen_output = self.fusion_generator.forward(fake_text, fake_bg)
|
||||||
|
fake_fusion = fusion_gen_output["fake_fusion"]
|
||||||
|
return {
|
||||||
|
"fake_fusion": fake_fusion,
|
||||||
|
"fake_text": fake_text,
|
||||||
|
"fake_sk": fake_sk,
|
||||||
|
"fake_bg": fake_bg,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TextGenerator(nn.Layer):
|
||||||
|
def __init__(self, config):
|
||||||
|
super(TextGenerator, self).__init__()
|
||||||
|
name = config["module_name"]
|
||||||
|
encode_dim = config["encode_dim"]
|
||||||
|
norm_layer = config["norm_layer"]
|
||||||
|
conv_block_dropout = config["conv_block_dropout"]
|
||||||
|
conv_block_num = config["conv_block_num"]
|
||||||
|
conv_block_dilation = config["conv_block_dilation"]
|
||||||
|
if norm_layer == "InstanceNorm2D":
|
||||||
|
use_bias = True
|
||||||
|
else:
|
||||||
|
use_bias = False
|
||||||
|
self.encoder_text = Encoder(
|
||||||
|
name=name + "_encoder_text",
|
||||||
|
in_channels=3,
|
||||||
|
encode_dim=encode_dim,
|
||||||
|
use_bias=use_bias,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
act="ReLU",
|
||||||
|
act_attr=None,
|
||||||
|
conv_block_dropout=conv_block_dropout,
|
||||||
|
conv_block_num=conv_block_num,
|
||||||
|
conv_block_dilation=conv_block_dilation)
|
||||||
|
self.encoder_style = Encoder(
|
||||||
|
name=name + "_encoder_style",
|
||||||
|
in_channels=3,
|
||||||
|
encode_dim=encode_dim,
|
||||||
|
use_bias=use_bias,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
act="ReLU",
|
||||||
|
act_attr=None,
|
||||||
|
conv_block_dropout=conv_block_dropout,
|
||||||
|
conv_block_num=conv_block_num,
|
||||||
|
conv_block_dilation=conv_block_dilation)
|
||||||
|
self.decoder_text = Decoder(
|
||||||
|
name=name + "_decoder_text",
|
||||||
|
encode_dim=encode_dim,
|
||||||
|
out_channels=int(encode_dim / 2),
|
||||||
|
use_bias=use_bias,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
act="ReLU",
|
||||||
|
act_attr=None,
|
||||||
|
conv_block_dropout=conv_block_dropout,
|
||||||
|
conv_block_num=conv_block_num,
|
||||||
|
conv_block_dilation=conv_block_dilation,
|
||||||
|
out_conv_act="Tanh",
|
||||||
|
out_conv_act_attr=None)
|
||||||
|
self.decoder_sk = Decoder(
|
||||||
|
name=name + "_decoder_sk",
|
||||||
|
encode_dim=encode_dim,
|
||||||
|
out_channels=1,
|
||||||
|
use_bias=use_bias,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
act="ReLU",
|
||||||
|
act_attr=None,
|
||||||
|
conv_block_dropout=conv_block_dropout,
|
||||||
|
conv_block_num=conv_block_num,
|
||||||
|
conv_block_dilation=conv_block_dilation,
|
||||||
|
out_conv_act="Sigmoid",
|
||||||
|
out_conv_act_attr=None)
|
||||||
|
|
||||||
|
self.middle = MiddleNet(
|
||||||
|
name=name + "_middle_net",
|
||||||
|
in_channels=int(encode_dim / 2) + 1,
|
||||||
|
mid_channels=encode_dim,
|
||||||
|
out_channels=3,
|
||||||
|
use_bias=use_bias)
|
||||||
|
|
||||||
|
def forward(self, style_input, text_input):
|
||||||
|
style_feature = self.encoder_style.forward(style_input)["res_blocks"]
|
||||||
|
text_feature = self.encoder_text.forward(text_input)["res_blocks"]
|
||||||
|
fake_c_temp = self.decoder_text.forward([text_feature,
|
||||||
|
style_feature])["out_conv"]
|
||||||
|
fake_sk = self.decoder_sk.forward([text_feature,
|
||||||
|
style_feature])["out_conv"]
|
||||||
|
fake_text = self.middle(paddle.concat((fake_c_temp, fake_sk), axis=1))
|
||||||
|
return {"fake_sk": fake_sk, "fake_text": fake_text}
|
||||||
|
|
||||||
|
|
||||||
|
class BgGeneratorWithMask(nn.Layer):
|
||||||
|
def __init__(self, config):
|
||||||
|
super(BgGeneratorWithMask, self).__init__()
|
||||||
|
name = config["module_name"]
|
||||||
|
encode_dim = config["encode_dim"]
|
||||||
|
norm_layer = config["norm_layer"]
|
||||||
|
conv_block_dropout = config["conv_block_dropout"]
|
||||||
|
conv_block_num = config["conv_block_num"]
|
||||||
|
conv_block_dilation = config["conv_block_dilation"]
|
||||||
|
self.output_factor = config.get("output_factor", 1.0)
|
||||||
|
|
||||||
|
if norm_layer == "InstanceNorm2D":
|
||||||
|
use_bias = True
|
||||||
|
else:
|
||||||
|
use_bias = False
|
||||||
|
|
||||||
|
self.encoder_bg = Encoder(
|
||||||
|
name=name + "_encoder_bg",
|
||||||
|
in_channels=3,
|
||||||
|
encode_dim=encode_dim,
|
||||||
|
use_bias=use_bias,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
act="ReLU",
|
||||||
|
act_attr=None,
|
||||||
|
conv_block_dropout=conv_block_dropout,
|
||||||
|
conv_block_num=conv_block_num,
|
||||||
|
conv_block_dilation=conv_block_dilation)
|
||||||
|
|
||||||
|
self.decoder_bg = SingleDecoder(
|
||||||
|
name=name + "_decoder_bg",
|
||||||
|
encode_dim=encode_dim,
|
||||||
|
out_channels=3,
|
||||||
|
use_bias=use_bias,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
act="ReLU",
|
||||||
|
act_attr=None,
|
||||||
|
conv_block_dropout=conv_block_dropout,
|
||||||
|
conv_block_num=conv_block_num,
|
||||||
|
conv_block_dilation=conv_block_dilation,
|
||||||
|
out_conv_act="Tanh",
|
||||||
|
out_conv_act_attr=None)
|
||||||
|
|
||||||
|
self.decoder_mask = Decoder(
|
||||||
|
name=name + "_decoder_mask",
|
||||||
|
encode_dim=encode_dim // 2,
|
||||||
|
out_channels=1,
|
||||||
|
use_bias=use_bias,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
act="ReLU",
|
||||||
|
act_attr=None,
|
||||||
|
conv_block_dropout=conv_block_dropout,
|
||||||
|
conv_block_num=conv_block_num,
|
||||||
|
conv_block_dilation=conv_block_dilation,
|
||||||
|
out_conv_act="Sigmoid",
|
||||||
|
out_conv_act_attr=None)
|
||||||
|
|
||||||
|
self.middle = MiddleNet(
|
||||||
|
name=name + "_middle_net",
|
||||||
|
in_channels=3 + 1,
|
||||||
|
mid_channels=encode_dim,
|
||||||
|
out_channels=3,
|
||||||
|
use_bias=use_bias)
|
||||||
|
|
||||||
|
def forward(self, style_input):
|
||||||
|
encode_bg_output = self.encoder_bg(style_input)
|
||||||
|
decode_bg_output = self.decoder_bg(encode_bg_output["res_blocks"],
|
||||||
|
encode_bg_output["down2"],
|
||||||
|
encode_bg_output["down1"])
|
||||||
|
|
||||||
|
fake_c_temp = decode_bg_output["out_conv"]
|
||||||
|
fake_bg_mask = self.decoder_mask.forward(encode_bg_output[
|
||||||
|
"res_blocks"])["out_conv"]
|
||||||
|
fake_bg = self.middle(
|
||||||
|
paddle.concat(
|
||||||
|
(fake_c_temp, fake_bg_mask), axis=1))
|
||||||
|
return {
|
||||||
|
"bg_encode_feature": encode_bg_output["res_blocks"],
|
||||||
|
"bg_decode_feature1": decode_bg_output["up1"],
|
||||||
|
"bg_decode_feature2": decode_bg_output["up2"],
|
||||||
|
"fake_bg": fake_bg,
|
||||||
|
"fake_bg_mask": fake_bg_mask,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class FusionGeneratorSimple(nn.Layer):
|
||||||
|
def __init__(self, config):
|
||||||
|
super(FusionGeneratorSimple, self).__init__()
|
||||||
|
name = config["module_name"]
|
||||||
|
encode_dim = config["encode_dim"]
|
||||||
|
norm_layer = config["norm_layer"]
|
||||||
|
conv_block_dropout = config["conv_block_dropout"]
|
||||||
|
conv_block_dilation = config["conv_block_dilation"]
|
||||||
|
if norm_layer == "InstanceNorm2D":
|
||||||
|
use_bias = True
|
||||||
|
else:
|
||||||
|
use_bias = False
|
||||||
|
|
||||||
|
self._conv = nn.Conv2D(
|
||||||
|
in_channels=6,
|
||||||
|
out_channels=encode_dim,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1,
|
||||||
|
groups=1,
|
||||||
|
weight_attr=paddle.ParamAttr(name=name + "_conv_weights"),
|
||||||
|
bias_attr=False)
|
||||||
|
|
||||||
|
self._res_block = ResBlock(
|
||||||
|
name="{}_conv_block".format(name),
|
||||||
|
channels=encode_dim,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
use_dropout=conv_block_dropout,
|
||||||
|
use_dilation=conv_block_dilation,
|
||||||
|
use_bias=use_bias)
|
||||||
|
|
||||||
|
self._reduce_conv = nn.Conv2D(
|
||||||
|
in_channels=encode_dim,
|
||||||
|
out_channels=3,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1,
|
||||||
|
groups=1,
|
||||||
|
weight_attr=paddle.ParamAttr(name=name + "_reduce_conv_weights"),
|
||||||
|
bias_attr=False)
|
||||||
|
|
||||||
|
def forward(self, fake_text, fake_bg):
|
||||||
|
fake_concat = paddle.concat((fake_text, fake_bg), axis=1)
|
||||||
|
fake_concat_tmp = self._conv(fake_concat)
|
||||||
|
output_res = self._res_block(fake_concat_tmp)
|
||||||
|
fake_fusion = self._reduce_conv(output_res)
|
||||||
|
return {"fake_fusion": fake_fusion}
|
||||||
54
tools/style_text_rec/configs/config.yml
Normal file
54
tools/style_text_rec/configs/config.yml
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
Global:
|
||||||
|
output_num: 10
|
||||||
|
output_dir: output_data
|
||||||
|
use_gpu: false
|
||||||
|
image_height: 32
|
||||||
|
image_width: 320
|
||||||
|
TextDrawer:
|
||||||
|
fonts:
|
||||||
|
en: fonts/en_standard.ttf
|
||||||
|
ch: fonts/ch_standard.ttf
|
||||||
|
ko: fonts/ko_standard.ttf
|
||||||
|
Predictor:
|
||||||
|
method: StyleTextRecPredictor
|
||||||
|
algorithm: StyleTextRec
|
||||||
|
scale: 0.00392156862745098
|
||||||
|
mean:
|
||||||
|
- 0.5
|
||||||
|
- 0.5
|
||||||
|
- 0.5
|
||||||
|
std:
|
||||||
|
- 0.5
|
||||||
|
- 0.5
|
||||||
|
- 0.5
|
||||||
|
expand_result: false
|
||||||
|
bg_generator:
|
||||||
|
pretrain: style_text_models/bg_generator
|
||||||
|
module_name: bg_generator
|
||||||
|
generator_type: BgGeneratorWithMask
|
||||||
|
encode_dim: 64
|
||||||
|
norm_layer: null
|
||||||
|
conv_block_num: 4
|
||||||
|
conv_block_dropout: false
|
||||||
|
conv_block_dilation: true
|
||||||
|
output_factor: 1.05
|
||||||
|
text_generator:
|
||||||
|
pretrain: style_text_models/text_generator
|
||||||
|
module_name: text_generator
|
||||||
|
generator_type: TextGenerator
|
||||||
|
encode_dim: 64
|
||||||
|
norm_layer: InstanceNorm2D
|
||||||
|
conv_block_num: 4
|
||||||
|
conv_block_dropout: false
|
||||||
|
conv_block_dilation: true
|
||||||
|
fusion_generator:
|
||||||
|
pretrain: style_text_models/fusion_generator
|
||||||
|
module_name: fusion_generator
|
||||||
|
generator_type: FusionGeneratorSimple
|
||||||
|
encode_dim: 64
|
||||||
|
norm_layer: null
|
||||||
|
conv_block_num: 4
|
||||||
|
conv_block_dropout: false
|
||||||
|
conv_block_dilation: true
|
||||||
|
Writer:
|
||||||
|
method: SimpleWriter
|
||||||
64
tools/style_text_rec/configs/dataset_config.yml
Normal file
64
tools/style_text_rec/configs/dataset_config.yml
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
Global:
|
||||||
|
output_num: 10
|
||||||
|
output_dir: output_data
|
||||||
|
use_gpu: false
|
||||||
|
image_height: 32
|
||||||
|
image_width: 320
|
||||||
|
standard_font: fonts/en_standard.ttf
|
||||||
|
TextDrawer:
|
||||||
|
fonts:
|
||||||
|
en: fonts/en_standard.ttf
|
||||||
|
ch: fonts/ch_standard.ttf
|
||||||
|
ko: fonts/ko_standard.ttf
|
||||||
|
StyleSampler:
|
||||||
|
method: DatasetSampler
|
||||||
|
image_home: examples
|
||||||
|
label_file: examples/image_list.txt
|
||||||
|
with_label: true
|
||||||
|
CorpusGenerator:
|
||||||
|
method: FileCorpus
|
||||||
|
language: ch
|
||||||
|
corpus_file: examples/corpus/example.txt
|
||||||
|
Predictor:
|
||||||
|
method: StyleTextRecPredictor
|
||||||
|
algorithm: StyleTextRec
|
||||||
|
scale: 0.00392156862745098
|
||||||
|
mean:
|
||||||
|
- 0.5
|
||||||
|
- 0.5
|
||||||
|
- 0.5
|
||||||
|
std:
|
||||||
|
- 0.5
|
||||||
|
- 0.5
|
||||||
|
- 0.5
|
||||||
|
expand_result: false
|
||||||
|
bg_generator:
|
||||||
|
pretrain: style_text_models/bg_generator
|
||||||
|
module_name: bg_generator
|
||||||
|
generator_type: BgGeneratorWithMask
|
||||||
|
encode_dim: 64
|
||||||
|
norm_layer: null
|
||||||
|
conv_block_num: 4
|
||||||
|
conv_block_dropout: false
|
||||||
|
conv_block_dilation: true
|
||||||
|
output_factor: 1.05
|
||||||
|
text_generator:
|
||||||
|
pretrain: style_text_models/text_generator
|
||||||
|
module_name: text_generator
|
||||||
|
generator_type: TextGenerator
|
||||||
|
encode_dim: 64
|
||||||
|
norm_layer: InstanceNorm2D
|
||||||
|
conv_block_num: 4
|
||||||
|
conv_block_dropout: false
|
||||||
|
conv_block_dilation: true
|
||||||
|
fusion_generator:
|
||||||
|
pretrain: style_text_models/fusion_generator
|
||||||
|
module_name: fusion_generator
|
||||||
|
generator_type: FusionGeneratorSimple
|
||||||
|
encode_dim: 64
|
||||||
|
norm_layer: null
|
||||||
|
conv_block_num: 4
|
||||||
|
conv_block_dropout: false
|
||||||
|
conv_block_dilation: true
|
||||||
|
Writer:
|
||||||
|
method: SimpleWriter
|
||||||
54
tools/style_text_rec/engine/corpus_generators.py
Normal file
54
tools/style_text_rec/engine/corpus_generators.py
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
import random
|
||||||
|
from PIL import Image, ImageDraw, ImageFont
|
||||||
|
import numpy as np
|
||||||
|
from utils.logging import get_logger
|
||||||
|
|
||||||
|
|
||||||
|
class FileCorpus(object):
|
||||||
|
def __init__(self, config):
|
||||||
|
self.logger = get_logger()
|
||||||
|
self.logger.info("using FileCorpus")
|
||||||
|
|
||||||
|
self.char_list = " 0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
|
||||||
|
|
||||||
|
corpus_file = config["CorpusGenerator"]["corpus_file"]
|
||||||
|
self.language = config["CorpusGenerator"]["language"]
|
||||||
|
with open(corpus_file, 'r') as f:
|
||||||
|
corpus_raw = f.read()
|
||||||
|
self.corpus_list = corpus_raw.split("\n")[:-1]
|
||||||
|
assert len(self.corpus_list) > 0
|
||||||
|
random.shuffle(self.corpus_list)
|
||||||
|
self.index = 0
|
||||||
|
|
||||||
|
def generate(self, corpus_length=0):
|
||||||
|
if self.index >= len(self.corpus_list):
|
||||||
|
self.index = 0
|
||||||
|
random.shuffle(self.corpus_list)
|
||||||
|
corpus = self.corpus_list[self.index]
|
||||||
|
if corpus_length != 0:
|
||||||
|
corpus = corpus[0:corpus_length]
|
||||||
|
if corpus_length > len(corpus):
|
||||||
|
self.logger.warning("generated corpus is shorter than expected.")
|
||||||
|
self.index += 1
|
||||||
|
return self.language, corpus
|
||||||
|
|
||||||
|
|
||||||
|
class EnNumCorpus(object):
|
||||||
|
def __init__(self, config):
|
||||||
|
self.logger = get_logger()
|
||||||
|
self.logger.info("using NumberCorpus")
|
||||||
|
self.num_list = "0123456789"
|
||||||
|
self.en_char_list = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
|
||||||
|
self.height = config["Global"]["image_height"]
|
||||||
|
self.max_width = config["Global"]["image_width"]
|
||||||
|
|
||||||
|
def generate(self, corpus_length=0):
|
||||||
|
corpus = ""
|
||||||
|
if corpus_length == 0:
|
||||||
|
corpus_length = random.randint(5, 15)
|
||||||
|
for i in range(corpus_length):
|
||||||
|
if random.random() < 0.2:
|
||||||
|
corpus += "{}".format(random.choice(self.en_char_list))
|
||||||
|
else:
|
||||||
|
corpus += "{}".format(random.choice(self.num_list))
|
||||||
|
return "en", corpus
|
||||||
115
tools/style_text_rec/engine/predictors.py
Normal file
115
tools/style_text_rec/engine/predictors.py
Normal file
@ -0,0 +1,115 @@
|
|||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
import math
|
||||||
|
import paddle
|
||||||
|
|
||||||
|
from arch import style_text_rec
|
||||||
|
from utils.sys_funcs import check_gpu
|
||||||
|
from utils.logging import get_logger
|
||||||
|
|
||||||
|
|
||||||
|
class StyleTextRecPredictor(object):
|
||||||
|
def __init__(self, config):
|
||||||
|
algorithm = config['Predictor']['algorithm']
|
||||||
|
assert algorithm in ["StyleTextRec"
|
||||||
|
], "Generator {} not supported.".format(algorithm)
|
||||||
|
use_gpu = config["Global"]['use_gpu']
|
||||||
|
check_gpu(use_gpu)
|
||||||
|
self.logger = get_logger()
|
||||||
|
self.generator = getattr(style_text_rec, algorithm)(config)
|
||||||
|
self.height = config["Global"]["image_height"]
|
||||||
|
self.width = config["Global"]["image_width"]
|
||||||
|
self.scale = config["Predictor"]["scale"]
|
||||||
|
self.mean = config["Predictor"]["mean"]
|
||||||
|
self.std = config["Predictor"]["std"]
|
||||||
|
self.expand_result = config["Predictor"]["expand_result"]
|
||||||
|
|
||||||
|
def predict(self, style_input, text_input):
|
||||||
|
style_input = self.rep_style_input(style_input, text_input)
|
||||||
|
tensor_style_input = self.preprocess(style_input)
|
||||||
|
tensor_text_input = self.preprocess(text_input)
|
||||||
|
style_text_result = self.generator.forward(tensor_style_input,
|
||||||
|
tensor_text_input)
|
||||||
|
fake_fusion = self.postprocess(style_text_result["fake_fusion"])
|
||||||
|
fake_text = self.postprocess(style_text_result["fake_text"])
|
||||||
|
fake_sk = self.postprocess(style_text_result["fake_sk"])
|
||||||
|
fake_bg = self.postprocess(style_text_result["fake_bg"])
|
||||||
|
bbox = self.get_text_boundary(fake_text)
|
||||||
|
if bbox:
|
||||||
|
left, right, top, bottom = bbox
|
||||||
|
fake_fusion = fake_fusion[top:bottom, left:right, :]
|
||||||
|
fake_text = fake_text[top:bottom, left:right, :]
|
||||||
|
fake_sk = fake_sk[top:bottom, left:right, :]
|
||||||
|
fake_bg = fake_bg[top:bottom, left:right, :]
|
||||||
|
|
||||||
|
# fake_fusion = self.crop_by_text(img_fake_fusion, img_fake_text)
|
||||||
|
return {
|
||||||
|
"fake_fusion": fake_fusion,
|
||||||
|
"fake_text": fake_text,
|
||||||
|
"fake_sk": fake_sk,
|
||||||
|
"fake_bg": fake_bg,
|
||||||
|
}
|
||||||
|
|
||||||
|
def preprocess(self, img):
|
||||||
|
img = (img.astype('float32') * self.scale - self.mean) / self.std
|
||||||
|
img_height, img_width, channel = img.shape
|
||||||
|
assert channel == 3, "Please use an rgb image."
|
||||||
|
ratio = img_width / float(img_height)
|
||||||
|
if math.ceil(self.height * ratio) > self.width:
|
||||||
|
resized_w = self.width
|
||||||
|
else:
|
||||||
|
resized_w = int(math.ceil(self.height * ratio))
|
||||||
|
img = cv2.resize(img, (resized_w, self.height))
|
||||||
|
|
||||||
|
new_img = np.zeros([self.height, self.width, 3]).astype('float32')
|
||||||
|
new_img[:, 0:resized_w, :] = img
|
||||||
|
img = new_img.transpose((2, 0, 1))
|
||||||
|
img = img[np.newaxis, :, :, :]
|
||||||
|
return paddle.to_tensor(img)
|
||||||
|
|
||||||
|
def postprocess(self, tensor):
|
||||||
|
img = tensor.numpy()[0]
|
||||||
|
img = img.transpose((1, 2, 0))
|
||||||
|
img = (img * self.std + self.mean) / self.scale
|
||||||
|
img = np.maximum(img, 0.0)
|
||||||
|
img = np.minimum(img, 255.0)
|
||||||
|
img = img.astype('uint8')
|
||||||
|
return img
|
||||||
|
|
||||||
|
def rep_style_input(self, style_input, text_input):
|
||||||
|
rep_num = int(1.2 * (text_input.shape[1] / text_input.shape[0]) /
|
||||||
|
(style_input.shape[1] / style_input.shape[0])) + 1
|
||||||
|
style_input = np.tile(style_input, reps=[1, rep_num, 1])
|
||||||
|
max_width = int(self.width / self.height * style_input.shape[0])
|
||||||
|
style_input = style_input[:, :max_width, :]
|
||||||
|
return style_input
|
||||||
|
|
||||||
|
def get_text_boundary(self, text_img):
|
||||||
|
img_height = text_img.shape[0]
|
||||||
|
img_width = text_img.shape[1]
|
||||||
|
bounder = 3
|
||||||
|
text_canny_img = cv2.Canny(text_img, 10, 20)
|
||||||
|
edge_num_h = text_canny_img.sum(axis=0)
|
||||||
|
no_zero_list_h = np.where(edge_num_h > 0)[0]
|
||||||
|
edge_num_w = text_canny_img.sum(axis=1)
|
||||||
|
no_zero_list_w = np.where(edge_num_w > 0)[0]
|
||||||
|
if len(no_zero_list_h) == 0 or len(no_zero_list_w) == 0:
|
||||||
|
return None
|
||||||
|
left = max(no_zero_list_h[0] - bounder, 0)
|
||||||
|
right = min(no_zero_list_h[-1] + bounder, img_width)
|
||||||
|
top = max(no_zero_list_w[0] - bounder, 0)
|
||||||
|
bottom = min(no_zero_list_w[-1] + bounder, img_height)
|
||||||
|
return [left, right, top, bottom]
|
||||||
62
tools/style_text_rec/engine/style_samplers.py
Normal file
62
tools/style_text_rec/engine/style_samplers.py
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import numpy as np
|
||||||
|
import random
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetSampler(object):
|
||||||
|
def __init__(self, config):
|
||||||
|
self.image_home = config["StyleSampler"]["image_home"]
|
||||||
|
label_file = config["StyleSampler"]["label_file"]
|
||||||
|
self.dataset_with_label = config["StyleSampler"]["with_label"]
|
||||||
|
self.height = config["Global"]["image_height"]
|
||||||
|
self.index = 0
|
||||||
|
with open(label_file, "r") as f:
|
||||||
|
label_raw = f.read()
|
||||||
|
self.path_label_list = label_raw.split("\n")[:-1]
|
||||||
|
assert len(self.path_label_list) > 0
|
||||||
|
random.shuffle(self.path_label_list)
|
||||||
|
|
||||||
|
def sample(self):
|
||||||
|
if self.index >= len(self.path_label_list):
|
||||||
|
random.shuffle(self.path_label_list)
|
||||||
|
self.index = 0
|
||||||
|
if self.dataset_with_label:
|
||||||
|
path_label = self.path_label_list[self.index]
|
||||||
|
rel_image_path, label = path_label.split('\t')
|
||||||
|
else:
|
||||||
|
rel_image_path = self.path_label_list[self.index]
|
||||||
|
label = None
|
||||||
|
img_path = "{}/{}".format(self.image_home, rel_image_path)
|
||||||
|
image = cv2.imread(img_path)
|
||||||
|
origin_height = image.shape[0]
|
||||||
|
ratio = self.height / origin_height
|
||||||
|
width = int(image.shape[1] * ratio)
|
||||||
|
height = int(image.shape[0] * ratio)
|
||||||
|
image = cv2.resize(image, (width, height))
|
||||||
|
|
||||||
|
self.index += 1
|
||||||
|
if label:
|
||||||
|
return {"image": image, "label": label}
|
||||||
|
else:
|
||||||
|
return {"image": image}
|
||||||
|
|
||||||
|
|
||||||
|
def duplicate_image(image, width):
|
||||||
|
image_width = image.shape[1]
|
||||||
|
dup_num = width // image_width + 1
|
||||||
|
image = np.tile(image, reps=[1, dup_num, 1])
|
||||||
|
cropped_image = image[:, :width, :]
|
||||||
|
return cropped_image
|
||||||
58
tools/style_text_rec/engine/synthesisers.py
Normal file
58
tools/style_text_rec/engine/synthesisers.py
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
from utils.config import ArgsParser, load_config, override_config
|
||||||
|
from utils.logging import get_logger
|
||||||
|
from engine import style_samplers, corpus_generators, text_drawers, predictors, writers
|
||||||
|
|
||||||
|
|
||||||
|
class ImageSynthesiser(object):
|
||||||
|
def __init__(self):
|
||||||
|
self.FLAGS = ArgsParser().parse_args()
|
||||||
|
self.config = load_config(self.FLAGS.config)
|
||||||
|
self.config = override_config(self.config, options=self.FLAGS.override)
|
||||||
|
self.output_dir = self.config["Global"]["output_dir"]
|
||||||
|
if not os.path.exists(self.output_dir):
|
||||||
|
os.mkdir(self.output_dir)
|
||||||
|
self.logger = get_logger(
|
||||||
|
log_file='{}/predict.log'.format(self.output_dir))
|
||||||
|
|
||||||
|
self.text_drawer = text_drawers.StdTextDrawer(self.config)
|
||||||
|
|
||||||
|
predictor_method = self.config["Predictor"]["method"]
|
||||||
|
assert predictor_method is not None
|
||||||
|
self.predictor = getattr(predictors, predictor_method)(self.config)
|
||||||
|
|
||||||
|
def synth_image(self, corpus, style_input, language="en"):
|
||||||
|
corpus, text_input = self.text_drawer.draw_text(corpus, language)
|
||||||
|
synth_result = self.predictor.predict(style_input, text_input)
|
||||||
|
return synth_result
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetSynthesiser(ImageSynthesiser):
|
||||||
|
def __init__(self):
|
||||||
|
super(DatasetSynthesiser, self).__init__()
|
||||||
|
self.tag = self.FLAGS.tag
|
||||||
|
self.output_num = self.config["Global"]["output_num"]
|
||||||
|
corpus_generator_method = self.config["CorpusGenerator"]["method"]
|
||||||
|
self.corpus_generator = getattr(corpus_generators,
|
||||||
|
corpus_generator_method)(self.config)
|
||||||
|
|
||||||
|
style_sampler_method = self.config["StyleSampler"]["method"]
|
||||||
|
assert style_sampler_method is not None
|
||||||
|
self.style_sampler = style_samplers.DatasetSampler(self.config)
|
||||||
|
self.writer = writers.SimpleWriter(self.config, self.tag)
|
||||||
|
|
||||||
|
def synth_dataset(self):
|
||||||
|
for i in range(self.output_num):
|
||||||
|
style_data = self.style_sampler.sample()
|
||||||
|
style_input = style_data["image"]
|
||||||
|
corpus_language, text_input_label = self.corpus_generator.generate(
|
||||||
|
)
|
||||||
|
text_input_label, text_input = self.text_drawer.draw_text(
|
||||||
|
text_input_label, corpus_language)
|
||||||
|
|
||||||
|
synth_result = self.predictor.predict(style_input, text_input)
|
||||||
|
fake_fusion = synth_result["fake_fusion"]
|
||||||
|
self.writer.save_image(fake_fusion, text_input_label)
|
||||||
|
self.writer.save_label()
|
||||||
|
self.writer.merge_label()
|
||||||
58
tools/style_text_rec/engine/text_drawers.py
Normal file
58
tools/style_text_rec/engine/text_drawers.py
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
import random
|
||||||
|
from PIL import Image, ImageDraw, ImageFont
|
||||||
|
import numpy as np
|
||||||
|
from utils.logging import get_logger
|
||||||
|
|
||||||
|
|
||||||
|
class StdTextDrawer(object):
|
||||||
|
def __init__(self, config):
|
||||||
|
self.logger = get_logger()
|
||||||
|
self.max_width = config["Global"]["image_width"]
|
||||||
|
self.char_list = " 0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
|
||||||
|
self.height = config["Global"]["image_height"]
|
||||||
|
self.font_dict = {}
|
||||||
|
self.load_fonts(config["TextDrawer"]["fonts"])
|
||||||
|
self.support_languages = list(self.font_dict)
|
||||||
|
|
||||||
|
def load_fonts(self, fonts_config):
|
||||||
|
for language in fonts_config:
|
||||||
|
font_path = fonts_config[language]
|
||||||
|
font_height = self.get_valid_height(font_path)
|
||||||
|
font = ImageFont.truetype(font_path, font_height)
|
||||||
|
self.font_dict[language] = font
|
||||||
|
|
||||||
|
def get_valid_height(self, font_path):
|
||||||
|
font = ImageFont.truetype(font_path, self.height - 4)
|
||||||
|
_, font_height = font.getsize(self.char_list)
|
||||||
|
if font_height <= self.height - 4:
|
||||||
|
return self.height - 4
|
||||||
|
else:
|
||||||
|
return int((self.height - 4)**2 / font_height)
|
||||||
|
|
||||||
|
def draw_text(self, corpus, language="en", crop=True):
|
||||||
|
if language not in self.support_languages:
|
||||||
|
self.logger.warning(
|
||||||
|
"language {} not supported, use en instead.".format(language))
|
||||||
|
language = "en"
|
||||||
|
if crop:
|
||||||
|
width = min(self.max_width, len(corpus) * self.height) + 4
|
||||||
|
else:
|
||||||
|
width = len(corpus) * self.height + 4
|
||||||
|
bg = Image.new("RGB", (width, self.height), color=(127, 127, 127))
|
||||||
|
draw = ImageDraw.Draw(bg)
|
||||||
|
|
||||||
|
char_x = 2
|
||||||
|
font = self.font_dict[language]
|
||||||
|
for i, char_i in enumerate(corpus):
|
||||||
|
char_size = font.getsize(char_i)[0]
|
||||||
|
draw.text((char_x, 2), char_i, fill=(0, 0, 0), font=font)
|
||||||
|
char_x += char_size
|
||||||
|
if char_x >= width:
|
||||||
|
corpus = corpus[0:i + 1]
|
||||||
|
self.logger.warning("corpus length exceed limit: {}".format(
|
||||||
|
corpus))
|
||||||
|
break
|
||||||
|
|
||||||
|
text_input = np.array(bg).astype(np.uint8)
|
||||||
|
text_input = text_input[:, 0:char_x, :]
|
||||||
|
return corpus, text_input
|
||||||
71
tools/style_text_rec/engine/writers.py
Normal file
71
tools/style_text_rec/engine/writers.py
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import os
|
||||||
|
import cv2
|
||||||
|
import glob
|
||||||
|
|
||||||
|
from utils.logging import get_logger
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleWriter(object):
|
||||||
|
def __init__(self, config, tag):
|
||||||
|
self.logger = get_logger()
|
||||||
|
self.output_dir = config["Global"]["output_dir"]
|
||||||
|
self.counter = 0
|
||||||
|
self.label_dict = {}
|
||||||
|
self.tag = tag
|
||||||
|
self.label_file_index = 0
|
||||||
|
|
||||||
|
def save_image(self, image, text_input_label):
|
||||||
|
image_home = os.path.join(self.output_dir, "images", self.tag)
|
||||||
|
if not os.path.exists(image_home):
|
||||||
|
os.makedirs(image_home)
|
||||||
|
|
||||||
|
image_path = os.path.join(image_home, "{}.png".format(self.counter))
|
||||||
|
# todo support continue synth
|
||||||
|
cv2.imwrite(image_path, image)
|
||||||
|
self.logger.info("generate image: {}".format(image_path))
|
||||||
|
|
||||||
|
image_name = os.path.join(self.tag, "{}.png".format(self.counter))
|
||||||
|
self.label_dict[image_name] = text_input_label
|
||||||
|
|
||||||
|
self.counter += 1
|
||||||
|
if not self.counter % 100:
|
||||||
|
self.save_label()
|
||||||
|
|
||||||
|
def save_label(self):
|
||||||
|
label_raw = ""
|
||||||
|
label_home = os.path.join(self.output_dir, "label")
|
||||||
|
if not os.path.exists(label_home):
|
||||||
|
os.mkdir(label_home)
|
||||||
|
for image_path in self.label_dict:
|
||||||
|
label = self.label_dict[image_path]
|
||||||
|
label_raw += "{}\t{}\n".format(image_path, label)
|
||||||
|
label_file_path = os.path.join(label_home,
|
||||||
|
"{}_label.txt".format(self.tag))
|
||||||
|
with open(label_file_path, "w") as f:
|
||||||
|
f.write(label_raw)
|
||||||
|
self.label_file_index += 1
|
||||||
|
|
||||||
|
def merge_label(self):
|
||||||
|
label_raw = ""
|
||||||
|
label_file_regex = os.path.join(self.output_dir, "label",
|
||||||
|
"*_label.txt")
|
||||||
|
label_file_list = glob.glob(label_file_regex)
|
||||||
|
for label_file_i in label_file_list:
|
||||||
|
with open(label_file_i, "r") as f:
|
||||||
|
label_raw += f.read()
|
||||||
|
label_file_path = os.path.join(self.output_dir, "label.txt")
|
||||||
|
with open(label_file_path, "w") as f:
|
||||||
|
f.write(label_raw)
|
||||||
2
tools/style_text_rec/examples/corpus/example.txt
Normal file
2
tools/style_text_rec/examples/corpus/example.txt
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
PaddleOCR
|
||||||
|
飞桨文字识别
|
||||||
2
tools/style_text_rec/examples/image_list.txt
Normal file
2
tools/style_text_rec/examples/image_list.txt
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
style_images/1.jpg NEATNESS
|
||||||
|
style_images/2.jpg 锁店君和宾馆
|
||||||
BIN
tools/style_text_rec/examples/style_images/1.jpg
Normal file
BIN
tools/style_text_rec/examples/style_images/1.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 2.5 KiB |
BIN
tools/style_text_rec/examples/style_images/2.jpg
Normal file
BIN
tools/style_text_rec/examples/style_images/2.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 3.8 KiB |
BIN
tools/style_text_rec/fonts/ch_standard.ttf
Executable file
BIN
tools/style_text_rec/fonts/ch_standard.ttf
Executable file
Binary file not shown.
BIN
tools/style_text_rec/fonts/en_standard.ttf
Executable file
BIN
tools/style_text_rec/fonts/en_standard.ttf
Executable file
Binary file not shown.
BIN
tools/style_text_rec/fonts/ko_standard.ttf
Executable file
BIN
tools/style_text_rec/fonts/ko_standard.ttf
Executable file
Binary file not shown.
10
tools/style_text_rec/tools/synth_dataset.py
Normal file
10
tools/style_text_rec/tools/synth_dataset.py
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
from engine.synthesisers import DatasetSynthesiser
|
||||||
|
|
||||||
|
|
||||||
|
def synth_dataset():
|
||||||
|
dataset_synthesiser = DatasetSynthesiser()
|
||||||
|
dataset_synthesiser.synth_dataset()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
synth_dataset()
|
||||||
78
tools/style_text_rec/tools/synth_image.py
Normal file
78
tools/style_text_rec/tools/synth_image.py
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import os
|
||||||
|
import cv2
|
||||||
|
import sys
|
||||||
|
import glob
|
||||||
|
|
||||||
|
from engine.synthesisers import ImageSynthesiser
|
||||||
|
|
||||||
|
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
sys.path.append(__dir__)
|
||||||
|
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
|
||||||
|
|
||||||
|
|
||||||
|
def synth_image():
|
||||||
|
image_synthesiser = ImageSynthesiser()
|
||||||
|
img = cv2.imread("examples/style_images/1.jpg")
|
||||||
|
corpus = "PaddleOCR"
|
||||||
|
language = "en"
|
||||||
|
synth_result = image_synthesiser.synth_image(corpus, img, language)
|
||||||
|
fake_fusion = synth_result["fake_fusion"]
|
||||||
|
fake_text = synth_result["fake_text"]
|
||||||
|
fake_bg = synth_result["fake_bg"]
|
||||||
|
cv2.imwrite("fake_fusion.jpg", fake_fusion)
|
||||||
|
cv2.imwrite("fake_text.jpg", fake_text)
|
||||||
|
cv2.imwrite("fake_bg.jpg", fake_bg)
|
||||||
|
|
||||||
|
|
||||||
|
def batch_synth_images():
|
||||||
|
image_synthesiser = ImageSynthesiser()
|
||||||
|
|
||||||
|
corpus_file = "../StyleTextRec_data/test_20201208/test_text_list.txt"
|
||||||
|
style_data_dir = "../StyleTextRec_data/test_20201208/style_images/"
|
||||||
|
save_path = "./output_data/"
|
||||||
|
corpus_list = []
|
||||||
|
with open(corpus_file, "rb") as fin:
|
||||||
|
lines = fin.readlines()
|
||||||
|
for line in lines:
|
||||||
|
substr = line.decode("utf-8").strip("\n").split("\t")
|
||||||
|
corpus_list.append(substr)
|
||||||
|
style_img_list = glob.glob("{}/*.jpg".format(style_data_dir))
|
||||||
|
corpus_num = len(corpus_list)
|
||||||
|
style_img_num = len(style_img_list)
|
||||||
|
for cno in range(corpus_num):
|
||||||
|
for sno in range(style_img_num):
|
||||||
|
corpus, lang = corpus_list[cno]
|
||||||
|
style_img_path = style_img_list[sno]
|
||||||
|
img = cv2.imread(style_img_path)
|
||||||
|
synth_result = image_synthesiser.synth_image(corpus, img, lang)
|
||||||
|
fake_fusion = synth_result["fake_fusion"]
|
||||||
|
fake_text = synth_result["fake_text"]
|
||||||
|
fake_bg = synth_result["fake_bg"]
|
||||||
|
for tp in range(2):
|
||||||
|
if tp == 0:
|
||||||
|
prefix = "%s/c%d_s%d_" % (save_path, cno, sno)
|
||||||
|
else:
|
||||||
|
prefix = "%s/s%d_c%d_" % (save_path, sno, cno)
|
||||||
|
cv2.imwrite("%s_fake_fusion.jpg" % prefix, fake_fusion)
|
||||||
|
cv2.imwrite("%s_fake_text.jpg" % prefix, fake_text)
|
||||||
|
cv2.imwrite("%s_fake_bg.jpg" % prefix, fake_bg)
|
||||||
|
cv2.imwrite("%s_input_style.jpg" % prefix, img)
|
||||||
|
print(cno, corpus_num, sno, style_img_num)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# batch_synth_images()
|
||||||
|
synth_image()
|
||||||
219
tools/style_text_rec/utils/config.py
Normal file
219
tools/style_text_rec/utils/config.py
Normal file
@ -0,0 +1,219 @@
|
|||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import yaml
|
||||||
|
import os
|
||||||
|
from collections import OrderedDict
|
||||||
|
from argparse import ArgumentParser, RawDescriptionHelpFormatter
|
||||||
|
|
||||||
|
|
||||||
|
def override(dl, ks, v):
|
||||||
|
"""
|
||||||
|
Recursively replace dict of list
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dl(dict or list): dict or list to be replaced
|
||||||
|
ks(list): list of keys
|
||||||
|
v(str): value to be replaced
|
||||||
|
"""
|
||||||
|
|
||||||
|
def str2num(v):
|
||||||
|
try:
|
||||||
|
return eval(v)
|
||||||
|
except Exception:
|
||||||
|
return v
|
||||||
|
|
||||||
|
assert isinstance(dl, (list, dict)), ("{} should be a list or a dict")
|
||||||
|
assert len(ks) > 0, ('lenght of keys should larger than 0')
|
||||||
|
if isinstance(dl, list):
|
||||||
|
k = str2num(ks[0])
|
||||||
|
if len(ks) == 1:
|
||||||
|
assert k < len(dl), ('index({}) out of range({})'.format(k, dl))
|
||||||
|
dl[k] = str2num(v)
|
||||||
|
else:
|
||||||
|
override(dl[k], ks[1:], v)
|
||||||
|
else:
|
||||||
|
if len(ks) == 1:
|
||||||
|
#assert ks[0] in dl, ('{} is not exist in {}'.format(ks[0], dl))
|
||||||
|
if not ks[0] in dl:
|
||||||
|
logger.warning('A new filed ({}) detected!'.format(ks[0], dl))
|
||||||
|
dl[ks[0]] = str2num(v)
|
||||||
|
else:
|
||||||
|
assert ks[0] in dl, (
|
||||||
|
'({}) doesn\'t exist in {}, a new dict field is invalid'.
|
||||||
|
format(ks[0], dl))
|
||||||
|
override(dl[ks[0]], ks[1:], v)
|
||||||
|
|
||||||
|
|
||||||
|
def override_config(config, options=None):
|
||||||
|
"""
|
||||||
|
Recursively override the config
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config(dict): dict to be replaced
|
||||||
|
options(list): list of pairs(key0.key1.idx.key2=value)
|
||||||
|
such as: [
|
||||||
|
'topk=2',
|
||||||
|
'VALID.transforms.1.ResizeImage.resize_short=300'
|
||||||
|
]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
config(dict): replaced config
|
||||||
|
"""
|
||||||
|
if options is not None:
|
||||||
|
for opt in options:
|
||||||
|
assert isinstance(opt, str), (
|
||||||
|
"option({}) should be a str".format(opt))
|
||||||
|
assert "=" in opt, (
|
||||||
|
"option({}) should contain a ="
|
||||||
|
"to distinguish between key and value".format(opt))
|
||||||
|
pair = opt.split('=')
|
||||||
|
assert len(pair) == 2, ("there can be only a = in the option")
|
||||||
|
key, value = pair
|
||||||
|
keys = key.split('.')
|
||||||
|
override(config, keys, value)
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
class ArgsParser(ArgumentParser):
|
||||||
|
def __init__(self):
|
||||||
|
super(ArgsParser, self).__init__(
|
||||||
|
formatter_class=RawDescriptionHelpFormatter)
|
||||||
|
self.add_argument("-c", "--config", help="configuration file to use")
|
||||||
|
self.add_argument(
|
||||||
|
"-t", "--tag", default="0", help="tag for marking worker")
|
||||||
|
self.add_argument(
|
||||||
|
'-o',
|
||||||
|
'--override',
|
||||||
|
action='append',
|
||||||
|
default=[],
|
||||||
|
help='config options to be overridden')
|
||||||
|
|
||||||
|
def parse_args(self, argv=None):
|
||||||
|
args = super(ArgsParser, self).parse_args(argv)
|
||||||
|
assert args.config is not None, \
|
||||||
|
"Please specify --config=configure_file_path."
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def load_config(file_path):
|
||||||
|
"""
|
||||||
|
Load config from yml/yaml file.
|
||||||
|
Args:
|
||||||
|
file_path (str): Path of the config file to be loaded.
|
||||||
|
Returns: config
|
||||||
|
"""
|
||||||
|
ext = os.path.splitext(file_path)[1]
|
||||||
|
assert ext in ['.yml', '.yaml'], "only support yaml files for now"
|
||||||
|
with open(file_path, 'rb') as f:
|
||||||
|
config = yaml.load(f, Loader=yaml.Loader)
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def gen_config():
|
||||||
|
base_config = {
|
||||||
|
"Global": {
|
||||||
|
"algorithm": "SRNet",
|
||||||
|
"use_gpu": True,
|
||||||
|
"start_epoch": 1,
|
||||||
|
"stage1_epoch_num": 100,
|
||||||
|
"stage2_epoch_num": 100,
|
||||||
|
"log_smooth_window": 20,
|
||||||
|
"print_batch_step": 2,
|
||||||
|
"save_model_dir": "./output/SRNet",
|
||||||
|
"use_visualdl": False,
|
||||||
|
"save_epoch_step": 10,
|
||||||
|
"vgg_pretrain": "./pretrained/VGG19_pretrained",
|
||||||
|
"vgg_load_static_pretrain": True
|
||||||
|
},
|
||||||
|
"Architecture": {
|
||||||
|
"model_type": "data_aug",
|
||||||
|
"algorithm": "SRNet",
|
||||||
|
"net_g": {
|
||||||
|
"name": "srnet_net_g",
|
||||||
|
"encode_dim": 64,
|
||||||
|
"norm": "batch",
|
||||||
|
"use_dropout": False,
|
||||||
|
"init_type": "xavier",
|
||||||
|
"init_gain": 0.02,
|
||||||
|
"use_dilation": 1
|
||||||
|
},
|
||||||
|
# input_nc, ndf, netD,
|
||||||
|
# n_layers_D=3, norm='instance', use_sigmoid=False, init_type='normal', init_gain=0.02, gpu_id='cuda:0'
|
||||||
|
"bg_discriminator": {
|
||||||
|
"name": "srnet_bg_discriminator",
|
||||||
|
"input_nc": 6,
|
||||||
|
"ndf": 64,
|
||||||
|
"netD": "basic",
|
||||||
|
"norm": "none",
|
||||||
|
"init_type": "xavier",
|
||||||
|
},
|
||||||
|
"fusion_discriminator": {
|
||||||
|
"name": "srnet_fusion_discriminator",
|
||||||
|
"input_nc": 6,
|
||||||
|
"ndf": 64,
|
||||||
|
"netD": "basic",
|
||||||
|
"norm": "none",
|
||||||
|
"init_type": "xavier",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"Loss": {
|
||||||
|
"lamb": 10,
|
||||||
|
"perceptual_lamb": 1,
|
||||||
|
"muvar_lamb": 50,
|
||||||
|
"style_lamb": 500
|
||||||
|
},
|
||||||
|
"Optimizer": {
|
||||||
|
"name": "Adam",
|
||||||
|
"learning_rate": {
|
||||||
|
"name": "lambda",
|
||||||
|
"lr": 0.0002,
|
||||||
|
"lr_decay_iters": 50
|
||||||
|
},
|
||||||
|
"beta1": 0.5,
|
||||||
|
"beta2": 0.999,
|
||||||
|
},
|
||||||
|
"Train": {
|
||||||
|
"batch_size_per_card": 8,
|
||||||
|
"num_workers_per_card": 4,
|
||||||
|
"dataset": {
|
||||||
|
"delimiter": "\t",
|
||||||
|
"data_dir": "/",
|
||||||
|
"label_file": "tmp/label.txt",
|
||||||
|
"transforms": [{
|
||||||
|
"DecodeImage": {
|
||||||
|
"to_rgb": True,
|
||||||
|
"to_np": False,
|
||||||
|
"channel_first": False
|
||||||
|
}
|
||||||
|
}, {
|
||||||
|
"NormalizeImage": {
|
||||||
|
"scale": 1. / 255.,
|
||||||
|
"mean": [0.485, 0.456, 0.406],
|
||||||
|
"std": [0.229, 0.224, 0.225],
|
||||||
|
"order": None
|
||||||
|
}
|
||||||
|
}, {
|
||||||
|
"ToCHWImage": None
|
||||||
|
}]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
with open("config.yml", "w") as f:
|
||||||
|
yaml.dump(base_config, f)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
gen_config()
|
||||||
33
tools/style_text_rec/utils/load_params.py
Normal file
33
tools/style_text_rec/utils/load_params.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
|
||||||
|
__all__ = ['load_dygraph_pretrain']
|
||||||
|
|
||||||
|
|
||||||
|
def load_dygraph_pretrain(model, logger, path=None, load_static_weights=False):
|
||||||
|
if not os.path.exists(path + '.pdparams'):
|
||||||
|
raise ValueError("Model pretrain path {} does not "
|
||||||
|
"exists.".format(path))
|
||||||
|
param_state_dict = paddle.load(path + '.pdparams')
|
||||||
|
model.set_state_dict(param_state_dict)
|
||||||
|
logger.info("load pretrained model from {}".format(path))
|
||||||
|
return
|
||||||
66
tools/style_text_rec/utils/logging.py
Normal file
66
tools/style_text_rec/utils/logging.py
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import logging
|
||||||
|
import functools
|
||||||
|
import paddle.distributed as dist
|
||||||
|
|
||||||
|
logger_initialized = {}
|
||||||
|
|
||||||
|
|
||||||
|
@functools.lru_cache()
|
||||||
|
def get_logger(name='srnet', log_file=None, log_level=logging.INFO):
|
||||||
|
"""Initialize and get a logger by name.
|
||||||
|
If the logger has not been initialized, this method will initialize the
|
||||||
|
logger by adding one or two handlers, otherwise the initialized logger will
|
||||||
|
be directly returned. During initialization, a StreamHandler will always be
|
||||||
|
added. If `log_file` is specified a FileHandler will also be added.
|
||||||
|
Args:
|
||||||
|
name (str): Logger name.
|
||||||
|
log_file (str | None): The log filename. If specified, a FileHandler
|
||||||
|
will be added to the logger.
|
||||||
|
log_level (int): The logger level. Note that only the process of
|
||||||
|
rank 0 is affected, and other processes will set the level to
|
||||||
|
"Error" thus be silent most of the time.
|
||||||
|
Returns:
|
||||||
|
logging.Logger: The expected logger.
|
||||||
|
"""
|
||||||
|
logger = logging.getLogger(name)
|
||||||
|
if name in logger_initialized:
|
||||||
|
return logger
|
||||||
|
for logger_name in logger_initialized:
|
||||||
|
if name.startswith(logger_name):
|
||||||
|
return logger
|
||||||
|
|
||||||
|
formatter = logging.Formatter(
|
||||||
|
'[%(asctime)s] %(name)s %(levelname)s: %(message)s',
|
||||||
|
datefmt="%Y/%m/%d %H:%M:%S")
|
||||||
|
|
||||||
|
stream_handler = logging.StreamHandler(stream=sys.stdout)
|
||||||
|
stream_handler.setFormatter(formatter)
|
||||||
|
logger.addHandler(stream_handler)
|
||||||
|
if log_file is not None and dist.get_rank() == 0:
|
||||||
|
log_file_folder = os.path.split(log_file)[0]
|
||||||
|
os.makedirs(log_file_folder, exist_ok=True)
|
||||||
|
file_handler = logging.FileHandler(log_file, 'a')
|
||||||
|
file_handler.setFormatter(formatter)
|
||||||
|
logger.addHandler(file_handler)
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
logger.setLevel(log_level)
|
||||||
|
else:
|
||||||
|
logger.setLevel(logging.ERROR)
|
||||||
|
logger_initialized[name] = True
|
||||||
|
return logger
|
||||||
45
tools/style_text_rec/utils/math_functions.py
Normal file
45
tools/style_text_rec/utils/math_functions.py
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import paddle
|
||||||
|
|
||||||
|
|
||||||
|
def compute_mean_covariance(img):
|
||||||
|
batch_size = img.shape[0]
|
||||||
|
channel_num = img.shape[1]
|
||||||
|
height = img.shape[2]
|
||||||
|
width = img.shape[3]
|
||||||
|
num_pixels = height * width
|
||||||
|
|
||||||
|
# batch_size * channel_num * 1 * 1
|
||||||
|
mu = img.mean(2, keepdim=True).mean(3, keepdim=True)
|
||||||
|
|
||||||
|
# batch_size * channel_num * num_pixels
|
||||||
|
img_hat = img - mu.expand_as(img)
|
||||||
|
img_hat = img_hat.reshape([batch_size, channel_num, num_pixels])
|
||||||
|
# batch_size * num_pixels * channel_num
|
||||||
|
img_hat_transpose = img_hat.transpose([0, 2, 1])
|
||||||
|
# batch_size * channel_num * channel_num
|
||||||
|
covariance = paddle.bmm(img_hat, img_hat_transpose)
|
||||||
|
covariance = covariance / num_pixels
|
||||||
|
|
||||||
|
return mu, covariance
|
||||||
|
|
||||||
|
|
||||||
|
def dice_coefficient(y_true_cls, y_pred_cls, training_mask):
|
||||||
|
eps = 1e-5
|
||||||
|
intersection = paddle.sum(y_true_cls * y_pred_cls * training_mask)
|
||||||
|
union = paddle.sum(y_true_cls * training_mask) + paddle.sum(
|
||||||
|
y_pred_cls * training_mask) + eps
|
||||||
|
loss = 1. - (2 * intersection / union)
|
||||||
|
return loss
|
||||||
67
tools/style_text_rec/utils/sys_funcs.py
Normal file
67
tools/style_text_rec/utils/sys_funcs.py
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import errno
|
||||||
|
import paddle
|
||||||
|
|
||||||
|
|
||||||
|
def get_check_global_params(mode):
|
||||||
|
check_params = [
|
||||||
|
'use_gpu', 'max_text_length', 'image_shape', 'image_shape',
|
||||||
|
'character_type', 'loss_type'
|
||||||
|
]
|
||||||
|
if mode == "train_eval":
|
||||||
|
check_params = check_params + [
|
||||||
|
'train_batch_size_per_card', 'test_batch_size_per_card'
|
||||||
|
]
|
||||||
|
elif mode == "test":
|
||||||
|
check_params = check_params + ['test_batch_size_per_card']
|
||||||
|
return check_params
|
||||||
|
|
||||||
|
|
||||||
|
def check_gpu(use_gpu):
|
||||||
|
"""
|
||||||
|
Log error and exit when set use_gpu=true in paddlepaddle
|
||||||
|
cpu version.
|
||||||
|
"""
|
||||||
|
err = "Config use_gpu cannot be set as true while you are " \
|
||||||
|
"using paddlepaddle cpu version ! \nPlease try: \n" \
|
||||||
|
"\t1. Install paddlepaddle-gpu to run model on GPU \n" \
|
||||||
|
"\t2. Set use_gpu as false in config file to run " \
|
||||||
|
"model on CPU"
|
||||||
|
if use_gpu:
|
||||||
|
try:
|
||||||
|
if not paddle.is_compiled_with_cuda():
|
||||||
|
print(err)
|
||||||
|
sys.exit(1)
|
||||||
|
except:
|
||||||
|
print("Fail to check gpu state.")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
def _mkdir_if_not_exist(path, logger):
|
||||||
|
"""
|
||||||
|
mkdir if not exists, ignore the exception when multiprocess mkdir together
|
||||||
|
"""
|
||||||
|
if not os.path.exists(path):
|
||||||
|
try:
|
||||||
|
os.makedirs(path)
|
||||||
|
except OSError as e:
|
||||||
|
if e.errno == errno.EEXIST and os.path.isdir(path):
|
||||||
|
logger.warning(
|
||||||
|
'be happy if some process has already created {}'.format(
|
||||||
|
path))
|
||||||
|
else:
|
||||||
|
raise OSError('Failed to mkdir {}'.format(path))
|
||||||
Loading…
x
Reference in New Issue
Block a user