mirror of
				https://github.com/PaddlePaddle/PaddleOCR.git
				synced 2025-11-03 19:29:18 +00:00 
			
		
		
		
	Merge pull request #4547 from andyjpaddle/add_ref_for_sar
add refer to backbone and head of sar
This commit is contained in:
		
						commit
						aade74b03e
					
				@ -1,3 +1,22 @@
 | 
			
		||||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
"""
 | 
			
		||||
This code is refer from: 
 | 
			
		||||
https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textrecog/layers/conv_layer.py
 | 
			
		||||
https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textrecog/backbones/resnet31_ocr.py
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
from __future__ import absolute_import
 | 
			
		||||
from __future__ import division
 | 
			
		||||
from __future__ import print_function
 | 
			
		||||
@ -18,12 +37,12 @@ def conv3x3(in_channel, out_channel, stride=1):
 | 
			
		||||
        kernel_size=3,
 | 
			
		||||
        stride=stride,
 | 
			
		||||
        padding=1,
 | 
			
		||||
        bias_attr=False
 | 
			
		||||
    )
 | 
			
		||||
        bias_attr=False)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class BasicBlock(nn.Layer):
 | 
			
		||||
    expansion = 1
 | 
			
		||||
 | 
			
		||||
    def __init__(self, in_channels, channels, stride=1, downsample=False):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.conv1 = conv3x3(in_channels, channels, stride)
 | 
			
		||||
@ -34,9 +53,13 @@ class BasicBlock(nn.Layer):
 | 
			
		||||
        self.downsample = downsample
 | 
			
		||||
        if downsample:
 | 
			
		||||
            self.downsample = nn.Sequential(
 | 
			
		||||
                nn.Conv2D(in_channels, channels * self.expansion, 1, stride, bias_attr=False),
 | 
			
		||||
                nn.BatchNorm2D(channels * self.expansion),
 | 
			
		||||
            )
 | 
			
		||||
                nn.Conv2D(
 | 
			
		||||
                    in_channels,
 | 
			
		||||
                    channels * self.expansion,
 | 
			
		||||
                    1,
 | 
			
		||||
                    stride,
 | 
			
		||||
                    bias_attr=False),
 | 
			
		||||
                nn.BatchNorm2D(channels * self.expansion), )
 | 
			
		||||
        else:
 | 
			
		||||
            self.downsample = nn.Sequential()
 | 
			
		||||
        self.stride = stride
 | 
			
		||||
@ -57,7 +80,7 @@ class BasicBlock(nn.Layer):
 | 
			
		||||
        out += residual
 | 
			
		||||
        out = self.relu(out)
 | 
			
		||||
 | 
			
		||||
        return out        
 | 
			
		||||
        return out
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ResNet31(nn.Layer):
 | 
			
		||||
@ -69,12 +92,13 @@ class ResNet31(nn.Layer):
 | 
			
		||||
        out_indices (None | Sequence[int]): Indices of output stages.
 | 
			
		||||
        last_stage_pool (bool): If True, add `MaxPool2d` layer to last stage.
 | 
			
		||||
    '''
 | 
			
		||||
    def __init__(self, 
 | 
			
		||||
                in_channels=3, 
 | 
			
		||||
                layers=[1, 2, 5, 3],
 | 
			
		||||
                channels=[64, 128, 256, 256, 512, 512, 512],
 | 
			
		||||
                out_indices=None,
 | 
			
		||||
                last_stage_pool=False):
 | 
			
		||||
 | 
			
		||||
    def __init__(self,
 | 
			
		||||
                 in_channels=3,
 | 
			
		||||
                 layers=[1, 2, 5, 3],
 | 
			
		||||
                 channels=[64, 128, 256, 256, 512, 512, 512],
 | 
			
		||||
                 out_indices=None,
 | 
			
		||||
                 last_stage_pool=False):
 | 
			
		||||
        super(ResNet31, self).__init__()
 | 
			
		||||
        assert isinstance(in_channels, int)
 | 
			
		||||
        assert isinstance(last_stage_pool, bool)
 | 
			
		||||
@ -83,46 +107,56 @@ class ResNet31(nn.Layer):
 | 
			
		||||
        self.last_stage_pool = last_stage_pool
 | 
			
		||||
 | 
			
		||||
        # conv 1 (Conv Conv)
 | 
			
		||||
        self.conv1_1 = nn.Conv2D(in_channels, channels[0], kernel_size=3, stride=1, padding=1)
 | 
			
		||||
        self.conv1_1 = nn.Conv2D(
 | 
			
		||||
            in_channels, channels[0], kernel_size=3, stride=1, padding=1)
 | 
			
		||||
        self.bn1_1 = nn.BatchNorm2D(channels[0])
 | 
			
		||||
        self.relu1_1 = nn.ReLU()
 | 
			
		||||
 | 
			
		||||
        self.conv1_2 = nn.Conv2D(channels[0], channels[1], kernel_size=3, stride=1, padding=1)
 | 
			
		||||
        self.conv1_2 = nn.Conv2D(
 | 
			
		||||
            channels[0], channels[1], kernel_size=3, stride=1, padding=1)
 | 
			
		||||
        self.bn1_2 = nn.BatchNorm2D(channels[1])
 | 
			
		||||
        self.relu1_2 = nn.ReLU()
 | 
			
		||||
 | 
			
		||||
        # conv 2 (Max-pooling, Residual block, Conv)
 | 
			
		||||
        self.pool2 = nn.MaxPool2D(kernel_size=2, stride=2, padding=0, ceil_mode=True)
 | 
			
		||||
        self.pool2 = nn.MaxPool2D(
 | 
			
		||||
            kernel_size=2, stride=2, padding=0, ceil_mode=True)
 | 
			
		||||
        self.block2 = self._make_layer(channels[1], channels[2], layers[0])
 | 
			
		||||
        self.conv2 = nn.Conv2D(channels[2], channels[2], kernel_size=3, stride=1, padding=1)
 | 
			
		||||
        self.conv2 = nn.Conv2D(
 | 
			
		||||
            channels[2], channels[2], kernel_size=3, stride=1, padding=1)
 | 
			
		||||
        self.bn2 = nn.BatchNorm2D(channels[2])
 | 
			
		||||
        self.relu2 = nn.ReLU()
 | 
			
		||||
 | 
			
		||||
        # conv 3 (Max-pooling, Residual block, Conv)
 | 
			
		||||
        self.pool3 = nn.MaxPool2D(kernel_size=2, stride=2, padding=0, ceil_mode=True)
 | 
			
		||||
        self.pool3 = nn.MaxPool2D(
 | 
			
		||||
            kernel_size=2, stride=2, padding=0, ceil_mode=True)
 | 
			
		||||
        self.block3 = self._make_layer(channels[2], channels[3], layers[1])
 | 
			
		||||
        self.conv3 = nn.Conv2D(channels[3], channels[3], kernel_size=3, stride=1, padding=1)
 | 
			
		||||
        self.conv3 = nn.Conv2D(
 | 
			
		||||
            channels[3], channels[3], kernel_size=3, stride=1, padding=1)
 | 
			
		||||
        self.bn3 = nn.BatchNorm2D(channels[3])
 | 
			
		||||
        self.relu3 = nn.ReLU()
 | 
			
		||||
 | 
			
		||||
        # conv 4 (Max-pooling, Residual block, Conv)
 | 
			
		||||
        self.pool4 = nn.MaxPool2D(kernel_size=(2, 1), stride=(2, 1), padding=0, ceil_mode=True)
 | 
			
		||||
        self.pool4 = nn.MaxPool2D(
 | 
			
		||||
            kernel_size=(2, 1), stride=(2, 1), padding=0, ceil_mode=True)
 | 
			
		||||
        self.block4 = self._make_layer(channels[3], channels[4], layers[2])
 | 
			
		||||
        self.conv4 = nn.Conv2D(channels[4], channels[4], kernel_size=3, stride=1, padding=1)
 | 
			
		||||
        self.conv4 = nn.Conv2D(
 | 
			
		||||
            channels[4], channels[4], kernel_size=3, stride=1, padding=1)
 | 
			
		||||
        self.bn4 = nn.BatchNorm2D(channels[4])
 | 
			
		||||
        self.relu4 = nn.ReLU()
 | 
			
		||||
 | 
			
		||||
        # conv 5 ((Max-pooling), Residual block, Conv)
 | 
			
		||||
        self.pool5 = None
 | 
			
		||||
        if self.last_stage_pool:
 | 
			
		||||
            self.pool5 = nn.MaxPool2D(kernel_size=2, stride=2, padding=0, ceil_mode=True)
 | 
			
		||||
            self.pool5 = nn.MaxPool2D(
 | 
			
		||||
                kernel_size=2, stride=2, padding=0, ceil_mode=True)
 | 
			
		||||
        self.block5 = self._make_layer(channels[4], channels[5], layers[3])
 | 
			
		||||
        self.conv5 = nn.Conv2D(channels[5], channels[5], kernel_size=3, stride=1, padding=1)
 | 
			
		||||
        self.conv5 = nn.Conv2D(
 | 
			
		||||
            channels[5], channels[5], kernel_size=3, stride=1, padding=1)
 | 
			
		||||
        self.bn5 = nn.BatchNorm2D(channels[5])
 | 
			
		||||
        self.relu5 = nn.ReLU()
 | 
			
		||||
 | 
			
		||||
        self.out_channels = channels[-1]
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
    def _make_layer(self, input_channels, output_channels, blocks):
 | 
			
		||||
        layers = []
 | 
			
		||||
        for _ in range(blocks):
 | 
			
		||||
@ -130,19 +164,19 @@ class ResNet31(nn.Layer):
 | 
			
		||||
            if input_channels != output_channels:
 | 
			
		||||
                downsample = nn.Sequential(
 | 
			
		||||
                    nn.Conv2D(
 | 
			
		||||
                        input_channels, 
 | 
			
		||||
                        output_channels, 
 | 
			
		||||
                        kernel_size=1, 
 | 
			
		||||
                        stride=1, 
 | 
			
		||||
                        input_channels,
 | 
			
		||||
                        output_channels,
 | 
			
		||||
                        kernel_size=1,
 | 
			
		||||
                        stride=1,
 | 
			
		||||
                        bias_attr=False),
 | 
			
		||||
                    nn.BatchNorm2D(output_channels),
 | 
			
		||||
                )
 | 
			
		||||
                
 | 
			
		||||
            layers.append(BasicBlock(input_channels, output_channels, downsample=downsample))
 | 
			
		||||
                    nn.BatchNorm2D(output_channels), )
 | 
			
		||||
 | 
			
		||||
            layers.append(
 | 
			
		||||
                BasicBlock(
 | 
			
		||||
                    input_channels, output_channels, downsample=downsample))
 | 
			
		||||
            input_channels = output_channels
 | 
			
		||||
        return nn.Sequential(*layers)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    def forward(self, x):
 | 
			
		||||
        x = self.conv1_1(x)
 | 
			
		||||
        x = self.bn1_1(x)
 | 
			
		||||
@ -166,11 +200,11 @@ class ResNet31(nn.Layer):
 | 
			
		||||
            x = block_layer(x)
 | 
			
		||||
            x = conv_layer(x)
 | 
			
		||||
            x = bn_layer(x)
 | 
			
		||||
            x= relu_layer(x)
 | 
			
		||||
            x = relu_layer(x)
 | 
			
		||||
 | 
			
		||||
            outs.append(x)
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
        if self.out_indices is not None:
 | 
			
		||||
            return tuple([outs[i] for i in self.out_indices])
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
        return x
 | 
			
		||||
 | 
			
		||||
@ -1,3 +1,22 @@
 | 
			
		||||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
"""
 | 
			
		||||
This code is refer from: 
 | 
			
		||||
https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textrecog/encoders/sar_encoder.py
 | 
			
		||||
https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textrecog/decoders/sar_decoder.py
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
from __future__ import absolute_import
 | 
			
		||||
from __future__ import division
 | 
			
		||||
from __future__ import print_function
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user