mirror of
https://github.com/allenai/olmocr.git
synced 2025-10-30 09:24:25 +00:00
Decent resume/cv tagging
This commit is contained in:
parent
1f66b96ffd
commit
66d293c178
@ -71,7 +71,10 @@ metrics = MetricsKeeper(window=60 * 5)
|
|||||||
|
|
||||||
|
|
||||||
class PIIClassification(BaseModel):
|
class PIIClassification(BaseModel):
|
||||||
is_resume_or_cv: Optional[bool] = Field(..., description="True if the document is a page from a resume or cv.")
|
primary_language: str = Field(..., description="Primary language as a two-letter code")
|
||||||
|
document_type: str = Field(..., description="Basic summary of document type classification")
|
||||||
|
is_resume_cv: Optional[bool] = Field(..., description="True if the document is a page from a resume or cv")
|
||||||
|
contains_pii: Optional[bool] = Field(..., description="True if document contains PII")
|
||||||
|
|
||||||
|
|
||||||
async def _process_single_page(page_text: str) -> PIIClassification:
|
async def _process_single_page(page_text: str) -> PIIClassification:
|
||||||
@ -90,7 +93,7 @@ async def _process_single_page(page_text: str) -> PIIClassification:
|
|||||||
"type": "text",
|
"type": "text",
|
||||||
"text": (
|
"text": (
|
||||||
f"{text}\n\n-----------\n"
|
f"{text}\n\n-----------\n"
|
||||||
"Given the text above, determine if the text above is a resume (résumé) or CV. Answer in a simple JSON block."
|
"Given the text above, determine what type of document it is, and if it's a resume/CV. answer in JSON. The format of your json object should be {'primary_language': str, 'document_type': str, 'is_resume_cv': bool, 'contains_pii': bool}"
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@ -109,12 +112,12 @@ async def _process_single_page(page_text: str) -> PIIClassification:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"SGLang network error: {e!s}")
|
logger.warning(f"SGLang network error: {e!s}")
|
||||||
metrics.add_metrics(sglang_errors=1)
|
metrics.add_metrics(sglang_errors=1)
|
||||||
return PIIClassification(is_resume_or_cv=None)
|
return PIIClassification(primary_language="en", document_type="unknown", is_resume_cv=None, contains_pii=None)
|
||||||
|
|
||||||
if status != 200:
|
if status != 200:
|
||||||
logger.warning(f"SGLang HTTP {status}: {body[:250]!r}")
|
logger.warning(f"SGLang HTTP {status}: {body[:250]!r}")
|
||||||
metrics.add_metrics(sglang_errors=1)
|
metrics.add_metrics(sglang_errors=1)
|
||||||
return PIIClassification(is_resume_or_cv=None)
|
return PIIClassification(primary_language="en", document_type="unknown", is_resume_cv=None, contains_pii=None)
|
||||||
|
|
||||||
# ---------- Parse base JSON --------------------------------------------
|
# ---------- Parse base JSON --------------------------------------------
|
||||||
try:
|
try:
|
||||||
@ -122,7 +125,7 @@ async def _process_single_page(page_text: str) -> PIIClassification:
|
|||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
logger.warning(f"SGLang response is not valid JSON: {body[:250]!r}")
|
logger.warning(f"SGLang response is not valid JSON: {body[:250]!r}")
|
||||||
metrics.add_metrics(sglang_errors=1)
|
metrics.add_metrics(sglang_errors=1)
|
||||||
return PIIClassification(is_resume_or_cv=None)
|
return PIIClassification(primary_language="en", document_type="unknown", is_resume_cv=None, contains_pii=None)
|
||||||
|
|
||||||
# Token accounting if available
|
# Token accounting if available
|
||||||
if "usage" in base:
|
if "usage" in base:
|
||||||
@ -137,12 +140,12 @@ async def _process_single_page(page_text: str) -> PIIClassification:
|
|||||||
except (KeyError, IndexError, AttributeError) as e:
|
except (KeyError, IndexError, AttributeError) as e:
|
||||||
logger.warning(f"Missing fields in SGLang response: {e!s}")
|
logger.warning(f"Missing fields in SGLang response: {e!s}")
|
||||||
metrics.add_metrics(sglang_errors=1)
|
metrics.add_metrics(sglang_errors=1)
|
||||||
return PIIClassification(is_resume_or_cv=None)
|
return PIIClassification(primary_language="en", document_type="unknown", is_resume_cv=None, contains_pii=None)
|
||||||
|
|
||||||
if not isinstance(content, str):
|
if not isinstance(content, str):
|
||||||
logger.warning("SGLang `content` is not a string; treating as error.")
|
logger.warning("SGLang `content` is not a string; treating as error.")
|
||||||
metrics.add_metrics(sglang_errors=1)
|
metrics.add_metrics(sglang_errors=1)
|
||||||
return PIIClassification(is_resume_or_cv=None)
|
return PIIClassification(primary_language="en", document_type="unknown", is_resume_cv=None, contains_pii=None)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
pii_classification: PIIClassification = PIIClassification.model_validate_json(content)
|
pii_classification: PIIClassification = PIIClassification.model_validate_json(content)
|
||||||
@ -150,7 +153,7 @@ async def _process_single_page(page_text: str) -> PIIClassification:
|
|||||||
except ValidationError as e:
|
except ValidationError as e:
|
||||||
logger.warning(f"Unable to parse pii classification object: {e!s}")
|
logger.warning(f"Unable to parse pii classification object: {e!s}")
|
||||||
metrics.add_metrics(sglang_errors=1)
|
metrics.add_metrics(sglang_errors=1)
|
||||||
return PIIClassification(is_resume_or_cv=None)
|
return PIIClassification(primary_language="en", document_type="unknown", is_resume_cv=None, contains_pii=None)
|
||||||
|
|
||||||
|
|
||||||
# Manual simple implementation of HTTP Post
|
# Manual simple implementation of HTTP Post
|
||||||
@ -258,7 +261,7 @@ async def process_dolma_document(args, dolma_doc, sem):
|
|||||||
async with sem:
|
async with sem:
|
||||||
pii_class = await _process_single_page(page_text)
|
pii_class = await _process_single_page(page_text)
|
||||||
|
|
||||||
result_attributes[key_name].append([start_pos, end_pos, pii_class.is_resume_or_cv])
|
result_attributes[key_name].append([start_pos, end_pos, pii_class.is_resume_cv])
|
||||||
else:
|
else:
|
||||||
result_attributes[key_name].append([start_pos, end_pos, None])
|
result_attributes[key_name].append([start_pos, end_pos, None])
|
||||||
|
|
||||||
@ -400,6 +403,7 @@ async def sglang_server_task(model_name_or_path, args, semaphore):
|
|||||||
str(SGLANG_SERVER_PORT),
|
str(SGLANG_SERVER_PORT),
|
||||||
"--log-level-http",
|
"--log-level-http",
|
||||||
"warning",
|
"warning",
|
||||||
|
"--mem-fraction-static", "0.40"
|
||||||
]
|
]
|
||||||
|
|
||||||
proc = await asyncio.create_subprocess_exec(
|
proc = await asyncio.create_subprocess_exec(
|
||||||
@ -493,6 +497,8 @@ async def sglang_server_host(model_name_or_path, args, semaphore):
|
|||||||
MAX_RETRIES = 5
|
MAX_RETRIES = 5
|
||||||
retry = 0
|
retry = 0
|
||||||
|
|
||||||
|
await asyncio.sleep(1000000)
|
||||||
|
|
||||||
while retry < MAX_RETRIES:
|
while retry < MAX_RETRIES:
|
||||||
await sglang_server_task(model_name_or_path, args, semaphore)
|
await sglang_server_task(model_name_or_path, args, semaphore)
|
||||||
logger.warning("SGLang server task ended")
|
logger.warning("SGLang server task ended")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user