feat(docker): implement supervisor and secure API endpoints

Add supervisor configuration for managing Redis and Gunicorn processes
Replace direct process management with supervisord
Add secure and token-free API server variants
Implement JWT authentication for protected endpoints
Update datetime handling in async dispatcher
Add email domain verification

BREAKING CHANGE: Server startup now uses supervisord instead of direct process management
This commit is contained in:
UncleCode 2025-02-17 20:31:20 +08:00
parent 8bb799068e
commit 2864015469
12 changed files with 790 additions and 79 deletions

View File

@ -2,7 +2,7 @@ FROM python:3.10-slim
# Set build arguments
ARG APP_HOME=/app
ARG GITHUB_REPO=https://github.com/yourusername/crawl4ai.git
ARG GITHUB_REPO=https://github.com/unclecode/crawl4ai.git
ARG GITHUB_BRANCH=main
ARG USE_LOCAL=true
@ -37,6 +37,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
python3-dev \
libjpeg-dev \
redis-server \
supervisor \
&& rm -rf /var/lib/apt/lists/*
RUN apt-get update && apt-get install -y --no-install-recommends \
@ -102,6 +103,8 @@ fi' > /tmp/install.sh && chmod +x /tmp/install.sh
COPY . /tmp/project/
COPY deploy/docker/supervisord.conf .
COPY deploy/docker/requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
@ -148,19 +151,24 @@ HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
redis-cli ping > /dev/null && \
curl -f http://localhost:8000/health || exit 1'
COPY deploy/docker/docker-entrypoint.sh /usr/local/bin/
RUN chmod +x /usr/local/bin/docker-entrypoint.sh
# COPY deploy/docker/docker-entrypoint.sh /usr/local/bin/
# RUN chmod +x /usr/local/bin/docker-entrypoint.sh
EXPOSE 6379
ENTRYPOINT ["docker-entrypoint.sh"]
# ENTRYPOINT ["docker-entrypoint.sh"]
# CMD service redis-server start && gunicorn \
# --bind 0.0.0.0:8000 \
# --workers 4 \
# --threads 2 \
# --timeout 120 \
# --graceful-timeout 30 \
# --log-level info \
# --worker-class uvicorn.workers.UvicornWorker \
# server:app
# ENTRYPOINT ["docker-entrypoint.sh"]
CMD ["supervisord", "-c", "supervisord.conf"]
CMD service redis-server start && gunicorn \
--bind 0.0.0.0:8000 \
--workers 4 \
--threads 2 \
--timeout 120 \
--graceful-timeout 30 \
--log-level info \
--worker-class uvicorn.workers.UvicornWorker \
server:app

View File

@ -13,7 +13,7 @@ from rich.live import Live
from rich.table import Table
from rich.console import Console
from rich import box
from datetime import datetime, timedelta
from datetime import timedelta
from collections.abc import AsyncGenerator
import time
import psutil
@ -96,7 +96,7 @@ class CrawlerMonitor:
self.display_mode = display_mode
self.stats: Dict[str, CrawlStats] = {}
self.process = psutil.Process()
self.start_time = datetime.now()
self.start_time = time.time()
self.live = Live(self._create_table(), refresh_per_second=2)
def start(self):
@ -150,7 +150,7 @@ class CrawlerMonitor:
)
# Duration
duration = datetime.now() - self.start_time
duration = time.time() - self.start_time
# Create status row
table.add_column("Status", style="bold cyan")
@ -192,7 +192,7 @@ class CrawlerMonitor:
)
table.add_row(
"[yellow]Runtime[/yellow]",
str(timedelta(seconds=int(duration.total_seconds()))),
str(timedelta(seconds=int(duration))),
"",
)
@ -235,7 +235,7 @@ class CrawlerMonitor:
f"{self.process.memory_info().rss / (1024 * 1024):.1f}",
str(
timedelta(
seconds=int((datetime.now() - self.start_time).total_seconds())
seconds=int(time.time() - self.start_time)
)
),
f"{completed_count}{failed_count}",
@ -250,7 +250,7 @@ class CrawlerMonitor:
key=lambda x: (
x.status != CrawlStatus.IN_PROGRESS,
x.status != CrawlStatus.QUEUED,
x.end_time or datetime.max,
x.end_time or float('inf'),
),
)[: self.max_visible_rows]
@ -337,7 +337,7 @@ class MemoryAdaptiveDispatcher(BaseDispatcher):
config: CrawlerRunConfig,
task_id: str,
) -> CrawlerTaskResult:
start_time = datetime.now()
start_time = time.time()
error_message = ""
memory_usage = peak_memory = 0.0
@ -370,7 +370,7 @@ class MemoryAdaptiveDispatcher(BaseDispatcher):
memory_usage=memory_usage,
peak_memory=peak_memory,
start_time=start_time,
end_time=datetime.now(),
end_time=time.time(),
error_message=error_message,
)
await self.result_queue.put(result)
@ -392,7 +392,7 @@ class MemoryAdaptiveDispatcher(BaseDispatcher):
)
finally:
end_time = datetime.now()
end_time = time.time()
if self.monitor:
self.monitor.update_task(
task_id,
@ -542,7 +542,7 @@ class SemaphoreDispatcher(BaseDispatcher):
task_id: str,
semaphore: asyncio.Semaphore = None,
) -> CrawlerTaskResult:
start_time = datetime.now()
start_time = time.time()
error_message = ""
memory_usage = peak_memory = 0.0
@ -575,7 +575,7 @@ class SemaphoreDispatcher(BaseDispatcher):
memory_usage=memory_usage,
peak_memory=peak_memory,
start_time=start_time,
end_time=datetime.now(),
end_time=time.time(),
error_message=error_message,
)
@ -595,7 +595,7 @@ class SemaphoreDispatcher(BaseDispatcher):
)
finally:
end_time = datetime.now()
end_time = time.time()
if self.monitor:
self.monitor.update_task(
task_id,

View File

@ -1,3 +1,4 @@
from re import U
from pydantic import BaseModel, HttpUrl
from typing import List, Dict, Optional, Callable, Awaitable, Union, Any
from enum import Enum
@ -24,8 +25,8 @@ class CrawlerTaskResult:
result: "CrawlResult"
memory_usage: float
peak_memory: float
start_time: datetime
end_time: datetime
start_time: Union[datetime, float]
end_time: Union[datetime, float]
error_message: str = ""
@ -100,8 +101,8 @@ class DispatchResult(BaseModel):
task_id: str
memory_usage: float
peak_memory: float
start_time: datetime
end_time: datetime
start_time: Union[datetime, float]
end_time: Union[datetime, float]
error_message: str = ""

View File

@ -2,11 +2,10 @@ FROM python:3.10-slim
# Set build arguments
ARG APP_HOME=/app
ARG GITHUB_REPO=https://github.com/yourusername/crawl4ai.git
ARG GITHUB_BRANCH=main
ARG USE_LOCAL=false
ARG GITHUB_REPO=https://github.com/unclecode/crawl4ai.git
ARG GITHUB_BRANCH=next
ARG USE_LOCAL=true
# 🤓 Environment variables - because who doesn't love a good ENV party?
ENV PYTHONFAULTHANDLER=1 \
PYTHONHASHSEED=random \
PYTHONUNBUFFERED=1 \
@ -14,20 +13,19 @@ ENV PYTHONFAULTHANDLER=1 \
PYTHONDONTWRITEBYTECODE=1 \
PIP_DISABLE_PIP_VERSION_CHECK=1 \
PIP_DEFAULT_TIMEOUT=100 \
DEBIAN_FRONTEND=noninteractive
DEBIAN_FRONTEND=noninteractive \
REDIS_HOST=localhost \
REDIS_PORT=6379
# Other build arguments
ARG PYTHON_VERSION=3.10
ARG INSTALL_TYPE=default
ARG ENABLE_GPU=false
ARG TARGETARCH
# 🎯 Platform-specific labels - because even containers need ID badges
LABEL maintainer="unclecode"
LABEL description="🔥🕷️ Crawl4AI: Open-source LLM Friendly Web Crawler & scraper"
LABEL version="1.0"
# 📦 Installing system dependencies... please hold, your package is being delivered
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential \
curl \
@ -38,9 +36,10 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
pkg-config \
python3-dev \
libjpeg-dev \
redis-server \
supervisor \
&& rm -rf /var/lib/apt/lists/*
# 🎭 Playwright dependencies - because browsers need their vitamins too
RUN apt-get update && apt-get install -y --no-install-recommends \
libglib2.0-0 \
libnss3 \
@ -65,7 +64,6 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
libatspi2.0-0 \
&& rm -rf /var/lib/apt/lists/*
# 🎮 GPU support - because sometimes CPU just doesn't cut it
RUN if [ "$ENABLE_GPU" = "true" ] && [ "$TARGETARCH" = "amd64" ] ; then \
apt-get update && apt-get install -y --no-install-recommends \
nvidia-cuda-toolkit \
@ -74,7 +72,6 @@ else \
echo "Skipping NVIDIA CUDA Toolkit installation (unsupported platform or GPU disabled)"; \
fi
# 🏗️ Platform-specific optimizations - because one size doesn't fit all
RUN if [ "$TARGETARCH" = "arm64" ]; then \
echo "🦾 Installing ARM-specific optimizations"; \
apt-get update && apt-get install -y --no-install-recommends \
@ -85,11 +82,12 @@ elif [ "$TARGETARCH" = "amd64" ]; then \
apt-get update && apt-get install -y --no-install-recommends \
libomp-dev \
&& rm -rf /var/lib/apt/lists/*; \
else \
echo "Skipping platform-specific optimizations (unsupported platform)"; \
fi
WORKDIR ${APP_HOME}
# 🔄 Installation script - now with retry logic because sometimes Git needs a coffee break
RUN echo '#!/bin/bash\n\
if [ "$USE_LOCAL" = "true" ]; then\n\
echo "📦 Installing from local source..."\n\
@ -103,14 +101,13 @@ else\n\
pip install --no-cache-dir /tmp/crawl4ai\n\
fi' > /tmp/install.sh && chmod +x /tmp/install.sh
# Copy local project if USE_LOCAL is true
COPY . /tmp/project/
# Copy and install other requirements
COPY deploy/docker/supervisord.conf .
COPY deploy/docker/requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
# Install ML dependencies first for better layer caching
RUN if [ "$INSTALL_TYPE" = "all" ] ; then \
pip install --no-cache-dir \
torch \
@ -123,7 +120,6 @@ RUN if [ "$INSTALL_TYPE" = "all" ] ; then \
python -m nltk.downloader punkt stopwords ; \
fi
# Install the package
RUN if [ "$INSTALL_TYPE" = "all" ] ; then \
pip install "/tmp/project/[all]" && \
python -m crawl4ai.model_loader ; \
@ -136,7 +132,6 @@ RUN if [ "$INSTALL_TYPE" = "all" ] ; then \
pip install "/tmp/project" ; \
fi
# 🚀 Installation validation - trust but verify!
RUN pip install --no-cache-dir --upgrade pip && \
/tmp/install.sh && \
python -c "import crawl4ai; print('✅ crawl4ai is ready to rock!')" && \
@ -144,10 +139,8 @@ RUN pip install --no-cache-dir --upgrade pip && \
RUN playwright install --with-deps chromium
# Copy application files
COPY deploy/docker/* ${APP_HOME}/
# 🏥 Health check - now with memory validation!
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
CMD bash -c '\
MEM=$(free -m | awk "/^Mem:/{print \$2}"); \
@ -155,20 +148,27 @@ HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
echo "⚠️ Warning: Less than 2GB RAM available! Your container might need a memory boost! 🚀"; \
exit 1; \
fi && \
redis-cli ping > /dev/null && \
curl -f http://localhost:8000/health || exit 1'
# Entrypoint script
COPY deploy/docker/docker-entrypoint.sh /usr/local/bin/
RUN chmod +x /usr/local/bin/docker-entrypoint.sh
ENTRYPOINT ["docker-entrypoint.sh"]
# COPY deploy/docker/docker-entrypoint.sh /usr/local/bin/
# RUN chmod +x /usr/local/bin/docker-entrypoint.sh
EXPOSE 6379
# ENTRYPOINT ["docker-entrypoint.sh"]
# CMD service redis-server start && gunicorn \
# --bind 0.0.0.0:8000 \
# --workers 4 \
# --threads 2 \
# --timeout 120 \
# --graceful-timeout 30 \
# --log-level info \
# --worker-class uvicorn.workers.UvicornWorker \
# server:app
# ENTRYPOINT ["docker-entrypoint.sh"]
CMD ["supervisord", "-c", "supervisord.conf"]
# Default command - may the server be with you! 🚀
CMD ["gunicorn", \
"--bind", "0.0.0.0:8000", \
"--workers", "4", \
"--threads", "2", \
"--timeout", "120", \
"--graceful-timeout", "30", \
"--log-level", "info", \
"--worker-class", "uvicorn.workers.UvicornWorker", \
"server:app"]

View File

@ -1,3 +1,4 @@
from math import e
import os
import json
import logging
@ -14,6 +15,7 @@ from crawl4ai import (
LLMExtractionStrategy,
CacheMode
)
from crawl4ai.utils import perform_completion_with_backoff
from crawl4ai.content_filter_strategy import (
PruningContentFilter,
BM25ContentFilter,
@ -33,6 +35,51 @@ from utils import (
logger = logging.getLogger(__name__)
async def handle_llm_qa(
url: str,
query: str,
config: dict
) -> str:
"""Process QA using LLM with crawled content as context."""
try:
# Extract base URL by finding last '?q=' occurrence
last_q_index = url.rfind('?q=')
if last_q_index != -1:
url = url[:last_q_index]
# Get markdown content
async with AsyncWebCrawler() as crawler:
result = await crawler.arun(url)
if not result.success:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=result.error_message
)
content = result.markdown_v2.fit_markdown
# Create prompt and get LLM response
prompt = f"""Use the following content as context to answer the question.
Content:
{content}
Question: {query}
Answer:"""
response = perform_completion_with_backoff(
provider=config["llm"]["provider"],
prompt_with_variables=prompt,
api_token=os.environ.get(config["llm"].get("api_key_env", ""))
)
return response.choices[0].message.content
except Exception as e:
logger.error(f"QA processing error: {str(e)}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=str(e)
)
async def process_llm_extraction(
redis: aioredis.Redis,
config: dict,
@ -44,9 +91,15 @@ async def process_llm_extraction(
) -> None:
"""Process LLM extraction in background."""
try:
# If config['llm'] has api_key then ignore the api_key_env
api_key = ""
if "api_key" in config["llm"]:
api_key = config["llm"]["api_key"]
else:
api_key = os.environ.get(config["llm"].get("api_key_env", None), "")
llm_strategy = LLMExtractionStrategy(
provider=config["llm"]["provider"],
api_token=os.environ.get(config["llm"].get("api_key_env", None), ""),
api_token=api_key,
instruction=instruction,
schema=json.loads(schema) if schema else None,
)

View File

@ -11,6 +11,7 @@ app:
llm:
provider: "openai/gpt-4o-mini"
api_key_env: "OPENAI_API_KEY"
# api_key: sk-... # If you pass the API key directly then api_key_env will be ignored
# Redis Configuration
redis:

View File

@ -5,3 +5,6 @@ gunicorn>=23.0.0
slowapi>=0.1.9
prometheus-fastapi-instrumentator>=7.0.2
redis>=5.2.1
jwt>=1.3.1
dnspython>=2.7.0
email-validator>=2.2.0

View File

@ -0,0 +1,324 @@
import os
import sys
import time
import base64
from typing import List, Optional, Dict
from datetime import datetime, timedelta, timezone
from jwt import JWT, jwk_from_dict
from jwt.utils import get_int_from_datetime
from fastapi import FastAPI, HTTPException, Request, status, Depends, Query, Path
from fastapi.responses import StreamingResponse, RedirectResponse, PlainTextResponse, JSONResponse
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi.middleware.httpsredirect import HTTPSRedirectMiddleware
from fastapi.middleware.trustedhost import TrustedHostMiddleware
from pydantic import BaseModel, Field
from slowapi import Limiter
from slowapi.util import get_remote_address
from prometheus_fastapi_instrumentator import Instrumentator
from redis import asyncio as aioredis
from pydantic import EmailStr
sys.path.append(os.path.dirname(os.path.realpath(__file__)))
from utils import FilterType, load_config, setup_logging, verify_email_domain
from api import (
handle_markdown_request,
handle_llm_qa
)
__version__ = "0.1.2"
class CrawlRequest(BaseModel):
urls: List[str] = Field(
min_length=1,
max_length=100,
json_schema_extra={
"items": {"type": "string", "maxLength": 2000, "pattern": "\\S"}
}
)
browser_config: Optional[Dict] = Field(
default_factory=dict,
example={"headless": True, "viewport": {"width": 1200}}
)
crawler_config: Optional[Dict] = Field(
default_factory=dict,
example={"stream": True, "cache_mode": "aggressive"}
)
# Load configuration and setup
config = load_config()
setup_logging(config)
# Initialize Redis
redis = aioredis.from_url(config["redis"].get("uri", "redis://localhost"))
# Initialize rate limiter
limiter = Limiter(
key_func=get_remote_address,
default_limits=[config["rate_limiting"]["default_limit"]],
storage_uri=config["rate_limiting"]["storage_uri"]
)
app = FastAPI(
title=config["app"]["title"],
version=config["app"]["version"]
)
# Configure middleware
if config["security"]["enabled"]:
if config["security"]["https_redirect"]:
app.add_middleware(HTTPSRedirectMiddleware)
if config["security"]["trusted_hosts"] and config["security"]["trusted_hosts"] != ["*"]:
app.add_middleware(
TrustedHostMiddleware,
allowed_hosts=config["security"]["trusted_hosts"]
)
# Prometheus instrumentation
if config["observability"]["prometheus"]["enabled"]:
Instrumentator().instrument(app).expose(app)
# -------------------------------
# JWT Token Authentication Setup
# -------------------------------
instance = JWT()
# Use a secret key for symmetric signing (HS256)
SECRET_KEY = os.environ.get("SECRET_KEY", "mysecret")
ACCESS_TOKEN_EXPIRE_MINUTES = 60
# FastAPI security scheme for extracting the Authorization header
security = HTTPBearer()
def get_jwk_from_secret(secret: str):
"""
Convert a simple secret string into a JWK object.
The secret is base64 URL-safe encoded (without padding) as required.
"""
secret_bytes = secret.encode('utf-8')
b64_secret = base64.urlsafe_b64encode(secret_bytes).rstrip(b'=').decode('utf-8')
return jwk_from_dict({"kty": "oct", "k": b64_secret})
def create_access_token(data: dict, expires_delta: timedelta = None):
"""
Create a JWT access token with an expiration.
"""
to_encode = data.copy()
expire = datetime.now(timezone.utc) + (expires_delta if expires_delta else timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES))
to_encode.update({"exp": get_int_from_datetime(expire)})
# Convert the secret into a JWK object
signing_key = get_jwk_from_secret(SECRET_KEY)
encoded_jwt = instance.encode(to_encode, signing_key, alg='HS256')
return encoded_jwt
def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)):
"""
Verify the JWT token extracted from the Authorization header.
"""
token = credentials.credentials
# Convert the secret into a JWK object for verification
verifying_key = get_jwk_from_secret(SECRET_KEY)
try:
payload = instance.decode(token, verifying_key, do_time_check=True, algorithms='HS256')
return payload
except Exception as e:
raise HTTPException(status_code=401, detail="Invalid or expired token")
# -------------------------------
# Endpoints
# -------------------------------
@app.middleware("http")
async def add_security_headers(request: Request, call_next):
response = await call_next(request)
if config["security"]["enabled"]:
response.headers.update(config["security"]["headers"])
return response
class TokenRequest(BaseModel):
email: EmailStr
@app.post("/token")
async def get_token(request_data: TokenRequest):
"""
Minimal endpoint to generate a JWT token.
In a real-world scenario, you'd validate credentials here.
"""
# token = create_access_token({"sub": "user1"})
# return {"access_token": token, "token_type": "bearer"}
# Verify that the email domain likely exists (has MX records)
if not verify_email_domain(request_data.email):
raise HTTPException(
status_code=400,
detail="Email domain verification failed. Please use a valid email address."
)
token = create_access_token({"sub": request_data.email})
return {"email": request_data.email, "access_token": token, "token_type": "bearer"}
@app.get("/md/{url:path}")
@limiter.limit(config["rate_limiting"]["default_limit"])
async def get_markdown(
request: Request,
url: str,
f: FilterType = FilterType.FIT,
q: Optional[str] = None,
c: Optional[str] = "0",
token_data: dict = Depends(verify_token)
):
"""Get markdown from URL with optional filtering."""
result = await handle_markdown_request(url, f, q, c, config)
return PlainTextResponse(result)
@app.get("/llm/{url:path}", description="URL should be without http/https prefix")
async def llm_endpoint(
request: Request,
url: str = Path(..., description="Domain and path without protocol"),
q: Optional[str] = Query(None, description="Question to ask about the page content"),
token_data: dict = Depends(verify_token)
):
"""QA endpoint that uses LLM with crawled content as context."""
if not q:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Query parameter 'q' is required"
)
# Ensure URL starts with http/https
if not url.startswith(('http://', 'https://')):
url = 'https://' + url
try:
answer = await handle_llm_qa(url, q, config)
return JSONResponse({"answer": answer})
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=str(e)
)
@app.get("/schema")
async def get_schema():
"""Endpoint for client-side validation schema."""
from crawl4ai import BrowserConfig, CrawlerRunConfig
return {
"browser": BrowserConfig.model_json_schema(),
"crawler": CrawlerRunConfig.model_json_schema()
}
@app.get(config["observability"]["health_check"]["endpoint"])
async def health():
"""Health check endpoint."""
return {"status": "ok", "timestamp": time.time(), "version": __version__}
@app.get(config["observability"]["prometheus"]["endpoint"])
async def metrics():
"""Prometheus metrics endpoint."""
return RedirectResponse(url=config["observability"]["prometheus"]["endpoint"])
# -------------------------------
# Protected Endpoint Example: /crawl
# -------------------------------
@app.post("/crawl")
@limiter.limit(config["rate_limiting"]["default_limit"])
async def crawl(request: Request, crawl_request: CrawlRequest, token_data: dict = Depends(verify_token)):
"""Handle crawl requests. Protected by JWT authentication."""
from crawl4ai import (
AsyncWebCrawler,
BrowserConfig,
CrawlerRunConfig,
MemoryAdaptiveDispatcher,
RateLimiter
)
import asyncio
import logging
logger = logging.getLogger(__name__)
crawler = None
try:
if not crawl_request.urls:
logger.error("Empty URL list received")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="At least one URL required"
)
browser_config = BrowserConfig.load(crawl_request.browser_config)
crawler_config = CrawlerRunConfig.load(crawl_request.crawler_config)
dispatcher = MemoryAdaptiveDispatcher(
memory_threshold_percent=config["crawler"]["memory_threshold_percent"],
rate_limiter=RateLimiter(
base_delay=tuple(config["crawler"]["rate_limiter"]["base_delay"])
)
)
if crawler_config.stream:
crawler = AsyncWebCrawler(config=browser_config)
await crawler.start()
results_gen = await asyncio.wait_for(
crawler.arun_many(
urls=crawl_request.urls,
config=crawler_config,
dispatcher=dispatcher
),
timeout=config["crawler"]["timeouts"]["stream_init"]
)
from api import stream_results
return StreamingResponse(
stream_results(crawler, results_gen),
media_type='application/x-ndjson',
headers={
'Cache-Control': 'no-cache',
'Connection': 'keep-alive',
'X-Stream-Status': 'active'
}
)
else:
async with AsyncWebCrawler(config=browser_config) as crawler:
results = await asyncio.wait_for(
crawler.arun_many(
urls=crawl_request.urls,
config=crawler_config,
dispatcher=dispatcher
),
timeout=config["crawler"]["timeouts"]["batch_process"]
)
return JSONResponse({
"success": True,
"results": [result.model_dump() for result in results]
})
except asyncio.TimeoutError as e:
logger.error(f"Operation timed out: {str(e)}")
raise HTTPException(
status_code=status.HTTP_504_GATEWAY_TIMEOUT,
detail="Processing timeout"
)
except Exception as e:
logger.error(f"Server error: {str(e)}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Internal server error"
)
finally:
if crawler:
try:
await crawler.close()
except Exception as e:
logger.error(f"Final crawler cleanup error: {e}")
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"server-secure:app",
host=config["app"]["host"],
port=config["app"]["port"],
reload=config["app"]["reload"],
timeout_keep_alive=config["app"]["timeout_keep_alive"]
)

View File

@ -0,0 +1,267 @@
import os
import sys
import time
from typing import List, Optional
sys.path.append(os.path.dirname(os.path.realpath(__file__)))
from redis import asyncio as aioredis
from fastapi import FastAPI, HTTPException, Request, status
from fastapi.responses import StreamingResponse, RedirectResponse
from fastapi.middleware.httpsredirect import HTTPSRedirectMiddleware
from fastapi.middleware.trustedhost import TrustedHostMiddleware
from pydantic import BaseModel, Field
from slowapi import Limiter
from slowapi.util import get_remote_address
from prometheus_fastapi_instrumentator import Instrumentator
from fastapi.responses import PlainTextResponse
from fastapi.responses import JSONResponse
from fastapi.background import BackgroundTasks
from typing import Dict
from fastapi import Query, Path
import os
from utils import (
FilterType,
load_config,
setup_logging
)
from api import (
handle_markdown_request,
handle_llm_request,
handle_llm_qa
)
# Load configuration and setup
config = load_config()
setup_logging(config)
# Initialize Redis
redis = aioredis.from_url(config["redis"].get("uri", "redis://localhost"))
# Initialize rate limiter
limiter = Limiter(
key_func=get_remote_address,
default_limits=[config["rate_limiting"]["default_limit"]],
storage_uri=config["rate_limiting"]["storage_uri"]
)
app = FastAPI(
title=config["app"]["title"],
version=config["app"]["version"]
)
# Configure middleware
if config["security"]["enabled"]:
if config["security"]["https_redirect"]:
app.add_middleware(HTTPSRedirectMiddleware)
if config["security"]["trusted_hosts"] and config["security"]["trusted_hosts"] != ["*"]:
app.add_middleware(
TrustedHostMiddleware,
allowed_hosts=config["security"]["trusted_hosts"]
)
# Prometheus instrumentation
if config["observability"]["prometheus"]["enabled"]:
Instrumentator().instrument(app).expose(app)
class CrawlRequest(BaseModel):
urls: List[str] = Field(
min_length=1,
max_length=100,
json_schema_extra={
"items": {"type": "string", "maxLength": 2000, "pattern": "\\S"}
}
)
browser_config: Optional[Dict] = Field(
default_factory=dict,
example={"headless": True, "viewport": {"width": 1200}}
)
crawler_config: Optional[Dict] = Field(
default_factory=dict,
example={"stream": True, "cache_mode": "aggressive"}
)
@app.middleware("http")
async def add_security_headers(request: Request, call_next):
response = await call_next(request)
if config["security"]["enabled"]:
response.headers.update(config["security"]["headers"])
return response
@app.get("/md/{url:path}")
@limiter.limit(config["rate_limiting"]["default_limit"])
async def get_markdown(
request: Request,
url: str,
f: FilterType = FilterType.FIT,
q: Optional[str] = None,
c: Optional[str] = "0"
):
"""Get markdown from URL with optional filtering."""
result = await handle_markdown_request(url, f, q, c, config)
return PlainTextResponse(result)
@app.get("/llm/{url:path}", description="URL should be without http/https prefix")
async def llm_endpoint(
request: Request,
url: str = Path(..., description="Domain and path without protocol"),
q: Optional[str] = Query(None, description="Question to ask about the page content"),
):
"""QA endpoint that uses LLM with crawled content as context."""
if not q:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Query parameter 'q' is required"
)
# Ensure URL starts with http/https
if not url.startswith(('http://', 'https://')):
url = 'https://' + url
try:
answer = await handle_llm_qa(url, q, config)
return JSONResponse({"answer": answer})
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=str(e)
)
# @app.get("/llm/{input:path}")
# @limiter.limit(config["rate_limiting"]["default_limit"])
# async def llm_endpoint(
# request: Request,
# background_tasks: BackgroundTasks,
# input: str,
# q: Optional[str] = None,
# s: Optional[str] = None,
# c: Optional[str] = "0"
# ):
# """Handle LLM extraction requests."""
# return await handle_llm_request(
# redis, background_tasks, request, input, q, s, c, config
# )
@app.get("/schema")
async def get_schema():
"""Endpoint for client-side validation schema."""
from crawl4ai import BrowserConfig, CrawlerRunConfig
return {
"browser": BrowserConfig().dump(),
"crawler": CrawlerRunConfig().dump()
}
@app.get(config["observability"]["health_check"]["endpoint"])
async def health():
"""Health check endpoint."""
return {"status": "ok", "timestamp": time.time()}
@app.get(config["observability"]["prometheus"]["endpoint"])
async def metrics():
"""Prometheus metrics endpoint."""
return RedirectResponse(url=config["observability"]["prometheus"]["endpoint"])
@app.post("/crawl")
@limiter.limit(config["rate_limiting"]["default_limit"])
async def crawl(request: Request, crawl_request: CrawlRequest):
"""Handle crawl requests."""
from crawl4ai import (
AsyncWebCrawler,
BrowserConfig,
CrawlerRunConfig,
MemoryAdaptiveDispatcher,
RateLimiter
)
import asyncio
import logging
logger = logging.getLogger(__name__)
crawler = None
try:
if not crawl_request.urls:
logger.error("Empty URL list received")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="At least one URL required"
)
browser_config = BrowserConfig.load(crawl_request.browser_config)
crawler_config = CrawlerRunConfig.load(crawl_request.crawler_config)
dispatcher = MemoryAdaptiveDispatcher(
memory_threshold_percent=config["crawler"]["memory_threshold_percent"],
rate_limiter=RateLimiter(
base_delay=tuple(config["crawler"]["rate_limiter"]["base_delay"])
)
)
if crawler_config.stream:
crawler = AsyncWebCrawler(config=browser_config)
await crawler.start()
results_gen = await asyncio.wait_for(
crawler.arun_many(
urls=crawl_request.urls,
config=crawler_config,
dispatcher=dispatcher
),
timeout=config["crawler"]["timeouts"]["stream_init"]
)
from api import stream_results
return StreamingResponse(
stream_results(crawler, results_gen),
media_type='application/x-ndjson',
headers={
'Cache-Control': 'no-cache',
'Connection': 'keep-alive',
'X-Stream-Status': 'active'
}
)
else:
async with AsyncWebCrawler(config=browser_config) as crawler:
results = await asyncio.wait_for(
crawler.arun_many(
urls=crawl_request.urls,
config=crawler_config,
dispatcher=dispatcher
),
timeout=config["crawler"]["timeouts"]["batch_process"]
)
return JSONResponse({
"success": True,
"results": [result.model_dump() for result in results]
})
except asyncio.TimeoutError as e:
logger.error(f"Operation timed out: {str(e)}")
raise HTTPException(
status_code=status.HTTP_504_GATEWAY_TIMEOUT,
detail="Processing timeout"
)
except Exception as e:
logger.error(f"Server error: {str(e)}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Internal server error"
)
finally:
if crawler:
try:
await crawler.close()
except Exception as e:
logger.error(f"Final crawler cleanup error: {e}")
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"server:app",
host=config["app"]["host"],
port=config["app"]["port"],
reload=config["app"]["reload"],
timeout_keep_alive=config["app"]["timeout_keep_alive"]
)

View File

@ -18,6 +18,7 @@ from fastapi.responses import PlainTextResponse
from fastapi.responses import JSONResponse
from fastapi.background import BackgroundTasks
from typing import Dict
from fastapi import Query, Path
import os
from utils import (
@ -27,7 +28,8 @@ from utils import (
)
from api import (
handle_markdown_request,
handle_llm_request
handle_llm_request,
handle_llm_qa
)
# Load configuration and setup
@ -100,28 +102,56 @@ async def get_markdown(
result = await handle_markdown_request(url, f, q, c, config)
return PlainTextResponse(result)
@app.get("/llm/{input:path}")
@limiter.limit(config["rate_limiting"]["default_limit"])
@app.get("/llm/{url:path}", description="URL should be without http/https prefix")
async def llm_endpoint(
request: Request,
background_tasks: BackgroundTasks,
input: str,
q: Optional[str] = None,
s: Optional[str] = None,
c: Optional[str] = "0"
url: str = Path(..., description="Domain and path without protocol"),
q: Optional[str] = Query(None, description="Question to ask about the page content"),
):
"""Handle LLM extraction requests."""
return await handle_llm_request(
redis, background_tasks, request, input, q, s, c, config
"""QA endpoint that uses LLM with crawled content as context."""
if not q:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Query parameter 'q' is required"
)
# Ensure URL starts with http/https
if not url.startswith(('http://', 'https://')):
url = 'https://' + url
try:
answer = await handle_llm_qa(url, q, config)
return JSONResponse({"answer": answer})
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=str(e)
)
# @app.get("/llm/{input:path}")
# @limiter.limit(config["rate_limiting"]["default_limit"])
# async def llm_endpoint(
# request: Request,
# background_tasks: BackgroundTasks,
# input: str,
# q: Optional[str] = None,
# s: Optional[str] = None,
# c: Optional[str] = "0"
# ):
# """Handle LLM extraction requests."""
# return await handle_llm_request(
# redis, background_tasks, request, input, q, s, c, config
# )
@app.get("/schema")
async def get_schema():
"""Endpoint for client-side validation schema."""
from crawl4ai import BrowserConfig, CrawlerRunConfig
return {
"browser": BrowserConfig.model_json_schema(),
"crawler": CrawlerRunConfig.model_json_schema()
"browser": BrowserConfig().dump(),
"crawler": CrawlerRunConfig().dump()
}
@app.get(config["observability"]["health_check"]["endpoint"])

View File

@ -0,0 +1,12 @@
[supervisord]
nodaemon=true
[program:redis]
command=redis-server
autorestart=true
priority=10
[program:gunicorn]
command=gunicorn --bind 0.0.0.0:8000 --workers 4 --threads 2 --timeout 120 --graceful-timeout 30 --log-level info --worker-class uvicorn.workers.UvicornWorker server:app
autorestart=true
priority=20

View File

@ -1,3 +1,4 @@
import dns.resolver
import logging
import yaml
from datetime import datetime
@ -52,3 +53,14 @@ def should_cleanup_task(created_at: str) -> bool:
def decode_redis_hash(hash_data: Dict[bytes, bytes]) -> Dict[str, str]:
"""Decode Redis hash data from bytes to strings."""
return {k.decode('utf-8'): v.decode('utf-8') for k, v in hash_data.items()}
def verify_email_domain(email: str) -> bool:
try:
domain = email.split('@')[1]
# Try to resolve MX records for the domain.
records = dns.resolver.resolve(domain, 'MX')
return True if records else False
except Exception as e:
return False