mirror of
https://github.com/allenai/olmocr.git
synced 2025-10-14 09:42:47 +00:00
Less duped tests
This commit is contained in:
parent
9855f70fee
commit
0afacd6ac7
@ -246,39 +246,32 @@ def generate_table_tests(tables: List[np.ndarray], pdf_image: str, api_key: str,
|
|||||||
|
|
||||||
# Try up to 3x max_tests_per_table candidate cells
|
# Try up to 3x max_tests_per_table candidate cells
|
||||||
candidate_positions = []
|
candidate_positions = []
|
||||||
for _ in range(max_tests_per_table * 3):
|
for row in range(rows):
|
||||||
row = random.randint(0, rows - 1)
|
for col in range(cols):
|
||||||
col = random.randint(0, cols - 1)
|
if not table[row, col].strip():
|
||||||
if not table[row, col].strip():
|
continue
|
||||||
continue
|
if row > 0:
|
||||||
candidate_positions.append((row, col))
|
candidate_positions.append((row, col, "up"))
|
||||||
|
if row < rows - 1:
|
||||||
|
candidate_positions.append((row, col, "down"))
|
||||||
|
if col > 0:
|
||||||
|
candidate_positions.append((row, col, "left"))
|
||||||
|
if col < cols - 1:
|
||||||
|
candidate_positions.append((row, col, "right"))
|
||||||
|
if row > 0:
|
||||||
|
candidate_positions.append((row, col, "top_heading"))
|
||||||
|
if col > 0:
|
||||||
|
candidate_positions.append((row, col, "left_heading"))
|
||||||
|
|
||||||
random.shuffle(candidate_positions)
|
random.shuffle(candidate_positions)
|
||||||
tests_for_this_table = 0
|
tests_for_this_table = 0
|
||||||
|
|
||||||
for row, col in candidate_positions:
|
for row, col, relationship in candidate_positions:
|
||||||
if tests_for_this_table >= max_tests_per_table:
|
if tests_for_this_table >= max_tests_per_table:
|
||||||
break
|
break
|
||||||
|
|
||||||
cell_value = table[row, col].strip()
|
cell_value = table[row, col].strip()
|
||||||
# Determine valid relationship types based on candidate's position
|
|
||||||
valid_relationships = []
|
|
||||||
if row > 0:
|
|
||||||
valid_relationships.append("up")
|
|
||||||
if row < rows - 1:
|
|
||||||
valid_relationships.append("down")
|
|
||||||
if col > 0:
|
|
||||||
valid_relationships.append("left")
|
|
||||||
if col < cols - 1:
|
|
||||||
valid_relationships.append("right")
|
|
||||||
if row > 0:
|
|
||||||
valid_relationships.append("top_heading")
|
|
||||||
if col > 0:
|
|
||||||
valid_relationships.append("left_heading")
|
|
||||||
if not valid_relationships:
|
|
||||||
continue
|
|
||||||
|
|
||||||
relationship = random.choice(valid_relationships)
|
|
||||||
prompt = (
|
prompt = (
|
||||||
f"Given a cell in a table with value '{cell_value}', please answer: "
|
f"Given a cell in a table with value '{cell_value}', please answer: "
|
||||||
f"{prompt_map[relationship]} Provide only the cell's text or output 'null' if there is not a matching cell."
|
f"{prompt_map[relationship]} Provide only the cell's text or output 'null' if there is not a matching cell."
|
||||||
|
Loading…
x
Reference in New Issue
Block a user