mirror of
https://github.com/PaddlePaddle/PaddleOCR.git
synced 2025-07-29 11:50:15 +00:00

* refine mcp package * temp work * update image url handle logic * code review update * update comments * add the wrong deleted content * update readme * Fix bug * Add known limitations in documentation * Polish doc * local ocr stderr log * aistudio structure image logic refine * update code review for stdout * update code review for stdout * Polish installation doc * Fix typing and polish docs * Fix bugs * Fix bugs * Fix docs * Refine * Fix * Polish docs --------- Co-authored-by: Bobholamovic <mhlin425@whu.edu.cn>
894 lines
32 KiB
Python
894 lines
32 KiB
Python
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||
#
|
||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
# you may not use this file except in compliance with the License.
|
||
# You may obtain a copy of the License at
|
||
#
|
||
# http://www.apache.org/licenses/LICENSE-2.0
|
||
#
|
||
# Unless required by applicable law or agreed to in writing, software
|
||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
# See the License for the specific language governing permissions and
|
||
# limitations under the License.
|
||
|
||
# TODO:
|
||
# 1. Reuse `httpx` client.
|
||
# 2. Use `contextvars` to manage MCP context objects.
|
||
# 3. Implement structured logging, log stack traces, and log operation timing.
|
||
# 4. Report progress for long-running operations.
|
||
|
||
import abc
|
||
import asyncio
|
||
import base64
|
||
import contextlib
|
||
import io
|
||
import json
|
||
import re
|
||
from pathlib import PurePath
|
||
from queue import Queue
|
||
from threading import Thread
|
||
from typing import Any, Callable, Dict, List, NoReturn, Optional, Type, Union
|
||
from urllib.parse import urlparse
|
||
|
||
import httpx
|
||
import magic
|
||
import numpy as np
|
||
from fastmcp import Context, FastMCP
|
||
from mcp.types import ImageContent, TextContent
|
||
from PIL import Image as PILImage
|
||
from typing_extensions import Literal, Self, assert_never
|
||
|
||
try:
|
||
from paddleocr import PaddleOCR, PPStructureV3
|
||
|
||
LOCAL_OCR_AVAILABLE = True
|
||
except ImportError:
|
||
LOCAL_OCR_AVAILABLE = False
|
||
|
||
|
||
OutputMode = Literal["simple", "detailed"]
|
||
|
||
|
||
def _is_file_path(s: str) -> bool:
|
||
try:
|
||
PurePath(s)
|
||
return True
|
||
except Exception:
|
||
return False
|
||
|
||
|
||
def _is_base64(s: str) -> bool:
|
||
pattern = r"^[A-Za-z0-9+/]+={0,2}$"
|
||
return bool(re.fullmatch(pattern, s))
|
||
|
||
|
||
def _is_url(s: str) -> bool:
|
||
if not (s.startswith("http://") or s.startswith("https://")):
|
||
return False
|
||
result = urlparse(s)
|
||
return all([result.scheme, result.netloc]) and result.scheme in ("http", "https")
|
||
|
||
|
||
def _infer_file_type_from_bytes(data: bytes) -> Optional[str]:
|
||
mime = magic.from_buffer(data, mime=True)
|
||
if mime.startswith("image/"):
|
||
return "image"
|
||
elif mime == "application/pdf":
|
||
return "pdf"
|
||
return None
|
||
|
||
|
||
def get_str_with_max_len(obj: object, max_len: int) -> str:
|
||
s = str(obj)
|
||
if len(s) > max_len:
|
||
return s[:max_len] + "..."
|
||
else:
|
||
return s
|
||
|
||
|
||
class _EngineWrapper:
|
||
def __init__(self, engine: Any) -> None:
|
||
self._engine = engine
|
||
self._queue: Queue = Queue()
|
||
self._closed = False
|
||
self._loop = asyncio.get_running_loop()
|
||
self._thread = Thread(target=self._worker, daemon=False)
|
||
self._thread.start()
|
||
|
||
@property
|
||
def engine(self) -> Any:
|
||
return self._engine
|
||
|
||
async def call(self, func: Callable, *args: Any, **kwargs: Any) -> Any:
|
||
if self._closed:
|
||
raise RuntimeError("Engine wrapper has already been closed")
|
||
fut = self._loop.create_future()
|
||
self._queue.put((func, args, kwargs, fut))
|
||
return await fut
|
||
|
||
async def close(self) -> None:
|
||
if not self._closed:
|
||
self._queue.put(None)
|
||
await self._loop.run_in_executor(None, self._thread.join)
|
||
self._closed = True
|
||
|
||
def _worker(self) -> None:
|
||
while not self._closed:
|
||
item = self._queue.get()
|
||
if item is None:
|
||
break
|
||
func, args, kwargs, fut = item
|
||
try:
|
||
# FIXME: PaddleX currently writes to stdout erroneously when
|
||
# downloading files, which conflicts with the MCP server’s use
|
||
# of stdout. As a temporary workaround, we use
|
||
# `redirect_stdout`, but since this redirection is global, it
|
||
# should not be used inside a worker thread—it may
|
||
# unintentionally interfere with the MCP server’s normal stdout
|
||
# behavior. Although we haven’t observed any issues in testing
|
||
# so far, this workaround should be removed once the PaddleX bug
|
||
# is fixed.
|
||
with contextlib.redirect_stdout(io.StringIO()):
|
||
result = func(*args, **kwargs)
|
||
self._loop.call_soon_threadsafe(fut.set_result, result)
|
||
except Exception as e:
|
||
self._loop.call_soon_threadsafe(fut.set_exception, e)
|
||
finally:
|
||
self._queue.task_done()
|
||
|
||
|
||
class PipelineHandler(abc.ABC):
|
||
"""Abstract base class for pipeline handlers."""
|
||
|
||
def __init__(
|
||
self,
|
||
pipeline: str,
|
||
ppocr_source: str,
|
||
pipeline_config: Optional[str],
|
||
device: Optional[str],
|
||
server_url: Optional[str],
|
||
aistudio_access_token: Optional[str],
|
||
timeout: Optional[int],
|
||
) -> None:
|
||
"""Initialize the pipeline handler.
|
||
|
||
Args:
|
||
pipeline: Pipeline name.
|
||
ppocr_source: Source of PaddleOCR functionality.
|
||
pipeline_config: Path to pipeline configuration.
|
||
device: Device to run inference on.
|
||
server_url: Base URL for service mode.
|
||
aistudio_access_token: AI Studio access token.
|
||
timeout: Read timeout in seconds for HTTP requests.
|
||
"""
|
||
self._pipeline = pipeline
|
||
if ppocr_source == "local":
|
||
self._mode = "local"
|
||
elif ppocr_source in ("aistudio", "self_hosted"):
|
||
self._mode = "service"
|
||
else:
|
||
raise ValueError(f"Unknown PaddleOCR source {repr(ppocr_source)}")
|
||
self._ppocr_source = ppocr_source
|
||
self._pipeline_config = pipeline_config
|
||
self._device = device
|
||
self._server_url = server_url
|
||
self._aistudio_access_token = aistudio_access_token
|
||
self._timeout = timeout or 60
|
||
|
||
if self._mode == "local":
|
||
if not LOCAL_OCR_AVAILABLE:
|
||
raise RuntimeError("PaddleOCR is not locally available")
|
||
try:
|
||
self._engine = self._create_local_engine()
|
||
except Exception as e:
|
||
raise RuntimeError(
|
||
f"Failed to create PaddleOCR engine: {str(e)}"
|
||
) from e
|
||
|
||
self._status: Literal["initialized", "started", "stopped"] = "initialized"
|
||
|
||
async def start(self) -> None:
|
||
if self._status == "initialized":
|
||
if self._mode == "local":
|
||
self._engine_wrapper = _EngineWrapper(self._engine)
|
||
self._status = "started"
|
||
elif self._status == "started":
|
||
pass
|
||
elif self._status == "stopped":
|
||
raise RuntimeError("Pipeline handler has already been stopped")
|
||
else:
|
||
assert_never(self._status)
|
||
|
||
async def stop(self) -> None:
|
||
if self._status == "initialized":
|
||
raise RuntimeError("Pipeline handler has not been started")
|
||
elif self._status == "started":
|
||
if self._mode == "local":
|
||
await self._engine_wrapper.close()
|
||
self._status = "stopped"
|
||
elif self._status == "stopped":
|
||
pass
|
||
else:
|
||
assert_never(self._status)
|
||
|
||
async def __aenter__(self) -> Self:
|
||
await self.start()
|
||
return self
|
||
|
||
async def __aexit__(
|
||
self,
|
||
exc_type: Any,
|
||
exc_val: Any,
|
||
exc_tb: Any,
|
||
) -> None:
|
||
await self.stop()
|
||
|
||
@abc.abstractmethod
|
||
def register_tools(self, mcp: FastMCP) -> None:
|
||
"""Register tools with the MCP server.
|
||
|
||
Args:
|
||
mcp: The `FastMCP` instance.
|
||
"""
|
||
raise NotImplementedError
|
||
|
||
@abc.abstractmethod
|
||
def _create_local_engine(self) -> Any:
|
||
"""Create the local OCR engine.
|
||
|
||
Returns:
|
||
The OCR engine instance.
|
||
"""
|
||
raise NotImplementedError
|
||
|
||
@abc.abstractmethod
|
||
def _get_service_endpoint(self) -> str:
|
||
"""Get the service endpoint.
|
||
|
||
Returns:
|
||
Service endpoint path.
|
||
"""
|
||
raise NotImplementedError
|
||
|
||
@abc.abstractmethod
|
||
def _transform_local_kwargs(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""Transform keyword arguments for local execution.
|
||
|
||
Args:
|
||
kwargs: Keyword arguments.
|
||
|
||
Returns:
|
||
Transformed keyword arguments.
|
||
"""
|
||
raise NotImplementedError
|
||
|
||
@abc.abstractmethod
|
||
def _transform_service_kwargs(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""Transform keyword arguments for service execution.
|
||
|
||
Args:
|
||
kwargs: Keyword arguments.
|
||
|
||
Returns:
|
||
Transformed keyword arguments.
|
||
"""
|
||
raise NotImplementedError
|
||
|
||
@abc.abstractmethod
|
||
async def _parse_local_result(
|
||
self, local_result: Dict, ctx: Context
|
||
) -> Dict[str, Any]:
|
||
"""Parse raw result from local engine into a unified format.
|
||
|
||
Args:
|
||
local_result: Raw result from local engine.
|
||
ctx: MCP context.
|
||
|
||
Returns:
|
||
Parsed result in unified format.
|
||
"""
|
||
raise NotImplementedError
|
||
|
||
@abc.abstractmethod
|
||
async def _parse_service_result(
|
||
self, service_result: Dict[str, Any], ctx: Context
|
||
) -> Dict[str, Any]:
|
||
"""Parse raw result from the service into a unified format.
|
||
|
||
Args:
|
||
service_result: Raw result from the service.
|
||
ctx: MCP context.
|
||
|
||
Returns:
|
||
Parsed result in unified format.
|
||
"""
|
||
raise NotImplementedError
|
||
|
||
@abc.abstractmethod
|
||
async def _log_completion_stats(self, result: Dict[str, Any], ctx: Context) -> None:
|
||
"""Log statistics after processing completion.
|
||
|
||
Args:
|
||
result: Processing result.
|
||
ctx: MCP context.
|
||
"""
|
||
raise NotImplementedError
|
||
|
||
@abc.abstractmethod
|
||
async def _format_output(
|
||
self,
|
||
result: Dict[str, Any],
|
||
detailed: bool,
|
||
ctx: Context,
|
||
**kwargs: Any,
|
||
) -> Union[str, List[Union[TextContent, ImageContent]]]:
|
||
"""Format output into simple or detailed format.
|
||
|
||
Args:
|
||
result: Processing result.
|
||
detailed: Whether to use detailed format.
|
||
ctx: MCP context.
|
||
**kwargs: Additional arguments.
|
||
|
||
Returns:
|
||
Formatted output in requested format.
|
||
"""
|
||
raise NotImplementedError
|
||
|
||
async def _predict_with_local_engine(
|
||
self, processed_input: Union[str, np.ndarray], ctx: Context, **kwargs: Any
|
||
) -> Dict:
|
||
if not hasattr(self, "_engine_wrapper"):
|
||
raise RuntimeError("Engine wrapper has not been initialized")
|
||
return await self._engine_wrapper.call(
|
||
self._engine_wrapper.engine.predict, processed_input, **kwargs
|
||
)
|
||
|
||
|
||
class SimpleInferencePipelineHandler(PipelineHandler):
|
||
"""Base class for simple inference pipeline handlers."""
|
||
|
||
async def process(
|
||
self,
|
||
input_data: str,
|
||
output_mode: OutputMode,
|
||
ctx: Context,
|
||
file_type: Optional[str] = None,
|
||
infer_kwargs: Optional[Dict[str, Any]] = None,
|
||
format_kwargs: Optional[Dict[str, Any]] = None,
|
||
) -> Union[str, List[Union[TextContent, ImageContent]]]:
|
||
"""Process input data through the pipeline.
|
||
|
||
Args:
|
||
input_data: Input data (file path, URL, or Base64).
|
||
output_mode: Output mode ("simple" or "detailed").
|
||
ctx: MCP context.
|
||
file_type: File type for URLs ("image", "pdf", or None for auto-detection).
|
||
infer_kwargs: Additional arguments for performing pipeline inference.
|
||
format_kwargs: Additional arguments for formatting the output.
|
||
|
||
Returns:
|
||
Processed result in the requested output format.
|
||
"""
|
||
infer_kwargs = infer_kwargs or {}
|
||
format_kwargs = format_kwargs or {}
|
||
try:
|
||
await ctx.info(
|
||
f"Starting {self._pipeline} processing (source: {self._ppocr_source})"
|
||
)
|
||
|
||
if self._mode == "local":
|
||
processed_input = self._process_input_for_local(input_data, file_type)
|
||
infer_kwargs = self._transform_local_kwargs(infer_kwargs)
|
||
raw_result = await self._predict_with_local_engine(
|
||
processed_input, ctx, **infer_kwargs
|
||
)
|
||
result = await self._parse_local_result(raw_result, ctx)
|
||
else:
|
||
processed_input, inferred_file_type = self._process_input_for_service(
|
||
input_data, file_type
|
||
)
|
||
infer_kwargs = self._transform_service_kwargs(infer_kwargs)
|
||
raw_result = await self._call_service(
|
||
processed_input, inferred_file_type, ctx, **infer_kwargs
|
||
)
|
||
result = await self._parse_service_result(raw_result, ctx)
|
||
|
||
await self._log_completion_stats(result, ctx)
|
||
return await self._format_output(
|
||
result, output_mode == "detailed", ctx, **format_kwargs
|
||
)
|
||
|
||
except Exception as e:
|
||
await ctx.error(f"{self._pipeline} processing failed: {str(e)}")
|
||
self._handle_error(e, output_mode)
|
||
|
||
def _process_input_for_local(
|
||
self, input_data: str, file_type: Optional[str]
|
||
) -> Union[str, np.ndarray]:
|
||
# TODO: Use `file_type` to handle more cases.
|
||
if _is_base64(input_data):
|
||
if input_data.startswith("data:"):
|
||
base64_data = input_data.split(",", 1)[1]
|
||
else:
|
||
base64_data = input_data
|
||
try:
|
||
image_bytes = base64.b64decode(base64_data)
|
||
file_type = _infer_file_type_from_bytes(image_bytes)
|
||
if file_type != "image":
|
||
raise ValueError("Currently, only images can be passed via Base64.")
|
||
image_pil = PILImage.open(io.BytesIO(image_bytes))
|
||
image_arr = np.array(image_pil.convert("RGB"))
|
||
return np.ascontiguousarray(image_arr[..., ::-1])
|
||
except Exception as e:
|
||
raise ValueError(f"Failed to decode Base64 image: {e}")
|
||
elif _is_file_path(input_data) or _is_url(input_data):
|
||
return input_data
|
||
else:
|
||
raise ValueError("Invalid input data format")
|
||
|
||
def _process_input_for_service(
|
||
self, input_data: str, file_type: Optional[str]
|
||
) -> tuple[str, Optional[str]]:
|
||
if _is_url(input_data):
|
||
norm_ft = None
|
||
if isinstance(file_type, str):
|
||
if file_type.lower() in ("None", "none", "null", "unknown", ""):
|
||
norm_ft = None
|
||
else:
|
||
norm_ft = file_type.lower()
|
||
return input_data, norm_ft
|
||
elif _is_base64(input_data):
|
||
try:
|
||
if input_data.startswith("data:"):
|
||
base64_data = input_data.split(",", 1)[1]
|
||
else:
|
||
base64_data = input_data
|
||
bytes_ = base64.b64decode(base64_data)
|
||
file_type_str = _infer_file_type_from_bytes(bytes_)
|
||
if file_type_str is None:
|
||
raise ValueError(
|
||
"Unsupported file type in Base64 data. "
|
||
"Only image files (JPEG, PNG, etc.) and PDF documents are supported."
|
||
)
|
||
return input_data, file_type_str
|
||
except Exception as e:
|
||
raise ValueError(f"Failed to decode Base64 data: {e}")
|
||
elif _is_file_path(input_data):
|
||
try:
|
||
with open(input_data, "rb") as f:
|
||
bytes_ = f.read()
|
||
input_data = base64.b64encode(bytes_).decode("ascii")
|
||
file_type_str = _infer_file_type_from_bytes(bytes_)
|
||
if file_type_str is None:
|
||
raise ValueError(
|
||
f"Unsupported file type for '{input_data}'. "
|
||
"Only image files (JPEG, PNG, etc.) and PDF documents are supported."
|
||
)
|
||
return input_data, file_type_str
|
||
except Exception as e:
|
||
raise ValueError(f"Failed to read file: {e}")
|
||
else:
|
||
raise ValueError("Invalid input data format")
|
||
|
||
async def _call_service(
|
||
self,
|
||
processed_input: str,
|
||
file_type: Optional[str],
|
||
ctx: Context,
|
||
**kwargs: Any,
|
||
) -> Dict[str, Any]:
|
||
if not self._server_url:
|
||
raise RuntimeError("Server URL not configured")
|
||
|
||
endpoint = self._get_service_endpoint()
|
||
url = f"{self._server_url.rstrip('/')}/{endpoint.lstrip('/')}"
|
||
|
||
payload = self._prepare_service_payload(processed_input, file_type, **kwargs)
|
||
headers = {"Content-Type": "application/json"}
|
||
|
||
if self._ppocr_source == "aistudio":
|
||
if not self._aistudio_access_token:
|
||
raise RuntimeError("Missing AI Studio access token")
|
||
headers["Authorization"] = f"token {self._aistudio_access_token}"
|
||
|
||
try:
|
||
timeout = httpx.Timeout(
|
||
connect=30.0, read=self._timeout, write=30.0, pool=30.0
|
||
)
|
||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||
response = await client.post(url, json=payload, headers=headers)
|
||
response.raise_for_status()
|
||
return response.json()
|
||
except httpx.HTTPError as e:
|
||
raise RuntimeError(f"HTTP request failed: {type(e).__name__}: {str(e)}")
|
||
except json.JSONDecodeError as e:
|
||
raise RuntimeError(f"Invalid service response: {str(e)}")
|
||
|
||
def _prepare_service_payload(
|
||
self, processed_input: str, file_type: Optional[str], **kwargs: Any
|
||
) -> Dict[str, Any]:
|
||
payload: Dict[str, Any] = {"file": processed_input, **kwargs}
|
||
if file_type == "image":
|
||
payload["fileType"] = 1
|
||
elif file_type == "pdf":
|
||
payload["fileType"] = 0
|
||
else:
|
||
payload["fileType"] = None
|
||
|
||
return payload
|
||
|
||
def _handle_error(self, exc: Exception, output_mode: OutputMode) -> NoReturn:
|
||
raise exc
|
||
|
||
|
||
class OCRHandler(SimpleInferencePipelineHandler):
|
||
def register_tools(self, mcp: FastMCP) -> None:
|
||
@mcp.tool("ocr")
|
||
async def _ocr(
|
||
input_data: str,
|
||
output_mode: OutputMode = "simple",
|
||
file_type: Optional[str] = None,
|
||
*,
|
||
ctx: Context,
|
||
) -> Union[str, List[Union[TextContent, ImageContent]]]:
|
||
"""Extracts text from images and PDFs. Accepts file path, URL, or Base64.
|
||
|
||
Args:
|
||
input_data: The file to process (file path, URL, or Base64 string).
|
||
output_mode: The desired output format.
|
||
- "simple": (Default) Clean, readable text suitable for most use cases.
|
||
- "detailed": A JSON output including text, confidence, and precise bounding box coordinates. Only use this when coordinates are specifically required.
|
||
file_type: File type. This parameter is REQUIRED when `input_data` is a URL and should be omitted for other types.
|
||
- "image": For image files
|
||
- "pdf": For PDF documents
|
||
- None: For unknown file types
|
||
"""
|
||
await ctx.info(
|
||
f"--- OCR tool received `input_data`: {get_str_with_max_len(input_data, 50)} ---"
|
||
)
|
||
return await self.process(input_data, output_mode, ctx, file_type)
|
||
|
||
def _create_local_engine(self) -> Any:
|
||
return PaddleOCR(
|
||
paddlex_config=self._pipeline_config,
|
||
device=self._device,
|
||
)
|
||
|
||
def _get_service_endpoint(self) -> str:
|
||
return "ocr"
|
||
|
||
def _transform_local_kwargs(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||
return {
|
||
"use_doc_unwarping": False,
|
||
"use_doc_orientation_classify": False,
|
||
}
|
||
|
||
def _transform_service_kwargs(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||
return {
|
||
"useDocUnwarping": False,
|
||
"useDocOrientationClassify": False,
|
||
}
|
||
|
||
async def _parse_local_result(self, local_result: Dict, ctx: Context) -> Dict:
|
||
result = local_result[0]
|
||
texts = result["rec_texts"]
|
||
scores = result["rec_scores"]
|
||
boxes = result["rec_boxes"]
|
||
|
||
clean_texts, confidences, instances = [], [], []
|
||
|
||
for i, text in enumerate(texts):
|
||
if text and text.strip():
|
||
conf = scores[i] if i < len(scores) else 0
|
||
clean_texts.append(text.strip())
|
||
confidences.append(conf)
|
||
instance = {
|
||
"text": text.strip(),
|
||
"confidence": round(conf, 3),
|
||
"bbox": boxes[i].tolist(),
|
||
}
|
||
instances.append(instance)
|
||
|
||
return {
|
||
"text": "\n".join(clean_texts),
|
||
"confidence": sum(confidences) / len(confidences) if confidences else 0,
|
||
"instances": instances,
|
||
}
|
||
|
||
async def _parse_service_result(self, service_result: Dict, ctx: Context) -> Dict:
|
||
result_data = service_result.get("result", service_result)
|
||
ocr_results = result_data.get("ocrResults")
|
||
|
||
all_texts, all_confidences, instances = [], [], []
|
||
|
||
for ocr_result in ocr_results:
|
||
pruned = ocr_result["prunedResult"]
|
||
|
||
texts = pruned["rec_texts"]
|
||
scores = pruned["rec_scores"]
|
||
boxes = pruned["rec_boxes"]
|
||
|
||
for i, text in enumerate(texts):
|
||
if text and text.strip():
|
||
conf = scores[i] if i < len(scores) else 0
|
||
all_texts.append(text.strip())
|
||
all_confidences.append(conf)
|
||
instance = {
|
||
"text": text.strip(),
|
||
"confidence": round(conf, 3),
|
||
"bbox": boxes[i],
|
||
}
|
||
instances.append(instance)
|
||
|
||
return {
|
||
"text": "\n".join(all_texts),
|
||
"confidence": (
|
||
sum(all_confidences) / len(all_confidences) if all_confidences else 0
|
||
),
|
||
"instances": instances,
|
||
}
|
||
|
||
async def _log_completion_stats(self, result: Dict, ctx: Context) -> None:
|
||
text_length = len(result["text"])
|
||
instance_count = len(result["instances"])
|
||
await ctx.info(
|
||
f"OCR completed: {text_length} characters, {instance_count} text instances"
|
||
)
|
||
|
||
async def _format_output(
|
||
self,
|
||
result: Dict,
|
||
detailed: bool,
|
||
ctx: Context,
|
||
**kwargs: Any,
|
||
) -> Union[str, List[Union[TextContent, ImageContent]]]:
|
||
if not result["text"].strip():
|
||
return (
|
||
"❌ No text detected"
|
||
if not detailed
|
||
else json.dumps({"error": "No text detected"}, ensure_ascii=False)
|
||
)
|
||
|
||
if detailed:
|
||
return json.dumps(result, ensure_ascii=False, indent=2)
|
||
else:
|
||
confidence = result["confidence"]
|
||
instance_count = len(result["instances"])
|
||
|
||
output = result["text"]
|
||
if confidence > 0:
|
||
output += f"\n\n📊 Confidence: {(confidence * 100):.1f}% | {instance_count} text instances"
|
||
|
||
return output
|
||
|
||
|
||
class PPStructureV3Handler(SimpleInferencePipelineHandler):
|
||
def register_tools(self, mcp: FastMCP) -> None:
|
||
@mcp.tool("pp_structurev3")
|
||
async def _pp_structurev3(
|
||
input_data: str,
|
||
output_mode: OutputMode = "simple",
|
||
file_type: Optional[str] = None,
|
||
return_images: bool = True,
|
||
*,
|
||
ctx: Context,
|
||
) -> Union[str, List[Union[TextContent, ImageContent]]]:
|
||
"""Extracts structured markdown from complex documents (images/PDFs), including tables, formulas, etc. Accepts file path, URL, or Base64.
|
||
|
||
Args:
|
||
input_data: The file to process (file path, URL, or Base64 string).
|
||
output_mode: The desired output format.
|
||
- "simple": (Default) Clean, readable markdown with embedded images. Best for most use cases.
|
||
- "detailed": JSON data about document structure, plus markdown. Only use this when coordinates are specifically required.
|
||
file_type: File type. This parameter is REQUIRED when `input_data` is a URL and should be omitted for other types.
|
||
- "image": For image files
|
||
- "pdf": For PDF documents
|
||
- None: For unknown file types
|
||
return_images: Whether to return the images extracted from the document.
|
||
"""
|
||
return await self.process(
|
||
input_data,
|
||
output_mode,
|
||
ctx,
|
||
file_type,
|
||
format_kwargs={"return_images": return_images},
|
||
)
|
||
|
||
def _create_local_engine(self) -> Any:
|
||
return PPStructureV3(
|
||
paddlex_config=self._pipeline_config,
|
||
device=self._device,
|
||
)
|
||
|
||
def _get_service_endpoint(self) -> str:
|
||
return "layout-parsing"
|
||
|
||
def _transform_local_kwargs(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||
return {
|
||
"use_doc_unwarping": False,
|
||
"use_doc_orientation_classify": False,
|
||
}
|
||
|
||
def _transform_service_kwargs(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||
return {
|
||
"useDocUnwarping": False,
|
||
"useDocOrientationClassify": False,
|
||
}
|
||
|
||
async def _parse_local_result(self, local_result: Dict, ctx: Context) -> Dict:
|
||
markdown_parts = []
|
||
all_images_mapping = {}
|
||
detailed_results = []
|
||
|
||
for result in local_result:
|
||
markdown = result.markdown
|
||
text = markdown["markdown_texts"]
|
||
markdown_parts.append(text)
|
||
images = markdown["markdown_images"]
|
||
processed_images = {}
|
||
for img_key, img_data in images.items():
|
||
with io.BytesIO() as buffer:
|
||
img_data.save(buffer, format="JPEG")
|
||
processed_images[img_key] = base64.b64encode(buffer.getvalue())
|
||
all_images_mapping.update(processed_images)
|
||
detailed_results.append(result)
|
||
|
||
return {
|
||
# TODO: Page concatenation can be done better via `pipeline.concatenate_markdown_pages`
|
||
"markdown": "\n".join(markdown_parts),
|
||
"pages": len(local_result),
|
||
"images_mapping": all_images_mapping,
|
||
"detailed_results": detailed_results,
|
||
}
|
||
|
||
async def _parse_service_result(self, service_result: Dict, ctx: Context) -> Dict:
|
||
result_data = service_result.get("result", service_result)
|
||
layout_results = result_data.get("layoutParsingResults")
|
||
|
||
if not layout_results:
|
||
return {
|
||
"markdown": "",
|
||
"pages": 0,
|
||
"images_mapping": {},
|
||
"detailed_results": [],
|
||
}
|
||
|
||
markdown_parts = []
|
||
all_images_mapping = {}
|
||
detailed_results = []
|
||
|
||
for res in layout_results:
|
||
markdown_parts.append(res["markdown"]["text"])
|
||
images = res["markdown"]["images"]
|
||
processed_images = {}
|
||
for img_key, img_data in images.items():
|
||
processed_images[img_key] = await self._process_image_data(
|
||
img_data, ctx
|
||
)
|
||
all_images_mapping.update(processed_images)
|
||
detailed_results.append(res["prunedResult"])
|
||
|
||
return {
|
||
"markdown": "\n".join(markdown_parts),
|
||
"pages": len(layout_results),
|
||
"images_mapping": all_images_mapping,
|
||
"detailed_results": detailed_results,
|
||
}
|
||
|
||
async def _process_image_data(self, img_data: str, ctx: Context) -> str:
|
||
if _is_url(img_data):
|
||
try:
|
||
timeout = httpx.Timeout(connect=30.0, read=30.0, write=30.0, pool=30.0)
|
||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||
response = await client.get(img_data)
|
||
response.raise_for_status()
|
||
img_bytes = response.content
|
||
return base64.b64encode(img_bytes).decode("ascii")
|
||
except Exception as e:
|
||
await ctx.error(f"Failed to download image from URL {img_data}: {e}")
|
||
return img_data
|
||
elif _is_base64(img_data):
|
||
return img_data
|
||
else:
|
||
await ctx.error(
|
||
f"Unknown image data format: {get_str_with_max_len(img_data, 50)}"
|
||
)
|
||
return img_data
|
||
|
||
async def _log_completion_stats(self, result: Dict, ctx: Context) -> None:
|
||
page_count = result["pages"]
|
||
await ctx.info(f"Layout parsing completed: {page_count} pages")
|
||
|
||
async def _format_output(
|
||
self,
|
||
result: Dict,
|
||
detailed: bool,
|
||
ctx: Context,
|
||
**kwargs: Any,
|
||
) -> Union[str, List[Union[TextContent, ImageContent]]]:
|
||
if not result["markdown"].strip():
|
||
return (
|
||
"❌ No document content detected"
|
||
if not detailed
|
||
else json.dumps({"error": "No content detected"}, ensure_ascii=False)
|
||
)
|
||
|
||
markdown_text = result["markdown"]
|
||
images_mapping = result.get("images_mapping", {})
|
||
|
||
if kwargs.get("return_images"):
|
||
content_list = self._parse_markdown_with_images(
|
||
markdown_text, images_mapping
|
||
)
|
||
else:
|
||
content_list = [TextContent(type="text", text=markdown_text)]
|
||
|
||
if detailed:
|
||
if "detailed_results" in result and result["detailed_results"]:
|
||
for detailed_result in result["detailed_results"]:
|
||
content_list.append(
|
||
TextContent(
|
||
type="text",
|
||
text=json.dumps(
|
||
detailed_result,
|
||
ensure_ascii=False,
|
||
indent=2,
|
||
default=str,
|
||
),
|
||
)
|
||
)
|
||
|
||
return content_list
|
||
|
||
def _parse_markdown_with_images(
|
||
self, markdown_text: str, images_mapping: Dict[str, str]
|
||
) -> List[Union[TextContent, ImageContent]]:
|
||
"""Parse markdown text and return mixed list of text and images."""
|
||
if not images_mapping:
|
||
return [TextContent(type="text", text=markdown_text)]
|
||
|
||
content_list = []
|
||
img_pattern = r'<img[^>]+src="([^"]+)"[^>]*>'
|
||
last_pos = 0
|
||
|
||
for match in re.finditer(img_pattern, markdown_text):
|
||
text_before = markdown_text[last_pos : match.start()]
|
||
if text_before.strip():
|
||
content_list.append(TextContent(type="text", text=text_before))
|
||
|
||
img_src = match.group(1)
|
||
if img_src in images_mapping:
|
||
content_list.append(
|
||
ImageContent(
|
||
type="image",
|
||
data=images_mapping[img_src],
|
||
mimeType="image/jpeg",
|
||
)
|
||
)
|
||
|
||
last_pos = match.end()
|
||
|
||
remaining_text = markdown_text[last_pos:]
|
||
if remaining_text.strip():
|
||
content_list.append(TextContent(type="text", text=remaining_text))
|
||
|
||
return content_list or [TextContent(type="text", text=markdown_text)]
|
||
|
||
|
||
_PIPELINE_HANDLERS: Dict[str, Type[PipelineHandler]] = {
|
||
"OCR": OCRHandler,
|
||
"PP-StructureV3": PPStructureV3Handler,
|
||
}
|
||
|
||
|
||
def create_pipeline_handler(
|
||
pipeline: str, /, *args: Any, **kwargs: Any
|
||
) -> PipelineHandler:
|
||
if pipeline in _PIPELINE_HANDLERS:
|
||
cls = _PIPELINE_HANDLERS[pipeline]
|
||
return cls(pipeline, *args, **kwargs)
|
||
else:
|
||
raise ValueError(f"Unknown pipeline {repr(pipeline)}")
|