diff --git a/scripts/tagging_pipeline_v2.py b/scripts/tagging_pipeline_v2.py index 69097ea..e14344a 100644 --- a/scripts/tagging_pipeline_v2.py +++ b/scripts/tagging_pipeline_v2.py @@ -85,6 +85,7 @@ class PIIClassification(BaseModel): is_correspondence_or_letter: Optional[bool] is_public_order: Optional[bool] is_court_notice: Optional[bool] + is_completion_certificate: Optional[bool] contains_pii: Optional[bool] = Field(..., description="True if document contains PII") @@ -103,7 +104,7 @@ async def _process_single_page(page_text: str) -> PIIClassification: "type": "text", "text": ( f"{text}\n\n-----------\n" - "Given the text above, determine what type of document it is. Answer in JSON. The format of your json object should be {'primary_language': str, 'document_type': str, 'is_resume_cv': bool, 'is_academic_paper': bool, 'is_textbook': bool, 'is_news_article': bool, 'is_test_or_quiz': bool, 'is_homework_assignment': bool, 'is_class_syllabus': bool, 'is_meeting_minutes': bool, 'is_legal_contract': bool, 'is_form': bool, 'is_correspondence_or_letter': bool, 'is_public_order': bool, 'is_court_notice': bool, 'contains_pii': bool}" + "Given the text above, determine what type of document it is. Answer in JSON. The format of your json object should be {'primary_language': str, 'document_type': str, 'is_resume_cv': bool, 'is_academic_paper': bool, 'is_textbook': bool, 'is_news_article': bool, 'is_test_or_quiz': bool, 'is_homework_assignment': bool, 'is_class_syllabus': bool, 'is_meeting_minutes': bool, 'is_legal_contract': bool, 'is_form': bool, 'is_correspondence_or_letter': bool, 'is_public_order': bool, 'is_court_notice': bool, 'is_completion_certificate': bool, 'contains_pii': bool}" ), } ], @@ -265,7 +266,9 @@ async def process_dolma_document(args, dolma_doc, sem): # Take first 5000 characters of the document sample_text = text[:5000] - + text_length = len(text) + span_end = min(5000, text_length) + # Process the sample with the semaphore to limit concurrent requests async with sem: pii_class = await _process_single_page(sample_text) @@ -275,9 +278,12 @@ async def process_dolma_document(args, dolma_doc, sem): key_name = f"{prefix}_{field_name}" attribute_value = getattr(pii_class, field_name) - # Add the classification result to all pages - for start_pos, end_pos, page_num in page_numbers: - result_attributes[key_name].append([start_pos, end_pos, attribute_value]) + # Create a span from 0 to min(5000, len(text)) with the attribute value + result_attributes[key_name].append([0, span_end, attribute_value]) + + # If the document is longer than 5000 characters, add a null span for the rest + if text_length > 5000: + result_attributes[key_name].append([span_end, text_length, None]) return result_attributes else: