Less duped tests

This commit is contained in:
Jake Poznanski 2025-03-19 17:32:06 +00:00
parent 9855f70fee
commit 0afacd6ac7

View File

@ -246,39 +246,32 @@ def generate_table_tests(tables: List[np.ndarray], pdf_image: str, api_key: str,
# Try up to 3x max_tests_per_table candidate cells
candidate_positions = []
for _ in range(max_tests_per_table * 3):
row = random.randint(0, rows - 1)
col = random.randint(0, cols - 1)
for row in range(rows):
for col in range(cols):
if not table[row, col].strip():
continue
candidate_positions.append((row, col))
if row > 0:
candidate_positions.append((row, col, "up"))
if row < rows - 1:
candidate_positions.append((row, col, "down"))
if col > 0:
candidate_positions.append((row, col, "left"))
if col < cols - 1:
candidate_positions.append((row, col, "right"))
if row > 0:
candidate_positions.append((row, col, "top_heading"))
if col > 0:
candidate_positions.append((row, col, "left_heading"))
random.shuffle(candidate_positions)
tests_for_this_table = 0
for row, col in candidate_positions:
for row, col, relationship in candidate_positions:
if tests_for_this_table >= max_tests_per_table:
break
cell_value = table[row, col].strip()
# Determine valid relationship types based on candidate's position
valid_relationships = []
if row > 0:
valid_relationships.append("up")
if row < rows - 1:
valid_relationships.append("down")
if col > 0:
valid_relationships.append("left")
if col < cols - 1:
valid_relationships.append("right")
if row > 0:
valid_relationships.append("top_heading")
if col > 0:
valid_relationships.append("left_heading")
if not valid_relationships:
continue
relationship = random.choice(valid_relationships)
prompt = (
f"Given a cell in a table with value '{cell_value}', please answer: "
f"{prompt_map[relationship]} Provide only the cell's text or output 'null' if there is not a matching cell."