mirror of
https://github.com/PaddlePaddle/PaddleOCR.git
synced 2025-11-25 14:35:58 +00:00
add LayoutLM ser
This commit is contained in:
parent
f01dbb5648
commit
9131c4a7ac
@ -195,7 +195,7 @@ export CUDA_VISIBLE_DEVICES=0
|
|||||||
python3.7 infer_ser.py \
|
python3.7 infer_ser.py \
|
||||||
--model_name_or_path "./PP-Layout_v1.0_ser_pretrained/" \
|
--model_name_or_path "./PP-Layout_v1.0_ser_pretrained/" \
|
||||||
--ser_model_type "LayoutXLM" \
|
--ser_model_type "LayoutXLM" \
|
||||||
--output_dir "output_res/" \
|
--output_dir "output/ser/" \
|
||||||
--infer_imgs "XFUND/zh_val/image/" \
|
--infer_imgs "XFUND/zh_val/image/" \
|
||||||
--ocr_json_path "XFUND/zh_val/xfun_normalize_val.json"
|
--ocr_json_path "XFUND/zh_val/xfun_normalize_val.json"
|
||||||
```
|
```
|
||||||
@ -210,7 +210,7 @@ python3.7 infer_ser_e2e.py \
|
|||||||
--model_name_or_path "./output/PP-Layout_v1.0_ser_pretrained/" \
|
--model_name_or_path "./output/PP-Layout_v1.0_ser_pretrained/" \
|
||||||
--ser_model_type "LayoutXLM" \
|
--ser_model_type "LayoutXLM" \
|
||||||
--max_seq_length 512 \
|
--max_seq_length 512 \
|
||||||
--output_dir "output_res_e2e/" \
|
--output_dir "output/ser_e2e/" \
|
||||||
--infer_imgs "images/input/zh_val_0.jpg"
|
--infer_imgs "images/input/zh_val_0.jpg"
|
||||||
```
|
```
|
||||||
|
|
||||||
@ -284,7 +284,7 @@ python3 eval_re.py \
|
|||||||
--eval_data_dir "XFUND/zh_val/image" \
|
--eval_data_dir "XFUND/zh_val/image" \
|
||||||
--eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \
|
--eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \
|
||||||
--label_map_path 'labels/labels_ser.txt' \
|
--label_map_path 'labels/labels_ser.txt' \
|
||||||
--output_dir "output/re_test/" \
|
--output_dir "output/re/" \
|
||||||
--per_gpu_eval_batch_size 8 \
|
--per_gpu_eval_batch_size 8 \
|
||||||
--num_workers 8 \
|
--num_workers 8 \
|
||||||
--seed 2048
|
--seed 2048
|
||||||
@ -302,7 +302,7 @@ python3 infer_re.py \
|
|||||||
--eval_data_dir "XFUND/zh_val/image" \
|
--eval_data_dir "XFUND/zh_val/image" \
|
||||||
--eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \
|
--eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \
|
||||||
--label_map_path 'labels/labels_ser.txt' \
|
--label_map_path 'labels/labels_ser.txt' \
|
||||||
--output_dir "output_res" \
|
--output_dir "output/re/" \
|
||||||
--per_gpu_eval_batch_size 1 \
|
--per_gpu_eval_batch_size 1 \
|
||||||
--seed 2048
|
--seed 2048
|
||||||
```
|
```
|
||||||
@ -317,7 +317,7 @@ python3.7 infer_ser_re_e2e.py \
|
|||||||
--model_name_or_path "./PP-Layout_v1.0_ser_pretrained/" \
|
--model_name_or_path "./PP-Layout_v1.0_ser_pretrained/" \
|
||||||
--re_model_name_or_path "./PP-Layout_v1.0_re_pretrained/" \
|
--re_model_name_or_path "./PP-Layout_v1.0_re_pretrained/" \
|
||||||
--max_seq_length 512 \
|
--max_seq_length 512 \
|
||||||
--output_dir "output_ser_re_e2e_train/" \
|
--output_dir "output/ser_re_e2e/" \
|
||||||
--infer_imgs "images/input/zh_val_21.jpg"
|
--infer_imgs "images/input/zh_val_21.jpg"
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
61
ppstructure/vqa/infer.sh
Normal file
61
ppstructure/vqa/infer.sh
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
export CUDA_VISIBLE_DEVICES=6
|
||||||
|
# python3.7 infer_ser_e2e.py \
|
||||||
|
# --model_name_or_path "output/ser_distributed/best_model" \
|
||||||
|
# --max_seq_length 512 \
|
||||||
|
# --output_dir "output_res_e2e/" \
|
||||||
|
# --infer_imgs "/ssd1/zhoujun20/VQA/data/XFUN_v1.0_data/zh.val/zh_val_0.jpg"
|
||||||
|
|
||||||
|
|
||||||
|
# python3.7 infer_ser_re_e2e.py \
|
||||||
|
# --model_name_or_path "output/ser_distributed/best_model" \
|
||||||
|
# --re_model_name_or_path "output/re_test/best_model" \
|
||||||
|
# --max_seq_length 512 \
|
||||||
|
# --output_dir "output_ser_re_e2e_train/" \
|
||||||
|
# --infer_imgs "images/input/zh_val_21.jpg"
|
||||||
|
|
||||||
|
# python3.7 infer_ser.py \
|
||||||
|
# --model_name_or_path "output/ser_LayoutLM/best_model" \
|
||||||
|
# --ser_model_type "LayoutLM" \
|
||||||
|
# --output_dir "ser_LayoutLM/" \
|
||||||
|
# --infer_imgs "images/input/zh_val_21.jpg" \
|
||||||
|
# --ocr_json_path "/ssd1/zhoujun20/VQA/data/XFUN_v1.0_data/xfun_normalize_val.json"
|
||||||
|
|
||||||
|
python3.7 infer_ser.py \
|
||||||
|
--model_name_or_path "output/ser_new/best_model" \
|
||||||
|
--ser_model_type "LayoutXLM" \
|
||||||
|
--output_dir "ser_new/" \
|
||||||
|
--infer_imgs "images/input/zh_val_21.jpg" \
|
||||||
|
--ocr_json_path "/ssd1/zhoujun20/VQA/data/XFUN_v1.0_data/xfun_normalize_val.json"
|
||||||
|
|
||||||
|
# python3.7 infer_ser_e2e.py \
|
||||||
|
# --model_name_or_path "output/ser_new/best_model" \
|
||||||
|
# --ser_model_type "LayoutXLM" \
|
||||||
|
# --max_seq_length 512 \
|
||||||
|
# --output_dir "output/ser_new/" \
|
||||||
|
# --infer_imgs "images/input/zh_val_0.jpg"
|
||||||
|
|
||||||
|
|
||||||
|
# python3.7 infer_ser_e2e.py \
|
||||||
|
# --model_name_or_path "output/ser_LayoutLM/best_model" \
|
||||||
|
# --ser_model_type "LayoutLM" \
|
||||||
|
# --max_seq_length 512 \
|
||||||
|
# --output_dir "output/ser_LayoutLM/" \
|
||||||
|
# --infer_imgs "images/input/zh_val_0.jpg"
|
||||||
|
|
||||||
|
# python3 infer_re.py \
|
||||||
|
# --model_name_or_path "/ssd1/zhoujun20/VQA/PaddleOCR/ppstructure/vqa/output/re_test/best_model/" \
|
||||||
|
# --max_seq_length 512 \
|
||||||
|
# --eval_data_dir "/ssd1/zhoujun20/VQA/data/XFUN_v1.0_data/zh.val" \
|
||||||
|
# --eval_label_path "/ssd1/zhoujun20/VQA/data/XFUN_v1.0_data/xfun_normalize_val.json" \
|
||||||
|
# --label_map_path 'labels/labels_ser.txt' \
|
||||||
|
# --output_dir "output_res" \
|
||||||
|
# --per_gpu_eval_batch_size 1 \
|
||||||
|
# --seed 2048
|
||||||
|
|
||||||
|
# python3.7 infer_ser_re_e2e.py \
|
||||||
|
# --model_name_or_path "output/ser_LayoutLM/best_model" \
|
||||||
|
# --ser_model_type "LayoutLM" \
|
||||||
|
# --re_model_name_or_path "output/re_new/best_model" \
|
||||||
|
# --max_seq_length 512 \
|
||||||
|
# --output_dir "output_ser_re_e2e/" \
|
||||||
|
# --infer_imgs "images/input/zh_val_21.jpg"
|
||||||
@ -56,19 +56,19 @@ def infer(args):
|
|||||||
ocr_info_list = load_ocr(args.eval_data_dir, args.eval_label_path)
|
ocr_info_list = load_ocr(args.eval_data_dir, args.eval_label_path)
|
||||||
|
|
||||||
for idx, batch in enumerate(eval_dataloader):
|
for idx, batch in enumerate(eval_dataloader):
|
||||||
|
ocr_info = ocr_info_list[idx]
|
||||||
|
image_path = ocr_info['image_path']
|
||||||
|
ocr_info = ocr_info['ocr_info']
|
||||||
|
|
||||||
save_img_path = os.path.join(
|
save_img_path = os.path.join(
|
||||||
args.output_dir,
|
args.output_dir,
|
||||||
os.path.splitext(os.path.basename(img_path))[0] + "_re.jpg")
|
os.path.splitext(os.path.basename(image_path))[0] + "_re.jpg")
|
||||||
logger.info("[Infer] process: {}/{}, save_result to {}".format(
|
logger.info("[Infer] process: {}/{}, save_result to {}".format(
|
||||||
idx, len(eval_dataloader), save_img_path))
|
idx, len(eval_dataloader), save_img_path))
|
||||||
with paddle.no_grad():
|
with paddle.no_grad():
|
||||||
outputs = model(**batch)
|
outputs = model(**batch)
|
||||||
pred_relations = outputs['pred_relations']
|
pred_relations = outputs['pred_relations']
|
||||||
|
|
||||||
ocr_info = ocr_info_list[idx]
|
|
||||||
image_path = ocr_info['image_path']
|
|
||||||
ocr_info = ocr_info['ocr_info']
|
|
||||||
|
|
||||||
# 根据entity里的信息,做token解码后去过滤不要的ocr_info
|
# 根据entity里的信息,做token解码后去过滤不要的ocr_info
|
||||||
ocr_info = filter_bg_by_txt(ocr_info, batch, tokenizer)
|
ocr_info = filter_bg_by_txt(ocr_info, batch, tokenizer)
|
||||||
|
|
||||||
|
|||||||
@ -98,13 +98,13 @@ class SerPredictor(object):
|
|||||||
ocr_info=ocr_info,
|
ocr_info=ocr_info,
|
||||||
max_seq_len=self.max_seq_length)
|
max_seq_len=self.max_seq_length)
|
||||||
|
|
||||||
if args.ser_model_type == 'LayoutLM':
|
if self.args.ser_model_type == 'LayoutLM':
|
||||||
preds = self.model(
|
preds = self.model(
|
||||||
input_ids=inputs["input_ids"],
|
input_ids=inputs["input_ids"],
|
||||||
bbox=inputs["bbox"],
|
bbox=inputs["bbox"],
|
||||||
token_type_ids=inputs["token_type_ids"],
|
token_type_ids=inputs["token_type_ids"],
|
||||||
attention_mask=inputs["attention_mask"])
|
attention_mask=inputs["attention_mask"])
|
||||||
elif args.ser_model_type == 'LayoutXLM':
|
elif self.args.ser_model_type == 'LayoutXLM':
|
||||||
preds = self.model(
|
preds = self.model(
|
||||||
input_ids=inputs["input_ids"],
|
input_ids=inputs["input_ids"],
|
||||||
bbox=inputs["bbox"],
|
bbox=inputs["bbox"],
|
||||||
|
|||||||
@ -117,7 +117,11 @@ if __name__ == "__main__":
|
|||||||
"w",
|
"w",
|
||||||
encoding='utf-8') as fout:
|
encoding='utf-8') as fout:
|
||||||
for idx, img_path in enumerate(infer_imgs):
|
for idx, img_path in enumerate(infer_imgs):
|
||||||
print("process: [{}/{}], {}".format(idx, len(infer_imgs), img_path))
|
save_img_path = os.path.join(
|
||||||
|
args.output_dir,
|
||||||
|
os.path.splitext(os.path.basename(img_path))[0] + "_re.jpg")
|
||||||
|
print("process: [{}/{}], save_result to {}".format(
|
||||||
|
idx, len(infer_imgs), save_img_path))
|
||||||
|
|
||||||
img = cv2.imread(img_path)
|
img = cv2.imread(img_path)
|
||||||
|
|
||||||
@ -128,7 +132,4 @@ if __name__ == "__main__":
|
|||||||
}, ensure_ascii=False) + "\n")
|
}, ensure_ascii=False) + "\n")
|
||||||
|
|
||||||
img_res = draw_re_results(img, result)
|
img_res = draw_re_results(img, result)
|
||||||
cv2.imwrite(
|
cv2.imwrite(save_img_path, img_res)
|
||||||
os.path.join(args.output_dir,
|
|
||||||
os.path.splitext(os.path.basename(img_path))[0] +
|
|
||||||
"_re.jpg"), img_res)
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user