Working on table parsing

This commit is contained in:
Jake Poznanski 2025-10-23 21:19:27 +00:00
parent 88937c6e40
commit 1ef66fd313
3 changed files with 729 additions and 258 deletions

View File

@ -6,6 +6,7 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import asdict, dataclass, field
from enum import Enum
from typing import Dict, List, Optional, Set, Tuple, Union
from collections import defaultdict
import numpy as np
from bs4 import BeautifulSoup
@ -21,76 +22,33 @@ from .katex.render import compare_rendered_equations, render_equation
__test__ = False
@dataclass
@dataclass(frozen=True)
class TableData:
"""Class to hold table data and metadata about headers."""
"""Class which holds table data as a graph of cells. Ex. you can access the value at any row, col.
data: np.ndarray # The actual table data
header_rows: Set[int] = field(default_factory=set) # Indices of rows that are headers
header_cols: Set[int] = field(default_factory=set) # Indices of columns that are headers
col_headers: dict = field(default_factory=dict) # Maps column index to header text, handling colspan
row_headers: dict = field(default_factory=dict) # Maps row index to header text, handling rowspan
Cell texts are only ever present one time, so ex if on row 0, you have a colspan 1 and a colspan 2 column,
then text gets stored at (0,0) and (0,1) only, (0,2) is not present in the data
def __repr__(self) -> str:
"""Returns a concise representation of the TableData object for debugging."""
return f"TableData(shape={self.data.shape}, header_rows={len(self.header_rows)}, header_cols={len(self.header_cols)})"
However, you can also ask, given a row, col, which set of row,col pairs is "left", "right", "up", and "down"
from that one. There can be multiple values returned, because rowspans and colspans mean that you can have multiple cells in each direction.
def __str__(self) -> str:
"""Returns a pretty string representation of the table with header information."""
output = []
Further more, you can also query "top_heading" and "left_heading". Where we also mark cells that are considered "headings", ex. if they are in a thead
html tag.
Then, for each cell, you may request top_heading/left_heading, which returns all cell positions that are headings in those directions
Cells inside of <th> tags are considered headings automatically, but if these are not present, then the leftmost
cell in a row in automaticaly considered a header, and the same for the top most cell in a column.
"""
cell_text: Dict[tuple[int, int], str] # Stores map from row, col to cell text
heading_cells: Set[tuple[int, int]] # Contains the row, col pairs which are headings
# Table dimensions
output.append(f"Table: {self.data.shape[0]} rows × {self.data.shape[1]} columns")
up_relations: Dict[tuple[int, int], Set[tuple[int, int]]]
down_relations: Dict[tuple[int, int], Set[tuple[int, int]]]
left_relations: Dict[tuple[int, int], Set[tuple[int, int]]]
right_relations: Dict[tuple[int, int], Set[tuple[int, int]]]
# Header info
output.append(f"Header rows: {sorted(self.header_rows)}")
output.append(f"Header columns: {sorted(self.header_cols)}")
# Table content with formatting
separator = "+" + "+".join(["-" * 17] * self.data.shape[1]) + "+"
# Add a header for row indices
output.append(separator)
headers = [""] + [f"Column {i}" for i in range(self.data.shape[1])]
output.append("| {:<5} | ".format("Row") + " | ".join(["{:<15}".format(h) for h in headers[1:]]) + " |")
output.append(separator)
# Format each row
for i in range(min(self.data.shape[0], 15)): # Limit to 15 rows for readability
# Format cells, mark header cells
cells = []
for j in range(self.data.shape[1]):
cell = str(self.data[i, j])
if len(cell) > 15:
cell = cell[:12] + "..."
# Mark header cells with *
if i in self.header_rows or j in self.header_cols:
cell = f"*{cell}*"
cells.append(cell)
row_str = "| {:<5} | ".format(i) + " | ".join(["{:<15}".format(c) for c in cells]) + " |"
output.append(row_str)
output.append(separator)
# If table is too large, indicate truncation
if self.data.shape[0] > 15:
output.append(f"... {self.data.shape[0] - 15} more rows ...")
# Column header details if available
if self.col_headers:
output.append("\nColumn header mappings:")
for col, headers in sorted(self.col_headers.items()):
header_strs = [f"({row}, '{text}')" for row, text in headers]
output.append(f" Column {col}: {', '.join(header_strs)}")
# Row header details if available
if self.row_headers:
output.append("\nRow header mappings:")
for row, headers in sorted(self.row_headers.items()):
header_strs = [f"({col}, '{text}')" for col, text in headers]
output.append(f" Row {row}: {', '.join(header_strs)}")
return "\n".join(output)
top_heading_relations: Dict[tuple[int, int], Set[tuple[int, int]]]
left_heading_relations: Dict[tuple[int, int], Set[tuple[int, int]]]
class TestType(str, Enum):
@ -145,6 +103,225 @@ def normalize_text(md_content: str) -> str:
return md_content
def _safe_span_int(value: Optional[Union[str, int]], default: int = 1) -> int:
"""Convert rowspan/colspan attributes to positive integers."""
if value in (None, "", 0):
return default
try:
span = int(value)
except (TypeError, ValueError):
return default
if span <= 0:
return default
return span
def _build_table_data_from_specs(row_specs: List[List[Dict[str, Union[str, int, bool]]]]) -> Optional[TableData]:
"""
Build a TableData object from a list of row specifications.
Each row specification is a list of dictionaries with keys:
- text: cell text content
- rowspan: integer rowspan (>= 1)
- colspan: integer colspan (>= 1)
- is_heading: bool indicating if the cell should be treated as a heading
"""
if not row_specs:
return None
cell_text: Dict[Tuple[int, int], str] = {}
heading_cells: Set[Tuple[int, int]] = set()
cell_meta: Dict[Tuple[int, int], Dict[str, Union[int, bool]]] = {}
occupancy: List[List[Optional[Tuple[int, int]]]] = []
active_rowspans: List[Optional[Tuple[Tuple[int, int], int]]] = []
for row_idx, cells in enumerate(row_specs):
row_entries: List[Optional[Tuple[int, int]]] = []
col_index = 0
spec_idx = 0
total_specs = len(cells)
while spec_idx < total_specs or col_index < len(active_rowspans):
if col_index < len(active_rowspans) and active_rowspans[col_index] is not None:
cell_id, remaining = active_rowspans[col_index]
row_entries.append(cell_id)
remaining -= 1
active_rowspans[col_index] = (cell_id, remaining) if remaining > 0 else None
col_index += 1
continue
if spec_idx >= total_specs:
if col_index < len(active_rowspans):
row_entries.append(None)
col_index += 1
continue
break
spec = cells[spec_idx]
spec_idx += 1
text = spec.get("text", "") or ""
rowspan = spec.get("rowspan", 1)
colspan = spec.get("colspan", 1)
is_heading = bool(spec.get("is_heading", False))
rowspan = rowspan if isinstance(rowspan, int) else _safe_span_int(rowspan)
colspan = colspan if isinstance(colspan, int) else _safe_span_int(colspan)
rowspan = max(1, rowspan)
colspan = max(1, colspan)
cell_id = (row_idx, col_index)
cell_text[cell_id] = text
if is_heading:
heading_cells.add(cell_id)
cell_meta[cell_id] = {
"row": row_idx,
"col": col_index,
"rowspan": rowspan,
"colspan": colspan,
}
required_len = col_index + colspan
if len(active_rowspans) < required_len:
active_rowspans.extend([None] * (required_len - len(active_rowspans)))
for offset in range(colspan):
current_col = col_index + offset
row_entries.append(cell_id)
if rowspan > 1:
active_rowspans[current_col] = (cell_id, rowspan - 1)
else:
active_rowspans[current_col] = None
col_index += colspan
occupancy.append(row_entries)
# Flush any remaining active rowspans into additional rows
while any(entry is not None for entry in active_rowspans):
row_entries: List[Optional[Tuple[int, int]]] = []
for col_index, span_entry in enumerate(active_rowspans):
if span_entry is None:
row_entries.append(None)
continue
cell_id, remaining = span_entry
row_entries.append(cell_id)
remaining -= 1
active_rowspans[col_index] = (cell_id, remaining) if remaining > 0 else None
occupancy.append(row_entries)
if not cell_text:
return None
# Normalize occupancy to a consistent width based on populated columns
valid_columns = {idx for row in occupancy for idx, value in enumerate(row) if value is not None}
if valid_columns:
table_width = max(valid_columns) + 1
for row in occupancy:
if len(row) < table_width:
row.extend([None] * (table_width - len(row)))
elif len(row) > table_width:
del row[table_width:]
else:
return None
table_height = len(occupancy)
up_rel = defaultdict(set)
down_rel = defaultdict(set)
left_rel = defaultdict(set)
right_rel = defaultdict(set)
top_heading_rel = defaultdict(set)
left_heading_rel = defaultdict(set)
for cell_id, meta in cell_meta.items():
row_start = meta["row"]
col_start = meta["col"]
rowspan = meta["rowspan"]
colspan = meta["colspan"]
row_end = row_start + rowspan - 1
col_end = col_start + colspan - 1
# Right relations
for row in range(row_start, row_end + 1):
for col in range(col_end + 1, table_width):
neighbor = occupancy[row][col]
if neighbor is None or neighbor == cell_id:
continue
right_rel[cell_id].add(neighbor)
break
# Left relations
for row in range(row_start, row_end + 1):
for col in range(col_start - 1, -1, -1):
neighbor = occupancy[row][col]
if neighbor is None or neighbor == cell_id:
continue
left_rel[cell_id].add(neighbor)
break
# Down relations
for col in range(col_start, col_end + 1):
for row in range(row_end + 1, table_height):
if col >= len(occupancy[row]):
continue
neighbor = occupancy[row][col]
if neighbor is None or neighbor == cell_id:
continue
down_rel[cell_id].add(neighbor)
break
# Up relations
for col in range(col_start, col_end + 1):
for row in range(row_start - 1, -1, -1):
neighbor = occupancy[row][col]
if neighbor is None or neighbor == cell_id:
continue
up_rel[cell_id].add(neighbor)
break
# Top heading relations
for col in range(col_start, col_end + 1):
seen = set()
for row in range(row_start - 1, -1, -1):
neighbor = occupancy[row][col]
if neighbor is None or neighbor == cell_id or neighbor in seen:
continue
seen.add(neighbor)
if neighbor in heading_cells:
top_heading_rel[cell_id].add(neighbor)
# Left heading relations
for row in range(row_start, row_end + 1):
seen = set()
for col in range(col_start - 1, -1, -1):
neighbor = occupancy[row][col]
if neighbor is None or neighbor == cell_id or neighbor in seen:
continue
seen.add(neighbor)
if neighbor in heading_cells:
left_heading_rel[cell_id].add(neighbor)
# Ensure every cell has an entry in relations dictionaries
up_relations = {cell_id: set(up_rel[cell_id]) for cell_id in cell_text}
down_relations = {cell_id: set(down_rel[cell_id]) for cell_id in cell_text}
left_relations = {cell_id: set(left_rel[cell_id]) for cell_id in cell_text}
right_relations = {cell_id: set(right_rel[cell_id]) for cell_id in cell_text}
top_heading_relations = {cell_id: set(top_heading_rel[cell_id]) for cell_id in cell_text}
left_heading_relations = {cell_id: set(left_heading_rel[cell_id]) for cell_id in cell_text}
return TableData(
cell_text=cell_text,
heading_cells=heading_cells,
up_relations=up_relations,
down_relations=down_relations,
left_relations=left_relations,
right_relations=right_relations,
top_heading_relations=top_heading_relations,
left_heading_relations=left_heading_relations,
)
def parse_markdown_tables(md_content: str) -> List[TableData]:
"""
@ -183,74 +360,46 @@ def parse_markdown_tables(md_content: str) -> List[TableData]:
if len(current_table_lines) >= 2:
table_data = _process_table_lines(current_table_lines)
if table_data and len(table_data) > 0:
# Convert to numpy array for easier manipulation
max_cols = max(len(row) for row in table_data)
padded_data = [row + [""] * (max_cols - len(row)) for row in table_data]
table_array = np.array(padded_data)
# In markdown tables, the first row is typically a header row
header_rows = {0} if len(table_array) > 0 else set()
# Set up col_headers with first row headers for each column
col_headers = {}
if len(table_array) > 0:
for col_idx in range(table_array.shape[1]):
if col_idx < len(table_array[0]):
col_headers[col_idx] = [(0, table_array[0, col_idx])]
# Set up row_headers with first column headers for each row
row_headers = {}
if table_array.shape[1] > 0:
for row_idx in range(1, table_array.shape[0]): # Skip header row
row_headers[row_idx] = [(0, table_array[row_idx, 0])] # First column as heading
# Create TableData object
parsed_tables.append(
TableData(
data=table_array,
header_rows=header_rows,
header_cols={0} if table_array.shape[1] > 0 else set(), # First column as header
col_headers=col_headers,
row_headers=row_headers,
row_specs: List[List[Dict[str, Union[str, int, bool]]]] = []
for row_idx, row in enumerate(table_data):
row_specs.append(
[
{
"text": cell,
"rowspan": 1,
"colspan": 1,
"is_heading": row_idx == 0 or col_idx == 0,
}
for col_idx, cell in enumerate(row)
]
)
)
table = _build_table_data_from_specs(row_specs)
if table:
parsed_tables.append(table)
in_table = False
# Process the last table if we're still tracking one at the end of the file
if in_table and len(current_table_lines) >= 2:
table_data = _process_table_lines(current_table_lines)
if table_data and len(table_data) > 0:
# Convert to numpy array
max_cols = max(len(row) for row in table_data)
padded_data = [row + [""] * (max_cols - len(row)) for row in table_data]
table_array = np.array(padded_data)
# In markdown tables, the first row is typically a header row
header_rows = {0} if len(table_array) > 0 else set()
# Set up col_headers with first row headers for each column
col_headers = {}
if len(table_array) > 0:
for col_idx in range(table_array.shape[1]):
if col_idx < len(table_array[0]):
col_headers[col_idx] = [(0, table_array[0, col_idx])]
# Set up row_headers with first column headers for each row
row_headers = {}
if table_array.shape[1] > 0:
for row_idx in range(1, table_array.shape[0]): # Skip header row
row_headers[row_idx] = [(0, table_array[row_idx, 0])] # First column as heading
# Create TableData object
parsed_tables.append(
TableData(
data=table_array,
header_rows=header_rows,
header_cols={0} if table_array.shape[1] > 0 else set(), # First column as header
col_headers=col_headers,
row_headers=row_headers,
row_specs = []
for row_idx, row in enumerate(table_data):
row_specs.append(
[
{
"text": cell,
"rowspan": 1,
"colspan": 1,
"is_heading": row_idx == 0 or col_idx == 0,
}
for col_idx, cell in enumerate(row)
]
)
)
table = _build_table_data_from_specs(row_specs)
if table:
parsed_tables.append(table)
return parsed_tables
@ -318,159 +467,47 @@ def parse_html_tables(html_content: str) -> List[TableData]:
parsed_tables = []
for table in tables:
rows = table.find_all(["tr"])
table_data = []
header_rows = set()
header_cols = set()
col_headers = {} # Maps column index to all header cells above it
row_headers = {} # Maps row index to all header cells to its left
rows = table.find_all("tr")
if not rows:
continue
# Find rows inside thead tags - these are definitely header rows
thead = table.find("thead")
if thead:
thead_rows = thead.find_all("tr")
for tr in thead_rows:
header_rows.add(rows.index(tr))
row_specs: List[List[Dict[str, Union[str, int, bool]]]] = []
total_rows = len(rows)
# Initialize a grid to track filled cells due to rowspan/colspan
cell_grid = {}
col_span_info = {} # Tracks which columns contain headers
row_span_info = {} # Tracks which rows contain headers
# First pass: process each row to build the raw table data and identify headers
for row_idx, row in enumerate(rows):
cells = row.find_all(["th", "td"])
row_data = []
col_idx = 0
# If there are th elements in this row, it's likely a header row
if row.find("th"):
header_rows.add(row_idx)
cells = row.find_all(["th", "td"], recursive=False)
heading_context = row.find_parent("thead") is not None
row_spec: List[Dict[str, Union[str, int, bool]]] = []
for cell in cells:
# Skip positions already filled by rowspans from above
while (row_idx, col_idx) in cell_grid:
row_data.append(cell_grid[(row_idx, col_idx)])
col_idx += 1
# Replace <br> and <br/> tags with newlines before getting text
for br in cell.find_all("br"):
br.replace_with("\n")
cell_text = cell.get_text().strip()
# Handle rowspan/colspan
rowspan = int(cell.get("rowspan", 1))
colspan = int(cell.get("colspan", 1))
text = cell.get_text(separator="\n").strip()
raw_rowspan = cell.get("rowspan")
raw_colspan = cell.get("colspan")
# Add the cell to the row data
row_data.append(cell_text)
rowspan = _safe_span_int(raw_rowspan, 1)
colspan = _safe_span_int(raw_colspan, 1)
# Fill the grid for this cell and its rowspan/colspan
for i in range(rowspan):
for j in range(colspan):
if i == 0 and j == 0:
continue # Skip the main cell position
# For rowspan cells, preserve the text in all spanned rows
if j == 0 and i > 0: # Only for cells directly below
cell_grid[(row_idx + i, col_idx + j)] = cell_text
else:
cell_grid[(row_idx + i, col_idx + j)] = "" # Mark other spans as empty
# HTML specifies rowspan=0 to extend to the end of the table section
if isinstance(raw_rowspan, str) and raw_rowspan.strip() == "0":
rowspan = max(1, total_rows - row_idx)
# If this is a header cell (th), mark it and its span
if cell.name == "th":
# Mark columns as header columns
for j in range(colspan):
header_cols.add(col_idx + j)
is_heading = cell.name == "th" or heading_context
# For rowspan, mark spanned rows as part of header
for i in range(1, rowspan):
if row_idx + i < len(rows):
header_rows.add(row_idx + i)
row_spec.append({"text": text, "rowspan": rowspan, "colspan": colspan, "is_heading": is_heading})
# Record this header for all spanned columns
for j in range(colspan):
curr_col = col_idx + j
if curr_col not in col_headers:
col_headers[curr_col] = []
col_headers[curr_col].append((row_idx, cell_text))
row_specs.append(row_spec)
# Store which columns are covered by this header
if cell_text and colspan > 1:
if cell_text not in col_span_info:
col_span_info[cell_text] = set()
col_span_info[cell_text].add(curr_col)
# Store which rows are covered by this header for rowspan
if cell_text and rowspan > 1:
if cell_text not in row_span_info:
row_span_info[cell_text] = set()
for i in range(rowspan):
row_span_info[cell_text].add(row_idx + i)
# Also handle row headers from data cells that have rowspan
if cell.name == "td" and rowspan > 1 and col_idx in header_cols:
for i in range(1, rowspan):
if row_idx + i < len(rows):
if row_idx + i not in row_headers:
row_headers[row_idx + i] = []
row_headers[row_idx + i].append((col_idx, cell_text))
col_idx += colspan
# Pad the row if needed to handle different row lengths
table_data.append(row_data)
# Second pass: expand headers to cells that should inherit them
# First handle column headers
for header_text, columns in col_span_info.items():
for col in columns:
# Add this header to all columns it spans over
for row_idx in range(len(table_data)):
if row_idx not in header_rows: # Only apply to data rows
for j in range(col, len(table_data[row_idx]) if row_idx < len(table_data) else 0):
# Add header info to data cells in these columns
if j not in col_headers:
col_headers[j] = []
if not any(h[1] == header_text for h in col_headers[j]):
header_row = min([r for r, t in col_headers.get(col, [(0, "")])])
col_headers[j].append((header_row, header_text))
# Handle row headers
for header_text, rows in row_span_info.items():
for row in rows:
if row < len(table_data):
# Find first header column
header_col = min(header_cols) if header_cols else 0
if row not in row_headers:
row_headers[row] = []
if not any(h[1] == header_text for h in row_headers.get(row, [])):
row_headers[row].append((header_col, header_text))
# Process regular row headers - each cell in a header column becomes a header for its row
for col_idx in header_cols:
for row_idx, row in enumerate(table_data):
if col_idx < len(row) and row[col_idx].strip():
if row_idx not in row_headers:
row_headers[row_idx] = []
if not any(h[1] == row[col_idx] for h in row_headers.get(row_idx, [])):
row_headers[row_idx].append((col_idx, row[col_idx]))
# Calculate max columns for padding
max_cols = max(len(row) for row in table_data) if table_data else 0
# Ensure all rows have the same number of columns
table_data = _build_table_data_from_specs(row_specs)
if table_data:
padded_data = [row + [""] * (max_cols - len(row)) for row in table_data]
table_array = np.array(padded_data)
# Create TableData object with the table and header information
parsed_tables.append(
TableData(data=table_array, header_rows=header_rows, header_cols=header_cols, col_headers=col_headers, row_headers=row_headers)
)
parsed_tables.append(table_data)
return parsed_tables
@dataclass(kw_only=True)
class BasePDFTest:
"""

