mirror of
https://github.com/PaddlePaddle/PaddleOCR.git
synced 2025-11-03 19:29:18 +00:00
add LayoutLM ser
This commit is contained in:
parent
f01dbb5648
commit
9131c4a7ac
@ -195,7 +195,7 @@ export CUDA_VISIBLE_DEVICES=0
|
||||
python3.7 infer_ser.py \
|
||||
--model_name_or_path "./PP-Layout_v1.0_ser_pretrained/" \
|
||||
--ser_model_type "LayoutXLM" \
|
||||
--output_dir "output_res/" \
|
||||
--output_dir "output/ser/" \
|
||||
--infer_imgs "XFUND/zh_val/image/" \
|
||||
--ocr_json_path "XFUND/zh_val/xfun_normalize_val.json"
|
||||
```
|
||||
@ -210,7 +210,7 @@ python3.7 infer_ser_e2e.py \
|
||||
--model_name_or_path "./output/PP-Layout_v1.0_ser_pretrained/" \
|
||||
--ser_model_type "LayoutXLM" \
|
||||
--max_seq_length 512 \
|
||||
--output_dir "output_res_e2e/" \
|
||||
--output_dir "output/ser_e2e/" \
|
||||
--infer_imgs "images/input/zh_val_0.jpg"
|
||||
```
|
||||
|
||||
@ -284,7 +284,7 @@ python3 eval_re.py \
|
||||
--eval_data_dir "XFUND/zh_val/image" \
|
||||
--eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \
|
||||
--label_map_path 'labels/labels_ser.txt' \
|
||||
--output_dir "output/re_test/" \
|
||||
--output_dir "output/re/" \
|
||||
--per_gpu_eval_batch_size 8 \
|
||||
--num_workers 8 \
|
||||
--seed 2048
|
||||
@ -302,7 +302,7 @@ python3 infer_re.py \
|
||||
--eval_data_dir "XFUND/zh_val/image" \
|
||||
--eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \
|
||||
--label_map_path 'labels/labels_ser.txt' \
|
||||
--output_dir "output_res" \
|
||||
--output_dir "output/re/" \
|
||||
--per_gpu_eval_batch_size 1 \
|
||||
--seed 2048
|
||||
```
|
||||
@ -317,7 +317,7 @@ python3.7 infer_ser_re_e2e.py \
|
||||
--model_name_or_path "./PP-Layout_v1.0_ser_pretrained/" \
|
||||
--re_model_name_or_path "./PP-Layout_v1.0_re_pretrained/" \
|
||||
--max_seq_length 512 \
|
||||
--output_dir "output_ser_re_e2e_train/" \
|
||||
--output_dir "output/ser_re_e2e/" \
|
||||
--infer_imgs "images/input/zh_val_21.jpg"
|
||||
```
|
||||
|
||||
|
||||
61
ppstructure/vqa/infer.sh
Normal file
61
ppstructure/vqa/infer.sh
Normal file
@ -0,0 +1,61 @@
|
||||
export CUDA_VISIBLE_DEVICES=6
|
||||
# python3.7 infer_ser_e2e.py \
|
||||
# --model_name_or_path "output/ser_distributed/best_model" \
|
||||
# --max_seq_length 512 \
|
||||
# --output_dir "output_res_e2e/" \
|
||||
# --infer_imgs "/ssd1/zhoujun20/VQA/data/XFUN_v1.0_data/zh.val/zh_val_0.jpg"
|
||||
|
||||
|
||||
# python3.7 infer_ser_re_e2e.py \
|
||||
# --model_name_or_path "output/ser_distributed/best_model" \
|
||||
# --re_model_name_or_path "output/re_test/best_model" \
|
||||
# --max_seq_length 512 \
|
||||
# --output_dir "output_ser_re_e2e_train/" \
|
||||
# --infer_imgs "images/input/zh_val_21.jpg"
|
||||
|
||||
# python3.7 infer_ser.py \
|
||||
# --model_name_or_path "output/ser_LayoutLM/best_model" \
|
||||
# --ser_model_type "LayoutLM" \
|
||||
# --output_dir "ser_LayoutLM/" \
|
||||
# --infer_imgs "images/input/zh_val_21.jpg" \
|
||||
# --ocr_json_path "/ssd1/zhoujun20/VQA/data/XFUN_v1.0_data/xfun_normalize_val.json"
|
||||
|
||||
python3.7 infer_ser.py \
|
||||
--model_name_or_path "output/ser_new/best_model" \
|
||||
--ser_model_type "LayoutXLM" \
|
||||
--output_dir "ser_new/" \
|
||||
--infer_imgs "images/input/zh_val_21.jpg" \
|
||||
--ocr_json_path "/ssd1/zhoujun20/VQA/data/XFUN_v1.0_data/xfun_normalize_val.json"
|
||||
|
||||
# python3.7 infer_ser_e2e.py \
|
||||
# --model_name_or_path "output/ser_new/best_model" \
|
||||
# --ser_model_type "LayoutXLM" \
|
||||
# --max_seq_length 512 \
|
||||
# --output_dir "output/ser_new/" \
|
||||
# --infer_imgs "images/input/zh_val_0.jpg"
|
||||
|
||||
|
||||
# python3.7 infer_ser_e2e.py \
|
||||
# --model_name_or_path "output/ser_LayoutLM/best_model" \
|
||||
# --ser_model_type "LayoutLM" \
|
||||
# --max_seq_length 512 \
|
||||
# --output_dir "output/ser_LayoutLM/" \
|
||||
# --infer_imgs "images/input/zh_val_0.jpg"
|
||||
|
||||
# python3 infer_re.py \
|
||||
# --model_name_or_path "/ssd1/zhoujun20/VQA/PaddleOCR/ppstructure/vqa/output/re_test/best_model/" \
|
||||
# --max_seq_length 512 \
|
||||
# --eval_data_dir "/ssd1/zhoujun20/VQA/data/XFUN_v1.0_data/zh.val" \
|
||||
# --eval_label_path "/ssd1/zhoujun20/VQA/data/XFUN_v1.0_data/xfun_normalize_val.json" \
|
||||
# --label_map_path 'labels/labels_ser.txt' \
|
||||
# --output_dir "output_res" \
|
||||
# --per_gpu_eval_batch_size 1 \
|
||||
# --seed 2048
|
||||
|
||||
# python3.7 infer_ser_re_e2e.py \
|
||||
# --model_name_or_path "output/ser_LayoutLM/best_model" \
|
||||
# --ser_model_type "LayoutLM" \
|
||||
# --re_model_name_or_path "output/re_new/best_model" \
|
||||
# --max_seq_length 512 \
|
||||
# --output_dir "output_ser_re_e2e/" \
|
||||
# --infer_imgs "images/input/zh_val_21.jpg"
|
||||
@ -56,19 +56,19 @@ def infer(args):
|
||||
ocr_info_list = load_ocr(args.eval_data_dir, args.eval_label_path)
|
||||
|
||||
for idx, batch in enumerate(eval_dataloader):
|
||||
ocr_info = ocr_info_list[idx]
|
||||
image_path = ocr_info['image_path']
|
||||
ocr_info = ocr_info['ocr_info']
|
||||
|
||||
save_img_path = os.path.join(
|
||||
args.output_dir,
|
||||
os.path.splitext(os.path.basename(img_path))[0] + "_re.jpg")
|
||||
os.path.splitext(os.path.basename(image_path))[0] + "_re.jpg")
|
||||
logger.info("[Infer] process: {}/{}, save_result to {}".format(
|
||||
idx, len(eval_dataloader), save_img_path))
|
||||
with paddle.no_grad():
|
||||
outputs = model(**batch)
|
||||
pred_relations = outputs['pred_relations']
|
||||
|
||||
ocr_info = ocr_info_list[idx]
|
||||
image_path = ocr_info['image_path']
|
||||
ocr_info = ocr_info['ocr_info']
|
||||
|
||||
# 根据entity里的信息,做token解码后去过滤不要的ocr_info
|
||||
ocr_info = filter_bg_by_txt(ocr_info, batch, tokenizer)
|
||||
|
||||
|
||||
@ -98,13 +98,13 @@ class SerPredictor(object):
|
||||
ocr_info=ocr_info,
|
||||
max_seq_len=self.max_seq_length)
|
||||
|
||||
if args.ser_model_type == 'LayoutLM':
|
||||
if self.args.ser_model_type == 'LayoutLM':
|
||||
preds = self.model(
|
||||
input_ids=inputs["input_ids"],
|
||||
bbox=inputs["bbox"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
attention_mask=inputs["attention_mask"])
|
||||
elif args.ser_model_type == 'LayoutXLM':
|
||||
elif self.args.ser_model_type == 'LayoutXLM':
|
||||
preds = self.model(
|
||||
input_ids=inputs["input_ids"],
|
||||
bbox=inputs["bbox"],
|
||||
|
||||
@ -117,7 +117,11 @@ if __name__ == "__main__":
|
||||
"w",
|
||||
encoding='utf-8') as fout:
|
||||
for idx, img_path in enumerate(infer_imgs):
|
||||
print("process: [{}/{}], {}".format(idx, len(infer_imgs), img_path))
|
||||
save_img_path = os.path.join(
|
||||
args.output_dir,
|
||||
os.path.splitext(os.path.basename(img_path))[0] + "_re.jpg")
|
||||
print("process: [{}/{}], save_result to {}".format(
|
||||
idx, len(infer_imgs), save_img_path))
|
||||
|
||||
img = cv2.imread(img_path)
|
||||
|
||||
@ -128,7 +132,4 @@ if __name__ == "__main__":
|
||||
}, ensure_ascii=False) + "\n")
|
||||
|
||||
img_res = draw_re_results(img, result)
|
||||
cv2.imwrite(
|
||||
os.path.join(args.output_dir,
|
||||
os.path.splitext(os.path.basename(img_path))[0] +
|
||||
"_re.jpg"), img_res)
|
||||
cv2.imwrite(save_img_path, img_res)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user