mirror of
https://github.com/PaddlePaddle/PaddleOCR.git
synced 2025-12-28 15:38:18 +00:00
add eps
This commit is contained in:
parent
7c04ff55c7
commit
60a1408d4e
@ -16,6 +16,7 @@
|
||||
class ClsMetric(object):
|
||||
def __init__(self, main_indicator='acc', **kwargs):
|
||||
self.main_indicator = main_indicator
|
||||
self.eps = 1e-5
|
||||
self.reset()
|
||||
|
||||
def __call__(self, pred_label, *args, **kwargs):
|
||||
@ -28,7 +29,7 @@ class ClsMetric(object):
|
||||
all_num += 1
|
||||
self.correct_num += correct_num
|
||||
self.all_num += all_num
|
||||
return {'acc': correct_num / all_num, }
|
||||
return {'acc': correct_num / (all_num + self.eps), }
|
||||
|
||||
def get_metric(self):
|
||||
"""
|
||||
@ -36,7 +37,7 @@ class ClsMetric(object):
|
||||
'acc': 0
|
||||
}
|
||||
"""
|
||||
acc = self.correct_num / self.all_num
|
||||
acc = self.correct_num / (self.all_num + self.eps)
|
||||
self.reset()
|
||||
return {'acc': acc}
|
||||
|
||||
|
||||
@ -20,6 +20,7 @@ class RecMetric(object):
|
||||
def __init__(self, main_indicator='acc', is_filter=False, **kwargs):
|
||||
self.main_indicator = main_indicator
|
||||
self.is_filter = is_filter
|
||||
self.eps = 1e-5
|
||||
self.reset()
|
||||
|
||||
def _normalize_text(self, text):
|
||||
@ -47,8 +48,8 @@ class RecMetric(object):
|
||||
self.all_num += all_num
|
||||
self.norm_edit_dis += norm_edit_dis
|
||||
return {
|
||||
'acc': correct_num / all_num,
|
||||
'norm_edit_dis': 1 - norm_edit_dis / (all_num + 1e-3)
|
||||
'acc': correct_num / (all_num + self.eps),
|
||||
'norm_edit_dis': 1 - norm_edit_dis / (all_num + self.eps)
|
||||
}
|
||||
|
||||
def get_metric(self):
|
||||
@ -58,8 +59,8 @@ class RecMetric(object):
|
||||
'norm_edit_dis': 0,
|
||||
}
|
||||
"""
|
||||
acc = 1.0 * self.correct_num / (self.all_num + 1e-3)
|
||||
norm_edit_dis = 1 - self.norm_edit_dis / (self.all_num + 1e-3)
|
||||
acc = 1.0 * self.correct_num / (self.all_num + self.eps)
|
||||
norm_edit_dis = 1 - self.norm_edit_dis / (self.all_num + self.eps)
|
||||
self.reset()
|
||||
return {'acc': acc, 'norm_edit_dis': norm_edit_dis}
|
||||
|
||||
|
||||
@ -12,9 +12,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import numpy as np
|
||||
|
||||
|
||||
class TableMetric(object):
|
||||
def __init__(self, main_indicator='acc', **kwargs):
|
||||
self.main_indicator = main_indicator
|
||||
self.eps = 1e-5
|
||||
self.reset()
|
||||
|
||||
def __call__(self, pred, batch, *args, **kwargs):
|
||||
@ -31,9 +34,7 @@ class TableMetric(object):
|
||||
correct_num += 1
|
||||
self.correct_num += correct_num
|
||||
self.all_num += all_num
|
||||
return {
|
||||
'acc': correct_num * 1.0 / all_num,
|
||||
}
|
||||
return {'acc': correct_num * 1.0 / (all_num + self.eps), }
|
||||
|
||||
def get_metric(self):
|
||||
"""
|
||||
@ -41,7 +42,7 @@ class TableMetric(object):
|
||||
'acc': 0,
|
||||
}
|
||||
"""
|
||||
acc = 1.0 * self.correct_num / self.all_num
|
||||
acc = 1.0 * self.correct_num / (self.all_num + self.eps)
|
||||
self.reset()
|
||||
return {'acc': acc}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user