mirror of
https://github.com/allenai/olmocr.git
synced 2025-09-27 09:27:55 +00:00
Cleaner implementations of benchmark stuff
This commit is contained in:
parent
53494d9c7e
commit
9da1f92628
@ -4,9 +4,11 @@ This script runs olmocr bench.
|
|||||||
It will take as an argument a folder, and scan it for .jsonl files which contain the various rules and properties that we will check.
|
It will take as an argument a folder, and scan it for .jsonl files which contain the various rules and properties that we will check.
|
||||||
It will then validate the JSON files to make sure they are all valid.
|
It will then validate the JSON files to make sure they are all valid.
|
||||||
Then, each other folder in there (besides /pdfs) represents a pipeline tool that we will evaluate.
|
Then, each other folder in there (besides /pdfs) represents a pipeline tool that we will evaluate.
|
||||||
We will validate that each one of those contains a .md file corresponding to its parse for every .pdf in the /pdfs folder.
|
We will validate that each one of those contains at least one .md file (or repeated generations, e.g. _1.md, _2.md, etc.)
|
||||||
|
corresponding to its parse for every .pdf in the /pdfs folder.
|
||||||
Then, we will read each one, and check if they pass against all the rules.
|
Then, we will read each one, and check if they pass against all the rules.
|
||||||
If a rule fails, a short explanation is printed.
|
If a rule fails on some of the repeats, a short explanation is printed.
|
||||||
|
The final score is averaged over the repeated generations.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
@ -44,40 +46,35 @@ def validate_jsonl_file(jsonl_path: str, all_pdf_files: list[str]):
|
|||||||
raise ValueError(f"Missing required fields in line {line_num} of {jsonl_path}: {data}")
|
raise ValueError(f"Missing required fields in line {line_num} of {jsonl_path}: {data}")
|
||||||
|
|
||||||
rule_id = data["id"]
|
rule_id = data["id"]
|
||||||
|
|
||||||
if rule_id in rule_ids:
|
if rule_id in rule_ids:
|
||||||
raise ValueError(f"Duplicate rule {rule_id} in {jsonl_path}")
|
raise ValueError(f"Duplicate rule {rule_id} in {jsonl_path}")
|
||||||
else:
|
else:
|
||||||
rule_ids.add(rule_id)
|
rule_ids.add(rule_id)
|
||||||
|
|
||||||
# Make sure the document referenced exists
|
# Make sure the referenced PDF exists
|
||||||
if data["pdf"] not in all_pdf_basenames:
|
if data["pdf"] not in all_pdf_basenames:
|
||||||
raise ValueError(f"Missing pdf {data['pdf']} referenced by {rule_id} in {jsonl_path} line {line_num}")
|
raise ValueError(f"Missing pdf {data['pdf']} referenced by {rule_id} in {jsonl_path} line {line_num}")
|
||||||
|
|
||||||
# Additional validations depending on type
|
# Additional validations depending on rule type
|
||||||
rule_type = data["type"]
|
rule_type = data["type"]
|
||||||
if rule_type in ("present", "absent"):
|
if rule_type in ("present", "absent"):
|
||||||
if "text" not in data:
|
if "text" not in data:
|
||||||
raise ValueError(f"'text' field required for rule type '{rule_type}' in {jsonl_path} line {line_num}")
|
raise ValueError(f"'text' field required for rule type '{rule_type}' in {jsonl_path} line {line_num}")
|
||||||
elif rule_type == "order":
|
elif rule_type == "order":
|
||||||
# Check that anchor is present, and that either 'before' or 'after' is present
|
|
||||||
if "before" not in data:
|
if "before" not in data:
|
||||||
raise ValueError(f"'before' field required for rule type 'order' in {jsonl_path} line {line_num}")
|
raise ValueError(f"'before' field required for rule type 'order' in {jsonl_path} line {line_num}")
|
||||||
if len(data["before"]) < 10:
|
if len(data["before"]) < 10:
|
||||||
raise ValueError(f"'before' field too short {jsonl_path} line {line_num}")
|
raise ValueError(f"'before' field too short in {jsonl_path} line {line_num}")
|
||||||
if "after" not in data:
|
if "after" not in data:
|
||||||
raise ValueError(f"'after' required for rule type 'order' in {jsonl_path} line {line_num}")
|
raise ValueError(f"'after' field required for rule type 'order' in {jsonl_path} line {line_num}")
|
||||||
if len(data["after"]) < 10:
|
if len(data["after"]) < 10:
|
||||||
raise ValueError(f"'after' field too short {jsonl_path} line {line_num}")
|
raise ValueError(f"'after' field too short in {jsonl_path} line {line_num}")
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown rule type '{rule_type}' in {jsonl_path} line {line_num}")
|
raise ValueError(f"Unknown rule type '{rule_type}' in {jsonl_path} line {line_num}")
|
||||||
|
|
||||||
# If everything looks good, add to the rules list
|
|
||||||
rules.append(data)
|
rules.append(data)
|
||||||
|
|
||||||
return rules
|
return rules
|
||||||
|
|
||||||
|
|
||||||
def run_rule(rule, md_file_path: str) -> (bool, str):
|
def run_rule(rule, md_file_path: str) -> (bool, str):
|
||||||
"""
|
"""
|
||||||
Run the given rule on the content of the provided .md file.
|
Run the given rule on the content of the provided .md file.
|
||||||
@ -95,9 +92,7 @@ def run_rule(rule, md_file_path: str) -> (bool, str):
|
|||||||
if rule_type in ("present", "absent"):
|
if rule_type in ("present", "absent"):
|
||||||
reference_query = rule["text"]
|
reference_query = rule["text"]
|
||||||
threshold = rule.get("threshold", 1.0)
|
threshold = rule.get("threshold", 1.0)
|
||||||
|
|
||||||
best_ratio = fuzz.partial_ratio(reference_query, md_content) / 100.0
|
best_ratio = fuzz.partial_ratio(reference_query, md_content) / 100.0
|
||||||
|
|
||||||
if rule_type == "present":
|
if rule_type == "present":
|
||||||
if best_ratio >= threshold:
|
if best_ratio >= threshold:
|
||||||
return (True, "")
|
return (True, "")
|
||||||
@ -109,96 +104,96 @@ def run_rule(rule, md_file_path: str) -> (bool, str):
|
|||||||
else:
|
else:
|
||||||
return (False, f"Expected '{reference_query[:40]}...' with threshold {threshold} but best match ratio was {best_ratio:.3f}")
|
return (False, f"Expected '{reference_query[:40]}...' with threshold {threshold} but best match ratio was {best_ratio:.3f}")
|
||||||
elif rule_type == "order":
|
elif rule_type == "order":
|
||||||
# Implement a simple ordering check: ensure that the anchor text appears,
|
|
||||||
# and if 'before' is specified, it must appear before the anchor;
|
|
||||||
# if 'after' is specified, it must appear after the anchor.
|
|
||||||
before = rule.get("before")
|
before = rule.get("before")
|
||||||
after = rule.get("after")
|
after = rule.get("after")
|
||||||
threshold = rule.get("threshold", 1.0)
|
threshold = rule.get("threshold", 1.0)
|
||||||
|
|
||||||
max_l_dist = round((1.0 - threshold) * len(before))
|
max_l_dist = round((1.0 - threshold) * len(before))
|
||||||
|
|
||||||
before_matches = find_near_matches(before, md_content, max_l_dist=max_l_dist)
|
before_matches = find_near_matches(before, md_content, max_l_dist=max_l_dist)
|
||||||
after_matches = find_near_matches(after, md_content, max_l_dist=max_l_dist)
|
after_matches = find_near_matches(after, md_content, max_l_dist=max_l_dist)
|
||||||
|
|
||||||
if not before_matches:
|
if not before_matches:
|
||||||
return (False, f"'before' search text '{before[:40]}...' does not appear in parse with max_l_dist {max_l_dist}")
|
return (False, f"'before' search text '{before[:40]}...' not found with max_l_dist {max_l_dist}")
|
||||||
|
|
||||||
if not after_matches:
|
if not after_matches:
|
||||||
return (False, f"'after' search text '{after[:40]}...' does not appear in parse with max_l_dist {max_l_dist}")
|
return (False, f"'after' search text '{after[:40]}...' not found with max_l_dist {max_l_dist}")
|
||||||
|
|
||||||
# Go through each combination of matches and see if there exists one where the before .start is sooner than the after .start
|
|
||||||
for before_match, after_match in itertools.product(before_matches, after_matches):
|
for before_match, after_match in itertools.product(before_matches, after_matches):
|
||||||
if before_match.start < after_match.start:
|
if before_match.start < after_match.start:
|
||||||
return (True, "")
|
return (True, "")
|
||||||
|
return (False, f"Could not find a location where '{before[:40]}...' appears before '{after[:40]}...'.")
|
||||||
return (False, f"Could not find a place in the text where '{before[:40]}...' appears before '{after[:40]}...'.")
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Rule type '{rule_type}' is not implemented.")
|
raise NotImplementedError(f"Rule type '{rule_type}' is not implemented.")
|
||||||
|
|
||||||
|
|
||||||
def evaluate_candidate(candidate_folder: str, all_rules: list, pdf_basenames: list[str]):
|
def evaluate_candidate(candidate_folder: str, all_rules: list, pdf_basenames: list[str]):
|
||||||
"""
|
"""
|
||||||
For the candidate folder (pipeline tool output), first validate that it contains
|
For the candidate folder (pipeline tool output), validate that it contains at least one .md file
|
||||||
a .md file for every PDF in the pdf folder. Then, run each rule against the corresponding
|
(i.e. repeated generations like _1.md, _2.md, etc.) for every PDF in the pdf folder.
|
||||||
.md file.
|
Then, run each rule against all corresponding .md files and average the results.
|
||||||
|
|
||||||
Returns a tuple:
|
Returns a tuple:
|
||||||
(num_passed, total_rules, candidate_errors, rule_failures, rule_type_breakdown)
|
(overall_score, total_rules, candidate_errors, rule_failures, rule_type_breakdown)
|
||||||
where:
|
|
||||||
- candidate_errors is a list of error strings (e.g. missing files or exceptions)
|
|
||||||
- rule_failures is a list of rule failure messages (a rule returning False is not an error)
|
|
||||||
- rule_type_breakdown is a dict with rule type as key and a tuple (passed, total) as value
|
|
||||||
|
|
||||||
NOTE: A rule returning False is not considered an 'error' but simply a rule failure.
|
- overall_score: Average fraction of rules passed (averaged over repeats and rules).
|
||||||
Only exceptions and missing files are treated as candidate errors.
|
- total_rules: Total number of rules evaluated.
|
||||||
The rule_type_breakdown is added for a detailed breakdown of performance per rule type.
|
- candidate_errors: List of candidate errors (e.g. missing files).
|
||||||
|
- rule_failures: List of failure messages for rules not passing on all repeats.
|
||||||
|
- rule_type_breakdown: Dictionary mapping rule type to list of average pass ratios for rules of that type.
|
||||||
"""
|
"""
|
||||||
candidate_errors = []
|
candidate_errors = []
|
||||||
rule_failures = []
|
rule_failures = []
|
||||||
rule_type_breakdown = {} # key: rule type, value: [passed_count, total_count]
|
rule_type_breakdown = {} # key: rule type, value: list of average pass ratios
|
||||||
candidate_name = os.path.basename(candidate_folder)
|
candidate_name = os.path.basename(candidate_folder)
|
||||||
num_passed = 0
|
|
||||||
total_rules = 0
|
|
||||||
|
|
||||||
# Validate that a .md file exists for every PDF.
|
# Map each PDF to its corresponding MD repeats (e.g., doc1_1.md, doc1_2.md, etc.)
|
||||||
|
pdf_to_md_files = {}
|
||||||
for pdf_name in pdf_basenames:
|
for pdf_name in pdf_basenames:
|
||||||
# Change .pdf extension to .md (assumes pdf_name ends with .pdf)
|
md_base = os.path.splitext(pdf_name)[0]
|
||||||
md_name = os.path.splitext(pdf_name)[0] + ".md"
|
md_pattern = os.path.join(candidate_folder, f"{md_base}_*.md")
|
||||||
md_path = os.path.join(candidate_folder, md_name)
|
md_files = glob.glob(md_pattern)
|
||||||
if not os.path.exists(md_path):
|
if not md_files:
|
||||||
candidate_errors.append(f"Candidate '{candidate_name}' is missing {md_name} corresponding to {pdf_name}.")
|
candidate_errors.append(
|
||||||
|
f"Candidate '{candidate_name}' is missing MD repeats for {pdf_name} (expected files matching {md_base}_*.md)."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
pdf_to_md_files[pdf_name] = md_files
|
||||||
|
|
||||||
# If there are missing .md files, we don't run the rules.
|
|
||||||
if candidate_errors:
|
if candidate_errors:
|
||||||
return (0, len(all_rules), candidate_errors, rule_failures, rule_type_breakdown)
|
return (0.0, len(all_rules), candidate_errors, rule_failures, rule_type_breakdown)
|
||||||
|
|
||||||
# Evaluate rules. Each rule references a PDF (e.g., "doc1.pdf"), and we expect the candidate to have "doc1.md".
|
total_rule_score = 0.0
|
||||||
|
|
||||||
|
# Evaluate each rule. Each rule references a PDF (e.g., "doc1.pdf") so we get all its MD repeats.
|
||||||
for rule in all_rules:
|
for rule in all_rules:
|
||||||
rule_type = rule["type"]
|
rule_type = rule["type"]
|
||||||
# Initialize breakdown counts for this rule type if not already
|
|
||||||
if rule_type not in rule_type_breakdown:
|
if rule_type not in rule_type_breakdown:
|
||||||
rule_type_breakdown[rule_type] = [0, 0]
|
rule_type_breakdown[rule_type] = []
|
||||||
rule_type_breakdown[rule_type][1] += 1 # increment total count
|
|
||||||
|
|
||||||
pdf_name = rule["pdf"]
|
pdf_name = rule["pdf"]
|
||||||
md_name = os.path.splitext(pdf_name)[0] + ".md"
|
md_base = os.path.splitext(pdf_name)[0]
|
||||||
md_path = os.path.join(candidate_folder, md_name)
|
md_files = pdf_to_md_files.get(pdf_name, [])
|
||||||
total_rules += 1
|
if not md_files:
|
||||||
|
continue # Should not occur due to earlier check.
|
||||||
|
repeat_passes = 0
|
||||||
|
num_repeats = 0
|
||||||
|
explanations = []
|
||||||
|
for md_path in md_files:
|
||||||
|
num_repeats += 1
|
||||||
try:
|
try:
|
||||||
passed, explanation = run_rule(rule, md_path)
|
passed, explanation = run_rule(rule, md_path)
|
||||||
if passed:
|
if passed:
|
||||||
num_passed += 1
|
repeat_passes += 1
|
||||||
rule_type_breakdown[rule_type][0] += 1 # increment passed count
|
|
||||||
else:
|
else:
|
||||||
# A rule returning False is recorded as a rule failure, not an error.
|
explanations.append(explanation)
|
||||||
rule_failures.append(f"Rule {rule.get('id')} on {md_name} failed: {explanation}")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Exceptions are considered candidate errors.
|
candidate_errors.append(f"Error running rule {rule.get('id')} on {md_path}: {e}")
|
||||||
candidate_errors.append(f"Error running rule {rule.get('id')} on {md_name}: {e}")
|
explanations.append(str(e))
|
||||||
|
rule_avg = repeat_passes / num_repeats if num_repeats > 0 else 0.0
|
||||||
|
total_rule_score += rule_avg
|
||||||
|
if rule_avg < 1.0:
|
||||||
|
rule_failures.append(
|
||||||
|
f"Rule {rule.get('id')} on {md_base} average pass ratio: {rule_avg:.3f} ({repeat_passes}/{num_repeats} repeats passed). "
|
||||||
|
f"Example explanation: {explanations[0] if explanations else 'No explanation'}"
|
||||||
|
)
|
||||||
|
rule_type_breakdown[rule_type].append(rule_avg)
|
||||||
|
|
||||||
return (num_passed, total_rules, candidate_errors, rule_failures, rule_type_breakdown)
|
overall_score = total_rule_score / len(all_rules) if all_rules else 0.0
|
||||||
|
return (overall_score, len(all_rules), candidate_errors, rule_failures, rule_type_breakdown)
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(description="Run OLMOCR Bench.")
|
parser = argparse.ArgumentParser(description="Run OLMOCR Bench.")
|
||||||
@ -224,7 +219,7 @@ def main():
|
|||||||
# Get PDF basenames (e.g. "doc1.pdf")
|
# Get PDF basenames (e.g. "doc1.pdf")
|
||||||
pdf_basenames = [os.path.basename(p) for p in all_pdf_files]
|
pdf_basenames = [os.path.basename(p) for p in all_pdf_files]
|
||||||
|
|
||||||
# Find .jsonl files in the input folder and validate them
|
# Find and validate .jsonl files in the input folder
|
||||||
jsonl_files = glob.glob(os.path.join(input_folder, "*.jsonl"))
|
jsonl_files = glob.glob(os.path.join(input_folder, "*.jsonl"))
|
||||||
if not jsonl_files:
|
if not jsonl_files:
|
||||||
print(f"Error: No .jsonl files found in {input_folder}.", file=sys.stderr)
|
print(f"Error: No .jsonl files found in {input_folder}.", file=sys.stderr)
|
||||||
@ -260,8 +255,8 @@ def main():
|
|||||||
print("\nRunning rules for each candidate:")
|
print("\nRunning rules for each candidate:")
|
||||||
for candidate in candidate_folders:
|
for candidate in candidate_folders:
|
||||||
candidate_name = os.path.basename(candidate)
|
candidate_name = os.path.basename(candidate)
|
||||||
num_passed, total_rules, candidate_errors, rule_failures, rule_type_breakdown = evaluate_candidate(candidate, all_rules, pdf_basenames)
|
overall_score, total_rules, candidate_errors, rule_failures, rule_type_breakdown = evaluate_candidate(candidate, all_rules, pdf_basenames)
|
||||||
summary.append((candidate_name, num_passed, total_rules, candidate_errors, rule_failures, rule_type_breakdown))
|
summary.append((candidate_name, overall_score, total_rules, candidate_errors, rule_failures, rule_type_breakdown))
|
||||||
print(f"\nCandidate: {candidate_name}")
|
print(f"\nCandidate: {candidate_name}")
|
||||||
if candidate_errors:
|
if candidate_errors:
|
||||||
for err in candidate_errors:
|
for err in candidate_errors:
|
||||||
@ -270,23 +265,24 @@ def main():
|
|||||||
if rule_failures:
|
if rule_failures:
|
||||||
for fail in rule_failures:
|
for fail in rule_failures:
|
||||||
print(f" [FAIL] {fail}")
|
print(f" [FAIL] {fail}")
|
||||||
print(f" Passed {num_passed} out of {total_rules} rules.")
|
print(f" Average Score: {overall_score * 100:.1f}% over {total_rules} rules.")
|
||||||
|
|
||||||
# Print a final summary (if only rule failures occurred, we output the score and breakdown)
|
# Print final summary with breakdown by rule type
|
||||||
print("\n" + "="*50)
|
print("\n" + "="*50)
|
||||||
print("Final Summary:")
|
print("Final Summary:")
|
||||||
for candidate_name, num_passed, total_rules, candidate_errors, _, rule_type_breakdown in summary:
|
for candidate_name, overall_score, total_rules, candidate_errors, _, rule_type_breakdown in summary:
|
||||||
if candidate_errors:
|
if candidate_errors:
|
||||||
status = "FAILED (errors)"
|
status = "FAILED (errors)"
|
||||||
else:
|
else:
|
||||||
status = f"{num_passed / total_rules * 100:0.1f}%"
|
status = f"{overall_score * 100:0.1f}%"
|
||||||
print(f"{candidate_name:20s} : {num_passed:3d}/{total_rules:3d} rules passed - {status}")
|
print(f"{candidate_name:20s} : Average Score: {overall_score * 100:0.1f}% over {total_rules:3d} rules - {status}")
|
||||||
print(" Breakdown by rule type:")
|
print(" Breakdown by rule type:")
|
||||||
for rtype, counts in rule_type_breakdown.items():
|
for rtype, scores in rule_type_breakdown.items():
|
||||||
passed_count, total_count = counts
|
if scores:
|
||||||
percentage = passed_count / total_count * 100 if total_count else 0
|
avg = sum(scores) / len(scores) * 100
|
||||||
print(f" {rtype:8s}: {passed_count:2d}/{total_count:2d} rules passed ({percentage:0.1f}%)")
|
else:
|
||||||
|
avg = 0.0
|
||||||
|
print(f" {rtype:8s}: {avg:0.1f}% average pass rate over {len(scores)} rules")
|
||||||
print("="*50)
|
print("="*50)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -1,39 +1,90 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import glob
|
import glob
|
||||||
|
import importlib
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
# Import all of the runners
|
def parse_method_arg(method_arg):
|
||||||
from olmocr.bench.runners.run_gotocr import run_gotocr
|
"""
|
||||||
from olmocr.bench.runners.run_marker import run_marker
|
Parse a method configuration string of the form:
|
||||||
|
method_name[:key=value[:key2=value2...]]
|
||||||
# Goes through each pdf in the data folder, and converts them with each provided method
|
Returns:
|
||||||
|
(method_name, kwargs_dict)
|
||||||
|
"""
|
||||||
|
parts = method_arg.split(":")
|
||||||
|
name = parts[0]
|
||||||
|
kwargs = {}
|
||||||
|
for extra in parts[1:]:
|
||||||
|
if "=" in extra:
|
||||||
|
key, value = extra.split("=", 1)
|
||||||
|
try:
|
||||||
|
converted = int(value)
|
||||||
|
except ValueError:
|
||||||
|
try:
|
||||||
|
converted = float(value)
|
||||||
|
except ValueError:
|
||||||
|
converted = value
|
||||||
|
kwargs[key] = converted
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Extra argument '{extra}' is not in key=value format")
|
||||||
|
return name, kwargs
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Run PDF conversion using specified OCR methods and extra parameters."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"methods",
|
||||||
|
nargs="+",
|
||||||
|
help="Methods to run in the format method[:key=value ...]. "
|
||||||
|
"Example: gotocr mineru:temperature=2 marker:runs=3"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--repeats",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Number of times to repeat the conversion for each PDF."
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Mapping of method names to a tuple: (module path, function name)
|
||||||
|
available_methods = {
|
||||||
|
"gotocr": ("olmocr.bench.runners.run_gotocr", "run_gotocr"),
|
||||||
|
"marker": ("olmocr.bench.runners.run_marker", "run_marker"),
|
||||||
|
"mineru": ("olmocr.bench.runners.run_mineru", "run_mineru"),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Build config by importing only requested methods.
|
||||||
|
config = {}
|
||||||
|
for method_arg in args.methods:
|
||||||
|
method_name, extra_kwargs = parse_method_arg(method_arg)
|
||||||
|
if method_name not in available_methods:
|
||||||
|
parser.error(f"Unknown method: {method_name}. "
|
||||||
|
f"Available methods: {', '.join(available_methods.keys())}")
|
||||||
|
module_path, function_name = available_methods[method_name]
|
||||||
|
# Dynamically import the module and get the function.
|
||||||
|
module = importlib.import_module(module_path)
|
||||||
|
function = getattr(module, function_name)
|
||||||
|
config[method_name] = {
|
||||||
|
"method": function,
|
||||||
|
"kwargs": extra_kwargs
|
||||||
|
}
|
||||||
|
|
||||||
data_directory = os.path.join(os.path.dirname(__file__), "sample_data")
|
data_directory = os.path.join(os.path.dirname(__file__), "sample_data")
|
||||||
pdf_directory = os.path.join(data_directory, "pdfs")
|
pdf_directory = os.path.join(data_directory, "pdfs")
|
||||||
|
|
||||||
config = {
|
# Process each PDF using each specified method and repeat the conversion as needed.
|
||||||
"marker": {
|
|
||||||
"method": run_marker
|
|
||||||
},
|
|
||||||
|
|
||||||
"got_ocr": {
|
|
||||||
"method": run_gotocr,
|
|
||||||
"temperature": 0.0,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
for candidate in config.keys():
|
for candidate in config.keys():
|
||||||
print(f"Starting conversion using {candidate}")
|
print(f"Starting conversion using {candidate} with kwargs: {config[candidate]['kwargs']}")
|
||||||
os.makedirs(os.path.join(data_directory, candidate), exist_ok=True)
|
candidate_output_dir = os.path.join(data_directory, candidate)
|
||||||
|
os.makedirs(candidate_output_dir, exist_ok=True)
|
||||||
|
|
||||||
for pdf_path in tqdm(glob.glob(os.path.join(pdf_directory, "*.pdf")), desc=candidate):
|
for pdf_path in tqdm(glob.glob(os.path.join(pdf_directory, "*.pdf")), desc=candidate):
|
||||||
markdown = config[candidate]["method"](pdf_path, page_num=1)
|
base_name = os.path.basename(pdf_path).replace(".pdf", "")
|
||||||
|
# Repeat the conversion as many times as specified.
|
||||||
with open(os.path.join(data_directory, candidate, os.path.basename(pdf_path).replace(".pdf", ".md")), "w") as out_f:
|
for i in range(1, args.repeats + 1):
|
||||||
|
markdown = config[candidate]["method"](pdf_path, page_num=1, **config[candidate]["kwargs"])
|
||||||
|
output_filename = f"{base_name}_{i}.md"
|
||||||
|
output_path = os.path.join(candidate_output_dir, output_filename)
|
||||||
|
with open(output_path, "w") as out_f:
|
||||||
out_f.write(markdown)
|
out_f.write(markdown)
|
||||||
|
|
||||||
|
|
@ -8,20 +8,7 @@ from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
|
|||||||
from magic_pdf.config.enums import SupportedPdfParseMethod
|
from magic_pdf.config.enums import SupportedPdfParseMethod
|
||||||
|
|
||||||
|
|
||||||
def run(pdf_folder):
|
def run_mineru(pdf_path: str, page_num: int=1) -> str:
|
||||||
"""
|
|
||||||
Convert all PDF files in the specified folder to markdown using MinerU.
|
|
||||||
For each PDF file, the script outputs markdown files along with visual and JSON outputs.
|
|
||||||
The outputs are saved in a folder called "mineru" (with an "images" subfolder)
|
|
||||||
located in the same parent directory as pdf_folder.
|
|
||||||
|
|
||||||
:param pdf_folder: Path to the folder containing PDF files.
|
|
||||||
"""
|
|
||||||
# Resolve absolute paths
|
|
||||||
pdf_folder = os.path.abspath(pdf_folder)
|
|
||||||
parent_dir = os.path.dirname(pdf_folder)
|
|
||||||
output_folder = os.path.join(parent_dir, "mineru")
|
|
||||||
image_output_folder = os.path.join(output_folder, "images")
|
|
||||||
|
|
||||||
# Create output directories if they don't exist
|
# Create output directories if they don't exist
|
||||||
os.makedirs(image_output_folder, exist_ok=True)
|
os.makedirs(image_output_folder, exist_ok=True)
|
||||||
@ -31,19 +18,6 @@ def run(pdf_folder):
|
|||||||
image_writer = FileBasedDataWriter(image_output_folder)
|
image_writer = FileBasedDataWriter(image_output_folder)
|
||||||
md_writer = FileBasedDataWriter(output_folder)
|
md_writer = FileBasedDataWriter(output_folder)
|
||||||
|
|
||||||
# List all PDF files in the provided folder
|
|
||||||
pdf_files = [
|
|
||||||
os.path.join(pdf_folder, filename)
|
|
||||||
for filename in os.listdir(pdf_folder)
|
|
||||||
if filename.lower().endswith(".pdf")
|
|
||||||
]
|
|
||||||
|
|
||||||
for pdf_path in pdf_files:
|
|
||||||
print(f"Processing {pdf_path}...")
|
|
||||||
# Get file name without suffix for naming outputs
|
|
||||||
pdf_file_name = os.path.basename(pdf_path)
|
|
||||||
name_without_suff = pdf_file_name.split(".")[0]
|
|
||||||
|
|
||||||
# Read the PDF file bytes
|
# Read the PDF file bytes
|
||||||
reader = FileBasedDataReader("")
|
reader = FileBasedDataReader("")
|
||||||
pdf_bytes = reader.read(pdf_path)
|
pdf_bytes = reader.read(pdf_path)
|
||||||
@ -67,10 +41,13 @@ def run(pdf_folder):
|
|||||||
md_file_name = f"{name_without_suff}.md"
|
md_file_name = f"{name_without_suff}.md"
|
||||||
pipe_result.dump_md(md_writer, md_file_name, image_dir_basename)
|
pipe_result.dump_md(md_writer, md_file_name, image_dir_basename)
|
||||||
|
|
||||||
|
with open(os.path.join(output_folder, md_file_name), "r") as f:
|
||||||
|
md_data = f.read()
|
||||||
|
|
||||||
# Remove useless image folder
|
# Remove useless image folder
|
||||||
shutil.rmtree(image_output_folder)
|
shutil.rmtree(image_output_folder)
|
||||||
|
|
||||||
print(f"Finished processing {pdf_file_name}. Outputs saved to {output_folder}.")
|
return md_data
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
Loading…
x
Reference in New Issue
Block a user