Hypothesis checker

This commit is contained in:
Jake Poznanski 2025-05-08 17:58:50 +00:00
parent 3aba3a5c10
commit 80645c886e

View File

@ -2,22 +2,46 @@
"""
Compare PII Detection Rules and Calculate IoU
This script processes JSONL attribute files from two different folders,
applies different rules to each for PII detection, and calculates the
This script processes documents and their attributes from S3 or local storage,
applies different rules for PII detection, and calculates the
Intersection over Union (IoU) to measure how well they overlap.
How it works:
1. Documents are stored in one location (--docs-folder)
2. Attributes are stored in separate folders (--ref-attr-folder and --hyp-attr-folder)
3. The script merges documents with their attributes by matching filenames and document IDs
4. PII detection rules are applied to each set of merged documents
5. IoU and other metrics are calculated to compare the results
Folder structure:
- s3://bucket/path/documents/ - Contains the main document JSONL files
- s3://bucket/path/attributes/attr_name/ - Contains attributes that can be matched with documents by ID
Document and attribute matching:
- Files are matched by basename (example.jsonl in documents matches example.jsonl in attributes)
- Within each file, documents are matched by their "id" field
- When a match is found, attributes from the attribute file are merged into the document
Example usage:
python pii_rule_comparison.py \
--ref-folder s3://bucket/workspace/attributes/model_a \
--hyp-folder s3://bucket/workspace/attributes/model_b \
--docs-folder s3://bucket/path/documents \
--ref-attr-folder s3://bucket/path/attributes/model_a \
--hyp-attr-folder s3://bucket/path/attributes/model_b \
--ref-rule "gpt_4_1_contains_pii:any" \
--hyp-rule "gpt_4_1_contains_email_addresses:any" \
--output-file iou_results.json
--output-file iou_results.json \
--recursive
Rule expression syntax:
- Simple rule: "attribute_name:rule_type" where rule_type is "any" or "all"
- Boolean expressions: "not rule1:any and rule2:all"
- Parentheses for grouping: "(rule1:any or rule2:any) and not rule3:all"
"""
import argparse
import boto3
import gzip
import io
import json
import logging
import os
@ -90,8 +114,9 @@ class BinaryNode(ExpressionNode):
def parse_args():
parser = argparse.ArgumentParser(description="Compare PII detection rules and calculate IoU")
parser.add_argument("--ref-folder", required=True, help="Reference attribute folder path (local or s3://)")
parser.add_argument("--hyp-folder", required=True, help="Hypothesis attribute folder path (local or s3://)")
parser.add_argument("--docs-folder", required=True, help="Documents folder path containing JSONL files (local or s3://)")
parser.add_argument("--ref-attr-folder", required=True, help="Reference attributes folder path (local or s3://)")
parser.add_argument("--hyp-attr-folder", required=True, help="Hypothesis attributes folder path (local or s3://)")
parser.add_argument("--ref-rule", required=True,
help="""Reference rule expression. Can be a simple rule in format 'attribute_name:rule_type',
where rule_type is 'any' or 'all'. Or a boolean expression like
@ -103,6 +128,7 @@ def parse_args():
parser.add_argument("--output-file", default="iou_results.json", help="Output JSON file to save results")
parser.add_argument("--aws-profile", help="AWS profile for S3 access")
parser.add_argument("--recursive", action="store_true", help="Recursively process folder structure")
parser.add_argument("--debug", action="store_true", help="Enable debug logging for more detailed output")
return parser.parse_args()
def parse_s3_path(s3_path):
@ -170,8 +196,35 @@ def load_jsonl_file(file_path, s3_client=None):
if file_path.endswith(".gz"):
decompressed = gzip.decompress(raw_data)
elif file_path.endswith((".zst", ".ztd", ".zstd")):
dctx = zstd.ZstdDecompressor()
decompressed = dctx.decompress(raw_data)
try:
# First try with standard decompression
dctx = zstd.ZstdDecompressor()
decompressed = dctx.decompress(raw_data)
except zstd.ZstdError as e:
# If that fails, try with stream decompression
logger.warning(f"Standard zstd decompression failed for {file_path}, trying stream decompression: {e}")
try:
# Try with content-size not required
dctx = zstd.ZstdDecompressor(max_window_size=2147483648) # Use a large window size
decompressor = dctx.stream_reader(io.BytesIO(raw_data))
decompressed = decompressor.read()
except Exception as inner_e:
# If both methods fail, try with chunking
logger.warning(f"Stream decompression also failed, trying chunked reading: {inner_e}")
# Chunked reading approach
buffer = io.BytesIO()
dctx = zstd.ZstdDecompressor(max_window_size=2147483648)
with dctx.stream_reader(io.BytesIO(raw_data)) as reader:
while True:
chunk = reader.read(16384) # Read in 16KB chunks
if not chunk:
break
buffer.write(chunk)
buffer.seek(0)
decompressed = buffer.read()
else:
decompressed = raw_data
@ -183,6 +236,129 @@ def load_jsonl_file(file_path, s3_client=None):
logger.error(f"Error loading file {file_path}: {e}")
return []
def load_documents_and_attributes(docs_folder, attr_folder, s3_client=None, recursive=False):
"""
Load documents and merge them with their attributes.
Args:
docs_folder: Path to the documents folder
attr_folder: Path to the attributes folder
s3_client: S3 client for S3 paths
recursive: Whether to process folders recursively
Returns:
List of documents with their attributes merged in
"""
try:
# List all document files
logger.info(f"Finding document files in: {docs_folder}")
doc_files = list_jsonl_files(docs_folder, s3_client, recursive)
logger.info(f"Found {len(doc_files)} document files")
if not doc_files:
logger.warning(f"No document files found in {docs_folder}. Check the path and permissions.")
# List all attribute files
logger.info(f"Finding attribute files in: {attr_folder}")
attr_files = list_jsonl_files(attr_folder, s3_client, recursive)
logger.info(f"Found {len(attr_files)} attribute files")
if not attr_files:
logger.warning(f"No attribute files found in {attr_folder}. Check the path and permissions.")
# Create a mapping from document filename to attribute filename
# based on matching basenames
attr_file_map = {}
for attr_path in attr_files:
if attr_path.startswith("s3://"):
_, attr_key = parse_s3_path(attr_path)
basename = os.path.basename(attr_key)
else:
basename = os.path.basename(attr_path)
attr_file_map[basename] = attr_path
logger.debug(f"Mapped attribute file basename: {basename} -> {attr_path}")
# Load and merge documents with their attributes
merged_docs = []
matched_files = 0
docs_with_matched_attrs = 0
for doc_path in doc_files:
try:
if doc_path.startswith("s3://"):
_, doc_key = parse_s3_path(doc_path)
basename = os.path.basename(doc_key)
else:
basename = os.path.basename(doc_path)
# Load documents
docs = load_jsonl_file(doc_path, s3_client)
if not docs:
logger.warning(f"No documents loaded from {basename} (path: {doc_path})")
continue
logger.info(f"Loaded {len(docs)} documents from {basename}")
# Find matching attribute file
if basename in attr_file_map:
matched_files += 1
attr_path = attr_file_map[basename]
attrs = load_jsonl_file(attr_path, s3_client)
if not attrs:
logger.warning(f"No attributes loaded from {os.path.basename(attr_path)} (path: {attr_path})")
merged_docs.extend(docs)
continue
logger.info(f"Loaded {len(attrs)} attributes from {os.path.basename(attr_path)}")
# Create a mapping from document ID to attributes
attr_by_id = {attr["id"]: attr for attr in attrs if "id" in attr}
logger.info(f"Created mapping for {len(attr_by_id)} attribute entries by ID")
# Count documents with matched attributes
docs_matched_in_file = 0
# Merge documents with their attributes
for doc in docs:
if "id" in doc and doc["id"] in attr_by_id:
docs_matched_in_file += 1
# If document doesn't have attributes field, create it
if "attributes" not in doc:
doc["attributes"] = {}
# If attributes document has attributes field, merge them
if "attributes" in attr_by_id[doc["id"]]:
doc["attributes"].update(attr_by_id[doc["id"]]["attributes"])
else:
logger.debug(f"Attribute document {doc['id']} has no 'attributes' field")
docs_with_matched_attrs += docs_matched_in_file
logger.info(f"Matched attributes for {docs_matched_in_file}/{len(docs)} documents in {basename}")
merged_docs.extend(docs)
else:
logger.warning(f"No matching attribute file found for {basename}")
merged_docs.extend(docs)
except Exception as e:
logger.error(f"Error processing document file {doc_path}: {e}")
continue
logger.info(f"Total documents processed: {len(merged_docs)}")
logger.info(f"Files with matched attributes: {matched_files}/{len(doc_files)}")
logger.info(f"Documents with matched attributes: {docs_with_matched_attrs}")
# Check if no attributes were merged
if docs_with_matched_attrs == 0:
logger.warning("No documents had matching attributes! This may indicate a problem with file naming or ID matching.")
return merged_docs
except Exception as e:
logger.error(f"Error in load_documents_and_attributes: {e}")
raise
def apply_rule(doc, rule):
"""
Apply a rule to determine if a document meets the PII criteria.
@ -214,11 +390,20 @@ def apply_simple_rule(doc, attribute_name, rule_type):
Returns:
True if the document matches the rule, False otherwise
"""
# Check if document has attributes
if "attributes" not in doc or not doc["attributes"]:
logger.debug(f"Document {doc.get('id', 'unknown')} has no attributes")
return False
attributes = doc["attributes"]
if attribute_name not in attributes or not attributes[attribute_name]:
# Check if the specific attribute exists
if attribute_name not in attributes:
logger.debug(f"Document {doc.get('id', 'unknown')} doesn't have attribute: {attribute_name}")
return False
if not attributes[attribute_name]:
logger.debug(f"Document {doc.get('id', 'unknown')} has empty attribute: {attribute_name}")
return False
# Extract the boolean values from the attribute spans
@ -226,12 +411,20 @@ def apply_simple_rule(doc, attribute_name, rule_type):
values = [span[2] for span in attributes[attribute_name] if len(span) >= 3 and span[2] is not None]
if not values:
logger.debug(f"Document {doc.get('id', 'unknown')} has no valid values for attribute: {attribute_name}")
return False
# Apply the rule
if rule_type == "any":
return any(values)
result = any(values)
if result:
logger.debug(f"Document {doc.get('id', 'unknown')} matched rule '{attribute_name}:{rule_type}' (found True in {len(values)} values)")
return result
elif rule_type == "all":
return all(values)
result = all(values)
if result:
logger.debug(f"Document {doc.get('id', 'unknown')} matched rule '{attribute_name}:{rule_type}' (all {len(values)} values are True)")
return result
else:
raise ValueError(f"Unknown rule type: {rule_type}")
@ -567,35 +760,6 @@ def is_complex_expression(rule):
"""Check if the rule is a complex boolean expression."""
return isinstance(rule, ExpressionNode)
def get_matching_files(ref_files, hyp_files):
"""
Find files that exist in both reference and hypothesis folders,
matching by their relative paths.
Returns dict mapping ref_path -> hyp_path for matched files
"""
# First, convert to relative paths for matching
def get_relative_path(path, base_folder):
if path.startswith("s3://"):
_, full_key = parse_s3_path(path)
_, base_key = parse_s3_path(base_folder)
return full_key[len(base_key):].lstrip("/") if full_key.startswith(base_key) else full_key
else:
return os.path.relpath(path, base_folder)
ref_base = args.ref_folder
hyp_base = args.hyp_folder
ref_relative = {get_relative_path(path, ref_base): path for path in ref_files}
hyp_relative = {get_relative_path(path, hyp_base): path for path in hyp_files}
# Find matching files
matched_files = {}
for rel_path in ref_relative:
if rel_path in hyp_relative:
matched_files[ref_relative[rel_path]] = hyp_relative[rel_path]
return matched_files
def calculate_iou(ref_ids, hyp_ids):
"""Calculate Intersection over Union of two sets of document IDs."""
@ -715,24 +879,19 @@ def get_expression_summary(expression):
return str(expression)
def compare_files(ref_path, hyp_path, ref_rule, hyp_rule, s3_client=None):
def compare_documents(ref_docs, hyp_docs, ref_rule, hyp_rule):
"""
Compare two JSONL files using the specified rules and calculate IoU.
Compare two sets of documents using the specified rules and calculate IoU.
Args:
ref_path: Path to reference JSONL file
hyp_path: Path to hypothesis JSONL file
ref_docs: List of reference documents
hyp_docs: List of hypothesis documents
ref_rule: Rule expression for reference (tuple or ExpressionNode)
hyp_rule: Rule expression for hypothesis (tuple or ExpressionNode)
s3_client: S3 client for S3 paths
Returns:
Dictionary with comparison results
"""
# Load the files
ref_docs = load_jsonl_file(ref_path, s3_client)
hyp_docs = load_jsonl_file(hyp_path, s3_client)
# Extract document IDs and create ID-to-document maps
ref_id_to_doc = {doc["id"]: doc for doc in ref_docs if "id" in doc}
hyp_id_to_doc = {doc["id"]: doc for doc in hyp_docs if "id" in doc}
@ -788,8 +947,6 @@ def compare_files(ref_path, hyp_path, ref_rule, hyp_rule, s3_client=None):
hyp_rule_str = get_expression_summary(hyp_rule)
return {
"ref_file": ref_path,
"hyp_file": hyp_path,
"total_docs": len(common_ids),
"ref_rule": ref_rule_str,
"hyp_rule": hyp_rule_str,
@ -850,9 +1007,16 @@ def main():
global args
args = parse_args()
# Set up logging based on arguments
if args.debug:
logger.setLevel(logging.DEBUG)
logger.debug("Debug logging enabled")
# Set up S3 client if needed
s3_client = None
if args.ref_folder.startswith("s3://") or args.hyp_folder.startswith("s3://"):
if (args.docs_folder.startswith("s3://") or
args.ref_attr_folder.startswith("s3://") or
args.hyp_attr_folder.startswith("s3://")):
session = boto3.Session(profile_name=args.aws_profile) if args.aws_profile else boto3.Session()
s3_client = session.client("s3")
@ -870,88 +1034,45 @@ def main():
logger.info(f"Reference rule parsed as: {ref_rule_str}")
logger.info(f"Hypothesis rule parsed as: {hyp_rule_str}")
# List JSONL files in both folders
logger.info(f"Finding JSONL files in reference folder: {args.ref_folder}")
ref_files = list_jsonl_files(args.ref_folder, s3_client, args.recursive)
# Load documents and merge with attributes
logger.info("Loading documents and merging with reference attributes...")
ref_docs = load_documents_and_attributes(args.docs_folder, args.ref_attr_folder, s3_client, args.recursive)
logger.info(f"Finding JSONL files in hypothesis folder: {args.hyp_folder}")
hyp_files = list_jsonl_files(args.hyp_folder, s3_client, args.recursive)
logger.info("Loading documents and merging with hypothesis attributes...")
hyp_docs = load_documents_and_attributes(args.docs_folder, args.hyp_attr_folder, s3_client, args.recursive)
logger.info(f"Found {len(ref_files)} files in reference folder and {len(hyp_files)} files in hypothesis folder")
# Compare the documents
logger.info("Comparing documents using reference and hypothesis rules...")
comparison_result = compare_documents(ref_docs, hyp_docs, ref_rule, hyp_rule)
# Find matching files
matched_files = get_matching_files(ref_files, hyp_files)
logger.info(f"Found {len(matched_files)} matching files between folders")
if not matched_files:
logger.error("No matching files found between reference and hypothesis folders")
sys.exit(1)
# Process each pair of files
results = []
# Prepare overall statistics
overall_stats = {
"total_files": len(matched_files),
"total_docs": 0,
"ref_matches": 0,
"hyp_matches": 0,
"true_positives": 0,
"false_positives": 0,
"false_negatives": 0,
# Initialize rule stats counters
"ref_rule_stats": defaultdict(int),
"hyp_rule_stats": defaultdict(int)
"total_docs": comparison_result["total_docs"],
"ref_matches": comparison_result["ref_matches"],
"hyp_matches": comparison_result["hyp_matches"],
"true_positives": comparison_result["true_positives"],
"false_positives": comparison_result["false_positives"],
"false_negatives": comparison_result["false_negatives"],
"precision": comparison_result["precision"],
"recall": comparison_result["recall"],
"f1": comparison_result["f1"],
"iou": comparison_result["iou"],
"ref_rule_stats": comparison_result["ref_rule_stats"],
"hyp_rule_stats": comparison_result["hyp_rule_stats"]
}
for i, (ref_path, hyp_path) in enumerate(matched_files.items()):
logger.info(f"Processing file pair {i+1}/{len(matched_files)}: {os.path.basename(ref_path)}")
file_result = compare_files(ref_path, hyp_path, ref_rule, hyp_rule, s3_client)
results.append(file_result)
# Accumulate overall statistics
overall_stats["total_docs"] += file_result["total_docs"]
overall_stats["ref_matches"] += file_result["ref_matches"]
overall_stats["hyp_matches"] += file_result["hyp_matches"]
overall_stats["true_positives"] += file_result["true_positives"]
overall_stats["false_positives"] += file_result["false_positives"]
overall_stats["false_negatives"] += file_result["false_negatives"]
# Accumulate rule statistics
for key, value in file_result["ref_rule_stats"].items():
overall_stats["ref_rule_stats"][key] += value
for key, value in file_result["hyp_rule_stats"].items():
overall_stats["hyp_rule_stats"][key] += value
# Calculate overall metrics
tp = overall_stats["true_positives"]
fp = overall_stats["false_positives"]
fn = overall_stats["false_negatives"]
overall_stats["precision"] = tp / (tp + fp) if (tp + fp) > 0 else 0
overall_stats["recall"] = tp / (tp + fn) if (tp + fn) > 0 else 0
overall_stats["f1"] = (
2 * overall_stats["precision"] * overall_stats["recall"] /
(overall_stats["precision"] + overall_stats["recall"])
if (overall_stats["precision"] + overall_stats["recall"]) > 0 else 0
)
overall_stats["iou"] = tp / (tp + fp + fn) if (tp + fp + fn) > 0 else 0
# Convert defaultdicts to regular dicts for JSON serialization
overall_stats["ref_rule_stats"] = dict(overall_stats["ref_rule_stats"])
overall_stats["hyp_rule_stats"] = dict(overall_stats["hyp_rule_stats"])
# Prepare final output
output = {
"config": {
"ref_folder": args.ref_folder,
"hyp_folder": args.hyp_folder,
"docs_folder": args.docs_folder,
"ref_attr_folder": args.ref_attr_folder,
"hyp_attr_folder": args.hyp_attr_folder,
"ref_rule": args.ref_rule,
"ref_rule_parsed": ref_rule_str,
"hyp_rule": args.hyp_rule,
"hyp_rule_parsed": hyp_rule_str
},
"overall": overall_stats,
"file_results": results
"overall": overall_stats
}
# Save results
@ -960,6 +1081,9 @@ def main():
# Print summary
logger.info("\n--- COMPARISON SUMMARY ---")
logger.info(f"Documents Folder: {args.docs_folder}")
logger.info(f"Reference Attributes Folder: {args.ref_attr_folder}")
logger.info(f"Hypothesis Attributes Folder: {args.hyp_attr_folder}")
logger.info(f"Reference Rule Expression: {args.ref_rule}")
logger.info(f" Parsed as: {ref_rule_str}")
logger.info(f"Hypothesis Rule Expression: {args.hyp_rule}")
@ -984,4 +1108,29 @@ def main():
logger.info(f"Detailed results saved to: {args.output_file}")
if __name__ == "__main__":
main()
main()
# Example commands with actual S3 paths:
"""
# Example for AI2 OE data with resume detection:
python scripts/pii_rule_comparison.py \
--docs-folder s3://ai2-oe-data/jakep/s2pdf_dedupe_minhash_v1_mini/documents/ \
--ref-attr-folder s3://ai2-oe-data/jakep/s2pdf_dedupe_minhash_v1_mini/attributes/chatgpt_pii_vision/ \
--hyp-attr-folder s3://ai2-oe-data/jakep/s2pdf_dedupe_minhash_v1_mini/attributes/model_pii_tagging/ \
--ref-rule "gpt_4_1_contains_pii:any and not gpt_4_1_is_public_document:all" \
--hyp-rule "google_gemma-3-4b-it_is_resume_cv:any" \
--output-file pii_resume_comparison.json \
--recursive \
--debug
# Example for Dolma data with PII detection:
python scripts/pii_rule_comparison.py \
--docs-folder s3://allenai-dolma/documents/v1.5 \
--ref-attr-folder s3://allenai-dolma/attributes/v1.5/pii_detection/gpt4_1 \
--hyp-attr-folder s3://allenai-dolma/attributes/v1.5/pii_detection/custom_rule \
--ref-rule "contains_pii:any" \
--hyp-rule "(contains_email_addresses:any or contains_phone_numbers:any) and not false_positive:any" \
--output-file pii_detection_comparison.json \
--recursive \
--aws-profile dolma
"""