mirror of
https://github.com/allenai/olmocr.git
synced 2025-10-14 01:32:31 +00:00
Less duped tests
This commit is contained in:
parent
9855f70fee
commit
0afacd6ac7
@ -246,39 +246,32 @@ def generate_table_tests(tables: List[np.ndarray], pdf_image: str, api_key: str,
|
||||
|
||||
# Try up to 3x max_tests_per_table candidate cells
|
||||
candidate_positions = []
|
||||
for _ in range(max_tests_per_table * 3):
|
||||
row = random.randint(0, rows - 1)
|
||||
col = random.randint(0, cols - 1)
|
||||
if not table[row, col].strip():
|
||||
continue
|
||||
candidate_positions.append((row, col))
|
||||
for row in range(rows):
|
||||
for col in range(cols):
|
||||
if not table[row, col].strip():
|
||||
continue
|
||||
if row > 0:
|
||||
candidate_positions.append((row, col, "up"))
|
||||
if row < rows - 1:
|
||||
candidate_positions.append((row, col, "down"))
|
||||
if col > 0:
|
||||
candidate_positions.append((row, col, "left"))
|
||||
if col < cols - 1:
|
||||
candidate_positions.append((row, col, "right"))
|
||||
if row > 0:
|
||||
candidate_positions.append((row, col, "top_heading"))
|
||||
if col > 0:
|
||||
candidate_positions.append((row, col, "left_heading"))
|
||||
|
||||
random.shuffle(candidate_positions)
|
||||
tests_for_this_table = 0
|
||||
|
||||
for row, col in candidate_positions:
|
||||
for row, col, relationship in candidate_positions:
|
||||
if tests_for_this_table >= max_tests_per_table:
|
||||
break
|
||||
|
||||
cell_value = table[row, col].strip()
|
||||
# Determine valid relationship types based on candidate's position
|
||||
valid_relationships = []
|
||||
if row > 0:
|
||||
valid_relationships.append("up")
|
||||
if row < rows - 1:
|
||||
valid_relationships.append("down")
|
||||
if col > 0:
|
||||
valid_relationships.append("left")
|
||||
if col < cols - 1:
|
||||
valid_relationships.append("right")
|
||||
if row > 0:
|
||||
valid_relationships.append("top_heading")
|
||||
if col > 0:
|
||||
valid_relationships.append("left_heading")
|
||||
if not valid_relationships:
|
||||
continue
|
||||
|
||||
relationship = random.choice(valid_relationships)
|
||||
prompt = (
|
||||
f"Given a cell in a table with value '{cell_value}', please answer: "
|
||||
f"{prompt_map[relationship]} Provide only the cell's text or output 'null' if there is not a matching cell."
|
||||
|
Loading…
x
Reference in New Issue
Block a user