mirror of
https://github.com/allenai/olmocr.git
synced 2025-11-08 06:29:29 +00:00
Working on table parsing
This commit is contained in:
parent
88937c6e40
commit
1ef66fd313
@ -6,6 +6,7 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Optional, Set, Tuple, Union
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
from bs4 import BeautifulSoup
|
||||
@ -21,76 +22,33 @@ from .katex.render import compare_rendered_equations, render_equation
|
||||
__test__ = False
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TableData:
|
||||
"""Class to hold table data and metadata about headers."""
|
||||
"""Class which holds table data as a graph of cells. Ex. you can access the value at any row, col.
|
||||
|
||||
data: np.ndarray # The actual table data
|
||||
header_rows: Set[int] = field(default_factory=set) # Indices of rows that are headers
|
||||
header_cols: Set[int] = field(default_factory=set) # Indices of columns that are headers
|
||||
col_headers: dict = field(default_factory=dict) # Maps column index to header text, handling colspan
|
||||
row_headers: dict = field(default_factory=dict) # Maps row index to header text, handling rowspan
|
||||
Cell texts are only ever present one time, so ex if on row 0, you have a colspan 1 and a colspan 2 column,
|
||||
then text gets stored at (0,0) and (0,1) only, (0,2) is not present in the data
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Returns a concise representation of the TableData object for debugging."""
|
||||
return f"TableData(shape={self.data.shape}, header_rows={len(self.header_rows)}, header_cols={len(self.header_cols)})"
|
||||
However, you can also ask, given a row, col, which set of row,col pairs is "left", "right", "up", and "down"
|
||||
from that one. There can be multiple values returned, because rowspans and colspans mean that you can have multiple cells in each direction.
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Returns a pretty string representation of the table with header information."""
|
||||
output = []
|
||||
Further more, you can also query "top_heading" and "left_heading". Where we also mark cells that are considered "headings", ex. if they are in a thead
|
||||
html tag.
|
||||
Then, for each cell, you may request top_heading/left_heading, which returns all cell positions that are headings in those directions
|
||||
Cells inside of <th> tags are considered headings automatically, but if these are not present, then the leftmost
|
||||
cell in a row in automaticaly considered a header, and the same for the top most cell in a column.
|
||||
"""
|
||||
cell_text: Dict[tuple[int, int], str] # Stores map from row, col to cell text
|
||||
heading_cells: Set[tuple[int, int]] # Contains the row, col pairs which are headings
|
||||
|
||||
# Table dimensions
|
||||
output.append(f"Table: {self.data.shape[0]} rows × {self.data.shape[1]} columns")
|
||||
up_relations: Dict[tuple[int, int], Set[tuple[int, int]]]
|
||||
down_relations: Dict[tuple[int, int], Set[tuple[int, int]]]
|
||||
left_relations: Dict[tuple[int, int], Set[tuple[int, int]]]
|
||||
right_relations: Dict[tuple[int, int], Set[tuple[int, int]]]
|
||||
|
||||
# Header info
|
||||
output.append(f"Header rows: {sorted(self.header_rows)}")
|
||||
output.append(f"Header columns: {sorted(self.header_cols)}")
|
||||
|
||||
# Table content with formatting
|
||||
separator = "+" + "+".join(["-" * 17] * self.data.shape[1]) + "+"
|
||||
|
||||
# Add a header for row indices
|
||||
output.append(separator)
|
||||
headers = [""] + [f"Column {i}" for i in range(self.data.shape[1])]
|
||||
output.append("| {:<5} | ".format("Row") + " | ".join(["{:<15}".format(h) for h in headers[1:]]) + " |")
|
||||
output.append(separator)
|
||||
|
||||
# Format each row
|
||||
for i in range(min(self.data.shape[0], 15)): # Limit to 15 rows for readability
|
||||
# Format cells, mark header cells
|
||||
cells = []
|
||||
for j in range(self.data.shape[1]):
|
||||
cell = str(self.data[i, j])
|
||||
if len(cell) > 15:
|
||||
cell = cell[:12] + "..."
|
||||
# Mark header cells with *
|
||||
if i in self.header_rows or j in self.header_cols:
|
||||
cell = f"*{cell}*"
|
||||
cells.append(cell)
|
||||
|
||||
row_str = "| {:<5} | ".format(i) + " | ".join(["{:<15}".format(c) for c in cells]) + " |"
|
||||
output.append(row_str)
|
||||
output.append(separator)
|
||||
|
||||
# If table is too large, indicate truncation
|
||||
if self.data.shape[0] > 15:
|
||||
output.append(f"... {self.data.shape[0] - 15} more rows ...")
|
||||
|
||||
# Column header details if available
|
||||
if self.col_headers:
|
||||
output.append("\nColumn header mappings:")
|
||||
for col, headers in sorted(self.col_headers.items()):
|
||||
header_strs = [f"({row}, '{text}')" for row, text in headers]
|
||||
output.append(f" Column {col}: {', '.join(header_strs)}")
|
||||
|
||||
# Row header details if available
|
||||
if self.row_headers:
|
||||
output.append("\nRow header mappings:")
|
||||
for row, headers in sorted(self.row_headers.items()):
|
||||
header_strs = [f"({col}, '{text}')" for col, text in headers]
|
||||
output.append(f" Row {row}: {', '.join(header_strs)}")
|
||||
|
||||
return "\n".join(output)
|
||||
top_heading_relations: Dict[tuple[int, int], Set[tuple[int, int]]]
|
||||
left_heading_relations: Dict[tuple[int, int], Set[tuple[int, int]]]
|
||||
|
||||
|
||||
class TestType(str, Enum):
|
||||
@ -145,6 +103,225 @@ def normalize_text(md_content: str) -> str:
|
||||
|
||||
return md_content
|
||||
|
||||
def _safe_span_int(value: Optional[Union[str, int]], default: int = 1) -> int:
|
||||
"""Convert rowspan/colspan attributes to positive integers."""
|
||||
if value in (None, "", 0):
|
||||
return default
|
||||
try:
|
||||
span = int(value)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
if span <= 0:
|
||||
return default
|
||||
return span
|
||||
|
||||
|
||||
def _build_table_data_from_specs(row_specs: List[List[Dict[str, Union[str, int, bool]]]]) -> Optional[TableData]:
|
||||
"""
|
||||
Build a TableData object from a list of row specifications.
|
||||
|
||||
Each row specification is a list of dictionaries with keys:
|
||||
- text: cell text content
|
||||
- rowspan: integer rowspan (>= 1)
|
||||
- colspan: integer colspan (>= 1)
|
||||
- is_heading: bool indicating if the cell should be treated as a heading
|
||||
"""
|
||||
if not row_specs:
|
||||
return None
|
||||
|
||||
cell_text: Dict[Tuple[int, int], str] = {}
|
||||
heading_cells: Set[Tuple[int, int]] = set()
|
||||
cell_meta: Dict[Tuple[int, int], Dict[str, Union[int, bool]]] = {}
|
||||
occupancy: List[List[Optional[Tuple[int, int]]]] = []
|
||||
active_rowspans: List[Optional[Tuple[Tuple[int, int], int]]] = []
|
||||
|
||||
for row_idx, cells in enumerate(row_specs):
|
||||
row_entries: List[Optional[Tuple[int, int]]] = []
|
||||
col_index = 0
|
||||
spec_idx = 0
|
||||
total_specs = len(cells)
|
||||
|
||||
while spec_idx < total_specs or col_index < len(active_rowspans):
|
||||
if col_index < len(active_rowspans) and active_rowspans[col_index] is not None:
|
||||
cell_id, remaining = active_rowspans[col_index]
|
||||
row_entries.append(cell_id)
|
||||
remaining -= 1
|
||||
active_rowspans[col_index] = (cell_id, remaining) if remaining > 0 else None
|
||||
col_index += 1
|
||||
continue
|
||||
|
||||
if spec_idx >= total_specs:
|
||||
if col_index < len(active_rowspans):
|
||||
row_entries.append(None)
|
||||
col_index += 1
|
||||
continue
|
||||
break
|
||||
|
||||
spec = cells[spec_idx]
|
||||
spec_idx += 1
|
||||
|
||||
text = spec.get("text", "") or ""
|
||||
rowspan = spec.get("rowspan", 1)
|
||||
colspan = spec.get("colspan", 1)
|
||||
is_heading = bool(spec.get("is_heading", False))
|
||||
|
||||
rowspan = rowspan if isinstance(rowspan, int) else _safe_span_int(rowspan)
|
||||
colspan = colspan if isinstance(colspan, int) else _safe_span_int(colspan)
|
||||
rowspan = max(1, rowspan)
|
||||
colspan = max(1, colspan)
|
||||
|
||||
cell_id = (row_idx, col_index)
|
||||
cell_text[cell_id] = text
|
||||
if is_heading:
|
||||
heading_cells.add(cell_id)
|
||||
|
||||
cell_meta[cell_id] = {
|
||||
"row": row_idx,
|
||||
"col": col_index,
|
||||
"rowspan": rowspan,
|
||||
"colspan": colspan,
|
||||
}
|
||||
|
||||
required_len = col_index + colspan
|
||||
if len(active_rowspans) < required_len:
|
||||
active_rowspans.extend([None] * (required_len - len(active_rowspans)))
|
||||
|
||||
for offset in range(colspan):
|
||||
current_col = col_index + offset
|
||||
row_entries.append(cell_id)
|
||||
if rowspan > 1:
|
||||
active_rowspans[current_col] = (cell_id, rowspan - 1)
|
||||
else:
|
||||
active_rowspans[current_col] = None
|
||||
|
||||
col_index += colspan
|
||||
|
||||
occupancy.append(row_entries)
|
||||
|
||||
# Flush any remaining active rowspans into additional rows
|
||||
while any(entry is not None for entry in active_rowspans):
|
||||
row_entries: List[Optional[Tuple[int, int]]] = []
|
||||
for col_index, span_entry in enumerate(active_rowspans):
|
||||
if span_entry is None:
|
||||
row_entries.append(None)
|
||||
continue
|
||||
cell_id, remaining = span_entry
|
||||
row_entries.append(cell_id)
|
||||
remaining -= 1
|
||||
active_rowspans[col_index] = (cell_id, remaining) if remaining > 0 else None
|
||||
occupancy.append(row_entries)
|
||||
|
||||
if not cell_text:
|
||||
return None
|
||||
|
||||
# Normalize occupancy to a consistent width based on populated columns
|
||||
valid_columns = {idx for row in occupancy for idx, value in enumerate(row) if value is not None}
|
||||
if valid_columns:
|
||||
table_width = max(valid_columns) + 1
|
||||
for row in occupancy:
|
||||
if len(row) < table_width:
|
||||
row.extend([None] * (table_width - len(row)))
|
||||
elif len(row) > table_width:
|
||||
del row[table_width:]
|
||||
else:
|
||||
return None
|
||||
|
||||
table_height = len(occupancy)
|
||||
|
||||
up_rel = defaultdict(set)
|
||||
down_rel = defaultdict(set)
|
||||
left_rel = defaultdict(set)
|
||||
right_rel = defaultdict(set)
|
||||
top_heading_rel = defaultdict(set)
|
||||
left_heading_rel = defaultdict(set)
|
||||
|
||||
for cell_id, meta in cell_meta.items():
|
||||
row_start = meta["row"]
|
||||
col_start = meta["col"]
|
||||
rowspan = meta["rowspan"]
|
||||
colspan = meta["colspan"]
|
||||
row_end = row_start + rowspan - 1
|
||||
col_end = col_start + colspan - 1
|
||||
|
||||
# Right relations
|
||||
for row in range(row_start, row_end + 1):
|
||||
for col in range(col_end + 1, table_width):
|
||||
neighbor = occupancy[row][col]
|
||||
if neighbor is None or neighbor == cell_id:
|
||||
continue
|
||||
right_rel[cell_id].add(neighbor)
|
||||
break
|
||||
|
||||
# Left relations
|
||||
for row in range(row_start, row_end + 1):
|
||||
for col in range(col_start - 1, -1, -1):
|
||||
neighbor = occupancy[row][col]
|
||||
if neighbor is None or neighbor == cell_id:
|
||||
continue
|
||||
left_rel[cell_id].add(neighbor)
|
||||
break
|
||||
|
||||
# Down relations
|
||||
for col in range(col_start, col_end + 1):
|
||||
for row in range(row_end + 1, table_height):
|
||||
if col >= len(occupancy[row]):
|
||||
continue
|
||||
neighbor = occupancy[row][col]
|
||||
if neighbor is None or neighbor == cell_id:
|
||||
continue
|
||||
down_rel[cell_id].add(neighbor)
|
||||
break
|
||||
|
||||
# Up relations
|
||||
for col in range(col_start, col_end + 1):
|
||||
for row in range(row_start - 1, -1, -1):
|
||||
neighbor = occupancy[row][col]
|
||||
if neighbor is None or neighbor == cell_id:
|
||||
continue
|
||||
up_rel[cell_id].add(neighbor)
|
||||
break
|
||||
|
||||
# Top heading relations
|
||||
for col in range(col_start, col_end + 1):
|
||||
seen = set()
|
||||
for row in range(row_start - 1, -1, -1):
|
||||
neighbor = occupancy[row][col]
|
||||
if neighbor is None or neighbor == cell_id or neighbor in seen:
|
||||
continue
|
||||
seen.add(neighbor)
|
||||
if neighbor in heading_cells:
|
||||
top_heading_rel[cell_id].add(neighbor)
|
||||
|
||||
# Left heading relations
|
||||
for row in range(row_start, row_end + 1):
|
||||
seen = set()
|
||||
for col in range(col_start - 1, -1, -1):
|
||||
neighbor = occupancy[row][col]
|
||||
if neighbor is None or neighbor == cell_id or neighbor in seen:
|
||||
continue
|
||||
seen.add(neighbor)
|
||||
if neighbor in heading_cells:
|
||||
left_heading_rel[cell_id].add(neighbor)
|
||||
|
||||
# Ensure every cell has an entry in relations dictionaries
|
||||
up_relations = {cell_id: set(up_rel[cell_id]) for cell_id in cell_text}
|
||||
down_relations = {cell_id: set(down_rel[cell_id]) for cell_id in cell_text}
|
||||
left_relations = {cell_id: set(left_rel[cell_id]) for cell_id in cell_text}
|
||||
right_relations = {cell_id: set(right_rel[cell_id]) for cell_id in cell_text}
|
||||
top_heading_relations = {cell_id: set(top_heading_rel[cell_id]) for cell_id in cell_text}
|
||||
left_heading_relations = {cell_id: set(left_heading_rel[cell_id]) for cell_id in cell_text}
|
||||
|
||||
return TableData(
|
||||
cell_text=cell_text,
|
||||
heading_cells=heading_cells,
|
||||
up_relations=up_relations,
|
||||
down_relations=down_relations,
|
||||
left_relations=left_relations,
|
||||
right_relations=right_relations,
|
||||
top_heading_relations=top_heading_relations,
|
||||
left_heading_relations=left_heading_relations,
|
||||
)
|
||||
|
||||
|
||||
def parse_markdown_tables(md_content: str) -> List[TableData]:
|
||||
"""
|
||||
@ -183,74 +360,46 @@ def parse_markdown_tables(md_content: str) -> List[TableData]:
|
||||
if len(current_table_lines) >= 2:
|
||||
table_data = _process_table_lines(current_table_lines)
|
||||
if table_data and len(table_data) > 0:
|
||||
# Convert to numpy array for easier manipulation
|
||||
max_cols = max(len(row) for row in table_data)
|
||||
padded_data = [row + [""] * (max_cols - len(row)) for row in table_data]
|
||||
table_array = np.array(padded_data)
|
||||
|
||||
# In markdown tables, the first row is typically a header row
|
||||
header_rows = {0} if len(table_array) > 0 else set()
|
||||
|
||||
# Set up col_headers with first row headers for each column
|
||||
col_headers = {}
|
||||
if len(table_array) > 0:
|
||||
for col_idx in range(table_array.shape[1]):
|
||||
if col_idx < len(table_array[0]):
|
||||
col_headers[col_idx] = [(0, table_array[0, col_idx])]
|
||||
|
||||
# Set up row_headers with first column headers for each row
|
||||
row_headers = {}
|
||||
if table_array.shape[1] > 0:
|
||||
for row_idx in range(1, table_array.shape[0]): # Skip header row
|
||||
row_headers[row_idx] = [(0, table_array[row_idx, 0])] # First column as heading
|
||||
|
||||
# Create TableData object
|
||||
parsed_tables.append(
|
||||
TableData(
|
||||
data=table_array,
|
||||
header_rows=header_rows,
|
||||
header_cols={0} if table_array.shape[1] > 0 else set(), # First column as header
|
||||
col_headers=col_headers,
|
||||
row_headers=row_headers,
|
||||
row_specs: List[List[Dict[str, Union[str, int, bool]]]] = []
|
||||
for row_idx, row in enumerate(table_data):
|
||||
row_specs.append(
|
||||
[
|
||||
{
|
||||
"text": cell,
|
||||
"rowspan": 1,
|
||||
"colspan": 1,
|
||||
"is_heading": row_idx == 0 or col_idx == 0,
|
||||
}
|
||||
for col_idx, cell in enumerate(row)
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
table = _build_table_data_from_specs(row_specs)
|
||||
if table:
|
||||
parsed_tables.append(table)
|
||||
in_table = False
|
||||
|
||||
# Process the last table if we're still tracking one at the end of the file
|
||||
if in_table and len(current_table_lines) >= 2:
|
||||
table_data = _process_table_lines(current_table_lines)
|
||||
if table_data and len(table_data) > 0:
|
||||
# Convert to numpy array
|
||||
max_cols = max(len(row) for row in table_data)
|
||||
padded_data = [row + [""] * (max_cols - len(row)) for row in table_data]
|
||||
table_array = np.array(padded_data)
|
||||
|
||||
# In markdown tables, the first row is typically a header row
|
||||
header_rows = {0} if len(table_array) > 0 else set()
|
||||
|
||||
# Set up col_headers with first row headers for each column
|
||||
col_headers = {}
|
||||
if len(table_array) > 0:
|
||||
for col_idx in range(table_array.shape[1]):
|
||||
if col_idx < len(table_array[0]):
|
||||
col_headers[col_idx] = [(0, table_array[0, col_idx])]
|
||||
|
||||
# Set up row_headers with first column headers for each row
|
||||
row_headers = {}
|
||||
if table_array.shape[1] > 0:
|
||||
for row_idx in range(1, table_array.shape[0]): # Skip header row
|
||||
row_headers[row_idx] = [(0, table_array[row_idx, 0])] # First column as heading
|
||||
|
||||
# Create TableData object
|
||||
parsed_tables.append(
|
||||
TableData(
|
||||
data=table_array,
|
||||
header_rows=header_rows,
|
||||
header_cols={0} if table_array.shape[1] > 0 else set(), # First column as header
|
||||
col_headers=col_headers,
|
||||
row_headers=row_headers,
|
||||
row_specs = []
|
||||
for row_idx, row in enumerate(table_data):
|
||||
row_specs.append(
|
||||
[
|
||||
{
|
||||
"text": cell,
|
||||
"rowspan": 1,
|
||||
"colspan": 1,
|
||||
"is_heading": row_idx == 0 or col_idx == 0,
|
||||
}
|
||||
for col_idx, cell in enumerate(row)
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
table = _build_table_data_from_specs(row_specs)
|
||||
if table:
|
||||
parsed_tables.append(table)
|
||||
|
||||
return parsed_tables
|
||||
|
||||
@ -318,159 +467,47 @@ def parse_html_tables(html_content: str) -> List[TableData]:
|
||||
parsed_tables = []
|
||||
|
||||
for table in tables:
|
||||
rows = table.find_all(["tr"])
|
||||
table_data = []
|
||||
header_rows = set()
|
||||
header_cols = set()
|
||||
col_headers = {} # Maps column index to all header cells above it
|
||||
row_headers = {} # Maps row index to all header cells to its left
|
||||
rows = table.find_all("tr")
|
||||
if not rows:
|
||||
continue
|
||||
|
||||
# Find rows inside thead tags - these are definitely header rows
|
||||
thead = table.find("thead")
|
||||
if thead:
|
||||
thead_rows = thead.find_all("tr")
|
||||
for tr in thead_rows:
|
||||
header_rows.add(rows.index(tr))
|
||||
row_specs: List[List[Dict[str, Union[str, int, bool]]]] = []
|
||||
total_rows = len(rows)
|
||||
|
||||
# Initialize a grid to track filled cells due to rowspan/colspan
|
||||
cell_grid = {}
|
||||
col_span_info = {} # Tracks which columns contain headers
|
||||
row_span_info = {} # Tracks which rows contain headers
|
||||
|
||||
# First pass: process each row to build the raw table data and identify headers
|
||||
for row_idx, row in enumerate(rows):
|
||||
cells = row.find_all(["th", "td"])
|
||||
row_data = []
|
||||
col_idx = 0
|
||||
|
||||
# If there are th elements in this row, it's likely a header row
|
||||
if row.find("th"):
|
||||
header_rows.add(row_idx)
|
||||
cells = row.find_all(["th", "td"], recursive=False)
|
||||
heading_context = row.find_parent("thead") is not None
|
||||
|
||||
row_spec: List[Dict[str, Union[str, int, bool]]] = []
|
||||
for cell in cells:
|
||||
# Skip positions already filled by rowspans from above
|
||||
while (row_idx, col_idx) in cell_grid:
|
||||
row_data.append(cell_grid[(row_idx, col_idx)])
|
||||
col_idx += 1
|
||||
|
||||
# Replace <br> and <br/> tags with newlines before getting text
|
||||
for br in cell.find_all("br"):
|
||||
br.replace_with("\n")
|
||||
cell_text = cell.get_text().strip()
|
||||
|
||||
# Handle rowspan/colspan
|
||||
rowspan = int(cell.get("rowspan", 1))
|
||||
colspan = int(cell.get("colspan", 1))
|
||||
text = cell.get_text(separator="\n").strip()
|
||||
raw_rowspan = cell.get("rowspan")
|
||||
raw_colspan = cell.get("colspan")
|
||||
|
||||
# Add the cell to the row data
|
||||
row_data.append(cell_text)
|
||||
rowspan = _safe_span_int(raw_rowspan, 1)
|
||||
colspan = _safe_span_int(raw_colspan, 1)
|
||||
|
||||
# Fill the grid for this cell and its rowspan/colspan
|
||||
for i in range(rowspan):
|
||||
for j in range(colspan):
|
||||
if i == 0 and j == 0:
|
||||
continue # Skip the main cell position
|
||||
# For rowspan cells, preserve the text in all spanned rows
|
||||
if j == 0 and i > 0: # Only for cells directly below
|
||||
cell_grid[(row_idx + i, col_idx + j)] = cell_text
|
||||
else:
|
||||
cell_grid[(row_idx + i, col_idx + j)] = "" # Mark other spans as empty
|
||||
# HTML specifies rowspan=0 to extend to the end of the table section
|
||||
if isinstance(raw_rowspan, str) and raw_rowspan.strip() == "0":
|
||||
rowspan = max(1, total_rows - row_idx)
|
||||
|
||||
# If this is a header cell (th), mark it and its span
|
||||
if cell.name == "th":
|
||||
# Mark columns as header columns
|
||||
for j in range(colspan):
|
||||
header_cols.add(col_idx + j)
|
||||
is_heading = cell.name == "th" or heading_context
|
||||
|
||||
# For rowspan, mark spanned rows as part of header
|
||||
for i in range(1, rowspan):
|
||||
if row_idx + i < len(rows):
|
||||
header_rows.add(row_idx + i)
|
||||
row_spec.append({"text": text, "rowspan": rowspan, "colspan": colspan, "is_heading": is_heading})
|
||||
|
||||
# Record this header for all spanned columns
|
||||
for j in range(colspan):
|
||||
curr_col = col_idx + j
|
||||
if curr_col not in col_headers:
|
||||
col_headers[curr_col] = []
|
||||
col_headers[curr_col].append((row_idx, cell_text))
|
||||
row_specs.append(row_spec)
|
||||
|
||||
# Store which columns are covered by this header
|
||||
if cell_text and colspan > 1:
|
||||
if cell_text not in col_span_info:
|
||||
col_span_info[cell_text] = set()
|
||||
col_span_info[cell_text].add(curr_col)
|
||||
|
||||
# Store which rows are covered by this header for rowspan
|
||||
if cell_text and rowspan > 1:
|
||||
if cell_text not in row_span_info:
|
||||
row_span_info[cell_text] = set()
|
||||
for i in range(rowspan):
|
||||
row_span_info[cell_text].add(row_idx + i)
|
||||
|
||||
# Also handle row headers from data cells that have rowspan
|
||||
if cell.name == "td" and rowspan > 1 and col_idx in header_cols:
|
||||
for i in range(1, rowspan):
|
||||
if row_idx + i < len(rows):
|
||||
if row_idx + i not in row_headers:
|
||||
row_headers[row_idx + i] = []
|
||||
row_headers[row_idx + i].append((col_idx, cell_text))
|
||||
|
||||
col_idx += colspan
|
||||
|
||||
# Pad the row if needed to handle different row lengths
|
||||
table_data.append(row_data)
|
||||
|
||||
# Second pass: expand headers to cells that should inherit them
|
||||
# First handle column headers
|
||||
for header_text, columns in col_span_info.items():
|
||||
for col in columns:
|
||||
# Add this header to all columns it spans over
|
||||
for row_idx in range(len(table_data)):
|
||||
if row_idx not in header_rows: # Only apply to data rows
|
||||
for j in range(col, len(table_data[row_idx]) if row_idx < len(table_data) else 0):
|
||||
# Add header info to data cells in these columns
|
||||
if j not in col_headers:
|
||||
col_headers[j] = []
|
||||
if not any(h[1] == header_text for h in col_headers[j]):
|
||||
header_row = min([r for r, t in col_headers.get(col, [(0, "")])])
|
||||
col_headers[j].append((header_row, header_text))
|
||||
|
||||
# Handle row headers
|
||||
for header_text, rows in row_span_info.items():
|
||||
for row in rows:
|
||||
if row < len(table_data):
|
||||
# Find first header column
|
||||
header_col = min(header_cols) if header_cols else 0
|
||||
if row not in row_headers:
|
||||
row_headers[row] = []
|
||||
if not any(h[1] == header_text for h in row_headers.get(row, [])):
|
||||
row_headers[row].append((header_col, header_text))
|
||||
|
||||
# Process regular row headers - each cell in a header column becomes a header for its row
|
||||
for col_idx in header_cols:
|
||||
for row_idx, row in enumerate(table_data):
|
||||
if col_idx < len(row) and row[col_idx].strip():
|
||||
if row_idx not in row_headers:
|
||||
row_headers[row_idx] = []
|
||||
if not any(h[1] == row[col_idx] for h in row_headers.get(row_idx, [])):
|
||||
row_headers[row_idx].append((col_idx, row[col_idx]))
|
||||
|
||||
# Calculate max columns for padding
|
||||
max_cols = max(len(row) for row in table_data) if table_data else 0
|
||||
|
||||
# Ensure all rows have the same number of columns
|
||||
table_data = _build_table_data_from_specs(row_specs)
|
||||
if table_data:
|
||||
padded_data = [row + [""] * (max_cols - len(row)) for row in table_data]
|
||||
table_array = np.array(padded_data)
|
||||
|
||||
# Create TableData object with the table and header information
|
||||
parsed_tables.append(
|
||||
TableData(data=table_array, header_rows=header_rows, header_cols=header_cols, col_headers=col_headers, row_headers=row_headers)
|
||||
)
|
||||
parsed_tables.append(table_data)
|
||||
|
||||
return parsed_tables
|
||||
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class BasePDFTest:
|
||||
"""
|
||||
|
||||
272
scripts/run_infrapartner_benchmark.sh
Executable file
272
scripts/run_infrapartner_benchmark.sh
Executable file
@ -0,0 +1,272 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Runs an olmocr-bench run using the full pipeline (no fallback) for infrapartner testing
|
||||
# This version skips the performance task and adds support for --server, --model, and --beaker-secret arguments
|
||||
#
|
||||
# Usage examples:
|
||||
# ./scripts/run_infrapartner_benchmark.sh --server http://example.com --model your-model-name --beaker-secret my-api-key-secret
|
||||
# ./scripts/run_infrapartner_benchmark.sh --beaker-image jakep/olmocr-benchmark-0.3.3-780bc7d934 --server http://example.com
|
||||
|
||||
set -e
|
||||
|
||||
# Parse command line arguments
|
||||
MODEL=""
|
||||
SERVER=""
|
||||
BEAKER_SECRET=""
|
||||
BENCH_BRANCH=""
|
||||
BEAKER_IMAGE=""
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case $1 in
|
||||
--model)
|
||||
MODEL="$2"
|
||||
shift 2
|
||||
;;
|
||||
--server)
|
||||
SERVER="$2"
|
||||
shift 2
|
||||
;;
|
||||
--beaker-secret)
|
||||
BEAKER_SECRET="$2"
|
||||
shift 2
|
||||
;;
|
||||
--benchbranch)
|
||||
BENCH_BRANCH="$2"
|
||||
shift 2
|
||||
;;
|
||||
--beaker-image)
|
||||
BEAKER_IMAGE="$2"
|
||||
shift 2
|
||||
;;
|
||||
*)
|
||||
echo "Unknown option: $1"
|
||||
echo "Usage: $0 [--server SERVER_URL] [--model MODEL_NAME] [--beaker-secret SECRET_NAME] [--benchbranch BRANCH_NAME] [--beaker-image IMAGE_NAME]"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
# Check for uncommitted changes
|
||||
if [ -n "$BEAKER_IMAGE" ]; then
|
||||
echo "Skipping docker build"
|
||||
else
|
||||
if ! git diff-index --quiet HEAD --; then
|
||||
echo "Error: There are uncommitted changes in the repository."
|
||||
echo "Please commit or stash your changes before running the benchmark."
|
||||
echo ""
|
||||
echo "Uncommitted changes:"
|
||||
git status --short
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
# Use conda environment Python if available, otherwise use system Python
|
||||
if [ -n "$CONDA_PREFIX" ]; then
|
||||
PYTHON="$CONDA_PREFIX/bin/python"
|
||||
echo "Using conda Python from: $CONDA_PREFIX"
|
||||
else
|
||||
PYTHON="python"
|
||||
echo "Warning: No conda environment detected, using system Python"
|
||||
fi
|
||||
|
||||
# Get version from version.py
|
||||
VERSION=$($PYTHON -c 'import olmocr.version; print(olmocr.version.VERSION)')
|
||||
echo "OlmOCR version: $VERSION"
|
||||
|
||||
# Get first 10 characters of git hash
|
||||
GIT_HASH=$(git rev-parse HEAD | cut -c1-10)
|
||||
echo "Git hash: $GIT_HASH"
|
||||
|
||||
# Get current git branch name
|
||||
GIT_BRANCH=$(git rev-parse --abbrev-ref HEAD)
|
||||
echo "Git branch: $GIT_BRANCH"
|
||||
|
||||
# Check if a Beaker image was provided
|
||||
if [ -n "$BEAKER_IMAGE" ]; then
|
||||
echo "Using provided Beaker image: $BEAKER_IMAGE"
|
||||
IMAGE_TAG="$BEAKER_IMAGE"
|
||||
else
|
||||
# Create full image tag
|
||||
IMAGE_TAG="olmocr-benchmark-${VERSION}-${GIT_HASH}"
|
||||
echo "Building Docker image with tag: $IMAGE_TAG"
|
||||
|
||||
# Build the Docker image
|
||||
echo "Building Docker image..."
|
||||
docker build --platform linux/amd64 -f ./Dockerfile -t $IMAGE_TAG .
|
||||
|
||||
# Push image to beaker
|
||||
echo "Trying to push image to Beaker..."
|
||||
if ! beaker image create --workspace ai2/oe-data-pdf --name $IMAGE_TAG $IMAGE_TAG 2>/dev/null; then
|
||||
echo "Warning: Beaker image with tag $IMAGE_TAG already exists. Using existing image."
|
||||
fi
|
||||
fi
|
||||
|
||||
# Get Beaker username
|
||||
BEAKER_USER=$(beaker account whoami --format json | jq -r '.[0].name')
|
||||
echo "Beaker user: $BEAKER_USER"
|
||||
|
||||
# Create Python script to run beaker experiment
|
||||
cat << 'EOF' > /tmp/run_infrapartner_benchmark_experiment.py
|
||||
import sys
|
||||
from beaker import Beaker, ExperimentSpec, TaskSpec, TaskContext, ResultSpec, TaskResources, ImageSource, Priority, Constraints, EnvVar
|
||||
|
||||
# Get image tag, beaker user, git branch, git hash, optional model, server, beaker_secret, and bench branch from command line
|
||||
image_tag = sys.argv[1]
|
||||
beaker_user = sys.argv[2]
|
||||
git_branch = sys.argv[3]
|
||||
git_hash = sys.argv[4]
|
||||
model = None
|
||||
server = None
|
||||
beaker_secret = None
|
||||
bench_branch = None
|
||||
|
||||
# Parse remaining arguments
|
||||
arg_idx = 5
|
||||
while arg_idx < len(sys.argv):
|
||||
if sys.argv[arg_idx] == "--benchbranch":
|
||||
bench_branch = sys.argv[arg_idx + 1]
|
||||
arg_idx += 2
|
||||
elif sys.argv[arg_idx] == "--model":
|
||||
model = sys.argv[arg_idx + 1]
|
||||
arg_idx += 2
|
||||
elif sys.argv[arg_idx] == "--server":
|
||||
server = sys.argv[arg_idx + 1]
|
||||
arg_idx += 2
|
||||
elif sys.argv[arg_idx] == "--beaker-secret":
|
||||
beaker_secret = sys.argv[arg_idx + 1]
|
||||
arg_idx += 2
|
||||
else:
|
||||
print(f"Unknown argument: {sys.argv[arg_idx]}")
|
||||
arg_idx += 1
|
||||
|
||||
# Initialize Beaker client
|
||||
b = Beaker.from_env(default_workspace="ai2/olmocr")
|
||||
|
||||
# Build the pipeline command with optional parameters
|
||||
pipeline_cmd = "python -m olmocr.pipeline ./localworkspace --markdown --pdfs ./olmOCR-bench/bench_data/pdfs/**/*.pdf"
|
||||
if model:
|
||||
pipeline_cmd += f" --model {model}"
|
||||
if server:
|
||||
pipeline_cmd += f" --server {server}"
|
||||
if beaker_secret:
|
||||
pipeline_cmd += " --api_key \"$API_KEY\""
|
||||
|
||||
# Check if AWS credentials secret exists
|
||||
aws_creds_secret = f"{beaker_user}-AWS_CREDENTIALS_FILE"
|
||||
try:
|
||||
# Try to get the secret to see if it exists
|
||||
b.secret.get(aws_creds_secret, workspace="ai2/olmocr")
|
||||
has_aws_creds = True
|
||||
print(f"Found AWS credentials secret: {aws_creds_secret}")
|
||||
except:
|
||||
has_aws_creds = False
|
||||
print(f"AWS credentials secret not found: {aws_creds_secret}")
|
||||
|
||||
# First experiment: Original benchmark job
|
||||
commands = []
|
||||
if has_aws_creds:
|
||||
commands.extend([
|
||||
"mkdir -p ~/.aws",
|
||||
'echo "$AWS_CREDENTIALS_FILE" > ~/.aws/credentials'
|
||||
])
|
||||
|
||||
# If beaker_secret is provided, export it as API_KEY environment variable
|
||||
if beaker_secret:
|
||||
commands.append('export API_KEY="$BEAKER_API_KEY"')
|
||||
|
||||
# Build git clone command with optional branch
|
||||
git_clone_cmd = "git clone https://huggingface.co/datasets/allenai/olmOCR-bench"
|
||||
if bench_branch:
|
||||
git_clone_cmd += f" -b {bench_branch}"
|
||||
|
||||
commands.extend([
|
||||
git_clone_cmd,
|
||||
"cd olmOCR-bench && git lfs pull && cd ..",
|
||||
pipeline_cmd,
|
||||
"python olmocr/bench/scripts/workspace_to_bench.py localworkspace/ olmOCR-bench/bench_data/olmocr --bench-path ./olmOCR-bench/",
|
||||
"pip install s5cmd",
|
||||
"s5cmd cp localworkspace/ s3://ai2-oe-data/jakep/olmocr-bench-runs/$BEAKER_WORKLOAD_ID/",
|
||||
"python -m olmocr.bench.benchmark --dir ./olmOCR-bench/bench_data"
|
||||
])
|
||||
|
||||
# Build task spec with optional env vars
|
||||
# If image_tag contains '/', it's already a full beaker image reference
|
||||
if '/' in image_tag:
|
||||
image_ref = image_tag
|
||||
else:
|
||||
image_ref = f"{beaker_user}/{image_tag}"
|
||||
|
||||
task_spec_args = {
|
||||
"name": "olmocr-infrapartner-benchmark",
|
||||
"image": ImageSource(beaker=image_ref),
|
||||
"command": [
|
||||
"bash", "-c",
|
||||
" && ".join(commands)
|
||||
],
|
||||
"context": TaskContext(
|
||||
priority=Priority.normal,
|
||||
preemptible=True,
|
||||
),
|
||||
"resources": TaskResources(gpu_count=0),
|
||||
"constraints": Constraints(cluster=["ai2/phobos", "ai2/neptune", "ai2/saturn"]),
|
||||
"result": ResultSpec(path="/noop-results"),
|
||||
}
|
||||
|
||||
# Build env vars list
|
||||
env_vars = []
|
||||
if has_aws_creds:
|
||||
env_vars.append(EnvVar(name="AWS_CREDENTIALS_FILE", secret=aws_creds_secret))
|
||||
if beaker_secret:
|
||||
env_vars.append(EnvVar(name="BEAKER_API_KEY", secret=beaker_secret))
|
||||
|
||||
# Add env vars if any exist
|
||||
if env_vars:
|
||||
task_spec_args["env_vars"] = env_vars
|
||||
|
||||
# Create experiment spec
|
||||
experiment_spec = ExperimentSpec(
|
||||
description=f"OlmOCR InfraPartner Benchmark Run - Branch: {git_branch}, Commit: {git_hash}",
|
||||
budget="ai2/oe-base",
|
||||
tasks=[TaskSpec(**task_spec_args)],
|
||||
)
|
||||
|
||||
# Create the experiment
|
||||
experiment = b.experiment.create(spec=experiment_spec, workspace="ai2/olmocr")
|
||||
print(f"Created benchmark experiment: {experiment.id}")
|
||||
print(f"View at: https://beaker.org/ex/{experiment.id}")
|
||||
print("-------")
|
||||
print("")
|
||||
print("Note: Performance test has been skipped for infrapartner benchmark")
|
||||
EOF
|
||||
|
||||
# Run the Python script to create the experiment
|
||||
echo "Creating Beaker experiment..."
|
||||
|
||||
# Build command with appropriate arguments
|
||||
CMD="$PYTHON /tmp/run_infrapartner_benchmark_experiment.py $IMAGE_TAG $BEAKER_USER $GIT_BRANCH $GIT_HASH"
|
||||
|
||||
if [ -n "$MODEL" ]; then
|
||||
echo "Using model: $MODEL"
|
||||
CMD="$CMD --model $MODEL"
|
||||
fi
|
||||
|
||||
if [ -n "$SERVER" ]; then
|
||||
echo "Using server: $SERVER"
|
||||
CMD="$CMD --server $SERVER"
|
||||
fi
|
||||
|
||||
if [ -n "$BEAKER_SECRET" ]; then
|
||||
echo "Using beaker secret for API key: $BEAKER_SECRET"
|
||||
CMD="$CMD --beaker-secret $BEAKER_SECRET"
|
||||
fi
|
||||
|
||||
if [ -n "$BENCH_BRANCH" ]; then
|
||||
echo "Using bench branch: $BENCH_BRANCH"
|
||||
CMD="$CMD --benchbranch $BENCH_BRANCH"
|
||||
fi
|
||||
|
||||
eval $CMD
|
||||
|
||||
# Clean up temporary file
|
||||
rm /tmp/run_infrapartner_benchmark_experiment.py
|
||||
|
||||
echo "InfraPartner benchmark experiment submitted successfully!"
|
||||
162
tests/test_table_parsing.py
Normal file
162
tests/test_table_parsing.py
Normal file
@ -0,0 +1,162 @@
|
||||
import unittest
|
||||
|
||||
from olmocr.bench.tests import (
|
||||
parse_html_tables,
|
||||
parse_markdown_tables,
|
||||
)
|
||||
|
||||
|
||||
class TestParseHtmlTables(unittest.TestCase):
|
||||
def test_basic_table(self):
|
||||
data = parse_html_tables("""
|
||||
<table border="1">
|
||||
<thead>
|
||||
<tr>
|
||||
<th></th>
|
||||
<th>ArXiv</th>
|
||||
<th>Old<br>scans<br>math</th>
|
||||
<th>Tables</th>
|
||||
<th>Old<br>scans</th>
|
||||
<th>Headers<br>&<br>footers</th>
|
||||
<th>Multi<br>column</th>
|
||||
<th>Long<br>tiny<br>text</th>
|
||||
<th>Base</th>
|
||||
<th>Overall</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
<tr>
|
||||
<td>Mistral OCR API</td>
|
||||
<td>77.2</td>
|
||||
<td>67.5</td>
|
||||
<td>60.6</td>
|
||||
<td>29.3</td>
|
||||
<td>93.6</td>
|
||||
<td>71.3</td>
|
||||
<td>77.1</td>
|
||||
<td>99.4</td>
|
||||
<td>72.0±1.1</td>
|
||||
</tr>
|
||||
</tbody></table>""")[0]
|
||||
|
||||
print(data)
|
||||
|
||||
self.assertEqual(data.cell_text[0,0], "")
|
||||
self.assertEqual(data.cell_text[0,1], "ArXiv")
|
||||
|
||||
self.assertEqual(data.left_relations[0,0], set())
|
||||
self.assertEqual(data.up_relations[0,0], set())
|
||||
|
||||
self.assertEqual(data.left_relations[0,1], {(0,0)})
|
||||
self.assertEqual(data.up_relations[1,0], {(0,0)})
|
||||
|
||||
self.assertEqual(data.heading_cells, {
|
||||
(0,0), (0,1), (0,2), (0,3),(0,4), (0,5),(0,6), (0,7), (0,8), (0,9)
|
||||
})
|
||||
|
||||
self.assertEqual(data.top_heading_relations[1,3], {(0,3)})
|
||||
|
||||
# If there are no left headings defined, then the left most column is considered the left heading
|
||||
self.assertEqual(data.left_heading_relations[1,3], {(1,0)})
|
||||
|
||||
def test_multiple_top_headings(self):
|
||||
data = parse_html_tables("""
|
||||
<table border="1">
|
||||
<thead>
|
||||
<tr>
|
||||
<th colspan="2">Fruit Costs in Unittest land</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<th>Fruit Type</th>
|
||||
<th>Cost</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
<tr>
|
||||
<td>Apples</td>
|
||||
<td>$1.00</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>Oranges</td>
|
||||
<td>$2.00</td>
|
||||
</tr>
|
||||
</tbody></table>""")[0]
|
||||
|
||||
print(data)
|
||||
|
||||
self.assertEqual(data.cell_text[0,0], "Fruit Costs in Unittest land")
|
||||
self.assertEqual(data.cell_text[1,0], "Fruit Type")
|
||||
self.assertEqual(data.cell_text[1,1], "Cost")
|
||||
self.assertEqual(data.cell_text[2,0], "Apples")
|
||||
self.assertEqual(data.cell_text[2,1], "$1.00")
|
||||
self.assertEqual(data.cell_text[3,0], "Oranges")
|
||||
self.assertEqual(data.cell_text[3,1], "$2.00")
|
||||
|
||||
|
||||
self.assertEqual(data.up_relations[1,0], {(0,0)})
|
||||
self.assertEqual(data.up_relations[1,1], {(0,0)})
|
||||
|
||||
self.assertEqual(data.up_relations[2,0], {(1,0)})
|
||||
self.assertEqual(data.up_relations[2,1], {(1,1)})
|
||||
|
||||
self.assertEqual(data.top_heading_relations[1,0], {(0,0)})
|
||||
self.assertEqual(data.top_heading_relations[1,1], {(0,0)})
|
||||
|
||||
self.assertEqual(data.top_heading_relations[2,0], {(0,0), (1,0)})
|
||||
self.assertEqual(data.top_heading_relations[2,1], {(0,0), (1,1)})
|
||||
|
||||
def test_4x4_table_with_spans(self):
|
||||
"""Test a 4x4 table with various row spans and column spans"""
|
||||
data = parse_html_tables("""
|
||||
<table border="1">
|
||||
<thead>
|
||||
<tr>
|
||||
<th>Header 1</th>
|
||||
<th colspan="2">Header 2-3</th>
|
||||
<th>Header 4</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
<tr>
|
||||
<td rowspan="2">Cell A (spans 2 rows)</td>
|
||||
<td>Cell B</td>
|
||||
<td>Cell C</td>
|
||||
<td rowspan="2">Cell D (spans 2 rows)</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td colspan="2">Cell E-F (spans 2 cols)</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>Cell G</td>
|
||||
<td colspan="3">Cell H-I-J (spans 3 cols)</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>""")[0]
|
||||
|
||||
print(data)
|
||||
|
||||
# Test header row
|
||||
self.assertEqual(data.cell_text[0,0], "Header 1")
|
||||
self.assertEqual(data.cell_text[0,1], "Header 2-3")
|
||||
|
||||
self.assertNotIn((0,2), data.cell_text) # colspan=2, so that next cell is empty
|
||||
self.assertEqual(data.cell_text[0,3], "Header 4")
|
||||
|
||||
# Test first body row
|
||||
self.assertEqual(data.cell_text[1,0], "Cell A (spans 2 rows)")
|
||||
self.assertEqual(data.cell_text[1,1], "Cell B")
|
||||
self.assertEqual(data.cell_text[1,2], "Cell C")
|
||||
self.assertEqual(data.cell_text[1,3], "Cell D (spans 2 rows)")
|
||||
|
||||
# Test second body row
|
||||
self.assertNotIn((2,0), data.cell_text)
|
||||
self.assertEqual(data.cell_text[2,1], "Cell E-F (spans 2 cols)")
|
||||
|
||||
# Test third body row
|
||||
self.assertEqual(data.cell_text[3,0], "Cell G")
|
||||
self.assertEqual(data.cell_text[3,1], "Cell H-I-J (spans 3 cols)")
|
||||
|
||||
# Test heading cells
|
||||
self.assertEqual(data.heading_cells, {
|
||||
(0,0), (0,1), (0,3)
|
||||
})
|
||||
Loading…
x
Reference in New Issue
Block a user