diff --git a/olmocr/bench/tests.py b/olmocr/bench/tests.py
index 0045ea1..adc44f9 100644
--- a/olmocr/bench/tests.py
+++ b/olmocr/bench/tests.py
@@ -6,6 +6,7 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import asdict, dataclass, field
from enum import Enum
from typing import Dict, List, Optional, Set, Tuple, Union
+from collections import defaultdict
import numpy as np
from bs4 import BeautifulSoup
@@ -21,76 +22,33 @@ from .katex.render import compare_rendered_equations, render_equation
__test__ = False
-@dataclass
+
+@dataclass(frozen=True)
class TableData:
- """Class to hold table data and metadata about headers."""
+ """Class which holds table data as a graph of cells. Ex. you can access the value at any row, col.
- data: np.ndarray # The actual table data
- header_rows: Set[int] = field(default_factory=set) # Indices of rows that are headers
- header_cols: Set[int] = field(default_factory=set) # Indices of columns that are headers
- col_headers: dict = field(default_factory=dict) # Maps column index to header text, handling colspan
- row_headers: dict = field(default_factory=dict) # Maps row index to header text, handling rowspan
+ Cell texts are only ever present one time, so ex if on row 0, you have a colspan 1 and a colspan 2 column,
+ then text gets stored at (0,0) and (0,1) only, (0,2) is not present in the data
- def __repr__(self) -> str:
- """Returns a concise representation of the TableData object for debugging."""
- return f"TableData(shape={self.data.shape}, header_rows={len(self.header_rows)}, header_cols={len(self.header_cols)})"
+ However, you can also ask, given a row, col, which set of row,col pairs is "left", "right", "up", and "down"
+ from that one. There can be multiple values returned, because rowspans and colspans mean that you can have multiple cells in each direction.
- def __str__(self) -> str:
- """Returns a pretty string representation of the table with header information."""
- output = []
+ Further more, you can also query "top_heading" and "left_heading". Where we also mark cells that are considered "headings", ex. if they are in a thead
+ html tag.
+ Then, for each cell, you may request top_heading/left_heading, which returns all cell positions that are headings in those directions
+ Cells inside of
tags are considered headings automatically, but if these are not present, then the leftmost
+ cell in a row in automaticaly considered a header, and the same for the top most cell in a column.
+ """
+ cell_text: Dict[tuple[int, int], str] # Stores map from row, col to cell text
+ heading_cells: Set[tuple[int, int]] # Contains the row, col pairs which are headings
- # Table dimensions
- output.append(f"Table: {self.data.shape[0]} rows × {self.data.shape[1]} columns")
+ up_relations: Dict[tuple[int, int], Set[tuple[int, int]]]
+ down_relations: Dict[tuple[int, int], Set[tuple[int, int]]]
+ left_relations: Dict[tuple[int, int], Set[tuple[int, int]]]
+ right_relations: Dict[tuple[int, int], Set[tuple[int, int]]]
- # Header info
- output.append(f"Header rows: {sorted(self.header_rows)}")
- output.append(f"Header columns: {sorted(self.header_cols)}")
-
- # Table content with formatting
- separator = "+" + "+".join(["-" * 17] * self.data.shape[1]) + "+"
-
- # Add a header for row indices
- output.append(separator)
- headers = [""] + [f"Column {i}" for i in range(self.data.shape[1])]
- output.append("| {:<5} | ".format("Row") + " | ".join(["{:<15}".format(h) for h in headers[1:]]) + " |")
- output.append(separator)
-
- # Format each row
- for i in range(min(self.data.shape[0], 15)): # Limit to 15 rows for readability
- # Format cells, mark header cells
- cells = []
- for j in range(self.data.shape[1]):
- cell = str(self.data[i, j])
- if len(cell) > 15:
- cell = cell[:12] + "..."
- # Mark header cells with *
- if i in self.header_rows or j in self.header_cols:
- cell = f"*{cell}*"
- cells.append(cell)
-
- row_str = "| {:<5} | ".format(i) + " | ".join(["{:<15}".format(c) for c in cells]) + " |"
- output.append(row_str)
- output.append(separator)
-
- # If table is too large, indicate truncation
- if self.data.shape[0] > 15:
- output.append(f"... {self.data.shape[0] - 15} more rows ...")
-
- # Column header details if available
- if self.col_headers:
- output.append("\nColumn header mappings:")
- for col, headers in sorted(self.col_headers.items()):
- header_strs = [f"({row}, '{text}')" for row, text in headers]
- output.append(f" Column {col}: {', '.join(header_strs)}")
-
- # Row header details if available
- if self.row_headers:
- output.append("\nRow header mappings:")
- for row, headers in sorted(self.row_headers.items()):
- header_strs = [f"({col}, '{text}')" for col, text in headers]
- output.append(f" Row {row}: {', '.join(header_strs)}")
-
- return "\n".join(output)
+ top_heading_relations: Dict[tuple[int, int], Set[tuple[int, int]]]
+ left_heading_relations: Dict[tuple[int, int], Set[tuple[int, int]]]
class TestType(str, Enum):
@@ -145,6 +103,225 @@ def normalize_text(md_content: str) -> str:
return md_content
+def _safe_span_int(value: Optional[Union[str, int]], default: int = 1) -> int:
+ """Convert rowspan/colspan attributes to positive integers."""
+ if value in (None, "", 0):
+ return default
+ try:
+ span = int(value)
+ except (TypeError, ValueError):
+ return default
+ if span <= 0:
+ return default
+ return span
+
+
+def _build_table_data_from_specs(row_specs: List[List[Dict[str, Union[str, int, bool]]]]) -> Optional[TableData]:
+ """
+ Build a TableData object from a list of row specifications.
+
+ Each row specification is a list of dictionaries with keys:
+ - text: cell text content
+ - rowspan: integer rowspan (>= 1)
+ - colspan: integer colspan (>= 1)
+ - is_heading: bool indicating if the cell should be treated as a heading
+ """
+ if not row_specs:
+ return None
+
+ cell_text: Dict[Tuple[int, int], str] = {}
+ heading_cells: Set[Tuple[int, int]] = set()
+ cell_meta: Dict[Tuple[int, int], Dict[str, Union[int, bool]]] = {}
+ occupancy: List[List[Optional[Tuple[int, int]]]] = []
+ active_rowspans: List[Optional[Tuple[Tuple[int, int], int]]] = []
+
+ for row_idx, cells in enumerate(row_specs):
+ row_entries: List[Optional[Tuple[int, int]]] = []
+ col_index = 0
+ spec_idx = 0
+ total_specs = len(cells)
+
+ while spec_idx < total_specs or col_index < len(active_rowspans):
+ if col_index < len(active_rowspans) and active_rowspans[col_index] is not None:
+ cell_id, remaining = active_rowspans[col_index]
+ row_entries.append(cell_id)
+ remaining -= 1
+ active_rowspans[col_index] = (cell_id, remaining) if remaining > 0 else None
+ col_index += 1
+ continue
+
+ if spec_idx >= total_specs:
+ if col_index < len(active_rowspans):
+ row_entries.append(None)
+ col_index += 1
+ continue
+ break
+
+ spec = cells[spec_idx]
+ spec_idx += 1
+
+ text = spec.get("text", "") or ""
+ rowspan = spec.get("rowspan", 1)
+ colspan = spec.get("colspan", 1)
+ is_heading = bool(spec.get("is_heading", False))
+
+ rowspan = rowspan if isinstance(rowspan, int) else _safe_span_int(rowspan)
+ colspan = colspan if isinstance(colspan, int) else _safe_span_int(colspan)
+ rowspan = max(1, rowspan)
+ colspan = max(1, colspan)
+
+ cell_id = (row_idx, col_index)
+ cell_text[cell_id] = text
+ if is_heading:
+ heading_cells.add(cell_id)
+
+ cell_meta[cell_id] = {
+ "row": row_idx,
+ "col": col_index,
+ "rowspan": rowspan,
+ "colspan": colspan,
+ }
+
+ required_len = col_index + colspan
+ if len(active_rowspans) < required_len:
+ active_rowspans.extend([None] * (required_len - len(active_rowspans)))
+
+ for offset in range(colspan):
+ current_col = col_index + offset
+ row_entries.append(cell_id)
+ if rowspan > 1:
+ active_rowspans[current_col] = (cell_id, rowspan - 1)
+ else:
+ active_rowspans[current_col] = None
+
+ col_index += colspan
+
+ occupancy.append(row_entries)
+
+ # Flush any remaining active rowspans into additional rows
+ while any(entry is not None for entry in active_rowspans):
+ row_entries: List[Optional[Tuple[int, int]]] = []
+ for col_index, span_entry in enumerate(active_rowspans):
+ if span_entry is None:
+ row_entries.append(None)
+ continue
+ cell_id, remaining = span_entry
+ row_entries.append(cell_id)
+ remaining -= 1
+ active_rowspans[col_index] = (cell_id, remaining) if remaining > 0 else None
+ occupancy.append(row_entries)
+
+ if not cell_text:
+ return None
+
+ # Normalize occupancy to a consistent width based on populated columns
+ valid_columns = {idx for row in occupancy for idx, value in enumerate(row) if value is not None}
+ if valid_columns:
+ table_width = max(valid_columns) + 1
+ for row in occupancy:
+ if len(row) < table_width:
+ row.extend([None] * (table_width - len(row)))
+ elif len(row) > table_width:
+ del row[table_width:]
+ else:
+ return None
+
+ table_height = len(occupancy)
+
+ up_rel = defaultdict(set)
+ down_rel = defaultdict(set)
+ left_rel = defaultdict(set)
+ right_rel = defaultdict(set)
+ top_heading_rel = defaultdict(set)
+ left_heading_rel = defaultdict(set)
+
+ for cell_id, meta in cell_meta.items():
+ row_start = meta["row"]
+ col_start = meta["col"]
+ rowspan = meta["rowspan"]
+ colspan = meta["colspan"]
+ row_end = row_start + rowspan - 1
+ col_end = col_start + colspan - 1
+
+ # Right relations
+ for row in range(row_start, row_end + 1):
+ for col in range(col_end + 1, table_width):
+ neighbor = occupancy[row][col]
+ if neighbor is None or neighbor == cell_id:
+ continue
+ right_rel[cell_id].add(neighbor)
+ break
+
+ # Left relations
+ for row in range(row_start, row_end + 1):
+ for col in range(col_start - 1, -1, -1):
+ neighbor = occupancy[row][col]
+ if neighbor is None or neighbor == cell_id:
+ continue
+ left_rel[cell_id].add(neighbor)
+ break
+
+ # Down relations
+ for col in range(col_start, col_end + 1):
+ for row in range(row_end + 1, table_height):
+ if col >= len(occupancy[row]):
+ continue
+ neighbor = occupancy[row][col]
+ if neighbor is None or neighbor == cell_id:
+ continue
+ down_rel[cell_id].add(neighbor)
+ break
+
+ # Up relations
+ for col in range(col_start, col_end + 1):
+ for row in range(row_start - 1, -1, -1):
+ neighbor = occupancy[row][col]
+ if neighbor is None or neighbor == cell_id:
+ continue
+ up_rel[cell_id].add(neighbor)
+ break
+
+ # Top heading relations
+ for col in range(col_start, col_end + 1):
+ seen = set()
+ for row in range(row_start - 1, -1, -1):
+ neighbor = occupancy[row][col]
+ if neighbor is None or neighbor == cell_id or neighbor in seen:
+ continue
+ seen.add(neighbor)
+ if neighbor in heading_cells:
+ top_heading_rel[cell_id].add(neighbor)
+
+ # Left heading relations
+ for row in range(row_start, row_end + 1):
+ seen = set()
+ for col in range(col_start - 1, -1, -1):
+ neighbor = occupancy[row][col]
+ if neighbor is None or neighbor == cell_id or neighbor in seen:
+ continue
+ seen.add(neighbor)
+ if neighbor in heading_cells:
+ left_heading_rel[cell_id].add(neighbor)
+
+ # Ensure every cell has an entry in relations dictionaries
+ up_relations = {cell_id: set(up_rel[cell_id]) for cell_id in cell_text}
+ down_relations = {cell_id: set(down_rel[cell_id]) for cell_id in cell_text}
+ left_relations = {cell_id: set(left_rel[cell_id]) for cell_id in cell_text}
+ right_relations = {cell_id: set(right_rel[cell_id]) for cell_id in cell_text}
+ top_heading_relations = {cell_id: set(top_heading_rel[cell_id]) for cell_id in cell_text}
+ left_heading_relations = {cell_id: set(left_heading_rel[cell_id]) for cell_id in cell_text}
+
+ return TableData(
+ cell_text=cell_text,
+ heading_cells=heading_cells,
+ up_relations=up_relations,
+ down_relations=down_relations,
+ left_relations=left_relations,
+ right_relations=right_relations,
+ top_heading_relations=top_heading_relations,
+ left_heading_relations=left_heading_relations,
+ )
+
def parse_markdown_tables(md_content: str) -> List[TableData]:
"""
@@ -183,74 +360,46 @@ def parse_markdown_tables(md_content: str) -> List[TableData]:
if len(current_table_lines) >= 2:
table_data = _process_table_lines(current_table_lines)
if table_data and len(table_data) > 0:
- # Convert to numpy array for easier manipulation
- max_cols = max(len(row) for row in table_data)
- padded_data = [row + [""] * (max_cols - len(row)) for row in table_data]
- table_array = np.array(padded_data)
-
- # In markdown tables, the first row is typically a header row
- header_rows = {0} if len(table_array) > 0 else set()
-
- # Set up col_headers with first row headers for each column
- col_headers = {}
- if len(table_array) > 0:
- for col_idx in range(table_array.shape[1]):
- if col_idx < len(table_array[0]):
- col_headers[col_idx] = [(0, table_array[0, col_idx])]
-
- # Set up row_headers with first column headers for each row
- row_headers = {}
- if table_array.shape[1] > 0:
- for row_idx in range(1, table_array.shape[0]): # Skip header row
- row_headers[row_idx] = [(0, table_array[row_idx, 0])] # First column as heading
-
- # Create TableData object
- parsed_tables.append(
- TableData(
- data=table_array,
- header_rows=header_rows,
- header_cols={0} if table_array.shape[1] > 0 else set(), # First column as header
- col_headers=col_headers,
- row_headers=row_headers,
+ row_specs: List[List[Dict[str, Union[str, int, bool]]]] = []
+ for row_idx, row in enumerate(table_data):
+ row_specs.append(
+ [
+ {
+ "text": cell,
+ "rowspan": 1,
+ "colspan": 1,
+ "is_heading": row_idx == 0 or col_idx == 0,
+ }
+ for col_idx, cell in enumerate(row)
+ ]
)
- )
+
+ table = _build_table_data_from_specs(row_specs)
+ if table:
+ parsed_tables.append(table)
in_table = False
# Process the last table if we're still tracking one at the end of the file
if in_table and len(current_table_lines) >= 2:
table_data = _process_table_lines(current_table_lines)
if table_data and len(table_data) > 0:
- # Convert to numpy array
- max_cols = max(len(row) for row in table_data)
- padded_data = [row + [""] * (max_cols - len(row)) for row in table_data]
- table_array = np.array(padded_data)
-
- # In markdown tables, the first row is typically a header row
- header_rows = {0} if len(table_array) > 0 else set()
-
- # Set up col_headers with first row headers for each column
- col_headers = {}
- if len(table_array) > 0:
- for col_idx in range(table_array.shape[1]):
- if col_idx < len(table_array[0]):
- col_headers[col_idx] = [(0, table_array[0, col_idx])]
-
- # Set up row_headers with first column headers for each row
- row_headers = {}
- if table_array.shape[1] > 0:
- for row_idx in range(1, table_array.shape[0]): # Skip header row
- row_headers[row_idx] = [(0, table_array[row_idx, 0])] # First column as heading
-
- # Create TableData object
- parsed_tables.append(
- TableData(
- data=table_array,
- header_rows=header_rows,
- header_cols={0} if table_array.shape[1] > 0 else set(), # First column as header
- col_headers=col_headers,
- row_headers=row_headers,
+ row_specs = []
+ for row_idx, row in enumerate(table_data):
+ row_specs.append(
+ [
+ {
+ "text": cell,
+ "rowspan": 1,
+ "colspan": 1,
+ "is_heading": row_idx == 0 or col_idx == 0,
+ }
+ for col_idx, cell in enumerate(row)
+ ]
)
- )
+
+ table = _build_table_data_from_specs(row_specs)
+ if table:
+ parsed_tables.append(table)
return parsed_tables
@@ -318,159 +467,47 @@ def parse_html_tables(html_content: str) -> List[TableData]:
parsed_tables = []
for table in tables:
- rows = table.find_all(["tr"])
- table_data = []
- header_rows = set()
- header_cols = set()
- col_headers = {} # Maps column index to all header cells above it
- row_headers = {} # Maps row index to all header cells to its left
+ rows = table.find_all("tr")
+ if not rows:
+ continue
- # Find rows inside thead tags - these are definitely header rows
- thead = table.find("thead")
- if thead:
- thead_rows = thead.find_all("tr")
- for tr in thead_rows:
- header_rows.add(rows.index(tr))
+ row_specs: List[List[Dict[str, Union[str, int, bool]]]] = []
+ total_rows = len(rows)
- # Initialize a grid to track filled cells due to rowspan/colspan
- cell_grid = {}
- col_span_info = {} # Tracks which columns contain headers
- row_span_info = {} # Tracks which rows contain headers
-
- # First pass: process each row to build the raw table data and identify headers
for row_idx, row in enumerate(rows):
- cells = row.find_all(["th", "td"])
- row_data = []
- col_idx = 0
-
- # If there are th elements in this row, it's likely a header row
- if row.find("th"):
- header_rows.add(row_idx)
+ cells = row.find_all(["th", "td"], recursive=False)
+ heading_context = row.find_parent("thead") is not None
+ row_spec: List[Dict[str, Union[str, int, bool]]] = []
for cell in cells:
- # Skip positions already filled by rowspans from above
- while (row_idx, col_idx) in cell_grid:
- row_data.append(cell_grid[(row_idx, col_idx)])
- col_idx += 1
-
- # Replace and tags with newlines before getting text
for br in cell.find_all("br"):
br.replace_with("\n")
- cell_text = cell.get_text().strip()
- # Handle rowspan/colspan
- rowspan = int(cell.get("rowspan", 1))
- colspan = int(cell.get("colspan", 1))
+ text = cell.get_text(separator="\n").strip()
+ raw_rowspan = cell.get("rowspan")
+ raw_colspan = cell.get("colspan")
- # Add the cell to the row data
- row_data.append(cell_text)
+ rowspan = _safe_span_int(raw_rowspan, 1)
+ colspan = _safe_span_int(raw_colspan, 1)
- # Fill the grid for this cell and its rowspan/colspan
- for i in range(rowspan):
- for j in range(colspan):
- if i == 0 and j == 0:
- continue # Skip the main cell position
- # For rowspan cells, preserve the text in all spanned rows
- if j == 0 and i > 0: # Only for cells directly below
- cell_grid[(row_idx + i, col_idx + j)] = cell_text
- else:
- cell_grid[(row_idx + i, col_idx + j)] = "" # Mark other spans as empty
+ # HTML specifies rowspan=0 to extend to the end of the table section
+ if isinstance(raw_rowspan, str) and raw_rowspan.strip() == "0":
+ rowspan = max(1, total_rows - row_idx)
- # If this is a header cell (th), mark it and its span
- if cell.name == "th":
- # Mark columns as header columns
- for j in range(colspan):
- header_cols.add(col_idx + j)
+ is_heading = cell.name == "th" or heading_context
- # For rowspan, mark spanned rows as part of header
- for i in range(1, rowspan):
- if row_idx + i < len(rows):
- header_rows.add(row_idx + i)
+ row_spec.append({"text": text, "rowspan": rowspan, "colspan": colspan, "is_heading": is_heading})
- # Record this header for all spanned columns
- for j in range(colspan):
- curr_col = col_idx + j
- if curr_col not in col_headers:
- col_headers[curr_col] = []
- col_headers[curr_col].append((row_idx, cell_text))
+ row_specs.append(row_spec)
- # Store which columns are covered by this header
- if cell_text and colspan > 1:
- if cell_text not in col_span_info:
- col_span_info[cell_text] = set()
- col_span_info[cell_text].add(curr_col)
-
- # Store which rows are covered by this header for rowspan
- if cell_text and rowspan > 1:
- if cell_text not in row_span_info:
- row_span_info[cell_text] = set()
- for i in range(rowspan):
- row_span_info[cell_text].add(row_idx + i)
-
- # Also handle row headers from data cells that have rowspan
- if cell.name == "td" and rowspan > 1 and col_idx in header_cols:
- for i in range(1, rowspan):
- if row_idx + i < len(rows):
- if row_idx + i not in row_headers:
- row_headers[row_idx + i] = []
- row_headers[row_idx + i].append((col_idx, cell_text))
-
- col_idx += colspan
-
- # Pad the row if needed to handle different row lengths
- table_data.append(row_data)
-
- # Second pass: expand headers to cells that should inherit them
- # First handle column headers
- for header_text, columns in col_span_info.items():
- for col in columns:
- # Add this header to all columns it spans over
- for row_idx in range(len(table_data)):
- if row_idx not in header_rows: # Only apply to data rows
- for j in range(col, len(table_data[row_idx]) if row_idx < len(table_data) else 0):
- # Add header info to data cells in these columns
- if j not in col_headers:
- col_headers[j] = []
- if not any(h[1] == header_text for h in col_headers[j]):
- header_row = min([r for r, t in col_headers.get(col, [(0, "")])])
- col_headers[j].append((header_row, header_text))
-
- # Handle row headers
- for header_text, rows in row_span_info.items():
- for row in rows:
- if row < len(table_data):
- # Find first header column
- header_col = min(header_cols) if header_cols else 0
- if row not in row_headers:
- row_headers[row] = []
- if not any(h[1] == header_text for h in row_headers.get(row, [])):
- row_headers[row].append((header_col, header_text))
-
- # Process regular row headers - each cell in a header column becomes a header for its row
- for col_idx in header_cols:
- for row_idx, row in enumerate(table_data):
- if col_idx < len(row) and row[col_idx].strip():
- if row_idx not in row_headers:
- row_headers[row_idx] = []
- if not any(h[1] == row[col_idx] for h in row_headers.get(row_idx, [])):
- row_headers[row_idx].append((col_idx, row[col_idx]))
-
- # Calculate max columns for padding
- max_cols = max(len(row) for row in table_data) if table_data else 0
-
- # Ensure all rows have the same number of columns
+ table_data = _build_table_data_from_specs(row_specs)
if table_data:
- padded_data = [row + [""] * (max_cols - len(row)) for row in table_data]
- table_array = np.array(padded_data)
-
- # Create TableData object with the table and header information
- parsed_tables.append(
- TableData(data=table_array, header_rows=header_rows, header_cols=header_cols, col_headers=col_headers, row_headers=row_headers)
- )
+ parsed_tables.append(table_data)
return parsed_tables
+
@dataclass(kw_only=True)
class BasePDFTest:
"""
diff --git a/scripts/run_infrapartner_benchmark.sh b/scripts/run_infrapartner_benchmark.sh
new file mode 100755
index 0000000..3c8aa6b
--- /dev/null
+++ b/scripts/run_infrapartner_benchmark.sh
@@ -0,0 +1,272 @@
+#!/bin/bash
+
+# Runs an olmocr-bench run using the full pipeline (no fallback) for infrapartner testing
+# This version skips the performance task and adds support for --server, --model, and --beaker-secret arguments
+#
+# Usage examples:
+# ./scripts/run_infrapartner_benchmark.sh --server http://example.com --model your-model-name --beaker-secret my-api-key-secret
+# ./scripts/run_infrapartner_benchmark.sh --beaker-image jakep/olmocr-benchmark-0.3.3-780bc7d934 --server http://example.com
+
+set -e
+
+# Parse command line arguments
+MODEL=""
+SERVER=""
+BEAKER_SECRET=""
+BENCH_BRANCH=""
+BEAKER_IMAGE=""
+while [[ $# -gt 0 ]]; do
+ case $1 in
+ --model)
+ MODEL="$2"
+ shift 2
+ ;;
+ --server)
+ SERVER="$2"
+ shift 2
+ ;;
+ --beaker-secret)
+ BEAKER_SECRET="$2"
+ shift 2
+ ;;
+ --benchbranch)
+ BENCH_BRANCH="$2"
+ shift 2
+ ;;
+ --beaker-image)
+ BEAKER_IMAGE="$2"
+ shift 2
+ ;;
+ *)
+ echo "Unknown option: $1"
+ echo "Usage: $0 [--server SERVER_URL] [--model MODEL_NAME] [--beaker-secret SECRET_NAME] [--benchbranch BRANCH_NAME] [--beaker-image IMAGE_NAME]"
+ exit 1
+ ;;
+ esac
+done
+
+# Check for uncommitted changes
+if [ -n "$BEAKER_IMAGE" ]; then
+ echo "Skipping docker build"
+else
+ if ! git diff-index --quiet HEAD --; then
+ echo "Error: There are uncommitted changes in the repository."
+ echo "Please commit or stash your changes before running the benchmark."
+ echo ""
+ echo "Uncommitted changes:"
+ git status --short
+ exit 1
+ fi
+fi
+
+# Use conda environment Python if available, otherwise use system Python
+if [ -n "$CONDA_PREFIX" ]; then
+ PYTHON="$CONDA_PREFIX/bin/python"
+ echo "Using conda Python from: $CONDA_PREFIX"
+else
+ PYTHON="python"
+ echo "Warning: No conda environment detected, using system Python"
+fi
+
+# Get version from version.py
+VERSION=$($PYTHON -c 'import olmocr.version; print(olmocr.version.VERSION)')
+echo "OlmOCR version: $VERSION"
+
+# Get first 10 characters of git hash
+GIT_HASH=$(git rev-parse HEAD | cut -c1-10)
+echo "Git hash: $GIT_HASH"
+
+# Get current git branch name
+GIT_BRANCH=$(git rev-parse --abbrev-ref HEAD)
+echo "Git branch: $GIT_BRANCH"
+
+# Check if a Beaker image was provided
+if [ -n "$BEAKER_IMAGE" ]; then
+ echo "Using provided Beaker image: $BEAKER_IMAGE"
+ IMAGE_TAG="$BEAKER_IMAGE"
+else
+ # Create full image tag
+ IMAGE_TAG="olmocr-benchmark-${VERSION}-${GIT_HASH}"
+ echo "Building Docker image with tag: $IMAGE_TAG"
+
+ # Build the Docker image
+ echo "Building Docker image..."
+ docker build --platform linux/amd64 -f ./Dockerfile -t $IMAGE_TAG .
+
+ # Push image to beaker
+ echo "Trying to push image to Beaker..."
+ if ! beaker image create --workspace ai2/oe-data-pdf --name $IMAGE_TAG $IMAGE_TAG 2>/dev/null; then
+ echo "Warning: Beaker image with tag $IMAGE_TAG already exists. Using existing image."
+ fi
+fi
+
+# Get Beaker username
+BEAKER_USER=$(beaker account whoami --format json | jq -r '.[0].name')
+echo "Beaker user: $BEAKER_USER"
+
+# Create Python script to run beaker experiment
+cat << 'EOF' > /tmp/run_infrapartner_benchmark_experiment.py
+import sys
+from beaker import Beaker, ExperimentSpec, TaskSpec, TaskContext, ResultSpec, TaskResources, ImageSource, Priority, Constraints, EnvVar
+
+# Get image tag, beaker user, git branch, git hash, optional model, server, beaker_secret, and bench branch from command line
+image_tag = sys.argv[1]
+beaker_user = sys.argv[2]
+git_branch = sys.argv[3]
+git_hash = sys.argv[4]
+model = None
+server = None
+beaker_secret = None
+bench_branch = None
+
+# Parse remaining arguments
+arg_idx = 5
+while arg_idx < len(sys.argv):
+ if sys.argv[arg_idx] == "--benchbranch":
+ bench_branch = sys.argv[arg_idx + 1]
+ arg_idx += 2
+ elif sys.argv[arg_idx] == "--model":
+ model = sys.argv[arg_idx + 1]
+ arg_idx += 2
+ elif sys.argv[arg_idx] == "--server":
+ server = sys.argv[arg_idx + 1]
+ arg_idx += 2
+ elif sys.argv[arg_idx] == "--beaker-secret":
+ beaker_secret = sys.argv[arg_idx + 1]
+ arg_idx += 2
+ else:
+ print(f"Unknown argument: {sys.argv[arg_idx]}")
+ arg_idx += 1
+
+# Initialize Beaker client
+b = Beaker.from_env(default_workspace="ai2/olmocr")
+
+# Build the pipeline command with optional parameters
+pipeline_cmd = "python -m olmocr.pipeline ./localworkspace --markdown --pdfs ./olmOCR-bench/bench_data/pdfs/**/*.pdf"
+if model:
+ pipeline_cmd += f" --model {model}"
+if server:
+ pipeline_cmd += f" --server {server}"
+if beaker_secret:
+ pipeline_cmd += " --api_key \"$API_KEY\""
+
+# Check if AWS credentials secret exists
+aws_creds_secret = f"{beaker_user}-AWS_CREDENTIALS_FILE"
+try:
+ # Try to get the secret to see if it exists
+ b.secret.get(aws_creds_secret, workspace="ai2/olmocr")
+ has_aws_creds = True
+ print(f"Found AWS credentials secret: {aws_creds_secret}")
+except:
+ has_aws_creds = False
+ print(f"AWS credentials secret not found: {aws_creds_secret}")
+
+# First experiment: Original benchmark job
+commands = []
+if has_aws_creds:
+ commands.extend([
+ "mkdir -p ~/.aws",
+ 'echo "$AWS_CREDENTIALS_FILE" > ~/.aws/credentials'
+ ])
+
+# If beaker_secret is provided, export it as API_KEY environment variable
+if beaker_secret:
+ commands.append('export API_KEY="$BEAKER_API_KEY"')
+
+# Build git clone command with optional branch
+git_clone_cmd = "git clone https://huggingface.co/datasets/allenai/olmOCR-bench"
+if bench_branch:
+ git_clone_cmd += f" -b {bench_branch}"
+
+commands.extend([
+ git_clone_cmd,
+ "cd olmOCR-bench && git lfs pull && cd ..",
+ pipeline_cmd,
+ "python olmocr/bench/scripts/workspace_to_bench.py localworkspace/ olmOCR-bench/bench_data/olmocr --bench-path ./olmOCR-bench/",
+ "pip install s5cmd",
+ "s5cmd cp localworkspace/ s3://ai2-oe-data/jakep/olmocr-bench-runs/$BEAKER_WORKLOAD_ID/",
+ "python -m olmocr.bench.benchmark --dir ./olmOCR-bench/bench_data"
+])
+
+# Build task spec with optional env vars
+# If image_tag contains '/', it's already a full beaker image reference
+if '/' in image_tag:
+ image_ref = image_tag
+else:
+ image_ref = f"{beaker_user}/{image_tag}"
+
+task_spec_args = {
+ "name": "olmocr-infrapartner-benchmark",
+ "image": ImageSource(beaker=image_ref),
+ "command": [
+ "bash", "-c",
+ " && ".join(commands)
+ ],
+ "context": TaskContext(
+ priority=Priority.normal,
+ preemptible=True,
+ ),
+ "resources": TaskResources(gpu_count=0),
+ "constraints": Constraints(cluster=["ai2/phobos", "ai2/neptune", "ai2/saturn"]),
+ "result": ResultSpec(path="/noop-results"),
+}
+
+# Build env vars list
+env_vars = []
+if has_aws_creds:
+ env_vars.append(EnvVar(name="AWS_CREDENTIALS_FILE", secret=aws_creds_secret))
+if beaker_secret:
+ env_vars.append(EnvVar(name="BEAKER_API_KEY", secret=beaker_secret))
+
+# Add env vars if any exist
+if env_vars:
+ task_spec_args["env_vars"] = env_vars
+
+# Create experiment spec
+experiment_spec = ExperimentSpec(
+ description=f"OlmOCR InfraPartner Benchmark Run - Branch: {git_branch}, Commit: {git_hash}",
+ budget="ai2/oe-base",
+ tasks=[TaskSpec(**task_spec_args)],
+)
+
+# Create the experiment
+experiment = b.experiment.create(spec=experiment_spec, workspace="ai2/olmocr")
+print(f"Created benchmark experiment: {experiment.id}")
+print(f"View at: https://beaker.org/ex/{experiment.id}")
+print("-------")
+print("")
+print("Note: Performance test has been skipped for infrapartner benchmark")
+EOF
+
+# Run the Python script to create the experiment
+echo "Creating Beaker experiment..."
+
+# Build command with appropriate arguments
+CMD="$PYTHON /tmp/run_infrapartner_benchmark_experiment.py $IMAGE_TAG $BEAKER_USER $GIT_BRANCH $GIT_HASH"
+
+if [ -n "$MODEL" ]; then
+ echo "Using model: $MODEL"
+ CMD="$CMD --model $MODEL"
+fi
+
+if [ -n "$SERVER" ]; then
+ echo "Using server: $SERVER"
+ CMD="$CMD --server $SERVER"
+fi
+
+if [ -n "$BEAKER_SECRET" ]; then
+ echo "Using beaker secret for API key: $BEAKER_SECRET"
+ CMD="$CMD --beaker-secret $BEAKER_SECRET"
+fi
+
+if [ -n "$BENCH_BRANCH" ]; then
+ echo "Using bench branch: $BENCH_BRANCH"
+ CMD="$CMD --benchbranch $BENCH_BRANCH"
+fi
+
+eval $CMD
+
+# Clean up temporary file
+rm /tmp/run_infrapartner_benchmark_experiment.py
+
+echo "InfraPartner benchmark experiment submitted successfully!"
\ No newline at end of file
diff --git a/tests/test_table_parsing.py b/tests/test_table_parsing.py
new file mode 100644
index 0000000..aa3a26e
--- /dev/null
+++ b/tests/test_table_parsing.py
@@ -0,0 +1,162 @@
+import unittest
+
+from olmocr.bench.tests import (
+ parse_html_tables,
+ parse_markdown_tables,
+)
+
+
+class TestParseHtmlTables(unittest.TestCase):
+ def test_basic_table(self):
+ data = parse_html_tables("""
+
+
+
+ |
+ ArXiv |
+ Old scans math |
+ Tables |
+ Old scans |
+ Headers & footers |
+ Multi column |
+ Long tiny text |
+ Base |
+ Overall |
+
+
+
+
+ | Mistral OCR API |
+ 77.2 |
+ 67.5 |
+ 60.6 |
+ 29.3 |
+ 93.6 |
+ 71.3 |
+ 77.1 |
+ 99.4 |
+ 72.0±1.1 |
+
+ """)[0]
+
+ print(data)
+
+ self.assertEqual(data.cell_text[0,0], "")
+ self.assertEqual(data.cell_text[0,1], "ArXiv")
+
+ self.assertEqual(data.left_relations[0,0], set())
+ self.assertEqual(data.up_relations[0,0], set())
+
+ self.assertEqual(data.left_relations[0,1], {(0,0)})
+ self.assertEqual(data.up_relations[1,0], {(0,0)})
+
+ self.assertEqual(data.heading_cells, {
+ (0,0), (0,1), (0,2), (0,3),(0,4), (0,5),(0,6), (0,7), (0,8), (0,9)
+ })
+
+ self.assertEqual(data.top_heading_relations[1,3], {(0,3)})
+
+ # If there are no left headings defined, then the left most column is considered the left heading
+ self.assertEqual(data.left_heading_relations[1,3], {(1,0)})
+
+ def test_multiple_top_headings(self):
+ data = parse_html_tables("""
+
+
+
+ | Fruit Costs in Unittest land |
+
+
+ | Fruit Type |
+ Cost |
+
+
+
+
+ | Apples |
+ $1.00 |
+
+
+ | Oranges |
+ $2.00 |
+
+ """)[0]
+
+ print(data)
+
+ self.assertEqual(data.cell_text[0,0], "Fruit Costs in Unittest land")
+ self.assertEqual(data.cell_text[1,0], "Fruit Type")
+ self.assertEqual(data.cell_text[1,1], "Cost")
+ self.assertEqual(data.cell_text[2,0], "Apples")
+ self.assertEqual(data.cell_text[2,1], "$1.00")
+ self.assertEqual(data.cell_text[3,0], "Oranges")
+ self.assertEqual(data.cell_text[3,1], "$2.00")
+
+
+ self.assertEqual(data.up_relations[1,0], {(0,0)})
+ self.assertEqual(data.up_relations[1,1], {(0,0)})
+
+ self.assertEqual(data.up_relations[2,0], {(1,0)})
+ self.assertEqual(data.up_relations[2,1], {(1,1)})
+
+ self.assertEqual(data.top_heading_relations[1,0], {(0,0)})
+ self.assertEqual(data.top_heading_relations[1,1], {(0,0)})
+
+ self.assertEqual(data.top_heading_relations[2,0], {(0,0), (1,0)})
+ self.assertEqual(data.top_heading_relations[2,1], {(0,0), (1,1)})
+
+ def test_4x4_table_with_spans(self):
+ """Test a 4x4 table with various row spans and column spans"""
+ data = parse_html_tables("""
+
+
+
+ | Header 1 |
+ Header 2-3 |
+ Header 4 |
+
+
+
+
+ | Cell A (spans 2 rows) |
+ Cell B |
+ Cell C |
+ Cell D (spans 2 rows) |
+
+
+ | Cell E-F (spans 2 cols) |
+
+
+ | Cell G |
+ Cell H-I-J (spans 3 cols) |
+
+
+ """)[0]
+
+ print(data)
+
+ # Test header row
+ self.assertEqual(data.cell_text[0,0], "Header 1")
+ self.assertEqual(data.cell_text[0,1], "Header 2-3")
+
+ self.assertNotIn((0,2), data.cell_text) # colspan=2, so that next cell is empty
+ self.assertEqual(data.cell_text[0,3], "Header 4")
+
+ # Test first body row
+ self.assertEqual(data.cell_text[1,0], "Cell A (spans 2 rows)")
+ self.assertEqual(data.cell_text[1,1], "Cell B")
+ self.assertEqual(data.cell_text[1,2], "Cell C")
+ self.assertEqual(data.cell_text[1,3], "Cell D (spans 2 rows)")
+
+ # Test second body row
+ self.assertNotIn((2,0), data.cell_text)
+ self.assertEqual(data.cell_text[2,1], "Cell E-F (spans 2 cols)")
+
+ # Test third body row
+ self.assertEqual(data.cell_text[3,0], "Cell G")
+ self.assertEqual(data.cell_text[3,1], "Cell H-I-J (spans 3 cols)")
+
+ # Test heading cells
+ self.assertEqual(data.heading_cells, {
+ (0,0), (0,1), (0,3)
+ })
|