diff --git a/olmocr/bench/tests.py b/olmocr/bench/tests.py index 0045ea1..adc44f9 100644 --- a/olmocr/bench/tests.py +++ b/olmocr/bench/tests.py @@ -6,6 +6,7 @@ from concurrent.futures import ThreadPoolExecutor, as_completed from dataclasses import asdict, dataclass, field from enum import Enum from typing import Dict, List, Optional, Set, Tuple, Union +from collections import defaultdict import numpy as np from bs4 import BeautifulSoup @@ -21,76 +22,33 @@ from .katex.render import compare_rendered_equations, render_equation __test__ = False -@dataclass + +@dataclass(frozen=True) class TableData: - """Class to hold table data and metadata about headers.""" + """Class which holds table data as a graph of cells. Ex. you can access the value at any row, col. - data: np.ndarray # The actual table data - header_rows: Set[int] = field(default_factory=set) # Indices of rows that are headers - header_cols: Set[int] = field(default_factory=set) # Indices of columns that are headers - col_headers: dict = field(default_factory=dict) # Maps column index to header text, handling colspan - row_headers: dict = field(default_factory=dict) # Maps row index to header text, handling rowspan + Cell texts are only ever present one time, so ex if on row 0, you have a colspan 1 and a colspan 2 column, + then text gets stored at (0,0) and (0,1) only, (0,2) is not present in the data - def __repr__(self) -> str: - """Returns a concise representation of the TableData object for debugging.""" - return f"TableData(shape={self.data.shape}, header_rows={len(self.header_rows)}, header_cols={len(self.header_cols)})" + However, you can also ask, given a row, col, which set of row,col pairs is "left", "right", "up", and "down" + from that one. There can be multiple values returned, because rowspans and colspans mean that you can have multiple cells in each direction. - def __str__(self) -> str: - """Returns a pretty string representation of the table with header information.""" - output = [] + Further more, you can also query "top_heading" and "left_heading". Where we also mark cells that are considered "headings", ex. if they are in a thead + html tag. + Then, for each cell, you may request top_heading/left_heading, which returns all cell positions that are headings in those directions + Cells inside of tags are considered headings automatically, but if these are not present, then the leftmost + cell in a row in automaticaly considered a header, and the same for the top most cell in a column. + """ + cell_text: Dict[tuple[int, int], str] # Stores map from row, col to cell text + heading_cells: Set[tuple[int, int]] # Contains the row, col pairs which are headings - # Table dimensions - output.append(f"Table: {self.data.shape[0]} rows × {self.data.shape[1]} columns") + up_relations: Dict[tuple[int, int], Set[tuple[int, int]]] + down_relations: Dict[tuple[int, int], Set[tuple[int, int]]] + left_relations: Dict[tuple[int, int], Set[tuple[int, int]]] + right_relations: Dict[tuple[int, int], Set[tuple[int, int]]] - # Header info - output.append(f"Header rows: {sorted(self.header_rows)}") - output.append(f"Header columns: {sorted(self.header_cols)}") - - # Table content with formatting - separator = "+" + "+".join(["-" * 17] * self.data.shape[1]) + "+" - - # Add a header for row indices - output.append(separator) - headers = [""] + [f"Column {i}" for i in range(self.data.shape[1])] - output.append("| {:<5} | ".format("Row") + " | ".join(["{:<15}".format(h) for h in headers[1:]]) + " |") - output.append(separator) - - # Format each row - for i in range(min(self.data.shape[0], 15)): # Limit to 15 rows for readability - # Format cells, mark header cells - cells = [] - for j in range(self.data.shape[1]): - cell = str(self.data[i, j]) - if len(cell) > 15: - cell = cell[:12] + "..." - # Mark header cells with * - if i in self.header_rows or j in self.header_cols: - cell = f"*{cell}*" - cells.append(cell) - - row_str = "| {:<5} | ".format(i) + " | ".join(["{:<15}".format(c) for c in cells]) + " |" - output.append(row_str) - output.append(separator) - - # If table is too large, indicate truncation - if self.data.shape[0] > 15: - output.append(f"... {self.data.shape[0] - 15} more rows ...") - - # Column header details if available - if self.col_headers: - output.append("\nColumn header mappings:") - for col, headers in sorted(self.col_headers.items()): - header_strs = [f"({row}, '{text}')" for row, text in headers] - output.append(f" Column {col}: {', '.join(header_strs)}") - - # Row header details if available - if self.row_headers: - output.append("\nRow header mappings:") - for row, headers in sorted(self.row_headers.items()): - header_strs = [f"({col}, '{text}')" for col, text in headers] - output.append(f" Row {row}: {', '.join(header_strs)}") - - return "\n".join(output) + top_heading_relations: Dict[tuple[int, int], Set[tuple[int, int]]] + left_heading_relations: Dict[tuple[int, int], Set[tuple[int, int]]] class TestType(str, Enum): @@ -145,6 +103,225 @@ def normalize_text(md_content: str) -> str: return md_content +def _safe_span_int(value: Optional[Union[str, int]], default: int = 1) -> int: + """Convert rowspan/colspan attributes to positive integers.""" + if value in (None, "", 0): + return default + try: + span = int(value) + except (TypeError, ValueError): + return default + if span <= 0: + return default + return span + + +def _build_table_data_from_specs(row_specs: List[List[Dict[str, Union[str, int, bool]]]]) -> Optional[TableData]: + """ + Build a TableData object from a list of row specifications. + + Each row specification is a list of dictionaries with keys: + - text: cell text content + - rowspan: integer rowspan (>= 1) + - colspan: integer colspan (>= 1) + - is_heading: bool indicating if the cell should be treated as a heading + """ + if not row_specs: + return None + + cell_text: Dict[Tuple[int, int], str] = {} + heading_cells: Set[Tuple[int, int]] = set() + cell_meta: Dict[Tuple[int, int], Dict[str, Union[int, bool]]] = {} + occupancy: List[List[Optional[Tuple[int, int]]]] = [] + active_rowspans: List[Optional[Tuple[Tuple[int, int], int]]] = [] + + for row_idx, cells in enumerate(row_specs): + row_entries: List[Optional[Tuple[int, int]]] = [] + col_index = 0 + spec_idx = 0 + total_specs = len(cells) + + while spec_idx < total_specs or col_index < len(active_rowspans): + if col_index < len(active_rowspans) and active_rowspans[col_index] is not None: + cell_id, remaining = active_rowspans[col_index] + row_entries.append(cell_id) + remaining -= 1 + active_rowspans[col_index] = (cell_id, remaining) if remaining > 0 else None + col_index += 1 + continue + + if spec_idx >= total_specs: + if col_index < len(active_rowspans): + row_entries.append(None) + col_index += 1 + continue + break + + spec = cells[spec_idx] + spec_idx += 1 + + text = spec.get("text", "") or "" + rowspan = spec.get("rowspan", 1) + colspan = spec.get("colspan", 1) + is_heading = bool(spec.get("is_heading", False)) + + rowspan = rowspan if isinstance(rowspan, int) else _safe_span_int(rowspan) + colspan = colspan if isinstance(colspan, int) else _safe_span_int(colspan) + rowspan = max(1, rowspan) + colspan = max(1, colspan) + + cell_id = (row_idx, col_index) + cell_text[cell_id] = text + if is_heading: + heading_cells.add(cell_id) + + cell_meta[cell_id] = { + "row": row_idx, + "col": col_index, + "rowspan": rowspan, + "colspan": colspan, + } + + required_len = col_index + colspan + if len(active_rowspans) < required_len: + active_rowspans.extend([None] * (required_len - len(active_rowspans))) + + for offset in range(colspan): + current_col = col_index + offset + row_entries.append(cell_id) + if rowspan > 1: + active_rowspans[current_col] = (cell_id, rowspan - 1) + else: + active_rowspans[current_col] = None + + col_index += colspan + + occupancy.append(row_entries) + + # Flush any remaining active rowspans into additional rows + while any(entry is not None for entry in active_rowspans): + row_entries: List[Optional[Tuple[int, int]]] = [] + for col_index, span_entry in enumerate(active_rowspans): + if span_entry is None: + row_entries.append(None) + continue + cell_id, remaining = span_entry + row_entries.append(cell_id) + remaining -= 1 + active_rowspans[col_index] = (cell_id, remaining) if remaining > 0 else None + occupancy.append(row_entries) + + if not cell_text: + return None + + # Normalize occupancy to a consistent width based on populated columns + valid_columns = {idx for row in occupancy for idx, value in enumerate(row) if value is not None} + if valid_columns: + table_width = max(valid_columns) + 1 + for row in occupancy: + if len(row) < table_width: + row.extend([None] * (table_width - len(row))) + elif len(row) > table_width: + del row[table_width:] + else: + return None + + table_height = len(occupancy) + + up_rel = defaultdict(set) + down_rel = defaultdict(set) + left_rel = defaultdict(set) + right_rel = defaultdict(set) + top_heading_rel = defaultdict(set) + left_heading_rel = defaultdict(set) + + for cell_id, meta in cell_meta.items(): + row_start = meta["row"] + col_start = meta["col"] + rowspan = meta["rowspan"] + colspan = meta["colspan"] + row_end = row_start + rowspan - 1 + col_end = col_start + colspan - 1 + + # Right relations + for row in range(row_start, row_end + 1): + for col in range(col_end + 1, table_width): + neighbor = occupancy[row][col] + if neighbor is None or neighbor == cell_id: + continue + right_rel[cell_id].add(neighbor) + break + + # Left relations + for row in range(row_start, row_end + 1): + for col in range(col_start - 1, -1, -1): + neighbor = occupancy[row][col] + if neighbor is None or neighbor == cell_id: + continue + left_rel[cell_id].add(neighbor) + break + + # Down relations + for col in range(col_start, col_end + 1): + for row in range(row_end + 1, table_height): + if col >= len(occupancy[row]): + continue + neighbor = occupancy[row][col] + if neighbor is None or neighbor == cell_id: + continue + down_rel[cell_id].add(neighbor) + break + + # Up relations + for col in range(col_start, col_end + 1): + for row in range(row_start - 1, -1, -1): + neighbor = occupancy[row][col] + if neighbor is None or neighbor == cell_id: + continue + up_rel[cell_id].add(neighbor) + break + + # Top heading relations + for col in range(col_start, col_end + 1): + seen = set() + for row in range(row_start - 1, -1, -1): + neighbor = occupancy[row][col] + if neighbor is None or neighbor == cell_id or neighbor in seen: + continue + seen.add(neighbor) + if neighbor in heading_cells: + top_heading_rel[cell_id].add(neighbor) + + # Left heading relations + for row in range(row_start, row_end + 1): + seen = set() + for col in range(col_start - 1, -1, -1): + neighbor = occupancy[row][col] + if neighbor is None or neighbor == cell_id or neighbor in seen: + continue + seen.add(neighbor) + if neighbor in heading_cells: + left_heading_rel[cell_id].add(neighbor) + + # Ensure every cell has an entry in relations dictionaries + up_relations = {cell_id: set(up_rel[cell_id]) for cell_id in cell_text} + down_relations = {cell_id: set(down_rel[cell_id]) for cell_id in cell_text} + left_relations = {cell_id: set(left_rel[cell_id]) for cell_id in cell_text} + right_relations = {cell_id: set(right_rel[cell_id]) for cell_id in cell_text} + top_heading_relations = {cell_id: set(top_heading_rel[cell_id]) for cell_id in cell_text} + left_heading_relations = {cell_id: set(left_heading_rel[cell_id]) for cell_id in cell_text} + + return TableData( + cell_text=cell_text, + heading_cells=heading_cells, + up_relations=up_relations, + down_relations=down_relations, + left_relations=left_relations, + right_relations=right_relations, + top_heading_relations=top_heading_relations, + left_heading_relations=left_heading_relations, + ) + def parse_markdown_tables(md_content: str) -> List[TableData]: """ @@ -183,74 +360,46 @@ def parse_markdown_tables(md_content: str) -> List[TableData]: if len(current_table_lines) >= 2: table_data = _process_table_lines(current_table_lines) if table_data and len(table_data) > 0: - # Convert to numpy array for easier manipulation - max_cols = max(len(row) for row in table_data) - padded_data = [row + [""] * (max_cols - len(row)) for row in table_data] - table_array = np.array(padded_data) - - # In markdown tables, the first row is typically a header row - header_rows = {0} if len(table_array) > 0 else set() - - # Set up col_headers with first row headers for each column - col_headers = {} - if len(table_array) > 0: - for col_idx in range(table_array.shape[1]): - if col_idx < len(table_array[0]): - col_headers[col_idx] = [(0, table_array[0, col_idx])] - - # Set up row_headers with first column headers for each row - row_headers = {} - if table_array.shape[1] > 0: - for row_idx in range(1, table_array.shape[0]): # Skip header row - row_headers[row_idx] = [(0, table_array[row_idx, 0])] # First column as heading - - # Create TableData object - parsed_tables.append( - TableData( - data=table_array, - header_rows=header_rows, - header_cols={0} if table_array.shape[1] > 0 else set(), # First column as header - col_headers=col_headers, - row_headers=row_headers, + row_specs: List[List[Dict[str, Union[str, int, bool]]]] = [] + for row_idx, row in enumerate(table_data): + row_specs.append( + [ + { + "text": cell, + "rowspan": 1, + "colspan": 1, + "is_heading": row_idx == 0 or col_idx == 0, + } + for col_idx, cell in enumerate(row) + ] ) - ) + + table = _build_table_data_from_specs(row_specs) + if table: + parsed_tables.append(table) in_table = False # Process the last table if we're still tracking one at the end of the file if in_table and len(current_table_lines) >= 2: table_data = _process_table_lines(current_table_lines) if table_data and len(table_data) > 0: - # Convert to numpy array - max_cols = max(len(row) for row in table_data) - padded_data = [row + [""] * (max_cols - len(row)) for row in table_data] - table_array = np.array(padded_data) - - # In markdown tables, the first row is typically a header row - header_rows = {0} if len(table_array) > 0 else set() - - # Set up col_headers with first row headers for each column - col_headers = {} - if len(table_array) > 0: - for col_idx in range(table_array.shape[1]): - if col_idx < len(table_array[0]): - col_headers[col_idx] = [(0, table_array[0, col_idx])] - - # Set up row_headers with first column headers for each row - row_headers = {} - if table_array.shape[1] > 0: - for row_idx in range(1, table_array.shape[0]): # Skip header row - row_headers[row_idx] = [(0, table_array[row_idx, 0])] # First column as heading - - # Create TableData object - parsed_tables.append( - TableData( - data=table_array, - header_rows=header_rows, - header_cols={0} if table_array.shape[1] > 0 else set(), # First column as header - col_headers=col_headers, - row_headers=row_headers, + row_specs = [] + for row_idx, row in enumerate(table_data): + row_specs.append( + [ + { + "text": cell, + "rowspan": 1, + "colspan": 1, + "is_heading": row_idx == 0 or col_idx == 0, + } + for col_idx, cell in enumerate(row) + ] ) - ) + + table = _build_table_data_from_specs(row_specs) + if table: + parsed_tables.append(table) return parsed_tables @@ -318,159 +467,47 @@ def parse_html_tables(html_content: str) -> List[TableData]: parsed_tables = [] for table in tables: - rows = table.find_all(["tr"]) - table_data = [] - header_rows = set() - header_cols = set() - col_headers = {} # Maps column index to all header cells above it - row_headers = {} # Maps row index to all header cells to its left + rows = table.find_all("tr") + if not rows: + continue - # Find rows inside thead tags - these are definitely header rows - thead = table.find("thead") - if thead: - thead_rows = thead.find_all("tr") - for tr in thead_rows: - header_rows.add(rows.index(tr)) + row_specs: List[List[Dict[str, Union[str, int, bool]]]] = [] + total_rows = len(rows) - # Initialize a grid to track filled cells due to rowspan/colspan - cell_grid = {} - col_span_info = {} # Tracks which columns contain headers - row_span_info = {} # Tracks which rows contain headers - - # First pass: process each row to build the raw table data and identify headers for row_idx, row in enumerate(rows): - cells = row.find_all(["th", "td"]) - row_data = [] - col_idx = 0 - - # If there are th elements in this row, it's likely a header row - if row.find("th"): - header_rows.add(row_idx) + cells = row.find_all(["th", "td"], recursive=False) + heading_context = row.find_parent("thead") is not None + row_spec: List[Dict[str, Union[str, int, bool]]] = [] for cell in cells: - # Skip positions already filled by rowspans from above - while (row_idx, col_idx) in cell_grid: - row_data.append(cell_grid[(row_idx, col_idx)]) - col_idx += 1 - - # Replace
and
tags with newlines before getting text for br in cell.find_all("br"): br.replace_with("\n") - cell_text = cell.get_text().strip() - # Handle rowspan/colspan - rowspan = int(cell.get("rowspan", 1)) - colspan = int(cell.get("colspan", 1)) + text = cell.get_text(separator="\n").strip() + raw_rowspan = cell.get("rowspan") + raw_colspan = cell.get("colspan") - # Add the cell to the row data - row_data.append(cell_text) + rowspan = _safe_span_int(raw_rowspan, 1) + colspan = _safe_span_int(raw_colspan, 1) - # Fill the grid for this cell and its rowspan/colspan - for i in range(rowspan): - for j in range(colspan): - if i == 0 and j == 0: - continue # Skip the main cell position - # For rowspan cells, preserve the text in all spanned rows - if j == 0 and i > 0: # Only for cells directly below - cell_grid[(row_idx + i, col_idx + j)] = cell_text - else: - cell_grid[(row_idx + i, col_idx + j)] = "" # Mark other spans as empty + # HTML specifies rowspan=0 to extend to the end of the table section + if isinstance(raw_rowspan, str) and raw_rowspan.strip() == "0": + rowspan = max(1, total_rows - row_idx) - # If this is a header cell (th), mark it and its span - if cell.name == "th": - # Mark columns as header columns - for j in range(colspan): - header_cols.add(col_idx + j) + is_heading = cell.name == "th" or heading_context - # For rowspan, mark spanned rows as part of header - for i in range(1, rowspan): - if row_idx + i < len(rows): - header_rows.add(row_idx + i) + row_spec.append({"text": text, "rowspan": rowspan, "colspan": colspan, "is_heading": is_heading}) - # Record this header for all spanned columns - for j in range(colspan): - curr_col = col_idx + j - if curr_col not in col_headers: - col_headers[curr_col] = [] - col_headers[curr_col].append((row_idx, cell_text)) + row_specs.append(row_spec) - # Store which columns are covered by this header - if cell_text and colspan > 1: - if cell_text not in col_span_info: - col_span_info[cell_text] = set() - col_span_info[cell_text].add(curr_col) - - # Store which rows are covered by this header for rowspan - if cell_text and rowspan > 1: - if cell_text not in row_span_info: - row_span_info[cell_text] = set() - for i in range(rowspan): - row_span_info[cell_text].add(row_idx + i) - - # Also handle row headers from data cells that have rowspan - if cell.name == "td" and rowspan > 1 and col_idx in header_cols: - for i in range(1, rowspan): - if row_idx + i < len(rows): - if row_idx + i not in row_headers: - row_headers[row_idx + i] = [] - row_headers[row_idx + i].append((col_idx, cell_text)) - - col_idx += colspan - - # Pad the row if needed to handle different row lengths - table_data.append(row_data) - - # Second pass: expand headers to cells that should inherit them - # First handle column headers - for header_text, columns in col_span_info.items(): - for col in columns: - # Add this header to all columns it spans over - for row_idx in range(len(table_data)): - if row_idx not in header_rows: # Only apply to data rows - for j in range(col, len(table_data[row_idx]) if row_idx < len(table_data) else 0): - # Add header info to data cells in these columns - if j not in col_headers: - col_headers[j] = [] - if not any(h[1] == header_text for h in col_headers[j]): - header_row = min([r for r, t in col_headers.get(col, [(0, "")])]) - col_headers[j].append((header_row, header_text)) - - # Handle row headers - for header_text, rows in row_span_info.items(): - for row in rows: - if row < len(table_data): - # Find first header column - header_col = min(header_cols) if header_cols else 0 - if row not in row_headers: - row_headers[row] = [] - if not any(h[1] == header_text for h in row_headers.get(row, [])): - row_headers[row].append((header_col, header_text)) - - # Process regular row headers - each cell in a header column becomes a header for its row - for col_idx in header_cols: - for row_idx, row in enumerate(table_data): - if col_idx < len(row) and row[col_idx].strip(): - if row_idx not in row_headers: - row_headers[row_idx] = [] - if not any(h[1] == row[col_idx] for h in row_headers.get(row_idx, [])): - row_headers[row_idx].append((col_idx, row[col_idx])) - - # Calculate max columns for padding - max_cols = max(len(row) for row in table_data) if table_data else 0 - - # Ensure all rows have the same number of columns + table_data = _build_table_data_from_specs(row_specs) if table_data: - padded_data = [row + [""] * (max_cols - len(row)) for row in table_data] - table_array = np.array(padded_data) - - # Create TableData object with the table and header information - parsed_tables.append( - TableData(data=table_array, header_rows=header_rows, header_cols=header_cols, col_headers=col_headers, row_headers=row_headers) - ) + parsed_tables.append(table_data) return parsed_tables + @dataclass(kw_only=True) class BasePDFTest: """ diff --git a/scripts/run_infrapartner_benchmark.sh b/scripts/run_infrapartner_benchmark.sh new file mode 100755 index 0000000..3c8aa6b --- /dev/null +++ b/scripts/run_infrapartner_benchmark.sh @@ -0,0 +1,272 @@ +#!/bin/bash + +# Runs an olmocr-bench run using the full pipeline (no fallback) for infrapartner testing +# This version skips the performance task and adds support for --server, --model, and --beaker-secret arguments +# +# Usage examples: +# ./scripts/run_infrapartner_benchmark.sh --server http://example.com --model your-model-name --beaker-secret my-api-key-secret +# ./scripts/run_infrapartner_benchmark.sh --beaker-image jakep/olmocr-benchmark-0.3.3-780bc7d934 --server http://example.com + +set -e + +# Parse command line arguments +MODEL="" +SERVER="" +BEAKER_SECRET="" +BENCH_BRANCH="" +BEAKER_IMAGE="" +while [[ $# -gt 0 ]]; do + case $1 in + --model) + MODEL="$2" + shift 2 + ;; + --server) + SERVER="$2" + shift 2 + ;; + --beaker-secret) + BEAKER_SECRET="$2" + shift 2 + ;; + --benchbranch) + BENCH_BRANCH="$2" + shift 2 + ;; + --beaker-image) + BEAKER_IMAGE="$2" + shift 2 + ;; + *) + echo "Unknown option: $1" + echo "Usage: $0 [--server SERVER_URL] [--model MODEL_NAME] [--beaker-secret SECRET_NAME] [--benchbranch BRANCH_NAME] [--beaker-image IMAGE_NAME]" + exit 1 + ;; + esac +done + +# Check for uncommitted changes +if [ -n "$BEAKER_IMAGE" ]; then + echo "Skipping docker build" +else + if ! git diff-index --quiet HEAD --; then + echo "Error: There are uncommitted changes in the repository." + echo "Please commit or stash your changes before running the benchmark." + echo "" + echo "Uncommitted changes:" + git status --short + exit 1 + fi +fi + +# Use conda environment Python if available, otherwise use system Python +if [ -n "$CONDA_PREFIX" ]; then + PYTHON="$CONDA_PREFIX/bin/python" + echo "Using conda Python from: $CONDA_PREFIX" +else + PYTHON="python" + echo "Warning: No conda environment detected, using system Python" +fi + +# Get version from version.py +VERSION=$($PYTHON -c 'import olmocr.version; print(olmocr.version.VERSION)') +echo "OlmOCR version: $VERSION" + +# Get first 10 characters of git hash +GIT_HASH=$(git rev-parse HEAD | cut -c1-10) +echo "Git hash: $GIT_HASH" + +# Get current git branch name +GIT_BRANCH=$(git rev-parse --abbrev-ref HEAD) +echo "Git branch: $GIT_BRANCH" + +# Check if a Beaker image was provided +if [ -n "$BEAKER_IMAGE" ]; then + echo "Using provided Beaker image: $BEAKER_IMAGE" + IMAGE_TAG="$BEAKER_IMAGE" +else + # Create full image tag + IMAGE_TAG="olmocr-benchmark-${VERSION}-${GIT_HASH}" + echo "Building Docker image with tag: $IMAGE_TAG" + + # Build the Docker image + echo "Building Docker image..." + docker build --platform linux/amd64 -f ./Dockerfile -t $IMAGE_TAG . + + # Push image to beaker + echo "Trying to push image to Beaker..." + if ! beaker image create --workspace ai2/oe-data-pdf --name $IMAGE_TAG $IMAGE_TAG 2>/dev/null; then + echo "Warning: Beaker image with tag $IMAGE_TAG already exists. Using existing image." + fi +fi + +# Get Beaker username +BEAKER_USER=$(beaker account whoami --format json | jq -r '.[0].name') +echo "Beaker user: $BEAKER_USER" + +# Create Python script to run beaker experiment +cat << 'EOF' > /tmp/run_infrapartner_benchmark_experiment.py +import sys +from beaker import Beaker, ExperimentSpec, TaskSpec, TaskContext, ResultSpec, TaskResources, ImageSource, Priority, Constraints, EnvVar + +# Get image tag, beaker user, git branch, git hash, optional model, server, beaker_secret, and bench branch from command line +image_tag = sys.argv[1] +beaker_user = sys.argv[2] +git_branch = sys.argv[3] +git_hash = sys.argv[4] +model = None +server = None +beaker_secret = None +bench_branch = None + +# Parse remaining arguments +arg_idx = 5 +while arg_idx < len(sys.argv): + if sys.argv[arg_idx] == "--benchbranch": + bench_branch = sys.argv[arg_idx + 1] + arg_idx += 2 + elif sys.argv[arg_idx] == "--model": + model = sys.argv[arg_idx + 1] + arg_idx += 2 + elif sys.argv[arg_idx] == "--server": + server = sys.argv[arg_idx + 1] + arg_idx += 2 + elif sys.argv[arg_idx] == "--beaker-secret": + beaker_secret = sys.argv[arg_idx + 1] + arg_idx += 2 + else: + print(f"Unknown argument: {sys.argv[arg_idx]}") + arg_idx += 1 + +# Initialize Beaker client +b = Beaker.from_env(default_workspace="ai2/olmocr") + +# Build the pipeline command with optional parameters +pipeline_cmd = "python -m olmocr.pipeline ./localworkspace --markdown --pdfs ./olmOCR-bench/bench_data/pdfs/**/*.pdf" +if model: + pipeline_cmd += f" --model {model}" +if server: + pipeline_cmd += f" --server {server}" +if beaker_secret: + pipeline_cmd += " --api_key \"$API_KEY\"" + +# Check if AWS credentials secret exists +aws_creds_secret = f"{beaker_user}-AWS_CREDENTIALS_FILE" +try: + # Try to get the secret to see if it exists + b.secret.get(aws_creds_secret, workspace="ai2/olmocr") + has_aws_creds = True + print(f"Found AWS credentials secret: {aws_creds_secret}") +except: + has_aws_creds = False + print(f"AWS credentials secret not found: {aws_creds_secret}") + +# First experiment: Original benchmark job +commands = [] +if has_aws_creds: + commands.extend([ + "mkdir -p ~/.aws", + 'echo "$AWS_CREDENTIALS_FILE" > ~/.aws/credentials' + ]) + +# If beaker_secret is provided, export it as API_KEY environment variable +if beaker_secret: + commands.append('export API_KEY="$BEAKER_API_KEY"') + +# Build git clone command with optional branch +git_clone_cmd = "git clone https://huggingface.co/datasets/allenai/olmOCR-bench" +if bench_branch: + git_clone_cmd += f" -b {bench_branch}" + +commands.extend([ + git_clone_cmd, + "cd olmOCR-bench && git lfs pull && cd ..", + pipeline_cmd, + "python olmocr/bench/scripts/workspace_to_bench.py localworkspace/ olmOCR-bench/bench_data/olmocr --bench-path ./olmOCR-bench/", + "pip install s5cmd", + "s5cmd cp localworkspace/ s3://ai2-oe-data/jakep/olmocr-bench-runs/$BEAKER_WORKLOAD_ID/", + "python -m olmocr.bench.benchmark --dir ./olmOCR-bench/bench_data" +]) + +# Build task spec with optional env vars +# If image_tag contains '/', it's already a full beaker image reference +if '/' in image_tag: + image_ref = image_tag +else: + image_ref = f"{beaker_user}/{image_tag}" + +task_spec_args = { + "name": "olmocr-infrapartner-benchmark", + "image": ImageSource(beaker=image_ref), + "command": [ + "bash", "-c", + " && ".join(commands) + ], + "context": TaskContext( + priority=Priority.normal, + preemptible=True, + ), + "resources": TaskResources(gpu_count=0), + "constraints": Constraints(cluster=["ai2/phobos", "ai2/neptune", "ai2/saturn"]), + "result": ResultSpec(path="/noop-results"), +} + +# Build env vars list +env_vars = [] +if has_aws_creds: + env_vars.append(EnvVar(name="AWS_CREDENTIALS_FILE", secret=aws_creds_secret)) +if beaker_secret: + env_vars.append(EnvVar(name="BEAKER_API_KEY", secret=beaker_secret)) + +# Add env vars if any exist +if env_vars: + task_spec_args["env_vars"] = env_vars + +# Create experiment spec +experiment_spec = ExperimentSpec( + description=f"OlmOCR InfraPartner Benchmark Run - Branch: {git_branch}, Commit: {git_hash}", + budget="ai2/oe-base", + tasks=[TaskSpec(**task_spec_args)], +) + +# Create the experiment +experiment = b.experiment.create(spec=experiment_spec, workspace="ai2/olmocr") +print(f"Created benchmark experiment: {experiment.id}") +print(f"View at: https://beaker.org/ex/{experiment.id}") +print("-------") +print("") +print("Note: Performance test has been skipped for infrapartner benchmark") +EOF + +# Run the Python script to create the experiment +echo "Creating Beaker experiment..." + +# Build command with appropriate arguments +CMD="$PYTHON /tmp/run_infrapartner_benchmark_experiment.py $IMAGE_TAG $BEAKER_USER $GIT_BRANCH $GIT_HASH" + +if [ -n "$MODEL" ]; then + echo "Using model: $MODEL" + CMD="$CMD --model $MODEL" +fi + +if [ -n "$SERVER" ]; then + echo "Using server: $SERVER" + CMD="$CMD --server $SERVER" +fi + +if [ -n "$BEAKER_SECRET" ]; then + echo "Using beaker secret for API key: $BEAKER_SECRET" + CMD="$CMD --beaker-secret $BEAKER_SECRET" +fi + +if [ -n "$BENCH_BRANCH" ]; then + echo "Using bench branch: $BENCH_BRANCH" + CMD="$CMD --benchbranch $BENCH_BRANCH" +fi + +eval $CMD + +# Clean up temporary file +rm /tmp/run_infrapartner_benchmark_experiment.py + +echo "InfraPartner benchmark experiment submitted successfully!" \ No newline at end of file diff --git a/tests/test_table_parsing.py b/tests/test_table_parsing.py new file mode 100644 index 0000000..aa3a26e --- /dev/null +++ b/tests/test_table_parsing.py @@ -0,0 +1,162 @@ +import unittest + +from olmocr.bench.tests import ( + parse_html_tables, + parse_markdown_tables, +) + + +class TestParseHtmlTables(unittest.TestCase): + def test_basic_table(self): + data = parse_html_tables(""" + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
ArXivOld
scans
math
TablesOld
scans
Headers
&
footers
Multi
column
Long
tiny
text
BaseOverall
Mistral OCR API77.267.560.629.393.671.377.199.472.0±1.1
""")[0] + + print(data) + + self.assertEqual(data.cell_text[0,0], "") + self.assertEqual(data.cell_text[0,1], "ArXiv") + + self.assertEqual(data.left_relations[0,0], set()) + self.assertEqual(data.up_relations[0,0], set()) + + self.assertEqual(data.left_relations[0,1], {(0,0)}) + self.assertEqual(data.up_relations[1,0], {(0,0)}) + + self.assertEqual(data.heading_cells, { + (0,0), (0,1), (0,2), (0,3),(0,4), (0,5),(0,6), (0,7), (0,8), (0,9) + }) + + self.assertEqual(data.top_heading_relations[1,3], {(0,3)}) + + # If there are no left headings defined, then the left most column is considered the left heading + self.assertEqual(data.left_heading_relations[1,3], {(1,0)}) + + def test_multiple_top_headings(self): + data = parse_html_tables(""" + + + + + + + + + + + + + + + + + + + +
Fruit Costs in Unittest land
Fruit TypeCost
Apples$1.00
Oranges$2.00
""")[0] + + print(data) + + self.assertEqual(data.cell_text[0,0], "Fruit Costs in Unittest land") + self.assertEqual(data.cell_text[1,0], "Fruit Type") + self.assertEqual(data.cell_text[1,1], "Cost") + self.assertEqual(data.cell_text[2,0], "Apples") + self.assertEqual(data.cell_text[2,1], "$1.00") + self.assertEqual(data.cell_text[3,0], "Oranges") + self.assertEqual(data.cell_text[3,1], "$2.00") + + + self.assertEqual(data.up_relations[1,0], {(0,0)}) + self.assertEqual(data.up_relations[1,1], {(0,0)}) + + self.assertEqual(data.up_relations[2,0], {(1,0)}) + self.assertEqual(data.up_relations[2,1], {(1,1)}) + + self.assertEqual(data.top_heading_relations[1,0], {(0,0)}) + self.assertEqual(data.top_heading_relations[1,1], {(0,0)}) + + self.assertEqual(data.top_heading_relations[2,0], {(0,0), (1,0)}) + self.assertEqual(data.top_heading_relations[2,1], {(0,0), (1,1)}) + + def test_4x4_table_with_spans(self): + """Test a 4x4 table with various row spans and column spans""" + data = parse_html_tables(""" + + + + + + + + + + + + + + + + + + + + + + + +
Header 1Header 2-3Header 4
Cell A (spans 2 rows)Cell BCell CCell D (spans 2 rows)
Cell E-F (spans 2 cols)
Cell GCell H-I-J (spans 3 cols)
""")[0] + + print(data) + + # Test header row + self.assertEqual(data.cell_text[0,0], "Header 1") + self.assertEqual(data.cell_text[0,1], "Header 2-3") + + self.assertNotIn((0,2), data.cell_text) # colspan=2, so that next cell is empty + self.assertEqual(data.cell_text[0,3], "Header 4") + + # Test first body row + self.assertEqual(data.cell_text[1,0], "Cell A (spans 2 rows)") + self.assertEqual(data.cell_text[1,1], "Cell B") + self.assertEqual(data.cell_text[1,2], "Cell C") + self.assertEqual(data.cell_text[1,3], "Cell D (spans 2 rows)") + + # Test second body row + self.assertNotIn((2,0), data.cell_text) + self.assertEqual(data.cell_text[2,1], "Cell E-F (spans 2 cols)") + + # Test third body row + self.assertEqual(data.cell_text[3,0], "Cell G") + self.assertEqual(data.cell_text[3,1], "Cell H-I-J (spans 3 cols)") + + # Test heading cells + self.assertEqual(data.heading_cells, { + (0,0), (0,1), (0,3) + })