mirror of
https://github.com/allenai/olmocr.git
synced 2025-07-03 07:05:50 +00:00
1136 lines
45 KiB
Python
1136 lines
45 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Compare PII Detection Rules and Calculate IoU
|
|
|
|
This script processes documents and their attributes from S3 or local storage,
|
|
applies different rules for PII detection, and calculates the
|
|
Intersection over Union (IoU) to measure how well they overlap.
|
|
|
|
How it works:
|
|
1. Documents are stored in one location (--docs-folder)
|
|
2. Attributes are stored in separate folders (--ref-attr-folder and --hyp-attr-folder)
|
|
3. The script merges documents with their attributes by matching filenames and document IDs
|
|
4. PII detection rules are applied to each set of merged documents
|
|
5. IoU and other metrics are calculated to compare the results
|
|
|
|
Folder structure:
|
|
- s3://bucket/path/documents/ - Contains the main document JSONL files
|
|
- s3://bucket/path/attributes/attr_name/ - Contains attributes that can be matched with documents by ID
|
|
|
|
Document and attribute matching:
|
|
- Files are matched by basename (example.jsonl in documents matches example.jsonl in attributes)
|
|
- Within each file, documents are matched by their "id" field
|
|
- When a match is found, attributes from the attribute file are merged into the document
|
|
|
|
Example usage:
|
|
python pii_rule_comparison.py \
|
|
--docs-folder s3://bucket/path/documents \
|
|
--ref-attr-folder s3://bucket/path/attributes/model_a \
|
|
--hyp-attr-folder s3://bucket/path/attributes/model_b \
|
|
--ref-rule "gpt_4_1_contains_pii:any" \
|
|
--hyp-rule "gpt_4_1_contains_email_addresses:any" \
|
|
--output-file iou_results.json \
|
|
--recursive
|
|
|
|
Rule expression syntax:
|
|
- Simple rule: "attribute_name:rule_type" where rule_type is "any" or "all"
|
|
- Boolean expressions: "not rule1:any and rule2:all"
|
|
- Parentheses for grouping: "(rule1:any or rule2:any) and not rule3:all"
|
|
"""
|
|
|
|
import argparse
|
|
import boto3
|
|
import gzip
|
|
import io
|
|
import json
|
|
import logging
|
|
import os
|
|
import re
|
|
import sys
|
|
from collections import defaultdict
|
|
from enum import Enum, auto
|
|
from pathlib import Path
|
|
from typing import Dict, List, Set, Tuple, Union, Any, Callable
|
|
import zstandard as zstd
|
|
|
|
# Initialize logger
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Define token types for the rule expression parser
|
|
class TokenType(Enum):
|
|
RULE = auto()
|
|
AND = auto()
|
|
OR = auto()
|
|
NOT = auto()
|
|
LPAREN = auto()
|
|
RPAREN = auto()
|
|
EOF = auto()
|
|
|
|
class Token:
|
|
"""Token for rule expression parsing"""
|
|
def __init__(self, type, value=None):
|
|
self.type = type
|
|
self.value = value
|
|
|
|
def __repr__(self):
|
|
if self.value:
|
|
return f"Token({self.type}, {self.value})"
|
|
return f"Token({self.type})"
|
|
|
|
class ExpressionNode:
|
|
"""Base class for expression tree nodes"""
|
|
pass
|
|
|
|
class RuleNode(ExpressionNode):
|
|
"""Leaf node representing a single rule"""
|
|
def __init__(self, attribute_name, rule_type):
|
|
self.attribute_name = attribute_name
|
|
self.rule_type = rule_type
|
|
|
|
def __repr__(self):
|
|
return f"Rule({self.attribute_name}:{self.rule_type})"
|
|
|
|
class NotNode(ExpressionNode):
|
|
"""Unary NOT operation node"""
|
|
def __init__(self, operand):
|
|
self.operand = operand
|
|
|
|
def __repr__(self):
|
|
return f"NOT({self.operand})"
|
|
|
|
class BinaryNode(ExpressionNode):
|
|
"""Binary operation (AND/OR) node"""
|
|
def __init__(self, left, right, operator):
|
|
self.left = left
|
|
self.right = right
|
|
self.operator = operator
|
|
|
|
def __repr__(self):
|
|
return f"{self.operator}({self.left}, {self.right})"
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description="Compare PII detection rules and calculate IoU")
|
|
parser.add_argument("--docs-folder", required=True, help="Documents folder path containing JSONL files (local or s3://)")
|
|
parser.add_argument("--ref-attr-folder", required=True, help="Reference attributes folder path (local or s3://)")
|
|
parser.add_argument("--hyp-attr-folder", required=True, help="Hypothesis attributes folder path (local or s3://)")
|
|
parser.add_argument("--ref-rule", required=True,
|
|
help="""Reference rule expression. Can be a simple rule in format 'attribute_name:rule_type',
|
|
where rule_type is 'any' or 'all'. Or a boolean expression like
|
|
'not rule1:any and rule2:all' or '(rule1:any or rule2:any) and not rule3:all'""")
|
|
parser.add_argument("--hyp-rule", required=True,
|
|
help="""Hypothesis rule expression. Can be a simple rule in format 'attribute_name:rule_type',
|
|
where rule_type is 'any' or 'all'. Or a boolean expression like
|
|
'not rule1:any and rule2:all' or '(rule1:any or rule2:any) and not rule3:all'""")
|
|
parser.add_argument("--output-file", default="iou_results.json", help="Output JSON file to save results")
|
|
parser.add_argument("--aws-profile", help="AWS profile for S3 access")
|
|
parser.add_argument("--recursive", action="store_true", help="Recursively process folder structure")
|
|
parser.add_argument("--debug", action="store_true", help="Enable debug logging for more detailed output")
|
|
return parser.parse_args()
|
|
|
|
def parse_s3_path(s3_path):
|
|
"""Parse S3 path into bucket and prefix."""
|
|
parts = s3_path.replace("s3://", "").split("/", 1)
|
|
bucket = parts[0]
|
|
prefix = parts[1] if len(parts) > 1 else ""
|
|
return bucket, prefix
|
|
|
|
def get_s3_bytes(s3_client, s3_path):
|
|
"""Get bytes from S3 object."""
|
|
bucket, key = parse_s3_path(s3_path)
|
|
response = s3_client.get_object(Bucket=bucket, Key=key)
|
|
return response["Body"].read()
|
|
|
|
def list_jsonl_files(path, s3_client=None, recursive=False):
|
|
"""List all JSONL files in the given path, locally or in S3."""
|
|
jsonl_files = []
|
|
|
|
if path.startswith("s3://"):
|
|
bucket, prefix = parse_s3_path(path)
|
|
prefix = prefix.rstrip("/") + "/"
|
|
|
|
# List objects in S3 bucket with given prefix
|
|
paginator = s3_client.get_paginator("list_objects_v2")
|
|
for page in paginator.paginate(Bucket=bucket, Prefix=prefix):
|
|
if "Contents" in page:
|
|
for obj in page["Contents"]:
|
|
key = obj["Key"]
|
|
if (key.endswith(".jsonl") or key.endswith(".json") or
|
|
key.endswith(".jsonl.gz") or key.endswith(".jsonl.zst") or
|
|
key.endswith(".jsonl.ztd") or key.endswith(".jsonl.zstd")):
|
|
jsonl_files.append(f"s3://{bucket}/{key}")
|
|
else:
|
|
# Local file system
|
|
path_obj = Path(path)
|
|
if recursive:
|
|
for file_path in path_obj.rglob("*"):
|
|
if (file_path.name.endswith(".jsonl") or file_path.name.endswith(".json") or
|
|
file_path.name.endswith(".jsonl.gz") or file_path.name.endswith(".jsonl.zst") or
|
|
file_path.name.endswith(".jsonl.ztd") or file_path.name.endswith(".jsonl.zstd")):
|
|
jsonl_files.append(str(file_path))
|
|
else:
|
|
for file_path in path_obj.glob("*"):
|
|
if (file_path.name.endswith(".jsonl") or file_path.name.endswith(".json") or
|
|
file_path.name.endswith(".jsonl.gz") or file_path.name.endswith(".jsonl.zst") or
|
|
file_path.name.endswith(".jsonl.ztd") or file_path.name.endswith(".jsonl.zstd")):
|
|
jsonl_files.append(str(file_path))
|
|
|
|
return jsonl_files
|
|
|
|
def load_jsonl_file(file_path, s3_client=None):
|
|
"""Load and decompress a JSONL file, either from local or S3."""
|
|
try:
|
|
# Get file content
|
|
if file_path.startswith("s3://"):
|
|
if s3_client is None:
|
|
raise ValueError("S3 client is required for S3 paths")
|
|
raw_data = get_s3_bytes(s3_client, file_path)
|
|
else:
|
|
with open(file_path, "rb") as f:
|
|
raw_data = f.read()
|
|
|
|
# Decompress if needed
|
|
if file_path.endswith(".gz"):
|
|
decompressed = gzip.decompress(raw_data)
|
|
elif file_path.endswith((".zst", ".ztd", ".zstd")):
|
|
try:
|
|
# First try with standard decompression
|
|
dctx = zstd.ZstdDecompressor()
|
|
decompressed = dctx.decompress(raw_data)
|
|
except zstd.ZstdError as e:
|
|
# If that fails, try with stream decompression
|
|
logger.warning(f"Standard zstd decompression failed for {file_path}, trying stream decompression: {e}")
|
|
|
|
try:
|
|
# Try with content-size not required
|
|
dctx = zstd.ZstdDecompressor(max_window_size=2147483648) # Use a large window size
|
|
decompressor = dctx.stream_reader(io.BytesIO(raw_data))
|
|
decompressed = decompressor.read()
|
|
except Exception as inner_e:
|
|
# If both methods fail, try with chunking
|
|
logger.warning(f"Stream decompression also failed, trying chunked reading: {inner_e}")
|
|
|
|
# Chunked reading approach
|
|
buffer = io.BytesIO()
|
|
dctx = zstd.ZstdDecompressor(max_window_size=2147483648)
|
|
with dctx.stream_reader(io.BytesIO(raw_data)) as reader:
|
|
while True:
|
|
chunk = reader.read(16384) # Read in 16KB chunks
|
|
if not chunk:
|
|
break
|
|
buffer.write(chunk)
|
|
|
|
buffer.seek(0)
|
|
decompressed = buffer.read()
|
|
else:
|
|
decompressed = raw_data
|
|
|
|
# Parse JSON lines
|
|
lines = decompressed.decode("utf-8").strip().split("\n")
|
|
return [json.loads(line) for line in lines if line.strip()]
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error loading file {file_path}: {e}")
|
|
return []
|
|
|
|
def load_documents_and_attributes(docs_folder, attr_folder, s3_client=None, recursive=False):
|
|
"""
|
|
Load documents and merge them with their attributes.
|
|
|
|
Args:
|
|
docs_folder: Path to the documents folder
|
|
attr_folder: Path to the attributes folder
|
|
s3_client: S3 client for S3 paths
|
|
recursive: Whether to process folders recursively
|
|
|
|
Returns:
|
|
List of documents with their attributes merged in
|
|
"""
|
|
try:
|
|
# List all document files
|
|
logger.info(f"Finding document files in: {docs_folder}")
|
|
doc_files = list_jsonl_files(docs_folder, s3_client, recursive)
|
|
logger.info(f"Found {len(doc_files)} document files")
|
|
|
|
if not doc_files:
|
|
logger.warning(f"No document files found in {docs_folder}. Check the path and permissions.")
|
|
|
|
# List all attribute files
|
|
logger.info(f"Finding attribute files in: {attr_folder}")
|
|
attr_files = list_jsonl_files(attr_folder, s3_client, recursive)
|
|
logger.info(f"Found {len(attr_files)} attribute files")
|
|
|
|
if not attr_files:
|
|
logger.warning(f"No attribute files found in {attr_folder}. Check the path and permissions.")
|
|
|
|
# Create a mapping from document filename to attribute filename
|
|
# based on matching basenames
|
|
attr_file_map = {}
|
|
|
|
for attr_path in attr_files:
|
|
if attr_path.startswith("s3://"):
|
|
_, attr_key = parse_s3_path(attr_path)
|
|
basename = os.path.basename(attr_key)
|
|
else:
|
|
basename = os.path.basename(attr_path)
|
|
|
|
attr_file_map[basename] = attr_path
|
|
logger.debug(f"Mapped attribute file basename: {basename} -> {attr_path}")
|
|
|
|
# Load and merge documents with their attributes
|
|
merged_docs = []
|
|
matched_files = 0
|
|
docs_with_matched_attrs = 0
|
|
|
|
for doc_path in doc_files:
|
|
try:
|
|
if doc_path.startswith("s3://"):
|
|
_, doc_key = parse_s3_path(doc_path)
|
|
basename = os.path.basename(doc_key)
|
|
else:
|
|
basename = os.path.basename(doc_path)
|
|
|
|
# Load documents
|
|
docs = load_jsonl_file(doc_path, s3_client)
|
|
if not docs:
|
|
logger.warning(f"No documents loaded from {basename} (path: {doc_path})")
|
|
continue
|
|
|
|
logger.info(f"Loaded {len(docs)} documents from {basename}")
|
|
|
|
# Find matching attribute file
|
|
if basename in attr_file_map:
|
|
matched_files += 1
|
|
attr_path = attr_file_map[basename]
|
|
attrs = load_jsonl_file(attr_path, s3_client)
|
|
|
|
if not attrs:
|
|
logger.warning(f"No attributes loaded from {os.path.basename(attr_path)} (path: {attr_path})")
|
|
merged_docs.extend(docs)
|
|
continue
|
|
|
|
logger.info(f"Loaded {len(attrs)} attributes from {os.path.basename(attr_path)}")
|
|
|
|
# Create a mapping from document ID to attributes
|
|
attr_by_id = {attr["id"]: attr for attr in attrs if "id" in attr}
|
|
logger.info(f"Created mapping for {len(attr_by_id)} attribute entries by ID")
|
|
|
|
# Count documents with matched attributes
|
|
docs_matched_in_file = 0
|
|
|
|
# Merge documents with their attributes
|
|
for doc in docs:
|
|
if "id" in doc and doc["id"] in attr_by_id:
|
|
docs_matched_in_file += 1
|
|
|
|
# If document doesn't have attributes field, create it
|
|
if "attributes" not in doc:
|
|
doc["attributes"] = {}
|
|
|
|
# If attributes document has attributes field, merge them
|
|
if "attributes" in attr_by_id[doc["id"]]:
|
|
doc["attributes"].update(attr_by_id[doc["id"]]["attributes"])
|
|
else:
|
|
logger.debug(f"Attribute document {doc['id']} has no 'attributes' field")
|
|
|
|
docs_with_matched_attrs += docs_matched_in_file
|
|
logger.info(f"Matched attributes for {docs_matched_in_file}/{len(docs)} documents in {basename}")
|
|
merged_docs.extend(docs)
|
|
else:
|
|
logger.warning(f"No matching attribute file found for {basename}")
|
|
merged_docs.extend(docs)
|
|
except Exception as e:
|
|
logger.error(f"Error processing document file {doc_path}: {e}")
|
|
continue
|
|
|
|
logger.info(f"Total documents processed: {len(merged_docs)}")
|
|
logger.info(f"Files with matched attributes: {matched_files}/{len(doc_files)}")
|
|
logger.info(f"Documents with matched attributes: {docs_with_matched_attrs}")
|
|
|
|
# Check if no attributes were merged
|
|
if docs_with_matched_attrs == 0:
|
|
logger.warning("No documents had matching attributes! This may indicate a problem with file naming or ID matching.")
|
|
|
|
return merged_docs
|
|
except Exception as e:
|
|
logger.error(f"Error in load_documents_and_attributes: {e}")
|
|
raise
|
|
|
|
def apply_rule(doc, rule):
|
|
"""
|
|
Apply a rule to determine if a document meets the PII criteria.
|
|
|
|
Args:
|
|
doc: The document JSON object
|
|
rule: Either a tuple (attribute_name, rule_type) for simple rules,
|
|
or an ExpressionNode for complex boolean expressions
|
|
|
|
Returns:
|
|
True if the document matches the rule, False otherwise
|
|
"""
|
|
# Handle simple rule
|
|
if not is_complex_expression(rule):
|
|
return apply_simple_rule(doc, rule[0], rule[1])
|
|
|
|
# Handle complex expression
|
|
return evaluate_expression(doc, rule)
|
|
|
|
def apply_simple_rule(doc, attribute_name, rule_type):
|
|
"""
|
|
Apply a simple rule to determine if a document meets the PII criteria.
|
|
|
|
Args:
|
|
doc: The document JSON object
|
|
attribute_name: The attribute field to check (e.g., "gpt_4_1_contains_pii")
|
|
rule_type: 'any' for any true value, 'all' for all true values
|
|
|
|
Returns:
|
|
True if the document matches the rule, False otherwise
|
|
"""
|
|
# Check if document has attributes
|
|
if "attributes" not in doc or not doc["attributes"]:
|
|
logger.debug(f"Document {doc.get('id', 'unknown')} has no attributes")
|
|
return False
|
|
|
|
attributes = doc["attributes"]
|
|
|
|
# Check if the specific attribute exists
|
|
if attribute_name not in attributes:
|
|
logger.debug(f"Document {doc.get('id', 'unknown')} doesn't have attribute: {attribute_name}")
|
|
return False
|
|
|
|
if not attributes[attribute_name]:
|
|
logger.debug(f"Document {doc.get('id', 'unknown')} has empty attribute: {attribute_name}")
|
|
return False
|
|
|
|
# Extract the boolean values from the attribute spans
|
|
# Each span is formatted as [start_pos, end_pos, value]
|
|
values = [span[2] for span in attributes[attribute_name] if len(span) >= 3 and span[2] is not None]
|
|
|
|
if not values:
|
|
logger.debug(f"Document {doc.get('id', 'unknown')} has no valid values for attribute: {attribute_name}")
|
|
return False
|
|
|
|
# Apply the rule
|
|
if rule_type == "any":
|
|
result = any(values)
|
|
if result:
|
|
logger.debug(f"Document {doc.get('id', 'unknown')} matched rule '{attribute_name}:{rule_type}' (found True in {len(values)} values)")
|
|
return result
|
|
elif rule_type == "all":
|
|
result = all(values)
|
|
if result:
|
|
logger.debug(f"Document {doc.get('id', 'unknown')} matched rule '{attribute_name}:{rule_type}' (all {len(values)} values are True)")
|
|
return result
|
|
else:
|
|
raise ValueError(f"Unknown rule type: {rule_type}")
|
|
|
|
def evaluate_expression(doc, expr):
|
|
"""
|
|
Evaluate a boolean expression on a document.
|
|
|
|
Args:
|
|
doc: The document JSON object
|
|
expr: An ExpressionNode representing a boolean expression
|
|
|
|
Returns:
|
|
True if the document matches the expression, False otherwise
|
|
"""
|
|
if isinstance(expr, RuleNode):
|
|
# Base case: evaluate a leaf rule node
|
|
return apply_simple_rule(doc, expr.attribute_name, expr.rule_type)
|
|
|
|
elif isinstance(expr, NotNode):
|
|
# NOT operator
|
|
return not evaluate_expression(doc, expr.operand)
|
|
|
|
elif isinstance(expr, BinaryNode):
|
|
# Binary operators (AND/OR)
|
|
if expr.operator == "AND":
|
|
# Short-circuit AND evaluation
|
|
return evaluate_expression(doc, expr.left) and evaluate_expression(doc, expr.right)
|
|
elif expr.operator == "OR":
|
|
# Short-circuit OR evaluation
|
|
return evaluate_expression(doc, expr.left) or evaluate_expression(doc, expr.right)
|
|
|
|
# Should not reach here if the expression tree is well-formed
|
|
raise ValueError(f"Invalid expression node type: {type(expr)}")
|
|
|
|
def tokenize_expression(expression):
|
|
"""
|
|
Tokenize a rule expression string into a list of tokens.
|
|
|
|
Args:
|
|
expression: A string containing a boolean rule expression
|
|
(e.g., "not rule1:any and rule2:all")
|
|
|
|
Returns:
|
|
A list of Token objects
|
|
"""
|
|
tokens = []
|
|
i = 0
|
|
expression = expression.strip()
|
|
|
|
while i < len(expression):
|
|
char = expression[i]
|
|
|
|
# Skip whitespace
|
|
if char.isspace():
|
|
i += 1
|
|
continue
|
|
|
|
# Handle parentheses
|
|
elif char == '(':
|
|
tokens.append(Token(TokenType.LPAREN))
|
|
i += 1
|
|
elif char == ')':
|
|
tokens.append(Token(TokenType.RPAREN))
|
|
i += 1
|
|
|
|
# Handle operators
|
|
elif i + 2 < len(expression) and expression[i:i+3].lower() == 'and':
|
|
# Check if it's a standalone 'and' and not part of a word
|
|
if (i == 0 or expression[i-1].isspace() or expression[i-1] in "()") and \
|
|
(i+3 >= len(expression) or expression[i+3].isspace() or expression[i+3] in "()"):
|
|
tokens.append(Token(TokenType.AND))
|
|
i += 3
|
|
else:
|
|
# It's part of an attribute name
|
|
rule_start = i
|
|
while i < len(expression) and not expression[i].isspace() and expression[i] not in "()":
|
|
if i + 1 < len(expression) and expression[i] == ':':
|
|
break
|
|
i += 1
|
|
|
|
# Process rule if we found a colon
|
|
if i < len(expression) and expression[i] == ':':
|
|
rule_end = i
|
|
i += 1 # Skip the colon
|
|
|
|
# Find the rule type
|
|
type_start = i
|
|
while i < len(expression) and not expression[i].isspace() and expression[i] not in "()":
|
|
i += 1
|
|
|
|
rule_name = expression[rule_start:rule_end]
|
|
rule_type = expression[type_start:i]
|
|
|
|
tokens.append(Token(TokenType.RULE, (rule_name, rule_type)))
|
|
else:
|
|
raise ValueError(f"Invalid rule format at position {rule_start}")
|
|
|
|
elif i + 1 < len(expression) and expression[i:i+2].lower() == 'or':
|
|
# Check if it's a standalone 'or' and not part of a word
|
|
if (i == 0 or expression[i-1].isspace() or expression[i-1] in "()") and \
|
|
(i+2 >= len(expression) or expression[i+2].isspace() or expression[i+2] in "()"):
|
|
tokens.append(Token(TokenType.OR))
|
|
i += 2
|
|
else:
|
|
# Part of an attribute name
|
|
rule_start = i
|
|
while i < len(expression) and not expression[i].isspace() and expression[i] not in "()":
|
|
if i + 1 < len(expression) and expression[i] == ':':
|
|
break
|
|
i += 1
|
|
|
|
# Process rule if we found a colon
|
|
if i < len(expression) and expression[i] == ':':
|
|
rule_end = i
|
|
i += 1 # Skip the colon
|
|
|
|
# Find the rule type
|
|
type_start = i
|
|
while i < len(expression) and not expression[i].isspace() and expression[i] not in "()":
|
|
i += 1
|
|
|
|
rule_name = expression[rule_start:rule_end]
|
|
rule_type = expression[type_start:i]
|
|
|
|
tokens.append(Token(TokenType.RULE, (rule_name, rule_type)))
|
|
else:
|
|
raise ValueError(f"Invalid rule format at position {rule_start}")
|
|
|
|
elif i + 2 < len(expression) and expression[i:i+3].lower() == 'not':
|
|
# Check if it's a standalone 'not' and not part of a word
|
|
if (i == 0 or expression[i-1].isspace() or expression[i-1] in "()") and \
|
|
(i+3 >= len(expression) or expression[i+3].isspace() or expression[i+3] in "()"):
|
|
tokens.append(Token(TokenType.NOT))
|
|
i += 3
|
|
else:
|
|
# Part of an attribute name
|
|
rule_start = i
|
|
while i < len(expression) and not expression[i].isspace() and expression[i] not in "()":
|
|
if i + 1 < len(expression) and expression[i] == ':':
|
|
break
|
|
i += 1
|
|
|
|
# Process rule if we found a colon
|
|
if i < len(expression) and expression[i] == ':':
|
|
rule_end = i
|
|
i += 1 # Skip the colon
|
|
|
|
# Find the rule type
|
|
type_start = i
|
|
while i < len(expression) and not expression[i].isspace() and expression[i] not in "()":
|
|
i += 1
|
|
|
|
rule_name = expression[rule_start:rule_end]
|
|
rule_type = expression[type_start:i]
|
|
|
|
tokens.append(Token(TokenType.RULE, (rule_name, rule_type)))
|
|
else:
|
|
raise ValueError(f"Invalid rule format at position {rule_start}")
|
|
|
|
# Handle rule (attribute:type)
|
|
else:
|
|
rule_start = i
|
|
while i < len(expression) and not expression[i].isspace() and expression[i] not in "()":
|
|
if i + 1 < len(expression) and expression[i] == ':':
|
|
break
|
|
i += 1
|
|
|
|
# Process rule if we found a colon
|
|
if i < len(expression) and expression[i] == ':':
|
|
rule_end = i
|
|
i += 1 # Skip the colon
|
|
|
|
# Find the rule type
|
|
type_start = i
|
|
while i < len(expression) and not expression[i].isspace() and expression[i] not in "()":
|
|
i += 1
|
|
|
|
rule_name = expression[rule_start:rule_end]
|
|
rule_type = expression[type_start:i]
|
|
|
|
tokens.append(Token(TokenType.RULE, (rule_name, rule_type)))
|
|
else:
|
|
raise ValueError(f"Invalid rule format at position {rule_start}")
|
|
|
|
tokens.append(Token(TokenType.EOF))
|
|
return tokens
|
|
|
|
class Parser:
|
|
"""
|
|
Parser for boolean rule expressions.
|
|
Implements a recursive descent parser for expressions with the following grammar:
|
|
|
|
expression → or_expr
|
|
or_expr → and_expr ("or" and_expr)*
|
|
and_expr → unary_expr ("and" unary_expr)*
|
|
unary_expr → "not" unary_expr | primary
|
|
primary → rule | "(" expression ")"
|
|
rule → ATTRIBUTE ":" RULE_TYPE
|
|
"""
|
|
|
|
def __init__(self, tokens):
|
|
self.tokens = tokens
|
|
self.current = 0
|
|
|
|
def parse(self):
|
|
"""Parse the tokens into an expression tree."""
|
|
return self.expression()
|
|
|
|
def expression(self):
|
|
"""Parse an expression (top level)."""
|
|
return self.or_expr()
|
|
|
|
def or_expr(self):
|
|
"""Parse an OR expression."""
|
|
expr = self.and_expr()
|
|
|
|
while self.match(TokenType.OR):
|
|
right = self.and_expr()
|
|
expr = BinaryNode(expr, right, "OR")
|
|
|
|
return expr
|
|
|
|
def and_expr(self):
|
|
"""Parse an AND expression."""
|
|
expr = self.unary_expr()
|
|
|
|
while self.match(TokenType.AND):
|
|
right = self.unary_expr()
|
|
expr = BinaryNode(expr, right, "AND")
|
|
|
|
return expr
|
|
|
|
def unary_expr(self):
|
|
"""Parse a unary expression (NOT)."""
|
|
if self.match(TokenType.NOT):
|
|
operand = self.unary_expr()
|
|
return NotNode(operand)
|
|
|
|
return self.primary()
|
|
|
|
def primary(self):
|
|
"""Parse a primary expression (rule or parenthesized expression)."""
|
|
if self.match(TokenType.RULE):
|
|
rule_tuple = self.previous().value
|
|
attribute_name, rule_type = rule_tuple
|
|
|
|
# Validate rule type
|
|
if rule_type not in ["any", "all"]:
|
|
raise ValueError(f"Invalid rule type: {rule_type}. Supported types: 'any', 'all'")
|
|
|
|
return RuleNode(attribute_name, rule_type)
|
|
|
|
if self.match(TokenType.LPAREN):
|
|
expr = self.expression()
|
|
self.consume(TokenType.RPAREN, "Expected ')' after expression.")
|
|
return expr
|
|
|
|
raise ValueError(f"Expected rule or '(' at position {self.current}")
|
|
|
|
def match(self, *types):
|
|
"""Check if the current token matches any of the given types."""
|
|
for type in types:
|
|
if self.check(type):
|
|
self.advance()
|
|
return True
|
|
|
|
return False
|
|
|
|
def check(self, type):
|
|
"""Check if the current token is of the given type without advancing."""
|
|
if self.is_at_end():
|
|
return False
|
|
return self.peek().type == type
|
|
|
|
def advance(self):
|
|
"""Advance to the next token and return the previous one."""
|
|
if not self.is_at_end():
|
|
self.current += 1
|
|
return self.previous()
|
|
|
|
def consume(self, type, message):
|
|
"""Consume the current token if it matches the expected type."""
|
|
if self.check(type):
|
|
return self.advance()
|
|
|
|
raise ValueError(f"{message} at position {self.current}")
|
|
|
|
def is_at_end(self):
|
|
"""Check if we've reached the end of the tokens."""
|
|
return self.peek().type == TokenType.EOF
|
|
|
|
def peek(self):
|
|
"""Return the current token without advancing."""
|
|
return self.tokens[self.current]
|
|
|
|
def previous(self):
|
|
"""Return the previous token."""
|
|
return self.tokens[self.current - 1]
|
|
|
|
def parse_rule(rule_string):
|
|
"""
|
|
Parse a rule string into an expression tree or a simple attribute-rule_type tuple.
|
|
|
|
Args:
|
|
rule_string: A string containing a rule or boolean expression of rules
|
|
|
|
Returns:
|
|
Either a tuple (attribute_name, rule_type) for simple rules,
|
|
or an ExpressionNode for complex boolean expressions
|
|
"""
|
|
# Check if this is a simple rule
|
|
if "and" not in rule_string.lower() and "or" not in rule_string.lower() and "not" not in rule_string.lower() and "(" not in rule_string and ")" not in rule_string:
|
|
# Simple rule format: attribute_name:rule_type
|
|
parts = rule_string.split(":", 1)
|
|
if len(parts) != 2:
|
|
raise ValueError(f"Invalid rule format: {rule_string}. Expected format: 'attribute_name:rule_type'")
|
|
|
|
attribute_name, rule_type = parts
|
|
if rule_type not in ["any", "all"]:
|
|
raise ValueError(f"Invalid rule type: {rule_type}. Supported types: 'any', 'all'")
|
|
|
|
return attribute_name, rule_type
|
|
else:
|
|
# Complex rule expression
|
|
try:
|
|
tokens = tokenize_expression(rule_string)
|
|
parser = Parser(tokens)
|
|
return parser.parse()
|
|
except Exception as e:
|
|
raise ValueError(f"Error parsing expression '{rule_string}': {e}")
|
|
|
|
def is_complex_expression(rule):
|
|
"""Check if the rule is a complex boolean expression."""
|
|
return isinstance(rule, ExpressionNode)
|
|
|
|
|
|
def calculate_iou(ref_ids, hyp_ids):
|
|
"""Calculate Intersection over Union of two sets of document IDs."""
|
|
ref_set = set(ref_ids)
|
|
hyp_set = set(hyp_ids)
|
|
|
|
intersection = ref_set.intersection(hyp_set)
|
|
union = ref_set.union(hyp_set)
|
|
|
|
if not union:
|
|
return 0.0
|
|
|
|
return len(intersection) / len(union)
|
|
|
|
def collect_rule_stats(expression, doc):
|
|
"""
|
|
Collect statistics for all rules within a complex expression.
|
|
|
|
Args:
|
|
expression: A rule expression (either a tuple or ExpressionNode)
|
|
doc: The document to analyze
|
|
|
|
Returns:
|
|
A dictionary with rule statistics
|
|
"""
|
|
rule_stats = defaultdict(int)
|
|
|
|
# Handle simple rule
|
|
if not is_complex_expression(expression):
|
|
attribute_name, rule_type = expression
|
|
# Only process if document has this attribute
|
|
if ("attributes" in doc and doc["attributes"] and
|
|
attribute_name in doc["attributes"] and doc["attributes"][attribute_name]):
|
|
# The rule name will be the key for the statistics
|
|
rule_name = f"{attribute_name}:{rule_type}"
|
|
|
|
# Count entries in the attribute
|
|
entries = doc["attributes"][attribute_name]
|
|
rule_stats[f"{rule_name}_total_entries"] += len(entries)
|
|
|
|
# Count positive values
|
|
for span in entries:
|
|
if len(span) >= 3 and span[2] is True:
|
|
rule_stats[f"{rule_name}_positive_entries"] += 1
|
|
|
|
# Check if document matches the rule
|
|
if apply_simple_rule(doc, attribute_name, rule_type):
|
|
rule_stats[f"{rule_name}_matched_docs"] += 1
|
|
|
|
return rule_stats
|
|
|
|
# For complex expressions, traverse the expression tree
|
|
if isinstance(expression, RuleNode):
|
|
# Base case: leaf node is a simple rule
|
|
attribute_name, rule_type = expression.attribute_name, expression.rule_type
|
|
if ("attributes" in doc and doc["attributes"] and
|
|
attribute_name in doc["attributes"] and doc["attributes"][attribute_name]):
|
|
# The rule name will be the key for the statistics
|
|
rule_name = f"{attribute_name}:{rule_type}"
|
|
|
|
# Count entries in the attribute
|
|
entries = doc["attributes"][attribute_name]
|
|
rule_stats[f"{rule_name}_total_entries"] += len(entries)
|
|
|
|
# Count positive values
|
|
for span in entries:
|
|
if len(span) >= 3 and span[2] is True:
|
|
rule_stats[f"{rule_name}_positive_entries"] += 1
|
|
|
|
# Check if document matches the rule
|
|
if apply_simple_rule(doc, attribute_name, rule_type):
|
|
rule_stats[f"{rule_name}_matched_docs"] += 1
|
|
|
|
elif isinstance(expression, NotNode):
|
|
# Get stats from the operand
|
|
operand_stats = collect_rule_stats(expression.operand, doc)
|
|
# Merge with current stats
|
|
for key, value in operand_stats.items():
|
|
rule_stats[key] += value
|
|
|
|
elif isinstance(expression, BinaryNode):
|
|
# Get stats from both sides
|
|
left_stats = collect_rule_stats(expression.left, doc)
|
|
right_stats = collect_rule_stats(expression.right, doc)
|
|
|
|
# Merge with current stats
|
|
for key, value in left_stats.items():
|
|
rule_stats[key] += value
|
|
for key, value in right_stats.items():
|
|
rule_stats[key] += value
|
|
|
|
return rule_stats
|
|
|
|
def get_expression_summary(expression):
|
|
"""
|
|
Generate a string representation of a rule expression.
|
|
|
|
Args:
|
|
expression: A rule expression (either a tuple or ExpressionNode)
|
|
|
|
Returns:
|
|
A string representation of the expression
|
|
"""
|
|
if not is_complex_expression(expression):
|
|
return f"{expression[0]}:{expression[1]}"
|
|
|
|
if isinstance(expression, RuleNode):
|
|
return f"{expression.attribute_name}:{expression.rule_type}"
|
|
|
|
elif isinstance(expression, NotNode):
|
|
return f"not {get_expression_summary(expression.operand)}"
|
|
|
|
elif isinstance(expression, BinaryNode):
|
|
left_summary = get_expression_summary(expression.left)
|
|
right_summary = get_expression_summary(expression.right)
|
|
return f"({left_summary} {expression.operator.lower()} {right_summary})"
|
|
|
|
return str(expression)
|
|
|
|
def compare_documents(ref_docs, hyp_docs, ref_rule, hyp_rule):
|
|
"""
|
|
Compare two sets of documents using the specified rules and calculate IoU.
|
|
|
|
Args:
|
|
ref_docs: List of reference documents
|
|
hyp_docs: List of hypothesis documents
|
|
ref_rule: Rule expression for reference (tuple or ExpressionNode)
|
|
hyp_rule: Rule expression for hypothesis (tuple or ExpressionNode)
|
|
|
|
Returns:
|
|
Dictionary with comparison results
|
|
"""
|
|
# Extract document IDs and create ID-to-document maps
|
|
ref_id_to_doc = {doc["id"]: doc for doc in ref_docs if "id" in doc}
|
|
hyp_id_to_doc = {doc["id"]: doc for doc in hyp_docs if "id" in doc}
|
|
|
|
# Get common document IDs
|
|
common_ids = set(ref_id_to_doc.keys()).intersection(set(hyp_id_to_doc.keys()))
|
|
|
|
# Apply rules to each document
|
|
ref_matches = set()
|
|
hyp_matches = set()
|
|
|
|
# Track rule statistics
|
|
ref_rule_stats = defaultdict(int)
|
|
hyp_rule_stats = defaultdict(int)
|
|
|
|
for doc_id in common_ids:
|
|
ref_doc = ref_id_to_doc[doc_id]
|
|
hyp_doc = hyp_id_to_doc[doc_id]
|
|
|
|
# Collect statistics for all rules in the expressions
|
|
doc_ref_rule_stats = collect_rule_stats(ref_rule, ref_doc)
|
|
doc_hyp_rule_stats = collect_rule_stats(hyp_rule, hyp_doc)
|
|
|
|
# Merge with overall stats
|
|
for key, value in doc_ref_rule_stats.items():
|
|
ref_rule_stats[key] += value
|
|
for key, value in doc_hyp_rule_stats.items():
|
|
hyp_rule_stats[key] += value
|
|
|
|
# Check if document matches the rule expressions
|
|
if apply_rule(ref_doc, ref_rule):
|
|
ref_matches.add(doc_id)
|
|
ref_rule_stats["expression_matched_docs"] += 1
|
|
|
|
if apply_rule(hyp_doc, hyp_rule):
|
|
hyp_matches.add(doc_id)
|
|
hyp_rule_stats["expression_matched_docs"] += 1
|
|
|
|
# Calculate IoU
|
|
iou = calculate_iou(ref_matches, hyp_matches)
|
|
|
|
# Collect detailed statistics
|
|
tp = len(ref_matches.intersection(hyp_matches))
|
|
fp = len(hyp_matches - ref_matches)
|
|
fn = len(ref_matches - hyp_matches)
|
|
|
|
precision = tp / (tp + fp) if (tp + fp) > 0 else 0
|
|
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
|
|
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
|
|
|
|
# Generate string representations of the expressions
|
|
ref_rule_str = get_expression_summary(ref_rule)
|
|
hyp_rule_str = get_expression_summary(hyp_rule)
|
|
|
|
return {
|
|
"total_docs": len(common_ids),
|
|
"ref_rule": ref_rule_str,
|
|
"hyp_rule": hyp_rule_str,
|
|
"ref_matches": len(ref_matches),
|
|
"hyp_matches": len(hyp_matches),
|
|
"intersection": tp,
|
|
"union": tp + fp + fn,
|
|
"true_positives": tp,
|
|
"false_positives": fp,
|
|
"false_negatives": fn,
|
|
"precision": precision,
|
|
"recall": recall,
|
|
"f1": f1,
|
|
"iou": iou,
|
|
"ref_rule_stats": dict(ref_rule_stats),
|
|
"hyp_rule_stats": dict(hyp_rule_stats)
|
|
}
|
|
|
|
def format_rule_stats(rule_stats):
|
|
"""Format rule statistics for display."""
|
|
# Group the statistics by rule name
|
|
grouped_stats = defaultdict(dict)
|
|
|
|
# Process regular rule stats (format: "{rule_name}_{stat_type}")
|
|
for key, value in rule_stats.items():
|
|
if key == "expression_matched_docs":
|
|
# Special case for the overall expression match count
|
|
continue
|
|
|
|
# Extract rule name and stat type
|
|
if "_total_entries" in key:
|
|
rule_name = key.replace("_total_entries", "")
|
|
grouped_stats[rule_name]["total_entries"] = value
|
|
elif "_positive_entries" in key:
|
|
rule_name = key.replace("_positive_entries", "")
|
|
grouped_stats[rule_name]["positive_entries"] = value
|
|
elif "_matched_docs" in key:
|
|
rule_name = key.replace("_matched_docs", "")
|
|
grouped_stats[rule_name]["matched_docs"] = value
|
|
|
|
# Format the grouped statistics as a list of strings
|
|
formatted_stats = []
|
|
for rule_name, stats in grouped_stats.items():
|
|
formatted_stats.append(
|
|
f" {rule_name}:\n"
|
|
f" - Total Entries: {stats.get('total_entries', 0)}\n"
|
|
f" - Positive Entries: {stats.get('positive_entries', 0)}\n"
|
|
f" - Matched Documents: {stats.get('matched_docs', 0)}"
|
|
)
|
|
|
|
# Add the expression matched count if available
|
|
if "expression_matched_docs" in rule_stats:
|
|
formatted_stats.append(f" Overall Expression Matched Documents: {rule_stats['expression_matched_docs']}")
|
|
|
|
return "\n".join(formatted_stats)
|
|
|
|
def main():
|
|
global args
|
|
args = parse_args()
|
|
|
|
# Set up logging based on arguments
|
|
if args.debug:
|
|
logger.setLevel(logging.DEBUG)
|
|
logger.debug("Debug logging enabled")
|
|
|
|
# Set up S3 client if needed
|
|
s3_client = None
|
|
if (args.docs_folder.startswith("s3://") or
|
|
args.ref_attr_folder.startswith("s3://") or
|
|
args.hyp_attr_folder.startswith("s3://")):
|
|
session = boto3.Session(profile_name=args.aws_profile) if args.aws_profile else boto3.Session()
|
|
s3_client = session.client("s3")
|
|
|
|
# Parse the rules
|
|
logger.info(f"Parsing reference rule expression: {args.ref_rule}")
|
|
ref_rule = parse_rule(args.ref_rule)
|
|
|
|
logger.info(f"Parsing hypothesis rule expression: {args.hyp_rule}")
|
|
hyp_rule = parse_rule(args.hyp_rule)
|
|
|
|
# Generate string representations of the expressions
|
|
ref_rule_str = get_expression_summary(ref_rule)
|
|
hyp_rule_str = get_expression_summary(hyp_rule)
|
|
|
|
logger.info(f"Reference rule parsed as: {ref_rule_str}")
|
|
logger.info(f"Hypothesis rule parsed as: {hyp_rule_str}")
|
|
|
|
# Load documents and merge with attributes
|
|
logger.info("Loading documents and merging with reference attributes...")
|
|
ref_docs = load_documents_and_attributes(args.docs_folder, args.ref_attr_folder, s3_client, args.recursive)
|
|
|
|
logger.info("Loading documents and merging with hypothesis attributes...")
|
|
hyp_docs = load_documents_and_attributes(args.docs_folder, args.hyp_attr_folder, s3_client, args.recursive)
|
|
|
|
# Compare the documents
|
|
logger.info("Comparing documents using reference and hypothesis rules...")
|
|
comparison_result = compare_documents(ref_docs, hyp_docs, ref_rule, hyp_rule)
|
|
|
|
# Prepare overall statistics
|
|
overall_stats = {
|
|
"total_docs": comparison_result["total_docs"],
|
|
"ref_matches": comparison_result["ref_matches"],
|
|
"hyp_matches": comparison_result["hyp_matches"],
|
|
"true_positives": comparison_result["true_positives"],
|
|
"false_positives": comparison_result["false_positives"],
|
|
"false_negatives": comparison_result["false_negatives"],
|
|
"precision": comparison_result["precision"],
|
|
"recall": comparison_result["recall"],
|
|
"f1": comparison_result["f1"],
|
|
"iou": comparison_result["iou"],
|
|
"ref_rule_stats": comparison_result["ref_rule_stats"],
|
|
"hyp_rule_stats": comparison_result["hyp_rule_stats"]
|
|
}
|
|
|
|
# Prepare final output
|
|
output = {
|
|
"config": {
|
|
"docs_folder": args.docs_folder,
|
|
"ref_attr_folder": args.ref_attr_folder,
|
|
"hyp_attr_folder": args.hyp_attr_folder,
|
|
"ref_rule": args.ref_rule,
|
|
"ref_rule_parsed": ref_rule_str,
|
|
"hyp_rule": args.hyp_rule,
|
|
"hyp_rule_parsed": hyp_rule_str
|
|
},
|
|
"overall": overall_stats
|
|
}
|
|
|
|
# Save results
|
|
with open(args.output_file, "w") as f:
|
|
json.dump(output, f, indent=2)
|
|
|
|
# Print summary
|
|
logger.info("\n--- COMPARISON SUMMARY ---")
|
|
logger.info(f"Documents Folder: {args.docs_folder}")
|
|
logger.info(f"Reference Attributes Folder: {args.ref_attr_folder}")
|
|
logger.info(f"Hypothesis Attributes Folder: {args.hyp_attr_folder}")
|
|
logger.info(f"Reference Rule Expression: {args.ref_rule}")
|
|
logger.info(f" Parsed as: {ref_rule_str}")
|
|
logger.info(f"Hypothesis Rule Expression: {args.hyp_rule}")
|
|
logger.info(f" Parsed as: {hyp_rule_str}")
|
|
logger.info(f"Total Documents: {overall_stats['total_docs']}")
|
|
|
|
# Print rule statistics
|
|
logger.info("\n--- RULE MATCH STATISTICS ---")
|
|
|
|
logger.info("\nReference Rules:")
|
|
logger.info(format_rule_stats(overall_stats["ref_rule_stats"]))
|
|
|
|
logger.info("\nHypothesis Rules:")
|
|
logger.info(format_rule_stats(overall_stats["hyp_rule_stats"]))
|
|
|
|
# Print comparison metrics
|
|
logger.info("\n--- COMPARISON METRICS ---")
|
|
logger.info(f"IoU: {overall_stats['iou']:.4f}")
|
|
logger.info(f"Precision: {overall_stats['precision']:.4f}")
|
|
logger.info(f"Recall: {overall_stats['recall']:.4f}")
|
|
logger.info(f"F1 Score: {overall_stats['f1']:.4f}")
|
|
logger.info(f"Detailed results saved to: {args.output_file}")
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|
|
# Example commands with actual S3 paths:
|
|
"""
|
|
# Example for AI2 OE data with resume detection:
|
|
python scripts/pii_rule_comparison.py \
|
|
--docs-folder s3://ai2-oe-data/jakep/s2pdf_dedupe_minhash_v1_mini/documents/ \
|
|
--ref-attr-folder s3://ai2-oe-data/jakep/s2pdf_dedupe_minhash_v1_mini/attributes/chatgpt_pii_vision/ \
|
|
--hyp-attr-folder s3://ai2-oe-data/jakep/s2pdf_dedupe_minhash_v1_mini/attributes/model_pii_tagging/ \
|
|
--ref-rule "gpt_4_1_contains_pii:any and not gpt_4_1_is_public_document:all" \
|
|
--hyp-rule "google_gemma-3-4b-it_is_resume_cv:any" \
|
|
--output-file pii_resume_comparison.json \
|
|
--recursive \
|
|
--debug
|
|
|
|
# Example for Dolma data with PII detection:
|
|
python scripts/pii_rule_comparison.py \
|
|
--docs-folder s3://allenai-dolma/documents/v1.5 \
|
|
--ref-attr-folder s3://allenai-dolma/attributes/v1.5/pii_detection/gpt4_1 \
|
|
--hyp-attr-folder s3://allenai-dolma/attributes/v1.5/pii_detection/custom_rule \
|
|
--ref-rule "contains_pii:any" \
|
|
--hyp-rule "(contains_email_addresses:any or contains_phone_numbers:any) and not false_positive:any" \
|
|
--output-file pii_detection_comparison.json \
|
|
--recursive \
|
|
--aws-profile dolma
|
|
""" |