crawl4ai/main.py

527 lines
18 KiB
Python
Raw Normal View History

import asyncio, os
2025-01-13 19:19:58 +08:00
from fastapi import FastAPI, HTTPException
from fastapi import FastAPI, HTTPException
from fastapi.staticfiles import StaticFiles
2025-01-13 19:19:58 +08:00
from fastapi.middleware.cors import CORSMiddleware
from fastapi.templating import Jinja2Templates
from fastapi.responses import RedirectResponse
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi import Depends, Security
2024-11-04 20:33:15 +08:00
from pydantic import BaseModel, HttpUrl, Field
from typing import Optional, List, Dict, Any, Union
import psutil
import time
import uuid
import math
import logging
2024-11-04 20:33:15 +08:00
from enum import Enum
from dataclasses import dataclass
from crawl4ai import AsyncWebCrawler, CrawlResult, CacheMode
from crawl4ai.config import MIN_WORD_THRESHOLD
2024-11-04 20:33:15 +08:00
from crawl4ai.extraction_strategy import (
LLMExtractionStrategy,
CosineStrategy,
JsonCssExtractionStrategy,
)
__location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
2024-11-04 20:33:15 +08:00
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
2024-05-09 19:10:25 +08:00
2025-01-13 19:19:58 +08:00
2024-11-04 20:33:15 +08:00
class TaskStatus(str, Enum):
PENDING = "pending"
PROCESSING = "processing"
COMPLETED = "completed"
FAILED = "failed"
2025-01-13 19:19:58 +08:00
2024-11-04 20:33:15 +08:00
class CrawlerType(str, Enum):
BASIC = "basic"
LLM = "llm"
COSINE = "cosine"
JSON_CSS = "json_css"
2024-06-22 20:54:32 +08:00
2025-01-13 19:19:58 +08:00
2024-11-04 20:33:15 +08:00
class ExtractionConfig(BaseModel):
type: CrawlerType
params: Dict[str, Any] = {}
2024-05-09 19:10:25 +08:00
2025-01-13 19:19:58 +08:00
class ChunkingStrategy(BaseModel):
type: str
params: Dict[str, Any] = {}
2025-01-13 19:19:58 +08:00
class ContentFilter(BaseModel):
type: str = "bm25"
params: Dict[str, Any] = {}
2025-01-13 19:19:58 +08:00
class CrawlRequest(BaseModel):
2024-11-04 20:33:15 +08:00
urls: Union[HttpUrl, List[HttpUrl]]
word_count_threshold: int = MIN_WORD_THRESHOLD
2024-11-04 20:33:15 +08:00
extraction_config: Optional[ExtractionConfig] = None
chunking_strategy: Optional[ChunkingStrategy] = None
content_filter: Optional[ContentFilter] = None
2024-11-04 20:33:15 +08:00
js_code: Optional[List[str]] = None
wait_for: Optional[str] = None
css_selector: Optional[str] = None
2024-11-04 20:33:15 +08:00
screenshot: bool = False
magic: bool = False
extra: Optional[Dict[str, Any]] = {}
session_id: Optional[str] = None
cache_mode: Optional[CacheMode] = CacheMode.ENABLED
priority: int = Field(default=5, ge=1, le=10)
2025-01-13 19:19:58 +08:00
ttl: Optional[int] = 3600
crawler_params: Dict[str, Any] = {}
2024-11-04 20:33:15 +08:00
2025-01-13 19:19:58 +08:00
2024-11-04 20:33:15 +08:00
@dataclass
class TaskInfo:
id: str
status: TaskStatus
result: Optional[Union[CrawlResult, List[CrawlResult]]] = None
error: Optional[str] = None
created_at: float = time.time()
ttl: int = 3600
2025-01-13 19:19:58 +08:00
2024-11-04 20:33:15 +08:00
class ResourceMonitor:
def __init__(self, max_concurrent_tasks: int = 10):
self.max_concurrent_tasks = max_concurrent_tasks
self.memory_threshold = 0.85
self.cpu_threshold = 0.90
self._last_check = 0
self._check_interval = 1 # seconds
self._last_available_slots = max_concurrent_tasks
async def get_available_slots(self) -> int:
current_time = time.time()
if current_time - self._last_check < self._check_interval:
return self._last_available_slots
mem_usage = psutil.virtual_memory().percent / 100
cpu_usage = psutil.cpu_percent() / 100
2025-01-13 19:19:58 +08:00
memory_factor = max(
0, (self.memory_threshold - mem_usage) / self.memory_threshold
)
2024-11-04 20:33:15 +08:00
cpu_factor = max(0, (self.cpu_threshold - cpu_usage) / self.cpu_threshold)
self._last_available_slots = math.floor(
self.max_concurrent_tasks * min(memory_factor, cpu_factor)
)
self._last_check = current_time
return self._last_available_slots
2025-01-13 19:19:58 +08:00
2024-11-04 20:33:15 +08:00
class TaskManager:
def __init__(self, cleanup_interval: int = 300):
self.tasks: Dict[str, TaskInfo] = {}
self.high_priority = asyncio.PriorityQueue()
self.low_priority = asyncio.PriorityQueue()
self.cleanup_interval = cleanup_interval
self.cleanup_task = None
async def start(self):
self.cleanup_task = asyncio.create_task(self._cleanup_loop())
async def stop(self):
if self.cleanup_task:
self.cleanup_task.cancel()
try:
await self.cleanup_task
except asyncio.CancelledError:
pass
async def add_task(self, task_id: str, priority: int, ttl: int) -> None:
task_info = TaskInfo(id=task_id, status=TaskStatus.PENDING, ttl=ttl)
self.tasks[task_id] = task_info
queue = self.high_priority if priority > 5 else self.low_priority
await queue.put((-priority, task_id)) # Negative for proper priority ordering
async def get_next_task(self) -> Optional[str]:
try:
# Try high priority first
_, task_id = await asyncio.wait_for(self.high_priority.get(), timeout=0.1)
return task_id
except asyncio.TimeoutError:
try:
# Then try low priority
2025-01-13 19:19:58 +08:00
_, task_id = await asyncio.wait_for(
self.low_priority.get(), timeout=0.1
)
2024-11-04 20:33:15 +08:00
return task_id
except asyncio.TimeoutError:
return None
2025-01-13 19:19:58 +08:00
def update_task(
self, task_id: str, status: TaskStatus, result: Any = None, error: str = None
):
2024-11-04 20:33:15 +08:00
if task_id in self.tasks:
task_info = self.tasks[task_id]
task_info.status = status
task_info.result = result
task_info.error = error
def get_task(self, task_id: str) -> Optional[TaskInfo]:
return self.tasks.get(task_id)
async def _cleanup_loop(self):
while True:
try:
await asyncio.sleep(self.cleanup_interval)
current_time = time.time()
expired_tasks = [
task_id
for task_id, task in self.tasks.items()
if current_time - task.created_at > task.ttl
and task.status in [TaskStatus.COMPLETED, TaskStatus.FAILED]
]
for task_id in expired_tasks:
del self.tasks[task_id]
except Exception as e:
logger.error(f"Error in cleanup loop: {e}")
2025-01-13 19:19:58 +08:00
2024-11-04 20:33:15 +08:00
class CrawlerPool:
def __init__(self, max_size: int = 10):
self.max_size = max_size
self.active_crawlers: Dict[AsyncWebCrawler, float] = {}
self._lock = asyncio.Lock()
async def acquire(self, **kwargs) -> AsyncWebCrawler:
async with self._lock:
# Clean up inactive crawlers
current_time = time.time()
inactive = [
crawler
for crawler, last_used in self.active_crawlers.items()
if current_time - last_used > 600 # 10 minutes timeout
2024-05-09 19:10:25 +08:00
]
2024-11-04 20:33:15 +08:00
for crawler in inactive:
await crawler.__aexit__(None, None, None)
del self.active_crawlers[crawler]
# Create new crawler if needed
if len(self.active_crawlers) < self.max_size:
crawler = AsyncWebCrawler(**kwargs)
await crawler.__aenter__()
self.active_crawlers[crawler] = current_time
return crawler
# Reuse least recently used crawler
crawler = min(self.active_crawlers.items(), key=lambda x: x[1])[0]
self.active_crawlers[crawler] = current_time
return crawler
async def release(self, crawler: AsyncWebCrawler):
async with self._lock:
if crawler in self.active_crawlers:
self.active_crawlers[crawler] = time.time()
async def cleanup(self):
async with self._lock:
for crawler in list(self.active_crawlers.keys()):
await crawler.__aexit__(None, None, None)
self.active_crawlers.clear()
2025-01-13 19:19:58 +08:00
2024-11-04 20:33:15 +08:00
class CrawlerService:
def __init__(self, max_concurrent_tasks: int = 10):
self.resource_monitor = ResourceMonitor(max_concurrent_tasks)
self.task_manager = TaskManager()
self.crawler_pool = CrawlerPool(max_concurrent_tasks)
self._processing_task = None
async def start(self):
await self.task_manager.start()
self._processing_task = asyncio.create_task(self._process_queue())
async def stop(self):
if self._processing_task:
self._processing_task.cancel()
try:
await self._processing_task
except asyncio.CancelledError:
pass
await self.task_manager.stop()
await self.crawler_pool.cleanup()
def _create_extraction_strategy(self, config: ExtractionConfig):
if not config:
return None
if config.type == CrawlerType.LLM:
return LLMExtractionStrategy(**config.params)
elif config.type == CrawlerType.COSINE:
return CosineStrategy(**config.params)
elif config.type == CrawlerType.JSON_CSS:
return JsonCssExtractionStrategy(**config.params)
return None
2024-11-04 20:33:15 +08:00
async def submit_task(self, request: CrawlRequest) -> str:
task_id = str(uuid.uuid4())
await self.task_manager.add_task(task_id, request.priority, request.ttl or 3600)
2025-01-13 19:19:58 +08:00
2024-11-04 20:33:15 +08:00
# Store request data with task
self.task_manager.tasks[task_id].request = request
2025-01-13 19:19:58 +08:00
2024-11-04 20:33:15 +08:00
return task_id
async def _process_queue(self):
while True:
try:
available_slots = await self.resource_monitor.get_available_slots()
if False and available_slots <= 0:
2024-11-04 20:33:15 +08:00
await asyncio.sleep(1)
continue
task_id = await self.task_manager.get_next_task()
if not task_id:
await asyncio.sleep(1)
continue
task_info = self.task_manager.get_task(task_id)
if not task_info:
continue
request = task_info.request
self.task_manager.update_task(task_id, TaskStatus.PROCESSING)
try:
crawler = await self.crawler_pool.acquire(**request.crawler_params)
2025-01-13 19:19:58 +08:00
extraction_strategy = self._create_extraction_strategy(
request.extraction_config
)
2024-11-04 20:33:15 +08:00
if isinstance(request.urls, list):
results = await crawler.arun_many(
urls=[str(url) for url in request.urls],
word_count_threshold=MIN_WORD_THRESHOLD,
2024-11-04 20:33:15 +08:00
extraction_strategy=extraction_strategy,
js_code=request.js_code,
wait_for=request.wait_for,
css_selector=request.css_selector,
screenshot=request.screenshot,
magic=request.magic,
session_id=request.session_id,
cache_mode=request.cache_mode,
**request.extra,
2024-11-04 20:33:15 +08:00
)
else:
results = await crawler.arun(
url=str(request.urls),
extraction_strategy=extraction_strategy,
js_code=request.js_code,
wait_for=request.wait_for,
css_selector=request.css_selector,
screenshot=request.screenshot,
magic=request.magic,
session_id=request.session_id,
cache_mode=request.cache_mode,
**request.extra,
2024-11-04 20:33:15 +08:00
)
await self.crawler_pool.release(crawler)
2025-01-13 19:19:58 +08:00
self.task_manager.update_task(
task_id, TaskStatus.COMPLETED, results
)
2024-11-04 20:33:15 +08:00
except Exception as e:
logger.error(f"Error processing task {task_id}: {str(e)}")
2025-01-13 19:19:58 +08:00
self.task_manager.update_task(
task_id, TaskStatus.FAILED, error=str(e)
)
2024-11-04 20:33:15 +08:00
except Exception as e:
logger.error(f"Error in queue processing: {str(e)}")
await asyncio.sleep(1)
2025-01-13 19:19:58 +08:00
2024-11-04 20:33:15 +08:00
app = FastAPI(title="Crawl4AI API")
# CORS configuration
origins = ["*"] # Allow all origins
app.add_middleware(
CORSMiddleware,
allow_origins=origins, # List of origins that are allowed to make requests
allow_credentials=True,
allow_methods=["*"], # Allows all methods
allow_headers=["*"], # Allows all headers
)
# API token security
security = HTTPBearer()
CRAWL4AI_API_TOKEN = os.getenv("CRAWL4AI_API_TOKEN")
2025-01-13 19:19:58 +08:00
async def verify_token(credentials: HTTPAuthorizationCredentials = Security(security)):
if not CRAWL4AI_API_TOKEN:
return credentials # No token verification if CRAWL4AI_API_TOKEN is not set
if credentials.credentials != CRAWL4AI_API_TOKEN:
raise HTTPException(status_code=401, detail="Invalid token")
return credentials
2025-01-13 19:19:58 +08:00
def secure_endpoint():
"""Returns security dependency only if CRAWL4AI_API_TOKEN is set"""
return Depends(verify_token) if CRAWL4AI_API_TOKEN else None
2025-01-13 19:19:58 +08:00
# Check if site directory exists
if os.path.exists(__location__ + "/site"):
# Mount the site directory as a static directory
app.mount("/mkdocs", StaticFiles(directory="site", html=True), name="mkdocs")
site_templates = Jinja2Templates(directory=__location__ + "/site")
2024-11-04 20:33:15 +08:00
crawler_service = CrawlerService()
2025-01-13 19:19:58 +08:00
2024-11-04 20:33:15 +08:00
@app.on_event("startup")
async def startup_event():
await crawler_service.start()
2025-01-13 19:19:58 +08:00
2024-11-04 20:33:15 +08:00
@app.on_event("shutdown")
async def shutdown_event():
await crawler_service.stop()
2024-06-22 20:54:32 +08:00
2025-01-13 19:19:58 +08:00
@app.get("/")
def read_root():
if os.path.exists(__location__ + "/site"):
return RedirectResponse(url="/mkdocs")
# Return a json response
return {"message": "Crawl4AI API service is running"}
2025-01-13 19:19:58 +08:00
@app.post("/crawl", dependencies=[secure_endpoint()] if CRAWL4AI_API_TOKEN else [])
2024-11-04 20:33:15 +08:00
async def crawl(request: CrawlRequest) -> Dict[str, str]:
task_id = await crawler_service.submit_task(request)
return {"task_id": task_id}
2025-01-13 19:19:58 +08:00
@app.get(
"/task/{task_id}", dependencies=[secure_endpoint()] if CRAWL4AI_API_TOKEN else []
)
2024-11-04 20:33:15 +08:00
async def get_task_status(task_id: str):
task_info = crawler_service.task_manager.get_task(task_id)
if not task_info:
raise HTTPException(status_code=404, detail="Task not found")
response = {
"status": task_info.status,
"created_at": task_info.created_at,
}
if task_info.status == TaskStatus.COMPLETED:
# Convert CrawlResult to dict for JSON response
if isinstance(task_info.result, list):
response["results"] = [result.dict() for result in task_info.result]
else:
response["result"] = task_info.result.dict()
elif task_info.status == TaskStatus.FAILED:
response["error"] = task_info.error
return response
2025-01-13 19:19:58 +08:00
@app.post("/crawl_sync", dependencies=[secure_endpoint()] if CRAWL4AI_API_TOKEN else [])
async def crawl_sync(request: CrawlRequest) -> Dict[str, Any]:
task_id = await crawler_service.submit_task(request)
2025-01-13 19:19:58 +08:00
# Wait up to 60 seconds for task completion
for _ in range(60):
task_info = crawler_service.task_manager.get_task(task_id)
if not task_info:
raise HTTPException(status_code=404, detail="Task not found")
2025-01-13 19:19:58 +08:00
if task_info.status == TaskStatus.COMPLETED:
# Return same format as /task/{task_id} endpoint
if isinstance(task_info.result, list):
2025-01-13 19:19:58 +08:00
return {
"status": task_info.status,
"results": [result.dict() for result in task_info.result],
}
return {"status": task_info.status, "result": task_info.result.dict()}
2025-01-13 19:19:58 +08:00
if task_info.status == TaskStatus.FAILED:
raise HTTPException(status_code=500, detail=task_info.error)
2025-01-13 19:19:58 +08:00
await asyncio.sleep(1)
2025-01-13 19:19:58 +08:00
# If we get here, task didn't complete within timeout
raise HTTPException(status_code=408, detail="Task timed out")
2025-01-13 19:19:58 +08:00
@app.post(
"/crawl_direct", dependencies=[secure_endpoint()] if CRAWL4AI_API_TOKEN else []
)
async def crawl_direct(request: CrawlRequest) -> Dict[str, Any]:
try:
crawler = await crawler_service.crawler_pool.acquire(**request.crawler_params)
2025-01-13 19:19:58 +08:00
extraction_strategy = crawler_service._create_extraction_strategy(
request.extraction_config
)
try:
if isinstance(request.urls, list):
results = await crawler.arun_many(
urls=[str(url) for url in request.urls],
extraction_strategy=extraction_strategy,
js_code=request.js_code,
wait_for=request.wait_for,
css_selector=request.css_selector,
screenshot=request.screenshot,
magic=request.magic,
cache_mode=request.cache_mode,
session_id=request.session_id,
**request.extra,
)
return {"results": [result.dict() for result in results]}
else:
result = await crawler.arun(
url=str(request.urls),
extraction_strategy=extraction_strategy,
js_code=request.js_code,
wait_for=request.wait_for,
css_selector=request.css_selector,
screenshot=request.screenshot,
magic=request.magic,
cache_mode=request.cache_mode,
session_id=request.session_id,
**request.extra,
)
return {"result": result.dict()}
finally:
await crawler_service.crawler_pool.release(crawler)
except Exception as e:
logger.error(f"Error in direct crawl: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
2025-01-13 19:19:58 +08:00
2024-11-04 20:33:15 +08:00
@app.get("/health")
async def health_check():
available_slots = await crawler_service.resource_monitor.get_available_slots()
memory = psutil.virtual_memory()
return {
"status": "healthy",
"available_slots": available_slots,
"memory_usage": memory.percent,
"cpu_usage": psutil.cpu_percent(),
}
2024-06-22 20:54:32 +08:00
2025-01-13 19:19:58 +08:00
2024-05-09 19:10:25 +08:00
if __name__ == "__main__":
import uvicorn
2025-01-13 19:19:58 +08:00
uvicorn.run(app, host="0.0.0.0", port=11235)