View File

@ -0,0 +1,272 @@
#!/bin/bash
# Runs an olmocr-bench run using the full pipeline (no fallback) for infrapartner testing
# This version skips the performance task and adds support for --server, --model, and --beaker-secret arguments
#
# Usage examples:
# ./scripts/run_infrapartner_benchmark.sh --server http://example.com --model your-model-name --beaker-secret my-api-key-secret
# ./scripts/run_infrapartner_benchmark.sh --beaker-image jakep/olmocr-benchmark-0.3.3-780bc7d934 --server http://example.com
set -e
# Parse command line arguments
MODEL=""
SERVER=""
BEAKER_SECRET=""
BENCH_BRANCH=""
BEAKER_IMAGE=""
while [[ $# -gt 0 ]]; do
case $1 in
--model)
MODEL="$2"
shift 2
;;
--server)
SERVER="$2"
shift 2
;;
--beaker-secret)
BEAKER_SECRET="$2"
shift 2
;;
--benchbranch)
BENCH_BRANCH="$2"
shift 2
;;
--beaker-image)
BEAKER_IMAGE="$2"
shift 2
;;
*)
echo "Unknown option: $1"
echo "Usage: $0 [--server SERVER_URL] [--model MODEL_NAME] [--beaker-secret SECRET_NAME] [--benchbranch BRANCH_NAME] [--beaker-image IMAGE_NAME]"
exit 1
;;
esac
done
# Check for uncommitted changes
if [ -n "$BEAKER_IMAGE" ]; then
echo "Skipping docker build"
else
if ! git diff-index --quiet HEAD --; then
echo "Error: There are uncommitted changes in the repository."
echo "Please commit or stash your changes before running the benchmark."
echo ""
echo "Uncommitted changes:"
git status --short
exit 1
fi
fi
# Use conda environment Python if available, otherwise use system Python
if [ -n "$CONDA_PREFIX" ]; then
PYTHON="$CONDA_PREFIX/bin/python"
echo "Using conda Python from: $CONDA_PREFIX"
else
PYTHON="python"
echo "Warning: No conda environment detected, using system Python"
fi
# Get version from version.py
VERSION=$($PYTHON -c 'import olmocr.version; print(olmocr.version.VERSION)')
echo "OlmOCR version: $VERSION"
# Get first 10 characters of git hash
GIT_HASH=$(git rev-parse HEAD | cut -c1-10)
echo "Git hash: $GIT_HASH"
# Get current git branch name
GIT_BRANCH=$(git rev-parse --abbrev-ref HEAD)
echo "Git branch: $GIT_BRANCH"
# Check if a Beaker image was provided
if [ -n "$BEAKER_IMAGE" ]; then
echo "Using provided Beaker image: $BEAKER_IMAGE"
IMAGE_TAG="$BEAKER_IMAGE"
else
# Create full image tag
IMAGE_TAG="olmocr-benchmark-${VERSION}-${GIT_HASH}"
echo "Building Docker image with tag: $IMAGE_TAG"
# Build the Docker image
echo "Building Docker image..."
docker build --platform linux/amd64 -f ./Dockerfile -t $IMAGE_TAG .
# Push image to beaker
echo "Trying to push image to Beaker..."
if ! beaker image create --workspace ai2/oe-data-pdf --name $IMAGE_TAG $IMAGE_TAG 2>/dev/null; then
echo "Warning: Beaker image with tag $IMAGE_TAG already exists. Using existing image."
fi
fi
# Get Beaker username
BEAKER_USER=$(beaker account whoami --format json | jq -r '.[0].name')
echo "Beaker user: $BEAKER_USER"
# Create Python script to run beaker experiment
cat << 'EOF' > /tmp/run_infrapartner_benchmark_experiment.py
import sys
from beaker import Beaker, ExperimentSpec, TaskSpec, TaskContext, ResultSpec, TaskResources, ImageSource, Priority, Constraints, EnvVar
# Get image tag, beaker user, git branch, git hash, optional model, server, beaker_secret, and bench branch from command line
image_tag = sys.argv[1]
beaker_user = sys.argv[2]
git_branch = sys.argv[3]
git_hash = sys.argv[4]
model = None
server = None
beaker_secret = None
bench_branch = None
# Parse remaining arguments
arg_idx = 5
while arg_idx < len(sys.argv):
if sys.argv[arg_idx] == "--benchbranch":
bench_branch = sys.argv[arg_idx + 1]
arg_idx += 2
elif sys.argv[arg_idx] == "--model":
model = sys.argv[arg_idx + 1]
arg_idx += 2
elif sys.argv[arg_idx] == "--server":
server = sys.argv[arg_idx + 1]
arg_idx += 2
elif sys.argv[arg_idx] == "--beaker-secret":
beaker_secret = sys.argv[arg_idx + 1]
arg_idx += 2
else:
print(f"Unknown argument: {sys.argv[arg_idx]}")
arg_idx += 1
# Initialize Beaker client
b = Beaker.from_env(default_workspace="ai2/olmocr")
# Build the pipeline command with optional parameters
pipeline_cmd = "python -m olmocr.pipeline ./localworkspace --markdown --pdfs ./olmOCR-bench/bench_data/pdfs/**/*.pdf"
if model:
pipeline_cmd += f" --model {model}"
if server:
pipeline_cmd += f" --server {server}"
if beaker_secret:
pipeline_cmd += " --api_key \"$API_KEY\""
# Check if AWS credentials secret exists
aws_creds_secret = f"{beaker_user}-AWS_CREDENTIALS_FILE"
try:
# Try to get the secret to see if it exists
b.secret.get(aws_creds_secret, workspace="ai2/olmocr")
has_aws_creds = True
print(f"Found AWS credentials secret: {aws_creds_secret}")
except:
has_aws_creds = False
print(f"AWS credentials secret not found: {aws_creds_secret}")
# First experiment: Original benchmark job
commands = []
if has_aws_creds:
commands.extend([
"mkdir -p ~/.aws",
'echo "$AWS_CREDENTIALS_FILE" > ~/.aws/credentials'
])
# If beaker_secret is provided, export it as API_KEY environment variable
if beaker_secret:
commands.append('export API_KEY="$BEAKER_API_KEY"')
# Build git clone command with optional branch
git_clone_cmd = "git clone https://huggingface.co/datasets/allenai/olmOCR-bench"
if bench_branch:
git_clone_cmd += f" -b {bench_branch}"
commands.extend([
git_clone_cmd,
"cd olmOCR-bench && git lfs pull && cd ..",
pipeline_cmd,
"python olmocr/bench/scripts/workspace_to_bench.py localworkspace/ olmOCR-bench/bench_data/olmocr --bench-path ./olmOCR-bench/",
"pip install s5cmd",
"s5cmd cp localworkspace/ s3://ai2-oe-data/jakep/olmocr-bench-runs/$BEAKER_WORKLOAD_ID/",
"python -m olmocr.bench.benchmark --dir ./olmOCR-bench/bench_data"
])
# Build task spec with optional env vars
# If image_tag contains '/', it's already a full beaker image reference
if '/' in image_tag:
image_ref = image_tag
else:
image_ref = f"{beaker_user}/{image_tag}"
task_spec_args = {
"name": "olmocr-infrapartner-benchmark",
"image": ImageSource(beaker=image_ref),
"command": [
"bash", "-c",
" && ".join(commands)
],
"context": TaskContext(
priority=Priority.normal,
preemptible=True,
),
"resources": TaskResources(gpu_count=0),
"constraints": Constraints(cluster=["ai2/phobos", "ai2/neptune", "ai2/saturn"]),
"result": ResultSpec(path="/noop-results"),
}
# Build env vars list
env_vars = []
if has_aws_creds:
env_vars.append(EnvVar(name="AWS_CREDENTIALS_FILE", secret=aws_creds_secret))
if beaker_secret:
env_vars.append(EnvVar(name="BEAKER_API_KEY", secret=beaker_secret))
# Add env vars if any exist
if env_vars:
task_spec_args["env_vars"] = env_vars
# Create experiment spec
experiment_spec = ExperimentSpec(
description=f"OlmOCR InfraPartner Benchmark Run - Branch: {git_branch}, Commit: {git_hash}",
budget="ai2/oe-base",
tasks=[TaskSpec(**task_spec_args)],
)
# Create the experiment
experiment = b.experiment.create(spec=experiment_spec, workspace="ai2/olmocr")
print(f"Created benchmark experiment: {experiment.id}")
print(f"View at: https://beaker.org/ex/{experiment.id}")
print("-------")
print("")
print("Note: Performance test has been skipped for infrapartner benchmark")
EOF
# Run the Python script to create the experiment
echo "Creating Beaker experiment..."
# Build command with appropriate arguments
CMD="$PYTHON /tmp/run_infrapartner_benchmark_experiment.py $IMAGE_TAG $BEAKER_USER $GIT_BRANCH $GIT_HASH"
if [ -n "$MODEL" ]; then
echo "Using model: $MODEL"
CMD="$CMD --model $MODEL"
fi
if [ -n "$SERVER" ]; then
echo "Using server: $SERVER"
CMD="$CMD --server $SERVER"
fi
if [ -n "$BEAKER_SECRET" ]; then
echo "Using beaker secret for API key: $BEAKER_SECRET"
CMD="$CMD --beaker-secret $BEAKER_SECRET"
fi
if [ -n "$BENCH_BRANCH" ]; then
echo "Using bench branch: $BENCH_BRANCH"
CMD="$CMD --benchbranch $BENCH_BRANCH"
fi
eval $CMD
# Clean up temporary file
rm /tmp/run_infrapartner_benchmark_experiment.py
echo "InfraPartner benchmark experiment submitted successfully!"

