| 
									
										
										
										
											2023-02-08 15:52:30 +08:00
										 |  |  | # -*- coding: utf-8 -*- | 
					
						
							|  |  |  | # @Time    : 2019/12/4 18:06 | 
					
						
							|  |  |  | # @Author  : zhoujun | 
					
						
							|  |  |  | import numpy as np | 
					
						
							|  |  |  | import imgaug | 
					
						
							|  |  |  | import imgaug.augmenters as iaa | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class AugmenterBuilder(object): | 
					
						
							|  |  |  |     def __init__(self): | 
					
						
							|  |  |  |         pass | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def build(self, args, root=True): | 
					
						
							|  |  |  |         if args is None or len(args) == 0: | 
					
						
							|  |  |  |             return None | 
					
						
							|  |  |  |         elif isinstance(args, list): | 
					
						
							|  |  |  |             if root: | 
					
						
							|  |  |  |                 sequence = [self.build(value, root=False) for value in args] | 
					
						
							|  |  |  |                 return iaa.Sequential(sequence) | 
					
						
							|  |  |  |             else: | 
					
						
							| 
									
										
										
										
											2024-04-21 21:46:20 +08:00
										 |  |  |                 return getattr(iaa, args[0])( | 
					
						
							|  |  |  |                     *[self.to_tuple_if_list(a) for a in args[1:]] | 
					
						
							|  |  |  |                 ) | 
					
						
							| 
									
										
										
										
											2023-02-08 15:52:30 +08:00
										 |  |  |         elif isinstance(args, dict): | 
					
						
							| 
									
										
										
										
											2024-04-21 21:46:20 +08:00
										 |  |  |             cls = getattr(iaa, args["type"]) | 
					
						
							|  |  |  |             return cls(**{k: self.to_tuple_if_list(v) for k, v in args["args"].items()}) | 
					
						
							| 
									
										
										
										
											2023-02-08 15:52:30 +08:00
										 |  |  |         else: | 
					
						
							| 
									
										
										
										
											2024-04-21 21:46:20 +08:00
										 |  |  |             raise RuntimeError("unknown augmenter arg: " + str(args)) | 
					
						
							| 
									
										
										
										
											2023-02-08 15:52:30 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def to_tuple_if_list(self, obj): | 
					
						
							|  |  |  |         if isinstance(obj, list): | 
					
						
							|  |  |  |             return tuple(obj) | 
					
						
							|  |  |  |         return obj | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-21 21:46:20 +08:00
										 |  |  | class IaaAugment: | 
					
						
							| 
									
										
										
										
											2023-02-08 15:52:30 +08:00
										 |  |  |     def __init__(self, augmenter_args): | 
					
						
							|  |  |  |         self.augmenter_args = augmenter_args | 
					
						
							|  |  |  |         self.augmenter = AugmenterBuilder().build(self.augmenter_args) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __call__(self, data): | 
					
						
							| 
									
										
										
										
											2024-04-21 21:46:20 +08:00
										 |  |  |         image = data["img"] | 
					
						
							| 
									
										
										
										
											2023-02-08 15:52:30 +08:00
										 |  |  |         shape = image.shape | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if self.augmenter: | 
					
						
							|  |  |  |             aug = self.augmenter.to_deterministic() | 
					
						
							| 
									
										
										
										
											2024-04-21 21:46:20 +08:00
										 |  |  |             data["img"] = aug.augment_image(image) | 
					
						
							| 
									
										
										
										
											2023-02-08 15:52:30 +08:00
										 |  |  |             data = self.may_augment_annotation(aug, data, shape) | 
					
						
							|  |  |  |         return data | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def may_augment_annotation(self, aug, data, shape): | 
					
						
							|  |  |  |         if aug is None: | 
					
						
							|  |  |  |             return data | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         line_polys = [] | 
					
						
							| 
									
										
										
										
											2024-04-21 21:46:20 +08:00
										 |  |  |         for poly in data["text_polys"]: | 
					
						
							| 
									
										
										
										
											2023-02-08 15:52:30 +08:00
										 |  |  |             new_poly = self.may_augment_poly(aug, shape, poly) | 
					
						
							|  |  |  |             line_polys.append(new_poly) | 
					
						
							| 
									
										
										
										
											2024-04-21 21:46:20 +08:00
										 |  |  |         data["text_polys"] = np.array(line_polys) | 
					
						
							| 
									
										
										
										
											2023-02-08 15:52:30 +08:00
										 |  |  |         return data | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def may_augment_poly(self, aug, img_shape, poly): | 
					
						
							|  |  |  |         keypoints = [imgaug.Keypoint(p[0], p[1]) for p in poly] | 
					
						
							|  |  |  |         keypoints = aug.augment_keypoints( | 
					
						
							| 
									
										
										
										
											2024-04-21 21:46:20 +08:00
										 |  |  |             [imgaug.KeypointsOnImage(keypoints, shape=img_shape)] | 
					
						
							|  |  |  |         )[0].keypoints | 
					
						
							| 
									
										
										
										
											2023-02-08 15:52:30 +08:00
										 |  |  |         poly = [(p.x, p.y) for p in keypoints] | 
					
						
							|  |  |  |         return poly |