diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a94ef86abd..3c26460ba0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -43,7 +43,7 @@ repos: - id: flake8 args: - --count - - --select=E9,F63,F7,F82 + - --select=E9,F63,F7,F82,E721 - --show-source - --statistics exclude: ^benchmark/|^test_tipc/ diff --git a/ppocr/losses/distillation_loss.py b/ppocr/losses/distillation_loss.py index cd66bca89d..98c9c546a5 100644 --- a/ppocr/losses/distillation_loss.py +++ b/ppocr/losses/distillation_loss.py @@ -78,9 +78,9 @@ class DistillationDMLLoss(DMLLoss): def _check_maps_name(self, maps_name): if maps_name is None: return None - elif type(maps_name) == str: + elif isinstance(maps_name, str): return [maps_name] - elif type(maps_name) == list: + elif isinstance(maps_name, list): return [maps_name] else: return None @@ -174,9 +174,9 @@ class DistillationKLDivLoss(KLDivLoss): def _check_maps_name(self, maps_name): if maps_name is None: return None - elif type(maps_name) == str: + elif isinstance(maps_name, str): return [maps_name] - elif type(maps_name) == list: + elif isinstance(maps_name, list): return [maps_name] else: return None @@ -282,9 +282,9 @@ class DistillationDKDLoss(DKDLoss): def _check_maps_name(self, maps_name): if maps_name is None: return None - elif type(maps_name) == str: + elif isinstance(maps_name, str): return [maps_name] - elif type(maps_name) == list: + elif isinstance(maps_name, list): return [maps_name] else: return None @@ -428,9 +428,9 @@ class DistillationKLDivLoss(KLDivLoss): def _check_maps_name(self, maps_name): if maps_name is None: return None - elif type(maps_name) == str: + elif isinstance(maps_name, str): return [maps_name] - elif type(maps_name) == list: + elif isinstance(maps_name, list): return [maps_name] else: return None @@ -536,9 +536,9 @@ class DistillationDKDLoss(DKDLoss): def _check_maps_name(self, maps_name): if maps_name is None: return None - elif type(maps_name) == str: + elif isinstance(maps_name, str): return [maps_name] - elif type(maps_name) == list: + elif isinstance(maps_name, list): return [maps_name] else: return None diff --git a/tools/program.py b/tools/program.py index d9aa068209..b2f3dbf107 100755 --- a/tools/program.py +++ b/tools/program.py @@ -209,7 +209,7 @@ def train( if "global_step" in pre_best_model_dict: global_step = pre_best_model_dict["global_step"] start_eval_step = 0 - if type(eval_batch_step) == list and len(eval_batch_step) >= 2: + if isinstance(eval_batch_step, list) and len(eval_batch_step) >= 2: start_eval_step = eval_batch_step[0] if not eval_batch_epoch else 0 eval_batch_step = ( eval_batch_step[1]