162
tests/test_table_parsing.py Normal file
View File

@ -0,0 +1,162 @@
import unittest
from olmocr.bench.tests import (
parse_html_tables,
parse_markdown_tables,
)
class TestParseHtmlTables(unittest.TestCase):
def test_basic_table(self):
data = parse_html_tables("""
<table border="1">
<thead>
<tr>
<th></th>
<th>ArXiv</th>
<th>Old<br>scans<br>math</th>
<th>Tables</th>
<th>Old<br>scans</th>
<th>Headers<br>&<br>footers</th>
<th>Multi<br>column</th>
<th>Long<br>tiny<br>text</th>
<th>Base</th>
<th>Overall</th>
</tr>
</thead>
<tbody>
<tr>
<td>Mistral OCR API</td>
<td>77.2</td>
<td>67.5</td>
<td>60.6</td>
<td>29.3</td>
<td>93.6</td>
<td>71.3</td>
<td>77.1</td>
<td>99.4</td>
<td>72.0±1.1</td>
</tr>
</tbody></table>""")[0]
print(data)
self.assertEqual(data.cell_text[0,0], "")
self.assertEqual(data.cell_text[0,1], "ArXiv")
self.assertEqual(data.left_relations[0,0], set())
self.assertEqual(data.up_relations[0,0], set())
self.assertEqual(data.left_relations[0,1], {(0,0)})
self.assertEqual(data.up_relations[1,0], {(0,0)})
self.assertEqual(data.heading_cells, {
(0,0), (0,1), (0,2), (0,3),(0,4), (0,5),(0,6), (0,7), (0,8), (0,9)
})
self.assertEqual(data.top_heading_relations[1,3], {(0,3)})
# If there are no left headings defined, then the left most column is considered the left heading
self.assertEqual(data.left_heading_relations[1,3], {(1,0)})
def test_multiple_top_headings(self):
data = parse_html_tables("""
<table border="1">
<thead>
<tr>
<th colspan="2">Fruit Costs in Unittest land</th>
</tr>
<tr>
<th>Fruit Type</th>
<th>Cost</th>
</tr>
</thead>
<tbody>
<tr>
<td>Apples</td>
<td>$1.00</td>
</tr>
<tr>
<td>Oranges</td>
<td>$2.00</td>
</tr>
</tbody></table>""")[0]
print(data)
self.assertEqual(data.cell_text[0,0], "Fruit Costs in Unittest land")
self.assertEqual(data.cell_text[1,0], "Fruit Type")
self.assertEqual(data.cell_text[1,1], "Cost")
self.assertEqual(data.cell_text[2,0], "Apples")
self.assertEqual(data.cell_text[2,1], "$1.00")
self.assertEqual(data.cell_text[3,0], "Oranges")
self.assertEqual(data.cell_text[3,1], "$2.00")
self.assertEqual(data.up_relations[1,0], {(0,0)})
self.assertEqual(data.up_relations[1,1], {(0,0)})
self.assertEqual(data.up_relations[2,0], {(1,0)})
self.assertEqual(data.up_relations[2,1], {(1,1)})
self.assertEqual(data.top_heading_relations[1,0], {(0,0)})
self.assertEqual(data.top_heading_relations[1,1], {(0,0)})
self.assertEqual(data.top_heading_relations[2,0], {(0,0), (1,0)})
self.assertEqual(data.top_heading_relations[2,1], {(0,0), (1,1)})
def test_4x4_table_with_spans(self):
"""Test a 4x4 table with various row spans and column spans"""
data = parse_html_tables("""
<table border="1">
<thead>
<tr>
<th>Header 1</th>
<th colspan="2">Header 2-3</th>
<th>Header 4</th>
</tr>
</thead>
<tbody>
<tr>
<td rowspan="2">Cell A (spans 2 rows)</td>
<td>Cell B</td>
<td>Cell C</td>
<td rowspan="2">Cell D (spans 2 rows)</td>
</tr>
<tr>
<td colspan="2">Cell E-F (spans 2 cols)</td>
</tr>
<tr>
<td>Cell G</td>
<td colspan="3">Cell H-I-J (spans 3 cols)</td>
</tr>
</tbody>
</table>""")[0]
print(data)
# Test header row
self.assertEqual(data.cell_text[0,0], "Header 1")
self.assertEqual(data.cell_text[0,1], "Header 2-3")
self.assertNotIn((0,2), data.cell_text) # colspan=2, so that next cell is empty
self.assertEqual(data.cell_text[0,3], "Header 4")
# Test first body row
self.assertEqual(data.cell_text[1,0], "Cell A (spans 2 rows)")
self.assertEqual(data.cell_text[1,1], "Cell B")
self.assertEqual(data.cell_text[1,2], "Cell C")
self.assertEqual(data.cell_text[1,3], "Cell D (spans 2 rows)")
# Test second body row
self.assertNotIn((2,0), data.cell_text)
self.assertEqual(data.cell_text[2,1], "Cell E-F (spans 2 cols)")
# Test third body row
self.assertEqual(data.cell_text[3,0], "Cell G")
self.assertEqual(data.cell_text[3,1], "Cell H-I-J (spans 3 cols)")
# Test heading cells
self.assertEqual(data.heading_cells, {
(0,0), (0,1), (0,3)
})