mirror of
https://github.com/PaddlePaddle/PaddleOCR.git
synced 2025-12-29 07:58:41 +00:00
fix ConvBn
This commit is contained in:
parent
b6cd60476c
commit
6658f5b23b
@ -27,7 +27,80 @@ __dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(__dir__)
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../../..')))
|
||||
|
||||
from ppocr.modeling.backbones.det_mobilenet_v3 import SEModule, ConvBNLayer
|
||||
from ppocr.modeling.backbones.det_mobilenet_v3 import SEModule
|
||||
|
||||
|
||||
class ConvBNLayer(nn.Layer):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
padding,
|
||||
stride=1,
|
||||
groups=None,
|
||||
if_act=True,
|
||||
act="relu"):
|
||||
super(ConvBNLayer, self).__init__()
|
||||
if groups == None:
|
||||
groups = in_channels
|
||||
self.if_act = if_act
|
||||
self.act = act
|
||||
self.conv1 = nn.Conv2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
groups=groups,
|
||||
bias_attr=False)
|
||||
|
||||
self.bn1 = nn.BatchNorm(num_channels=in_channels, act=None)
|
||||
|
||||
self.conv2 = nn.Conv2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=int(in_channels * 4),
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
bias_attr=False)
|
||||
|
||||
self.bn2 = nn.BatchNorm(num_channels=int(in_channels * 4), act=None)
|
||||
|
||||
self.conv3 = nn.Conv2D(
|
||||
in_channels=int(in_channels * 4),
|
||||
out_channels=out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
bias_attr=False)
|
||||
self._c = [in_channels, out_channels]
|
||||
if in_channels != out_channels:
|
||||
self.conv_end = nn.Conv2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
bias_attr=False)
|
||||
|
||||
def forward(self, inputs):
|
||||
|
||||
x = self.conv1(inputs)
|
||||
x = self.bn1(x)
|
||||
|
||||
x = self.conv2(x)
|
||||
x = self.bn2(x)
|
||||
if self.if_act:
|
||||
if self.act == "relu":
|
||||
x = F.relu(x)
|
||||
elif self.act == "hardswish":
|
||||
x = F.hardswish(x)
|
||||
else:
|
||||
print("The activation function({}) is selected incorrectly.".
|
||||
format(self.act))
|
||||
exit()
|
||||
|
||||
x = self.conv3(x)
|
||||
if self._c[0] != self._c[1]:
|
||||
x = x + self.conv_end(inputs)
|
||||
return x
|
||||
|
||||
|
||||
class DBFPN(nn.Layer):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user