mirror of
https://github.com/allenai/olmocr.git
synced 2025-06-27 04:00:02 +00:00
987 lines
37 KiB
Python
987 lines
37 KiB
Python
![]() |
#!/usr/bin/env python3
|
||
|
"""
|
||
|
Compare PII Detection Rules and Calculate IoU
|
||
|
|
||
|
This script processes JSONL attribute files from two different folders,
|
||
|
applies different rules to each for PII detection, and calculates the
|
||
|
Intersection over Union (IoU) to measure how well they overlap.
|
||
|
|
||
|
Example usage:
|
||
|
python pii_rule_comparison.py \
|
||
|
--ref-folder s3://bucket/workspace/attributes/model_a \
|
||
|
--hyp-folder s3://bucket/workspace/attributes/model_b \
|
||
|
--ref-rule "gpt_4_1_contains_pii:any" \
|
||
|
--hyp-rule "gpt_4_1_contains_email_addresses:any" \
|
||
|
--output-file iou_results.json
|
||
|
"""
|
||
|
|
||
|
import argparse
|
||
|
import boto3
|
||
|
import gzip
|
||
|
import json
|
||
|
import logging
|
||
|
import os
|
||
|
import re
|
||
|
import sys
|
||
|
from collections import defaultdict
|
||
|
from enum import Enum, auto
|
||
|
from pathlib import Path
|
||
|
from typing import Dict, List, Set, Tuple, Union, Any, Callable
|
||
|
import zstandard as zstd
|
||
|
|
||
|
# Initialize logger
|
||
|
logging.basicConfig(
|
||
|
level=logging.INFO,
|
||
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||
|
)
|
||
|
logger = logging.getLogger(__name__)
|
||
|
|
||
|
# Define token types for the rule expression parser
|
||
|
class TokenType(Enum):
|
||
|
RULE = auto()
|
||
|
AND = auto()
|
||
|
OR = auto()
|
||
|
NOT = auto()
|
||
|
LPAREN = auto()
|
||
|
RPAREN = auto()
|
||
|
EOF = auto()
|
||
|
|
||
|
class Token:
|
||
|
"""Token for rule expression parsing"""
|
||
|
def __init__(self, type, value=None):
|
||
|
self.type = type
|
||
|
self.value = value
|
||
|
|
||
|
def __repr__(self):
|
||
|
if self.value:
|
||
|
return f"Token({self.type}, {self.value})"
|
||
|
return f"Token({self.type})"
|
||
|
|
||
|
class ExpressionNode:
|
||
|
"""Base class for expression tree nodes"""
|
||
|
pass
|
||
|
|
||
|
class RuleNode(ExpressionNode):
|
||
|
"""Leaf node representing a single rule"""
|
||
|
def __init__(self, attribute_name, rule_type):
|
||
|
self.attribute_name = attribute_name
|
||
|
self.rule_type = rule_type
|
||
|
|
||
|
def __repr__(self):
|
||
|
return f"Rule({self.attribute_name}:{self.rule_type})"
|
||
|
|
||
|
class NotNode(ExpressionNode):
|
||
|
"""Unary NOT operation node"""
|
||
|
def __init__(self, operand):
|
||
|
self.operand = operand
|
||
|
|
||
|
def __repr__(self):
|
||
|
return f"NOT({self.operand})"
|
||
|
|
||
|
class BinaryNode(ExpressionNode):
|
||
|
"""Binary operation (AND/OR) node"""
|
||
|
def __init__(self, left, right, operator):
|
||
|
self.left = left
|
||
|
self.right = right
|
||
|
self.operator = operator
|
||
|
|
||
|
def __repr__(self):
|
||
|
return f"{self.operator}({self.left}, {self.right})"
|
||
|
|
||
|
def parse_args():
|
||
|
parser = argparse.ArgumentParser(description="Compare PII detection rules and calculate IoU")
|
||
|
parser.add_argument("--ref-folder", required=True, help="Reference attribute folder path (local or s3://)")
|
||
|
parser.add_argument("--hyp-folder", required=True, help="Hypothesis attribute folder path (local or s3://)")
|
||
|
parser.add_argument("--ref-rule", required=True,
|
||
|
help="""Reference rule expression. Can be a simple rule in format 'attribute_name:rule_type',
|
||
|
where rule_type is 'any' or 'all'. Or a boolean expression like
|
||
|
'not rule1:any and rule2:all' or '(rule1:any or rule2:any) and not rule3:all'""")
|
||
|
parser.add_argument("--hyp-rule", required=True,
|
||
|
help="""Hypothesis rule expression. Can be a simple rule in format 'attribute_name:rule_type',
|
||
|
where rule_type is 'any' or 'all'. Or a boolean expression like
|
||
|
'not rule1:any and rule2:all' or '(rule1:any or rule2:any) and not rule3:all'""")
|
||
|
parser.add_argument("--output-file", default="iou_results.json", help="Output JSON file to save results")
|
||
|
parser.add_argument("--aws-profile", help="AWS profile for S3 access")
|
||
|
parser.add_argument("--recursive", action="store_true", help="Recursively process folder structure")
|
||
|
return parser.parse_args()
|
||
|
|
||
|
def parse_s3_path(s3_path):
|
||
|
"""Parse S3 path into bucket and prefix."""
|
||
|
parts = s3_path.replace("s3://", "").split("/", 1)
|
||
|
bucket = parts[0]
|
||
|
prefix = parts[1] if len(parts) > 1 else ""
|
||
|
return bucket, prefix
|
||
|
|
||
|
def get_s3_bytes(s3_client, s3_path):
|
||
|
"""Get bytes from S3 object."""
|
||
|
bucket, key = parse_s3_path(s3_path)
|
||
|
response = s3_client.get_object(Bucket=bucket, Key=key)
|
||
|
return response["Body"].read()
|
||
|
|
||
|
def list_jsonl_files(path, s3_client=None, recursive=False):
|
||
|
"""List all JSONL files in the given path, locally or in S3."""
|
||
|
jsonl_files = []
|
||
|
|
||
|
if path.startswith("s3://"):
|
||
|
bucket, prefix = parse_s3_path(path)
|
||
|
prefix = prefix.rstrip("/") + "/"
|
||
|
|
||
|
# List objects in S3 bucket with given prefix
|
||
|
paginator = s3_client.get_paginator("list_objects_v2")
|
||
|
for page in paginator.paginate(Bucket=bucket, Prefix=prefix):
|
||
|
if "Contents" in page:
|
||
|
for obj in page["Contents"]:
|
||
|
key = obj["Key"]
|
||
|
if (key.endswith(".jsonl") or key.endswith(".json") or
|
||
|
key.endswith(".jsonl.gz") or key.endswith(".jsonl.zst") or
|
||
|
key.endswith(".jsonl.ztd") or key.endswith(".jsonl.zstd")):
|
||
|
jsonl_files.append(f"s3://{bucket}/{key}")
|
||
|
else:
|
||
|
# Local file system
|
||
|
path_obj = Path(path)
|
||
|
if recursive:
|
||
|
for file_path in path_obj.rglob("*"):
|
||
|
if (file_path.name.endswith(".jsonl") or file_path.name.endswith(".json") or
|
||
|
file_path.name.endswith(".jsonl.gz") or file_path.name.endswith(".jsonl.zst") or
|
||
|
file_path.name.endswith(".jsonl.ztd") or file_path.name.endswith(".jsonl.zstd")):
|
||
|
jsonl_files.append(str(file_path))
|
||
|
else:
|
||
|
for file_path in path_obj.glob("*"):
|
||
|
if (file_path.name.endswith(".jsonl") or file_path.name.endswith(".json") or
|
||
|
file_path.name.endswith(".jsonl.gz") or file_path.name.endswith(".jsonl.zst") or
|
||
|
file_path.name.endswith(".jsonl.ztd") or file_path.name.endswith(".jsonl.zstd")):
|
||
|
jsonl_files.append(str(file_path))
|
||
|
|
||
|
return jsonl_files
|
||
|
|
||
|
def load_jsonl_file(file_path, s3_client=None):
|
||
|
"""Load and decompress a JSONL file, either from local or S3."""
|
||
|
try:
|
||
|
# Get file content
|
||
|
if file_path.startswith("s3://"):
|
||
|
if s3_client is None:
|
||
|
raise ValueError("S3 client is required for S3 paths")
|
||
|
raw_data = get_s3_bytes(s3_client, file_path)
|
||
|
else:
|
||
|
with open(file_path, "rb") as f:
|
||
|
raw_data = f.read()
|
||
|
|
||
|
# Decompress if needed
|
||
|
if file_path.endswith(".gz"):
|
||
|
decompressed = gzip.decompress(raw_data)
|
||
|
elif file_path.endswith((".zst", ".ztd", ".zstd")):
|
||
|
dctx = zstd.ZstdDecompressor()
|
||
|
decompressed = dctx.decompress(raw_data)
|
||
|
else:
|
||
|
decompressed = raw_data
|
||
|
|
||
|
# Parse JSON lines
|
||
|
lines = decompressed.decode("utf-8").strip().split("\n")
|
||
|
return [json.loads(line) for line in lines if line.strip()]
|
||
|
|
||
|
except Exception as e:
|
||
|
logger.error(f"Error loading file {file_path}: {e}")
|
||
|
return []
|
||
|
|
||
|
def apply_rule(doc, rule):
|
||
|
"""
|
||
|
Apply a rule to determine if a document meets the PII criteria.
|
||
|
|
||
|
Args:
|
||
|
doc: The document JSON object
|
||
|
rule: Either a tuple (attribute_name, rule_type) for simple rules,
|
||
|
or an ExpressionNode for complex boolean expressions
|
||
|
|
||
|
Returns:
|
||
|
True if the document matches the rule, False otherwise
|
||
|
"""
|
||
|
# Handle simple rule
|
||
|
if not is_complex_expression(rule):
|
||
|
return apply_simple_rule(doc, rule[0], rule[1])
|
||
|
|
||
|
# Handle complex expression
|
||
|
return evaluate_expression(doc, rule)
|
||
|
|
||
|
def apply_simple_rule(doc, attribute_name, rule_type):
|
||
|
"""
|
||
|
Apply a simple rule to determine if a document meets the PII criteria.
|
||
|
|
||
|
Args:
|
||
|
doc: The document JSON object
|
||
|
attribute_name: The attribute field to check (e.g., "gpt_4_1_contains_pii")
|
||
|
rule_type: 'any' for any true value, 'all' for all true values
|
||
|
|
||
|
Returns:
|
||
|
True if the document matches the rule, False otherwise
|
||
|
"""
|
||
|
if "attributes" not in doc or not doc["attributes"]:
|
||
|
return False
|
||
|
|
||
|
attributes = doc["attributes"]
|
||
|
if attribute_name not in attributes or not attributes[attribute_name]:
|
||
|
return False
|
||
|
|
||
|
# Extract the boolean values from the attribute spans
|
||
|
# Each span is formatted as [start_pos, end_pos, value]
|
||
|
values = [span[2] for span in attributes[attribute_name] if len(span) >= 3 and span[2] is not None]
|
||
|
|
||
|
if not values:
|
||
|
return False
|
||
|
|
||
|
if rule_type == "any":
|
||
|
return any(values)
|
||
|
elif rule_type == "all":
|
||
|
return all(values)
|
||
|
else:
|
||
|
raise ValueError(f"Unknown rule type: {rule_type}")
|
||
|
|
||
|
def evaluate_expression(doc, expr):
|
||
|
"""
|
||
|
Evaluate a boolean expression on a document.
|
||
|
|
||
|
Args:
|
||
|
doc: The document JSON object
|
||
|
expr: An ExpressionNode representing a boolean expression
|
||
|
|
||
|
Returns:
|
||
|
True if the document matches the expression, False otherwise
|
||
|
"""
|
||
|
if isinstance(expr, RuleNode):
|
||
|
# Base case: evaluate a leaf rule node
|
||
|
return apply_simple_rule(doc, expr.attribute_name, expr.rule_type)
|
||
|
|
||
|
elif isinstance(expr, NotNode):
|
||
|
# NOT operator
|
||
|
return not evaluate_expression(doc, expr.operand)
|
||
|
|
||
|
elif isinstance(expr, BinaryNode):
|
||
|
# Binary operators (AND/OR)
|
||
|
if expr.operator == "AND":
|
||
|
# Short-circuit AND evaluation
|
||
|
return evaluate_expression(doc, expr.left) and evaluate_expression(doc, expr.right)
|
||
|
elif expr.operator == "OR":
|
||
|
# Short-circuit OR evaluation
|
||
|
return evaluate_expression(doc, expr.left) or evaluate_expression(doc, expr.right)
|
||
|
|
||
|
# Should not reach here if the expression tree is well-formed
|
||
|
raise ValueError(f"Invalid expression node type: {type(expr)}")
|
||
|
|
||
|
def tokenize_expression(expression):
|
||
|
"""
|
||
|
Tokenize a rule expression string into a list of tokens.
|
||
|
|
||
|
Args:
|
||
|
expression: A string containing a boolean rule expression
|
||
|
(e.g., "not rule1:any and rule2:all")
|
||
|
|
||
|
Returns:
|
||
|
A list of Token objects
|
||
|
"""
|
||
|
tokens = []
|
||
|
i = 0
|
||
|
expression = expression.strip()
|
||
|
|
||
|
while i < len(expression):
|
||
|
char = expression[i]
|
||
|
|
||
|
# Skip whitespace
|
||
|
if char.isspace():
|
||
|
i += 1
|
||
|
continue
|
||
|
|
||
|
# Handle parentheses
|
||
|
elif char == '(':
|
||
|
tokens.append(Token(TokenType.LPAREN))
|
||
|
i += 1
|
||
|
elif char == ')':
|
||
|
tokens.append(Token(TokenType.RPAREN))
|
||
|
i += 1
|
||
|
|
||
|
# Handle operators
|
||
|
elif i + 2 < len(expression) and expression[i:i+3].lower() == 'and':
|
||
|
# Check if it's a standalone 'and' and not part of a word
|
||
|
if (i == 0 or expression[i-1].isspace() or expression[i-1] in "()") and \
|
||
|
(i+3 >= len(expression) or expression[i+3].isspace() or expression[i+3] in "()"):
|
||
|
tokens.append(Token(TokenType.AND))
|
||
|
i += 3
|
||
|
else:
|
||
|
# It's part of an attribute name
|
||
|
rule_start = i
|
||
|
while i < len(expression) and not expression[i].isspace() and expression[i] not in "()":
|
||
|
if i + 1 < len(expression) and expression[i] == ':':
|
||
|
break
|
||
|
i += 1
|
||
|
|
||
|
# Process rule if we found a colon
|
||
|
if i < len(expression) and expression[i] == ':':
|
||
|
rule_end = i
|
||
|
i += 1 # Skip the colon
|
||
|
|
||
|
# Find the rule type
|
||
|
type_start = i
|
||
|
while i < len(expression) and not expression[i].isspace() and expression[i] not in "()":
|
||
|
i += 1
|
||
|
|
||
|
rule_name = expression[rule_start:rule_end]
|
||
|
rule_type = expression[type_start:i]
|
||
|
|
||
|
tokens.append(Token(TokenType.RULE, (rule_name, rule_type)))
|
||
|
else:
|
||
|
raise ValueError(f"Invalid rule format at position {rule_start}")
|
||
|
|
||
|
elif i + 1 < len(expression) and expression[i:i+2].lower() == 'or':
|
||
|
# Check if it's a standalone 'or' and not part of a word
|
||
|
if (i == 0 or expression[i-1].isspace() or expression[i-1] in "()") and \
|
||
|
(i+2 >= len(expression) or expression[i+2].isspace() or expression[i+2] in "()"):
|
||
|
tokens.append(Token(TokenType.OR))
|
||
|
i += 2
|
||
|
else:
|
||
|
# Part of an attribute name
|
||
|
rule_start = i
|
||
|
while i < len(expression) and not expression[i].isspace() and expression[i] not in "()":
|
||
|
if i + 1 < len(expression) and expression[i] == ':':
|
||
|
break
|
||
|
i += 1
|
||
|
|
||
|
# Process rule if we found a colon
|
||
|
if i < len(expression) and expression[i] == ':':
|
||
|
rule_end = i
|
||
|
i += 1 # Skip the colon
|
||
|
|
||
|
# Find the rule type
|
||
|
type_start = i
|
||
|
while i < len(expression) and not expression[i].isspace() and expression[i] not in "()":
|
||
|
i += 1
|
||
|
|
||
|
rule_name = expression[rule_start:rule_end]
|
||
|
rule_type = expression[type_start:i]
|
||
|
|
||
|
tokens.append(Token(TokenType.RULE, (rule_name, rule_type)))
|
||
|
else:
|
||
|
raise ValueError(f"Invalid rule format at position {rule_start}")
|
||
|
|
||
|
elif i + 2 < len(expression) and expression[i:i+3].lower() == 'not':
|
||
|
# Check if it's a standalone 'not' and not part of a word
|
||
|
if (i == 0 or expression[i-1].isspace() or expression[i-1] in "()") and \
|
||
|
(i+3 >= len(expression) or expression[i+3].isspace() or expression[i+3] in "()"):
|
||
|
tokens.append(Token(TokenType.NOT))
|
||
|
i += 3
|
||
|
else:
|
||
|
# Part of an attribute name
|
||
|
rule_start = i
|
||
|
while i < len(expression) and not expression[i].isspace() and expression[i] not in "()":
|
||
|
if i + 1 < len(expression) and expression[i] == ':':
|
||
|
break
|
||
|
i += 1
|
||
|
|
||
|
# Process rule if we found a colon
|
||
|
if i < len(expression) and expression[i] == ':':
|
||
|
rule_end = i
|
||
|
i += 1 # Skip the colon
|
||
|
|
||
|
# Find the rule type
|
||
|
type_start = i
|
||
|
while i < len(expression) and not expression[i].isspace() and expression[i] not in "()":
|
||
|
i += 1
|
||
|
|
||
|
rule_name = expression[rule_start:rule_end]
|
||
|
rule_type = expression[type_start:i]
|
||
|
|
||
|
tokens.append(Token(TokenType.RULE, (rule_name, rule_type)))
|
||
|
else:
|
||
|
raise ValueError(f"Invalid rule format at position {rule_start}")
|
||
|
|
||
|
# Handle rule (attribute:type)
|
||
|
else:
|
||
|
rule_start = i
|
||
|
while i < len(expression) and not expression[i].isspace() and expression[i] not in "()":
|
||
|
if i + 1 < len(expression) and expression[i] == ':':
|
||
|
break
|
||
|
i += 1
|
||
|
|
||
|
# Process rule if we found a colon
|
||
|
if i < len(expression) and expression[i] == ':':
|
||
|
rule_end = i
|
||
|
i += 1 # Skip the colon
|
||
|
|
||
|
# Find the rule type
|
||
|
type_start = i
|
||
|
while i < len(expression) and not expression[i].isspace() and expression[i] not in "()":
|
||
|
i += 1
|
||
|
|
||
|
rule_name = expression[rule_start:rule_end]
|
||
|
rule_type = expression[type_start:i]
|
||
|
|
||
|
tokens.append(Token(TokenType.RULE, (rule_name, rule_type)))
|
||
|
else:
|
||
|
raise ValueError(f"Invalid rule format at position {rule_start}")
|
||
|
|
||
|
tokens.append(Token(TokenType.EOF))
|
||
|
return tokens
|
||
|
|
||
|
class Parser:
|
||
|
"""
|
||
|
Parser for boolean rule expressions.
|
||
|
Implements a recursive descent parser for expressions with the following grammar:
|
||
|
|
||
|
expression → or_expr
|
||
|
or_expr → and_expr ("or" and_expr)*
|
||
|
and_expr → unary_expr ("and" unary_expr)*
|
||
|
unary_expr → "not" unary_expr | primary
|
||
|
primary → rule | "(" expression ")"
|
||
|
rule → ATTRIBUTE ":" RULE_TYPE
|
||
|
"""
|
||
|
|
||
|
def __init__(self, tokens):
|
||
|
self.tokens = tokens
|
||
|
self.current = 0
|
||
|
|
||
|
def parse(self):
|
||
|
"""Parse the tokens into an expression tree."""
|
||
|
return self.expression()
|
||
|
|
||
|
def expression(self):
|
||
|
"""Parse an expression (top level)."""
|
||
|
return self.or_expr()
|
||
|
|
||
|
def or_expr(self):
|
||
|
"""Parse an OR expression."""
|
||
|
expr = self.and_expr()
|
||
|
|
||
|
while self.match(TokenType.OR):
|
||
|
right = self.and_expr()
|
||
|
expr = BinaryNode(expr, right, "OR")
|
||
|
|
||
|
return expr
|
||
|
|
||
|
def and_expr(self):
|
||
|
"""Parse an AND expression."""
|
||
|
expr = self.unary_expr()
|
||
|
|
||
|
while self.match(TokenType.AND):
|
||
|
right = self.unary_expr()
|
||
|
expr = BinaryNode(expr, right, "AND")
|
||
|
|
||
|
return expr
|
||
|
|
||
|
def unary_expr(self):
|
||
|
"""Parse a unary expression (NOT)."""
|
||
|
if self.match(TokenType.NOT):
|
||
|
operand = self.unary_expr()
|
||
|
return NotNode(operand)
|
||
|
|
||
|
return self.primary()
|
||
|
|
||
|
def primary(self):
|
||
|
"""Parse a primary expression (rule or parenthesized expression)."""
|
||
|
if self.match(TokenType.RULE):
|
||
|
rule_tuple = self.previous().value
|
||
|
attribute_name, rule_type = rule_tuple
|
||
|
|
||
|
# Validate rule type
|
||
|
if rule_type not in ["any", "all"]:
|
||
|
raise ValueError(f"Invalid rule type: {rule_type}. Supported types: 'any', 'all'")
|
||
|
|
||
|
return RuleNode(attribute_name, rule_type)
|
||
|
|
||
|
if self.match(TokenType.LPAREN):
|
||
|
expr = self.expression()
|
||
|
self.consume(TokenType.RPAREN, "Expected ')' after expression.")
|
||
|
return expr
|
||
|
|
||
|
raise ValueError(f"Expected rule or '(' at position {self.current}")
|
||
|
|
||
|
def match(self, *types):
|
||
|
"""Check if the current token matches any of the given types."""
|
||
|
for type in types:
|
||
|
if self.check(type):
|
||
|
self.advance()
|
||
|
return True
|
||
|
|
||
|
return False
|
||
|
|
||
|
def check(self, type):
|
||
|
"""Check if the current token is of the given type without advancing."""
|
||
|
if self.is_at_end():
|
||
|
return False
|
||
|
return self.peek().type == type
|
||
|
|
||
|
def advance(self):
|
||
|
"""Advance to the next token and return the previous one."""
|
||
|
if not self.is_at_end():
|
||
|
self.current += 1
|
||
|
return self.previous()
|
||
|
|
||
|
def consume(self, type, message):
|
||
|
"""Consume the current token if it matches the expected type."""
|
||
|
if self.check(type):
|
||
|
return self.advance()
|
||
|
|
||
|
raise ValueError(f"{message} at position {self.current}")
|
||
|
|
||
|
def is_at_end(self):
|
||
|
"""Check if we've reached the end of the tokens."""
|
||
|
return self.peek().type == TokenType.EOF
|
||
|
|
||
|
def peek(self):
|
||
|
"""Return the current token without advancing."""
|
||
|
return self.tokens[self.current]
|
||
|
|
||
|
def previous(self):
|
||
|
"""Return the previous token."""
|
||
|
return self.tokens[self.current - 1]
|
||
|
|
||
|
def parse_rule(rule_string):
|
||
|
"""
|
||
|
Parse a rule string into an expression tree or a simple attribute-rule_type tuple.
|
||
|
|
||
|
Args:
|
||
|
rule_string: A string containing a rule or boolean expression of rules
|
||
|
|
||
|
Returns:
|
||
|
Either a tuple (attribute_name, rule_type) for simple rules,
|
||
|
or an ExpressionNode for complex boolean expressions
|
||
|
"""
|
||
|
# Check if this is a simple rule
|
||
|
if "and" not in rule_string.lower() and "or" not in rule_string.lower() and "not" not in rule_string.lower() and "(" not in rule_string and ")" not in rule_string:
|
||
|
# Simple rule format: attribute_name:rule_type
|
||
|
parts = rule_string.split(":", 1)
|
||
|
if len(parts) != 2:
|
||
|
raise ValueError(f"Invalid rule format: {rule_string}. Expected format: 'attribute_name:rule_type'")
|
||
|
|
||
|
attribute_name, rule_type = parts
|
||
|
if rule_type not in ["any", "all"]:
|
||
|
raise ValueError(f"Invalid rule type: {rule_type}. Supported types: 'any', 'all'")
|
||
|
|
||
|
return attribute_name, rule_type
|
||
|
else:
|
||
|
# Complex rule expression
|
||
|
try:
|
||
|
tokens = tokenize_expression(rule_string)
|
||
|
parser = Parser(tokens)
|
||
|
return parser.parse()
|
||
|
except Exception as e:
|
||
|
raise ValueError(f"Error parsing expression '{rule_string}': {e}")
|
||
|
|
||
|
def is_complex_expression(rule):
|
||
|
"""Check if the rule is a complex boolean expression."""
|
||
|
return isinstance(rule, ExpressionNode)
|
||
|
|
||
|
def get_matching_files(ref_files, hyp_files):
|
||
|
"""
|
||
|
Find files that exist in both reference and hypothesis folders,
|
||
|
matching by their relative paths.
|
||
|
|
||
|
Returns dict mapping ref_path -> hyp_path for matched files
|
||
|
"""
|
||
|
# First, convert to relative paths for matching
|
||
|
def get_relative_path(path, base_folder):
|
||
|
if path.startswith("s3://"):
|
||
|
_, full_key = parse_s3_path(path)
|
||
|
_, base_key = parse_s3_path(base_folder)
|
||
|
return full_key[len(base_key):].lstrip("/") if full_key.startswith(base_key) else full_key
|
||
|
else:
|
||
|
return os.path.relpath(path, base_folder)
|
||
|
|
||
|
ref_base = args.ref_folder
|
||
|
hyp_base = args.hyp_folder
|
||
|
|
||
|
ref_relative = {get_relative_path(path, ref_base): path for path in ref_files}
|
||
|
hyp_relative = {get_relative_path(path, hyp_base): path for path in hyp_files}
|
||
|
|
||
|
# Find matching files
|
||
|
matched_files = {}
|
||
|
for rel_path in ref_relative:
|
||
|
if rel_path in hyp_relative:
|
||
|
matched_files[ref_relative[rel_path]] = hyp_relative[rel_path]
|
||
|
|
||
|
return matched_files
|
||
|
|
||
|
def calculate_iou(ref_ids, hyp_ids):
|
||
|
"""Calculate Intersection over Union of two sets of document IDs."""
|
||
|
ref_set = set(ref_ids)
|
||
|
hyp_set = set(hyp_ids)
|
||
|
|
||
|
intersection = ref_set.intersection(hyp_set)
|
||
|
union = ref_set.union(hyp_set)
|
||
|
|
||
|
if not union:
|
||
|
return 0.0
|
||
|
|
||
|
return len(intersection) / len(union)
|
||
|
|
||
|
def collect_rule_stats(expression, doc):
|
||
|
"""
|
||
|
Collect statistics for all rules within a complex expression.
|
||
|
|
||
|
Args:
|
||
|
expression: A rule expression (either a tuple or ExpressionNode)
|
||
|
doc: The document to analyze
|
||
|
|
||
|
Returns:
|
||
|
A dictionary with rule statistics
|
||
|
"""
|
||
|
rule_stats = defaultdict(int)
|
||
|
|
||
|
# Handle simple rule
|
||
|
if not is_complex_expression(expression):
|
||
|
attribute_name, rule_type = expression
|
||
|
# Only process if document has this attribute
|
||
|
if ("attributes" in doc and doc["attributes"] and
|
||
|
attribute_name in doc["attributes"] and doc["attributes"][attribute_name]):
|
||
|
# The rule name will be the key for the statistics
|
||
|
rule_name = f"{attribute_name}:{rule_type}"
|
||
|
|
||
|
# Count entries in the attribute
|
||
|
entries = doc["attributes"][attribute_name]
|
||
|
rule_stats[f"{rule_name}_total_entries"] += len(entries)
|
||
|
|
||
|
# Count positive values
|
||
|
for span in entries:
|
||
|
if len(span) >= 3 and span[2] is True:
|
||
|
rule_stats[f"{rule_name}_positive_entries"] += 1
|
||
|
|
||
|
# Check if document matches the rule
|
||
|
if apply_simple_rule(doc, attribute_name, rule_type):
|
||
|
rule_stats[f"{rule_name}_matched_docs"] += 1
|
||
|
|
||
|
return rule_stats
|
||
|
|
||
|
# For complex expressions, traverse the expression tree
|
||
|
if isinstance(expression, RuleNode):
|
||
|
# Base case: leaf node is a simple rule
|
||
|
attribute_name, rule_type = expression.attribute_name, expression.rule_type
|
||
|
if ("attributes" in doc and doc["attributes"] and
|
||
|
attribute_name in doc["attributes"] and doc["attributes"][attribute_name]):
|
||
|
# The rule name will be the key for the statistics
|
||
|
rule_name = f"{attribute_name}:{rule_type}"
|
||
|
|
||
|
# Count entries in the attribute
|
||
|
entries = doc["attributes"][attribute_name]
|
||
|
rule_stats[f"{rule_name}_total_entries"] += len(entries)
|
||
|
|
||
|
# Count positive values
|
||
|
for span in entries:
|
||
|
if len(span) >= 3 and span[2] is True:
|
||
|
rule_stats[f"{rule_name}_positive_entries"] += 1
|
||
|
|
||
|
# Check if document matches the rule
|
||
|
if apply_simple_rule(doc, attribute_name, rule_type):
|
||
|
rule_stats[f"{rule_name}_matched_docs"] += 1
|
||
|
|
||
|
elif isinstance(expression, NotNode):
|
||
|
# Get stats from the operand
|
||
|
operand_stats = collect_rule_stats(expression.operand, doc)
|
||
|
# Merge with current stats
|
||
|
for key, value in operand_stats.items():
|
||
|
rule_stats[key] += value
|
||
|
|
||
|
elif isinstance(expression, BinaryNode):
|
||
|
# Get stats from both sides
|
||
|
left_stats = collect_rule_stats(expression.left, doc)
|
||
|
right_stats = collect_rule_stats(expression.right, doc)
|
||
|
|
||
|
# Merge with current stats
|
||
|
for key, value in left_stats.items():
|
||
|
rule_stats[key] += value
|
||
|
for key, value in right_stats.items():
|
||
|
rule_stats[key] += value
|
||
|
|
||
|
return rule_stats
|
||
|
|
||
|
def get_expression_summary(expression):
|
||
|
"""
|
||
|
Generate a string representation of a rule expression.
|
||
|
|
||
|
Args:
|
||
|
expression: A rule expression (either a tuple or ExpressionNode)
|
||
|
|
||
|
Returns:
|
||
|
A string representation of the expression
|
||
|
"""
|
||
|
if not is_complex_expression(expression):
|
||
|
return f"{expression[0]}:{expression[1]}"
|
||
|
|
||
|
if isinstance(expression, RuleNode):
|
||
|
return f"{expression.attribute_name}:{expression.rule_type}"
|
||
|
|
||
|
elif isinstance(expression, NotNode):
|
||
|
return f"not {get_expression_summary(expression.operand)}"
|
||
|
|
||
|
elif isinstance(expression, BinaryNode):
|
||
|
left_summary = get_expression_summary(expression.left)
|
||
|
right_summary = get_expression_summary(expression.right)
|
||
|
return f"({left_summary} {expression.operator.lower()} {right_summary})"
|
||
|
|
||
|
return str(expression)
|
||
|
|
||
|
def compare_files(ref_path, hyp_path, ref_rule, hyp_rule, s3_client=None):
|
||
|
"""
|
||
|
Compare two JSONL files using the specified rules and calculate IoU.
|
||
|
|
||
|
Args:
|
||
|
ref_path: Path to reference JSONL file
|
||
|
hyp_path: Path to hypothesis JSONL file
|
||
|
ref_rule: Rule expression for reference (tuple or ExpressionNode)
|
||
|
hyp_rule: Rule expression for hypothesis (tuple or ExpressionNode)
|
||
|
s3_client: S3 client for S3 paths
|
||
|
|
||
|
Returns:
|
||
|
Dictionary with comparison results
|
||
|
"""
|
||
|
# Load the files
|
||
|
ref_docs = load_jsonl_file(ref_path, s3_client)
|
||
|
hyp_docs = load_jsonl_file(hyp_path, s3_client)
|
||
|
|
||
|
# Extract document IDs and create ID-to-document maps
|
||
|
ref_id_to_doc = {doc["id"]: doc for doc in ref_docs if "id" in doc}
|
||
|
hyp_id_to_doc = {doc["id"]: doc for doc in hyp_docs if "id" in doc}
|
||
|
|
||
|
# Get common document IDs
|
||
|
common_ids = set(ref_id_to_doc.keys()).intersection(set(hyp_id_to_doc.keys()))
|
||
|
|
||
|
# Apply rules to each document
|
||
|
ref_matches = set()
|
||
|
hyp_matches = set()
|
||
|
|
||
|
# Track rule statistics
|
||
|
ref_rule_stats = defaultdict(int)
|
||
|
hyp_rule_stats = defaultdict(int)
|
||
|
|
||
|
for doc_id in common_ids:
|
||
|
ref_doc = ref_id_to_doc[doc_id]
|
||
|
hyp_doc = hyp_id_to_doc[doc_id]
|
||
|
|
||
|
# Collect statistics for all rules in the expressions
|
||
|
doc_ref_rule_stats = collect_rule_stats(ref_rule, ref_doc)
|
||
|
doc_hyp_rule_stats = collect_rule_stats(hyp_rule, hyp_doc)
|
||
|
|
||
|
# Merge with overall stats
|
||
|
for key, value in doc_ref_rule_stats.items():
|
||
|
ref_rule_stats[key] += value
|
||
|
for key, value in doc_hyp_rule_stats.items():
|
||
|
hyp_rule_stats[key] += value
|
||
|
|
||
|
# Check if document matches the rule expressions
|
||
|
if apply_rule(ref_doc, ref_rule):
|
||
|
ref_matches.add(doc_id)
|
||
|
ref_rule_stats["expression_matched_docs"] += 1
|
||
|
|
||
|
if apply_rule(hyp_doc, hyp_rule):
|
||
|
hyp_matches.add(doc_id)
|
||
|
hyp_rule_stats["expression_matched_docs"] += 1
|
||
|
|
||
|
# Calculate IoU
|
||
|
iou = calculate_iou(ref_matches, hyp_matches)
|
||
|
|
||
|
# Collect detailed statistics
|
||
|
tp = len(ref_matches.intersection(hyp_matches))
|
||
|
fp = len(hyp_matches - ref_matches)
|
||
|
fn = len(ref_matches - hyp_matches)
|
||
|
|
||
|
precision = tp / (tp + fp) if (tp + fp) > 0 else 0
|
||
|
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
|
||
|
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
|
||
|
|
||
|
# Generate string representations of the expressions
|
||
|
ref_rule_str = get_expression_summary(ref_rule)
|
||
|
hyp_rule_str = get_expression_summary(hyp_rule)
|
||
|
|
||
|
return {
|
||
|
"ref_file": ref_path,
|
||
|
"hyp_file": hyp_path,
|
||
|
"total_docs": len(common_ids),
|
||
|
"ref_rule": ref_rule_str,
|
||
|
"hyp_rule": hyp_rule_str,
|
||
|
"ref_matches": len(ref_matches),
|
||
|
"hyp_matches": len(hyp_matches),
|
||
|
"intersection": tp,
|
||
|
"union": tp + fp + fn,
|
||
|
"true_positives": tp,
|
||
|
"false_positives": fp,
|
||
|
"false_negatives": fn,
|
||
|
"precision": precision,
|
||
|
"recall": recall,
|
||
|
"f1": f1,
|
||
|
"iou": iou,
|
||
|
"ref_rule_stats": dict(ref_rule_stats),
|
||
|
"hyp_rule_stats": dict(hyp_rule_stats)
|
||
|
}
|
||
|
|
||
|
def format_rule_stats(rule_stats):
|
||
|
"""Format rule statistics for display."""
|
||
|
# Group the statistics by rule name
|
||
|
grouped_stats = defaultdict(dict)
|
||
|
|
||
|
# Process regular rule stats (format: "{rule_name}_{stat_type}")
|
||
|
for key, value in rule_stats.items():
|
||
|
if key == "expression_matched_docs":
|
||
|
# Special case for the overall expression match count
|
||
|
continue
|
||
|
|
||
|
# Extract rule name and stat type
|
||
|
if "_total_entries" in key:
|
||
|
rule_name = key.replace("_total_entries", "")
|
||
|
grouped_stats[rule_name]["total_entries"] = value
|
||
|
elif "_positive_entries" in key:
|
||
|
rule_name = key.replace("_positive_entries", "")
|
||
|
grouped_stats[rule_name]["positive_entries"] = value
|
||
|
elif "_matched_docs" in key:
|
||
|
rule_name = key.replace("_matched_docs", "")
|
||
|
grouped_stats[rule_name]["matched_docs"] = value
|
||
|
|
||
|
# Format the grouped statistics as a list of strings
|
||
|
formatted_stats = []
|
||
|
for rule_name, stats in grouped_stats.items():
|
||
|
formatted_stats.append(
|
||
|
f" {rule_name}:\n"
|
||
|
f" - Total Entries: {stats.get('total_entries', 0)}\n"
|
||
|
f" - Positive Entries: {stats.get('positive_entries', 0)}\n"
|
||
|
f" - Matched Documents: {stats.get('matched_docs', 0)}"
|
||
|
)
|
||
|
|
||
|
# Add the expression matched count if available
|
||
|
if "expression_matched_docs" in rule_stats:
|
||
|
formatted_stats.append(f" Overall Expression Matched Documents: {rule_stats['expression_matched_docs']}")
|
||
|
|
||
|
return "\n".join(formatted_stats)
|
||
|
|
||
|
def main():
|
||
|
global args
|
||
|
args = parse_args()
|
||
|
|
||
|
# Set up S3 client if needed
|
||
|
s3_client = None
|
||
|
if args.ref_folder.startswith("s3://") or args.hyp_folder.startswith("s3://"):
|
||
|
session = boto3.Session(profile_name=args.aws_profile) if args.aws_profile else boto3.Session()
|
||
|
s3_client = session.client("s3")
|
||
|
|
||
|
# Parse the rules
|
||
|
logger.info(f"Parsing reference rule expression: {args.ref_rule}")
|
||
|
ref_rule = parse_rule(args.ref_rule)
|
||
|
|
||
|
logger.info(f"Parsing hypothesis rule expression: {args.hyp_rule}")
|
||
|
hyp_rule = parse_rule(args.hyp_rule)
|
||
|
|
||
|
# Generate string representations of the expressions
|
||
|
ref_rule_str = get_expression_summary(ref_rule)
|
||
|
hyp_rule_str = get_expression_summary(hyp_rule)
|
||
|
|
||
|
logger.info(f"Reference rule parsed as: {ref_rule_str}")
|
||
|
logger.info(f"Hypothesis rule parsed as: {hyp_rule_str}")
|
||
|
|
||
|
# List JSONL files in both folders
|
||
|
logger.info(f"Finding JSONL files in reference folder: {args.ref_folder}")
|
||
|
ref_files = list_jsonl_files(args.ref_folder, s3_client, args.recursive)
|
||
|
|
||
|
logger.info(f"Finding JSONL files in hypothesis folder: {args.hyp_folder}")
|
||
|
hyp_files = list_jsonl_files(args.hyp_folder, s3_client, args.recursive)
|
||
|
|
||
|
logger.info(f"Found {len(ref_files)} files in reference folder and {len(hyp_files)} files in hypothesis folder")
|
||
|
|
||
|
# Find matching files
|
||
|
matched_files = get_matching_files(ref_files, hyp_files)
|
||
|
logger.info(f"Found {len(matched_files)} matching files between folders")
|
||
|
|
||
|
if not matched_files:
|
||
|
logger.error("No matching files found between reference and hypothesis folders")
|
||
|
sys.exit(1)
|
||
|
|
||
|
# Process each pair of files
|
||
|
results = []
|
||
|
overall_stats = {
|
||
|
"total_files": len(matched_files),
|
||
|
"total_docs": 0,
|
||
|
"ref_matches": 0,
|
||
|
"hyp_matches": 0,
|
||
|
"true_positives": 0,
|
||
|
"false_positives": 0,
|
||
|
"false_negatives": 0,
|
||
|
# Initialize rule stats counters
|
||
|
"ref_rule_stats": defaultdict(int),
|
||
|
"hyp_rule_stats": defaultdict(int)
|
||
|
}
|
||
|
|
||
|
for i, (ref_path, hyp_path) in enumerate(matched_files.items()):
|
||
|
logger.info(f"Processing file pair {i+1}/{len(matched_files)}: {os.path.basename(ref_path)}")
|
||
|
|
||
|
file_result = compare_files(ref_path, hyp_path, ref_rule, hyp_rule, s3_client)
|
||
|
results.append(file_result)
|
||
|
|
||
|
# Accumulate overall statistics
|
||
|
overall_stats["total_docs"] += file_result["total_docs"]
|
||
|
overall_stats["ref_matches"] += file_result["ref_matches"]
|
||
|
overall_stats["hyp_matches"] += file_result["hyp_matches"]
|
||
|
overall_stats["true_positives"] += file_result["true_positives"]
|
||
|
overall_stats["false_positives"] += file_result["false_positives"]
|
||
|
overall_stats["false_negatives"] += file_result["false_negatives"]
|
||
|
|
||
|
# Accumulate rule statistics
|
||
|
for key, value in file_result["ref_rule_stats"].items():
|
||
|
overall_stats["ref_rule_stats"][key] += value
|
||
|
for key, value in file_result["hyp_rule_stats"].items():
|
||
|
overall_stats["hyp_rule_stats"][key] += value
|
||
|
|
||
|
# Calculate overall metrics
|
||
|
tp = overall_stats["true_positives"]
|
||
|
fp = overall_stats["false_positives"]
|
||
|
fn = overall_stats["false_negatives"]
|
||
|
|
||
|
overall_stats["precision"] = tp / (tp + fp) if (tp + fp) > 0 else 0
|
||
|
overall_stats["recall"] = tp / (tp + fn) if (tp + fn) > 0 else 0
|
||
|
overall_stats["f1"] = (
|
||
|
2 * overall_stats["precision"] * overall_stats["recall"] /
|
||
|
(overall_stats["precision"] + overall_stats["recall"])
|
||
|
if (overall_stats["precision"] + overall_stats["recall"]) > 0 else 0
|
||
|
)
|
||
|
overall_stats["iou"] = tp / (tp + fp + fn) if (tp + fp + fn) > 0 else 0
|
||
|
|
||
|
# Convert defaultdicts to regular dicts for JSON serialization
|
||
|
overall_stats["ref_rule_stats"] = dict(overall_stats["ref_rule_stats"])
|
||
|
overall_stats["hyp_rule_stats"] = dict(overall_stats["hyp_rule_stats"])
|
||
|
|
||
|
# Prepare final output
|
||
|
output = {
|
||
|
"config": {
|
||
|
"ref_folder": args.ref_folder,
|
||
|
"hyp_folder": args.hyp_folder,
|
||
|
"ref_rule": args.ref_rule,
|
||
|
"ref_rule_parsed": ref_rule_str,
|
||
|
"hyp_rule": args.hyp_rule,
|
||
|
"hyp_rule_parsed": hyp_rule_str
|
||
|
},
|
||
|
"overall": overall_stats,
|
||
|
"file_results": results
|
||
|
}
|
||
|
|
||
|
# Save results
|
||
|
with open(args.output_file, "w") as f:
|
||
|
json.dump(output, f, indent=2)
|
||
|
|
||
|
# Print summary
|
||
|
logger.info("\n--- COMPARISON SUMMARY ---")
|
||
|
logger.info(f"Reference Rule Expression: {args.ref_rule}")
|
||
|
logger.info(f" Parsed as: {ref_rule_str}")
|
||
|
logger.info(f"Hypothesis Rule Expression: {args.hyp_rule}")
|
||
|
logger.info(f" Parsed as: {hyp_rule_str}")
|
||
|
logger.info(f"Total Documents: {overall_stats['total_docs']}")
|
||
|
|
||
|
# Print rule statistics
|
||
|
logger.info("\n--- RULE MATCH STATISTICS ---")
|
||
|
|
||
|
logger.info("\nReference Rules:")
|
||
|
logger.info(format_rule_stats(overall_stats["ref_rule_stats"]))
|
||
|
|
||
|
logger.info("\nHypothesis Rules:")
|
||
|
logger.info(format_rule_stats(overall_stats["hyp_rule_stats"]))
|
||
|
|
||
|
# Print comparison metrics
|
||
|
logger.info("\n--- COMPARISON METRICS ---")
|
||
|
logger.info(f"IoU: {overall_stats['iou']:.4f}")
|
||
|
logger.info(f"Precision: {overall_stats['precision']:.4f}")
|
||
|
logger.info(f"Recall: {overall_stats['recall']:.4f}")
|
||
|
logger.info(f"F1 Score: {overall_stats['f1']:.4f}")
|
||
|
logger.info(f"Detailed results saved to: {args.output_file}")
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
main()
|