| 
									
										
										
										
											2022-01-05 11:03:45 +00:00
										 |  |  | # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | # Licensed under the Apache License, Version 2.0 (the "License"); | 
					
						
							|  |  |  | # you may not use this file except in compliance with the License. | 
					
						
							|  |  |  | # You may obtain a copy of the License at | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | #    http://www.apache.org/licenses/LICENSE-2.0 | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | # Unless required by applicable law or agreed to in writing, software | 
					
						
							|  |  |  | # distributed under the License is distributed on an "AS IS" BASIS, | 
					
						
							|  |  |  | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
					
						
							|  |  |  | # See the License for the specific language governing permissions and | 
					
						
							|  |  |  | # limitations under the License. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from __future__ import absolute_import | 
					
						
							|  |  |  | from __future__ import division | 
					
						
							|  |  |  | from __future__ import print_function | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import os | 
					
						
							|  |  |  | from paddle import nn | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from paddlenlp.transformers import LayoutXLMModel, LayoutXLMForTokenClassification, LayoutXLMForRelationExtraction | 
					
						
							|  |  |  | from paddlenlp.transformers import LayoutLMModel, LayoutLMForTokenClassification | 
					
						
							| 
									
										
										
										
											2022-02-12 07:17:38 +00:00
										 |  |  | from paddlenlp.transformers import LayoutLMv2Model, LayoutLMv2ForTokenClassification, LayoutLMv2ForRelationExtraction | 
					
						
							| 
									
										
										
										
											2022-08-06 15:41:20 +08:00
										 |  |  | from paddlenlp.transformers import AutoModel | 
					
						
							| 
									
										
										
										
											2022-01-05 11:03:45 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-08-06 15:41:20 +08:00
										 |  |  | __all__ = ["LayoutXLMForSer", "LayoutLMForSer"] | 
					
						
							| 
									
										
										
										
											2022-01-05 11:03:45 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-01-06 03:35:30 +00:00
										 |  |  | pretrained_model_dict = { | 
					
						
							| 
									
										
										
										
											2022-08-06 15:41:20 +08:00
										 |  |  |     LayoutXLMModel: { | 
					
						
							|  |  |  |         "base": "layoutxlm-base-uncased", | 
					
						
							| 
									
										
										
										
											2022-08-24 05:53:42 +00:00
										 |  |  |         "vi": "vi-layoutxlm-base-uncased", | 
					
						
							| 
									
										
										
										
											2022-08-06 15:41:20 +08:00
										 |  |  |     }, | 
					
						
							|  |  |  |     LayoutLMModel: { | 
					
						
							|  |  |  |         "base": "layoutlm-base-uncased", | 
					
						
							|  |  |  |     }, | 
					
						
							|  |  |  |     LayoutLMv2Model: { | 
					
						
							|  |  |  |         "base": "layoutlmv2-base-uncased", | 
					
						
							| 
									
										
										
										
											2022-08-24 05:55:06 +00:00
										 |  |  |         "vi": "vi-layoutlmv2-base-uncased", | 
					
						
							| 
									
										
										
										
											2022-08-06 15:41:20 +08:00
										 |  |  |     }, | 
					
						
							| 
									
										
										
										
											2022-01-06 03:35:30 +00:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-01-05 11:03:45 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  | class NLPBaseModel(nn.Layer): | 
					
						
							|  |  |  |     def __init__(self, | 
					
						
							|  |  |  |                  base_model_class, | 
					
						
							|  |  |  |                  model_class, | 
					
						
							| 
									
										
										
										
											2022-08-06 15:41:20 +08:00
										 |  |  |                  mode="base", | 
					
						
							|  |  |  |                  type="ser", | 
					
						
							| 
									
										
										
										
											2022-01-06 03:35:30 +00:00
										 |  |  |                  pretrained=True, | 
					
						
							| 
									
										
										
										
											2022-01-05 11:03:45 +00:00
										 |  |  |                  checkpoints=None, | 
					
						
							|  |  |  |                  **kwargs): | 
					
						
							|  |  |  |         super(NLPBaseModel, self).__init__() | 
					
						
							| 
									
										
										
										
											2022-08-06 15:41:20 +08:00
										 |  |  |         if checkpoints is not None:  # load the trained model | 
					
						
							| 
									
										
										
										
											2022-01-05 11:03:45 +00:00
										 |  |  |             self.model = model_class.from_pretrained(checkpoints) | 
					
						
							| 
									
										
										
										
											2022-08-06 15:41:20 +08:00
										 |  |  |         else:  # load the pretrained-model | 
					
						
							|  |  |  |             pretrained_model_name = pretrained_model_dict[base_model_class][ | 
					
						
							|  |  |  |                 mode] | 
					
						
							| 
									
										
										
										
											2022-07-06 13:58:46 +08:00
										 |  |  |             if pretrained is True: | 
					
						
							| 
									
										
										
										
											2022-01-06 03:35:30 +00:00
										 |  |  |                 base_model = base_model_class.from_pretrained( | 
					
						
							|  |  |  |                     pretrained_model_name) | 
					
						
							|  |  |  |             else: | 
					
						
							| 
									
										
										
										
											2022-08-06 15:41:20 +08:00
										 |  |  |                 base_model = base_model_class.from_pretrained(pretrained) | 
					
						
							|  |  |  |             if type == "ser": | 
					
						
							| 
									
										
										
										
											2022-01-05 11:03:45 +00:00
										 |  |  |                 self.model = model_class( | 
					
						
							| 
									
										
										
										
											2022-08-06 15:41:20 +08:00
										 |  |  |                     base_model, num_classes=kwargs["num_classes"], dropout=None) | 
					
						
							| 
									
										
										
										
											2022-01-05 11:03:45 +00:00
										 |  |  |             else: | 
					
						
							|  |  |  |                 self.model = model_class(base_model, dropout=None) | 
					
						
							|  |  |  |         self.out_channels = 1 | 
					
						
							| 
									
										
										
										
											2022-08-06 15:41:20 +08:00
										 |  |  |         self.use_visual_backbone = True | 
					
						
							| 
									
										
										
										
											2022-01-05 11:03:45 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-02-12 07:17:38 +00:00
										 |  |  | class LayoutLMForSer(NLPBaseModel): | 
					
						
							| 
									
										
										
										
											2022-08-06 15:41:20 +08:00
										 |  |  |     def __init__(self, | 
					
						
							|  |  |  |                  num_classes, | 
					
						
							|  |  |  |                  pretrained=True, | 
					
						
							|  |  |  |                  checkpoints=None, | 
					
						
							|  |  |  |                  mode="base", | 
					
						
							| 
									
										
										
										
											2022-01-05 11:03:45 +00:00
										 |  |  |                  **kwargs): | 
					
						
							| 
									
										
										
										
											2022-02-12 07:17:38 +00:00
										 |  |  |         super(LayoutLMForSer, self).__init__( | 
					
						
							|  |  |  |             LayoutLMModel, | 
					
						
							|  |  |  |             LayoutLMForTokenClassification, | 
					
						
							| 
									
										
										
										
											2022-08-06 15:41:20 +08:00
										 |  |  |             mode, | 
					
						
							|  |  |  |             "ser", | 
					
						
							| 
									
										
										
										
											2022-02-12 07:17:38 +00:00
										 |  |  |             pretrained, | 
					
						
							|  |  |  |             checkpoints, | 
					
						
							| 
									
										
										
										
											2022-08-06 15:41:20 +08:00
										 |  |  |             num_classes=num_classes, ) | 
					
						
							|  |  |  |         self.use_visual_backbone = False | 
					
						
							| 
									
										
										
										
											2022-02-12 07:17:38 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def forward(self, x): | 
					
						
							|  |  |  |         x = self.model( | 
					
						
							|  |  |  |             input_ids=x[0], | 
					
						
							| 
									
										
										
										
											2022-07-01 08:52:08 +00:00
										 |  |  |             bbox=x[1], | 
					
						
							|  |  |  |             attention_mask=x[2], | 
					
						
							|  |  |  |             token_type_ids=x[3], | 
					
						
							| 
									
										
										
										
											2022-02-12 07:17:38 +00:00
										 |  |  |             position_ids=None, | 
					
						
							|  |  |  |             output_hidden_states=False) | 
					
						
							|  |  |  |         return x | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class LayoutLMv2ForSer(NLPBaseModel): | 
					
						
							| 
									
										
										
										
											2022-08-06 15:41:20 +08:00
										 |  |  |     def __init__(self, | 
					
						
							|  |  |  |                  num_classes, | 
					
						
							|  |  |  |                  pretrained=True, | 
					
						
							|  |  |  |                  checkpoints=None, | 
					
						
							|  |  |  |                  mode="base", | 
					
						
							| 
									
										
										
										
											2022-02-12 07:17:38 +00:00
										 |  |  |                  **kwargs): | 
					
						
							|  |  |  |         super(LayoutLMv2ForSer, self).__init__( | 
					
						
							|  |  |  |             LayoutLMv2Model, | 
					
						
							|  |  |  |             LayoutLMv2ForTokenClassification, | 
					
						
							| 
									
										
										
										
											2022-08-06 15:41:20 +08:00
										 |  |  |             mode, | 
					
						
							|  |  |  |             "ser", | 
					
						
							| 
									
										
										
										
											2022-01-06 03:35:30 +00:00
										 |  |  |             pretrained, | 
					
						
							| 
									
										
										
										
											2022-01-05 11:03:45 +00:00
										 |  |  |             checkpoints, | 
					
						
							|  |  |  |             num_classes=num_classes) | 
					
						
							| 
									
										
										
										
											2022-08-06 15:41:20 +08:00
										 |  |  |         if hasattr(self.model.layoutlmv2, "use_visual_backbone" | 
					
						
							|  |  |  |                    ) and self.model.layoutlmv2.use_visual_backbone is False: | 
					
						
							|  |  |  |             self.use_visual_backbone = False | 
					
						
							| 
									
										
										
										
											2022-01-05 11:03:45 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def forward(self, x): | 
					
						
							| 
									
										
										
										
											2022-08-06 15:41:20 +08:00
										 |  |  |         if self.use_visual_backbone is True: | 
					
						
							|  |  |  |             image = x[4] | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             image = None | 
					
						
							| 
									
										
										
										
											2022-01-05 11:03:45 +00:00
										 |  |  |         x = self.model( | 
					
						
							|  |  |  |             input_ids=x[0], | 
					
						
							| 
									
										
										
										
											2022-07-01 08:52:08 +00:00
										 |  |  |             bbox=x[1], | 
					
						
							|  |  |  |             attention_mask=x[2], | 
					
						
							|  |  |  |             token_type_ids=x[3], | 
					
						
							| 
									
										
										
										
											2022-08-06 15:41:20 +08:00
										 |  |  |             image=image, | 
					
						
							| 
									
										
										
										
											2022-01-05 11:03:45 +00:00
										 |  |  |             position_ids=None, | 
					
						
							|  |  |  |             head_mask=None, | 
					
						
							|  |  |  |             labels=None) | 
					
						
							| 
									
										
										
										
											2022-08-06 15:41:20 +08:00
										 |  |  |         if self.training: | 
					
						
							|  |  |  |             res = {"backbone_out": x[0]} | 
					
						
							|  |  |  |             res.update(x[1]) | 
					
						
							|  |  |  |             return res | 
					
						
							|  |  |  |         else: | 
					
						
							| 
									
										
										
										
											2022-07-01 08:52:08 +00:00
										 |  |  |             return x | 
					
						
							| 
									
										
										
										
											2022-01-05 11:03:45 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-02-12 07:17:38 +00:00
										 |  |  | class LayoutXLMForSer(NLPBaseModel): | 
					
						
							| 
									
										
										
										
											2022-08-06 15:41:20 +08:00
										 |  |  |     def __init__(self, | 
					
						
							|  |  |  |                  num_classes, | 
					
						
							|  |  |  |                  pretrained=True, | 
					
						
							|  |  |  |                  checkpoints=None, | 
					
						
							|  |  |  |                  mode="base", | 
					
						
							| 
									
										
										
										
											2022-01-05 11:03:45 +00:00
										 |  |  |                  **kwargs): | 
					
						
							| 
									
										
										
										
											2022-02-12 07:17:38 +00:00
										 |  |  |         super(LayoutXLMForSer, self).__init__( | 
					
						
							|  |  |  |             LayoutXLMModel, | 
					
						
							|  |  |  |             LayoutXLMForTokenClassification, | 
					
						
							| 
									
										
										
										
											2022-08-06 15:41:20 +08:00
										 |  |  |             mode, | 
					
						
							|  |  |  |             "ser", | 
					
						
							| 
									
										
										
										
											2022-01-06 03:35:30 +00:00
										 |  |  |             pretrained, | 
					
						
							| 
									
										
										
										
											2022-01-05 11:03:45 +00:00
										 |  |  |             checkpoints, | 
					
						
							|  |  |  |             num_classes=num_classes) | 
					
						
							| 
									
										
										
										
											2022-08-15 11:39:11 +08:00
										 |  |  |         if hasattr(self.model.layoutxlm, "use_visual_backbone" | 
					
						
							|  |  |  |                    ) and self.model.layoutxlm.use_visual_backbone is False: | 
					
						
							|  |  |  |             self.use_visual_backbone = False | 
					
						
							| 
									
										
										
										
											2022-01-05 11:03:45 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def forward(self, x): | 
					
						
							| 
									
										
										
										
											2022-08-06 15:41:20 +08:00
										 |  |  |         if self.use_visual_backbone is True: | 
					
						
							|  |  |  |             image = x[4] | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             image = None | 
					
						
							| 
									
										
										
										
											2022-01-05 11:03:45 +00:00
										 |  |  |         x = self.model( | 
					
						
							| 
									
										
										
										
											2022-07-01 09:59:11 +00:00
										 |  |  |             input_ids=x[0], | 
					
						
							|  |  |  |             bbox=x[1], | 
					
						
							|  |  |  |             attention_mask=x[2], | 
					
						
							|  |  |  |             token_type_ids=x[3], | 
					
						
							| 
									
										
										
										
											2022-08-06 15:41:20 +08:00
										 |  |  |             image=image, | 
					
						
							| 
									
										
										
										
											2022-07-01 09:59:11 +00:00
										 |  |  |             position_ids=None, | 
					
						
							|  |  |  |             head_mask=None, | 
					
						
							|  |  |  |             labels=None) | 
					
						
							| 
									
										
										
										
											2022-08-06 15:41:20 +08:00
										 |  |  |         if self.training: | 
					
						
							|  |  |  |             res = {"backbone_out": x[0]} | 
					
						
							|  |  |  |             res.update(x[1]) | 
					
						
							|  |  |  |             return res | 
					
						
							|  |  |  |         else: | 
					
						
							| 
									
										
										
										
											2022-07-01 08:52:08 +00:00
										 |  |  |             return x | 
					
						
							| 
									
										
										
										
											2022-02-12 07:17:38 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class LayoutLMv2ForRe(NLPBaseModel): | 
					
						
							| 
									
										
										
										
											2022-08-06 15:41:20 +08:00
										 |  |  |     def __init__(self, pretrained=True, checkpoints=None, mode="base", | 
					
						
							|  |  |  |                  **kwargs): | 
					
						
							|  |  |  |         super(LayoutLMv2ForRe, self).__init__( | 
					
						
							|  |  |  |             LayoutLMv2Model, LayoutLMv2ForRelationExtraction, mode, "re", | 
					
						
							|  |  |  |             pretrained, checkpoints) | 
					
						
							| 
									
										
										
										
											2022-08-15 11:39:11 +08:00
										 |  |  |         if hasattr(self.model.layoutlmv2, "use_visual_backbone" | 
					
						
							|  |  |  |                    ) and self.model.layoutlmv2.use_visual_backbone is False: | 
					
						
							|  |  |  |             self.use_visual_backbone = False | 
					
						
							| 
									
										
										
										
											2022-02-12 07:17:38 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def forward(self, x): | 
					
						
							|  |  |  |         x = self.model( | 
					
						
							|  |  |  |             input_ids=x[0], | 
					
						
							|  |  |  |             bbox=x[1], | 
					
						
							| 
									
										
										
										
											2022-07-01 08:52:08 +00:00
										 |  |  |             attention_mask=x[2], | 
					
						
							|  |  |  |             token_type_ids=x[3], | 
					
						
							|  |  |  |             image=x[4], | 
					
						
							| 
									
										
										
										
											2022-02-12 07:17:38 +00:00
										 |  |  |             position_ids=None, | 
					
						
							|  |  |  |             head_mask=None, | 
					
						
							| 
									
										
										
										
											2022-07-01 08:52:08 +00:00
										 |  |  |             labels=None, | 
					
						
							| 
									
										
										
										
											2022-02-12 07:17:38 +00:00
										 |  |  |             entities=x[5], | 
					
						
							|  |  |  |             relations=x[6]) | 
					
						
							| 
									
										
										
										
											2022-01-05 11:03:45 +00:00
										 |  |  |         return x | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class LayoutXLMForRe(NLPBaseModel): | 
					
						
							| 
									
										
										
										
											2022-08-06 15:41:20 +08:00
										 |  |  |     def __init__(self, pretrained=True, checkpoints=None, mode="base", | 
					
						
							|  |  |  |                  **kwargs): | 
					
						
							|  |  |  |         super(LayoutXLMForRe, self).__init__( | 
					
						
							|  |  |  |             LayoutXLMModel, LayoutXLMForRelationExtraction, mode, "re", | 
					
						
							|  |  |  |             pretrained, checkpoints) | 
					
						
							|  |  |  |         if hasattr(self.model.layoutxlm, "use_visual_backbone" | 
					
						
							|  |  |  |                    ) and self.model.layoutxlm.use_visual_backbone is False: | 
					
						
							|  |  |  |             self.use_visual_backbone = False | 
					
						
							| 
									
										
										
										
											2022-01-05 11:03:45 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def forward(self, x): | 
					
						
							| 
									
										
										
										
											2022-08-06 15:41:20 +08:00
										 |  |  |         if self.use_visual_backbone is True: | 
					
						
							|  |  |  |             image = x[4] | 
					
						
							| 
									
										
										
										
											2022-09-20 22:13:27 +08:00
										 |  |  |             entities = x[5] | 
					
						
							|  |  |  |             relations = x[6] | 
					
						
							| 
									
										
										
										
											2022-08-06 15:41:20 +08:00
										 |  |  |         else: | 
					
						
							|  |  |  |             image = None | 
					
						
							| 
									
										
										
										
											2022-09-20 22:13:27 +08:00
										 |  |  |             entities = x[4] | 
					
						
							|  |  |  |             relations = x[5] | 
					
						
							| 
									
										
										
										
											2022-01-05 11:03:45 +00:00
										 |  |  |         x = self.model( | 
					
						
							|  |  |  |             input_ids=x[0], | 
					
						
							|  |  |  |             bbox=x[1], | 
					
						
							| 
									
										
										
										
											2022-07-01 08:52:08 +00:00
										 |  |  |             attention_mask=x[2], | 
					
						
							|  |  |  |             token_type_ids=x[3], | 
					
						
							| 
									
										
										
										
											2022-08-06 15:41:20 +08:00
										 |  |  |             image=image, | 
					
						
							| 
									
										
										
										
											2022-01-05 11:03:45 +00:00
										 |  |  |             position_ids=None, | 
					
						
							|  |  |  |             head_mask=None, | 
					
						
							| 
									
										
										
										
											2022-07-01 08:52:08 +00:00
										 |  |  |             labels=None, | 
					
						
							| 
									
										
										
										
											2022-09-20 22:13:27 +08:00
										 |  |  |             entities=entities, | 
					
						
							|  |  |  |             relations=relations) | 
					
						
							| 
									
										
										
										
											2022-01-05 11:03:45 +00:00
										 |  |  |         return x |