mirror of
https://github.com/allenai/olmocr.git
synced 2025-09-26 08:54:01 +00:00
Better markdown table parsing
This commit is contained in:
parent
3fef3f914f
commit
a2b5ca8d41
@ -186,59 +186,65 @@ class TableTest(BasePDFTest):
|
|||||||
def parse_markdown_tables(self, md_content: str) -> List[np.ndarray]:
|
def parse_markdown_tables(self, md_content: str) -> List[np.ndarray]:
|
||||||
"""
|
"""
|
||||||
Extract and parse all markdown tables from the provided content.
|
Extract and parse all markdown tables from the provided content.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
md_content: The markdown content containing tables
|
md_content: The markdown content containing tables
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A list of numpy arrays, each representing a parsed table
|
A list of numpy arrays, each representing a parsed table
|
||||||
"""
|
"""
|
||||||
# Extract all tables from markdown
|
import re
|
||||||
table_pattern = r'(\|(?:[^|]*\|)+)\s*\n\|(?:[ :-]+\|)+\s*\n((?:\|(?:[^|]*\|)+\s*\n)+)'
|
import numpy as np
|
||||||
|
|
||||||
|
# Updated regex to allow optional leading and trailing pipes
|
||||||
|
table_pattern = (
|
||||||
|
r'(\|?(?:[^|\n]*\|)+[^|\n]*\|?)\s*\n'
|
||||||
|
r'\|?(?:[ :-]+\|)+[ :-]+\|?\s*\n'
|
||||||
|
r'((?:\|?(?:[^|\n]*\|)+[^|\n]*\|?\s*\n)+)'
|
||||||
|
)
|
||||||
table_matches = re.finditer(table_pattern, md_content)
|
table_matches = re.finditer(table_pattern, md_content)
|
||||||
|
|
||||||
parsed_tables = []
|
parsed_tables = []
|
||||||
|
|
||||||
for table_match in table_matches:
|
for table_match in table_matches:
|
||||||
# Extract header and body from the table match
|
# Extract header and body from the table match
|
||||||
header_row = table_match.group(1).strip()
|
header_row = table_match.group(1).strip()
|
||||||
body_rows = table_match.group(2).strip().split('\n')
|
body_rows = table_match.group(2).strip().split('\n')
|
||||||
|
|
||||||
# Process header and rows to remove leading/trailing |
|
# Process header and rows to remove leading/trailing pipes
|
||||||
header_cells = [cell.strip() for cell in header_row.split('|')]
|
header_cells = [cell.strip() for cell in header_row.split('|')]
|
||||||
if header_cells[0] == '':
|
if header_cells and header_cells[0] == '':
|
||||||
header_cells = header_cells[1:]
|
header_cells = header_cells[1:]
|
||||||
if header_cells[-1] == '':
|
if header_cells and header_cells[-1] == '':
|
||||||
header_cells = header_cells[:-1]
|
header_cells = header_cells[:-1]
|
||||||
|
|
||||||
# Process table body rows
|
# Process table body rows
|
||||||
table_data = []
|
table_data = []
|
||||||
for row in [header_row] + body_rows:
|
for row in [header_row] + body_rows:
|
||||||
if '|' not in row: # Skip separator row
|
if '|' not in row: # Skip separator row
|
||||||
continue
|
continue
|
||||||
|
|
||||||
cells = [cell.strip() for cell in row.split('|')]
|
cells = [cell.strip() for cell in row.split('|')]
|
||||||
if cells[0] == '':
|
if cells and cells[0] == '':
|
||||||
cells = cells[1:]
|
cells = cells[1:]
|
||||||
if cells[-1] == '':
|
if cells and cells[-1] == '':
|
||||||
cells = cells[:-1]
|
cells = cells[:-1]
|
||||||
|
|
||||||
table_data.append(cells)
|
table_data.append(cells)
|
||||||
|
|
||||||
# Skip separator row (second row with dashes)
|
# Skip separator row (second row with dashes)
|
||||||
if len(table_data) > 1 and all('-' in cell for cell in table_data[1]):
|
if len(table_data) > 1 and all('-' in cell for cell in table_data[1]):
|
||||||
table_data = [table_data[0]] + table_data[2:]
|
table_data = [table_data[0]] + table_data[2:]
|
||||||
|
|
||||||
# Convert to numpy array for easier manipulation
|
# Convert to numpy array for easier manipulation
|
||||||
# First ensure all rows have the same number of columns by padding if necessary
|
# Ensure all rows have the same number of columns by padding if necessary
|
||||||
max_cols = max(len(row) for row in table_data)
|
max_cols = max(len(row) for row in table_data)
|
||||||
padded_data = [row + [''] * (max_cols - len(row)) for row in table_data]
|
padded_data = [row + [''] * (max_cols - len(row)) for row in table_data]
|
||||||
table_array = np.array(padded_data)
|
table_array = np.array(padded_data)
|
||||||
|
|
||||||
parsed_tables.append(table_array)
|
|
||||||
|
|
||||||
return parsed_tables
|
|
||||||
|
|
||||||
|
parsed_tables.append(table_array)
|
||||||
|
|
||||||
|
return parsed_tables
|
||||||
|
|
||||||
def parse_html_tables(self, html_content: str) -> List[np.ndarray]:
|
def parse_html_tables(self, html_content: str) -> List[np.ndarray]:
|
||||||
"""
|
"""
|
||||||
|
Loading…
x
Reference in New Issue
Block a user