From ea29b423d5daa6e33b9baf971a3984386e4925a1 Mon Sep 17 00:00:00 2001 From: Wang Xin Date: Mon, 27 May 2024 11:12:05 +0800 Subject: [PATCH] fix and enable flake8 E721 (#12258) --- .pre-commit-config.yaml | 2 +- ppocr/losses/distillation_loss.py | 20 ++++++++++---------- tools/program.py | 2 +- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a94ef86abd..3c26460ba0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -43,7 +43,7 @@ repos: - id: flake8 args: - --count - - --select=E9,F63,F7,F82 + - --select=E9,F63,F7,F82,E721 - --show-source - --statistics exclude: ^benchmark/|^test_tipc/ diff --git a/ppocr/losses/distillation_loss.py b/ppocr/losses/distillation_loss.py index cd66bca89d..98c9c546a5 100644 --- a/ppocr/losses/distillation_loss.py +++ b/ppocr/losses/distillation_loss.py @@ -78,9 +78,9 @@ class DistillationDMLLoss(DMLLoss): def _check_maps_name(self, maps_name): if maps_name is None: return None - elif type(maps_name) == str: + elif isinstance(maps_name, str): return [maps_name] - elif type(maps_name) == list: + elif isinstance(maps_name, list): return [maps_name] else: return None @@ -174,9 +174,9 @@ class DistillationKLDivLoss(KLDivLoss): def _check_maps_name(self, maps_name): if maps_name is None: return None - elif type(maps_name) == str: + elif isinstance(maps_name, str): return [maps_name] - elif type(maps_name) == list: + elif isinstance(maps_name, list): return [maps_name] else: return None @@ -282,9 +282,9 @@ class DistillationDKDLoss(DKDLoss): def _check_maps_name(self, maps_name): if maps_name is None: return None - elif type(maps_name) == str: + elif isinstance(maps_name, str): return [maps_name] - elif type(maps_name) == list: + elif isinstance(maps_name, list): return [maps_name] else: return None @@ -428,9 +428,9 @@ class DistillationKLDivLoss(KLDivLoss): def _check_maps_name(self, maps_name): if maps_name is None: return None - elif type(maps_name) == str: + elif isinstance(maps_name, str): return [maps_name] - elif type(maps_name) == list: + elif isinstance(maps_name, list): return [maps_name] else: return None @@ -536,9 +536,9 @@ class DistillationDKDLoss(DKDLoss): def _check_maps_name(self, maps_name): if maps_name is None: return None - elif type(maps_name) == str: + elif isinstance(maps_name, str): return [maps_name] - elif type(maps_name) == list: + elif isinstance(maps_name, list): return [maps_name] else: return None diff --git a/tools/program.py b/tools/program.py index d9aa068209..b2f3dbf107 100755 --- a/tools/program.py +++ b/tools/program.py @@ -209,7 +209,7 @@ def train( if "global_step" in pre_best_model_dict: global_step = pre_best_model_dict["global_step"] start_eval_step = 0 - if type(eval_batch_step) == list and len(eval_batch_step) >= 2: + if isinstance(eval_batch_step, list) and len(eval_batch_step) >= 2: start_eval_step = eval_batch_step[0] if not eval_batch_epoch else 0 eval_batch_step = ( eval_batch_step[1]