mirror of
https://github.com/PaddlePaddle/PaddleOCR.git
synced 2025-12-25 22:18:19 +00:00
fix re infer bug
This commit is contained in:
parent
99de0353a6
commit
c703a5891a
@ -833,19 +833,20 @@ class VQATokenLabelEncode(object):
|
||||
segment_offset_id = []
|
||||
gt_label_list = []
|
||||
|
||||
if self.contains_re:
|
||||
# for re
|
||||
entities = []
|
||||
if not self.infer_mode:
|
||||
relations = []
|
||||
id2label = {}
|
||||
entity_id_to_index_map = {}
|
||||
empty_entity = set()
|
||||
entities = []
|
||||
|
||||
# for re
|
||||
train_re = self.contains_re and not self.infer_mode
|
||||
if train_re:
|
||||
relations = []
|
||||
id2label = {}
|
||||
entity_id_to_index_map = {}
|
||||
empty_entity = set()
|
||||
|
||||
data['ocr_info'] = copy.deepcopy(ocr_info)
|
||||
|
||||
for info in ocr_info:
|
||||
if self.contains_re and not self.infer_mode:
|
||||
if train_re:
|
||||
# for re
|
||||
if len(info["text"]) == 0:
|
||||
empty_entity.add(info["id"])
|
||||
@ -872,24 +873,22 @@ class VQATokenLabelEncode(object):
|
||||
gt_label = self._parse_label(label, encode_res)
|
||||
|
||||
# construct entities for re
|
||||
if self.contains_re:
|
||||
if not self.infer_mode:
|
||||
if gt_label[0] != self.label2id_map["O"]:
|
||||
entity_id_to_index_map[info["id"]] = len(entities)
|
||||
label = label.upper()
|
||||
entities.append({
|
||||
"start": len(input_ids_list),
|
||||
"end":
|
||||
len(input_ids_list) + len(encode_res["input_ids"]),
|
||||
"label": label.upper(),
|
||||
})
|
||||
else:
|
||||
if train_re:
|
||||
if gt_label[0] != self.label2id_map["O"]:
|
||||
entity_id_to_index_map[info["id"]] = len(entities)
|
||||
label = label.upper()
|
||||
entities.append({
|
||||
"start": len(input_ids_list),
|
||||
"end":
|
||||
len(input_ids_list) + len(encode_res["input_ids"]),
|
||||
"label": 'O',
|
||||
"label": label.upper(),
|
||||
})
|
||||
else:
|
||||
entities.append({
|
||||
"start": len(input_ids_list),
|
||||
"end": len(input_ids_list) + len(encode_res["input_ids"]),
|
||||
"label": 'O',
|
||||
})
|
||||
input_ids_list.extend(encode_res["input_ids"])
|
||||
token_type_ids_list.extend(encode_res["token_type_ids"])
|
||||
bbox_list.extend([bbox] * len(encode_res["input_ids"]))
|
||||
@ -908,19 +907,23 @@ class VQATokenLabelEncode(object):
|
||||
padding_side=self.tokenizer.padding_side,
|
||||
pad_token_type_id=self.tokenizer.pad_token_type_id,
|
||||
pad_token_id=self.tokenizer.pad_token_id)
|
||||
data['entities'] = entities
|
||||
|
||||
if self.contains_re:
|
||||
data['entities'] = entities
|
||||
if self.infer_mode:
|
||||
data['ocr_info'] = ocr_info
|
||||
else:
|
||||
data['relations'] = relations
|
||||
data['id2label'] = id2label
|
||||
data['empty_entity'] = empty_entity
|
||||
data['entity_id_to_index_map'] = entity_id_to_index_map
|
||||
if train_re:
|
||||
data['relations'] = relations
|
||||
data['id2label'] = id2label
|
||||
data['empty_entity'] = empty_entity
|
||||
data['entity_id_to_index_map'] = entity_id_to_index_map
|
||||
return data
|
||||
|
||||
def _load_ocr_info(self, data):
|
||||
def trans_poly_to_bbox(poly):
|
||||
x1 = np.min([p[0] for p in poly])
|
||||
x2 = np.max([p[0] for p in poly])
|
||||
y1 = np.min([p[1] for p in poly])
|
||||
y2 = np.max([p[1] for p in poly])
|
||||
return [x1, y1, x2, y2]
|
||||
|
||||
if self.infer_mode:
|
||||
ocr_result = self.ocr_engine.ocr(data['image'], cls=False)
|
||||
ocr_info = []
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user