2024-05-14 21:27:41 +08:00
|
|
|
import os
|
|
|
|
import importlib
|
|
|
|
import asyncio
|
|
|
|
from functools import lru_cache
|
2024-05-18 09:16:52 +00:00
|
|
|
import logging
|
|
|
|
logging.basicConfig(level=logging.DEBUG)
|
2024-05-14 21:27:41 +08:00
|
|
|
|
2024-05-09 19:10:25 +08:00
|
|
|
from fastapi import FastAPI, HTTPException, Request
|
2024-05-14 21:27:41 +08:00
|
|
|
from fastapi.responses import HTMLResponse, JSONResponse
|
2024-05-09 19:10:25 +08:00
|
|
|
from fastapi.staticfiles import StaticFiles
|
2024-05-14 21:27:41 +08:00
|
|
|
from fastapi.middleware.cors import CORSMiddleware
|
2024-05-16 17:31:44 +08:00
|
|
|
from fastapi.templating import Jinja2Templates
|
2024-06-22 20:54:32 +08:00
|
|
|
from fastapi.exceptions import RequestValidationError
|
|
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
|
|
from starlette.responses import FileResponse
|
|
|
|
from fastapi.responses import RedirectResponse
|
2024-05-14 21:27:41 +08:00
|
|
|
|
2024-05-09 19:10:25 +08:00
|
|
|
from pydantic import BaseModel, HttpUrl
|
2024-05-14 21:27:41 +08:00
|
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
2024-05-09 19:10:25 +08:00
|
|
|
from typing import List, Optional
|
2024-05-14 21:27:41 +08:00
|
|
|
|
2024-05-09 22:16:28 +08:00
|
|
|
from crawl4ai.web_crawler import WebCrawler
|
|
|
|
from crawl4ai.database import get_total_count, clear_db
|
2024-05-09 19:10:25 +08:00
|
|
|
|
2024-07-08 20:02:12 +08:00
|
|
|
import time
|
|
|
|
from slowapi import Limiter, _rate_limit_exceeded_handler
|
|
|
|
from slowapi.util import get_remote_address
|
|
|
|
from slowapi.errors import RateLimitExceeded
|
|
|
|
|
|
|
|
# load .env file
|
|
|
|
from dotenv import load_dotenv
|
|
|
|
load_dotenv()
|
|
|
|
|
2024-05-09 19:10:25 +08:00
|
|
|
# Configuration
|
|
|
|
__location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
|
|
|
|
MAX_CONCURRENT_REQUESTS = 10 # Adjust this to change the maximum concurrent requests
|
|
|
|
current_requests = 0
|
|
|
|
lock = asyncio.Lock()
|
|
|
|
|
|
|
|
app = FastAPI()
|
|
|
|
|
2024-07-08 20:02:12 +08:00
|
|
|
# Initialize rate limiter
|
|
|
|
def rate_limit_key_func(request: Request):
|
|
|
|
access_token = request.headers.get("access-token")
|
|
|
|
if access_token == os.environ.get('ACCESS_TOKEN'):
|
|
|
|
return None
|
|
|
|
return get_remote_address(request)
|
|
|
|
|
|
|
|
limiter = Limiter(key_func=rate_limit_key_func)
|
|
|
|
app.state.limiter = limiter
|
|
|
|
|
|
|
|
# Dictionary to store last request times for each client
|
|
|
|
last_request_times = {}
|
2024-07-08 20:24:00 +08:00
|
|
|
last_rate_limit = {}
|
|
|
|
|
2024-07-08 20:02:12 +08:00
|
|
|
|
|
|
|
def get_rate_limit():
|
|
|
|
limit = os.environ.get('ACCESS_PER_MIN', "5")
|
|
|
|
return f"{limit}/minute"
|
|
|
|
|
|
|
|
# Custom rate limit exceeded handler
|
|
|
|
async def custom_rate_limit_exceeded_handler(request: Request, exc: RateLimitExceeded) -> JSONResponse:
|
2024-07-08 20:24:00 +08:00
|
|
|
if request.client.host not in last_rate_limit or time.time() - last_rate_limit[request.client.host] > 60:
|
|
|
|
last_rate_limit[request.client.host] = time.time()
|
|
|
|
retry_after = 60 - (time.time() - last_rate_limit[request.client.host])
|
|
|
|
reset_at = time.time() + retry_after
|
2024-07-08 20:02:12 +08:00
|
|
|
return JSONResponse(
|
|
|
|
status_code=429,
|
|
|
|
content={
|
|
|
|
"detail": "Rate limit exceeded",
|
|
|
|
"limit": str(exc.limit.limit),
|
2024-07-08 20:24:00 +08:00
|
|
|
"retry_after": retry_after,
|
|
|
|
'reset_at': reset_at,
|
|
|
|
"message": f"You have exceeded the rate limit of {exc.limit.limit}."
|
2024-07-08 20:02:12 +08:00
|
|
|
}
|
|
|
|
)
|
|
|
|
|
|
|
|
app.add_exception_handler(RateLimitExceeded, custom_rate_limit_exceeded_handler)
|
|
|
|
|
|
|
|
|
|
|
|
# Middleware for token-based bypass and per-request limit
|
|
|
|
class RateLimitMiddleware(BaseHTTPMiddleware):
|
|
|
|
async def dispatch(self, request: Request, call_next):
|
|
|
|
SPAN = int(os.environ.get('ACCESS_TIME_SPAN', 10))
|
|
|
|
access_token = request.headers.get("access-token")
|
|
|
|
if access_token == os.environ.get('ACCESS_TOKEN'):
|
|
|
|
return await call_next(request)
|
|
|
|
|
|
|
|
path = request.url.path
|
|
|
|
if path in ["/crawl", "/old"]:
|
|
|
|
client_ip = request.client.host
|
|
|
|
current_time = time.time()
|
|
|
|
|
|
|
|
# Check time since last request
|
|
|
|
if client_ip in last_request_times:
|
|
|
|
time_since_last_request = current_time - last_request_times[client_ip]
|
|
|
|
if time_since_last_request < SPAN:
|
|
|
|
return JSONResponse(
|
|
|
|
status_code=429,
|
|
|
|
content={
|
|
|
|
"detail": "Too many requests",
|
|
|
|
"message": "Rate limit exceeded. Please wait 10 seconds between requests.",
|
2024-07-08 20:24:00 +08:00
|
|
|
"retry_after": max(0, SPAN - time_since_last_request),
|
|
|
|
"reset_at": current_time + max(0, SPAN - time_since_last_request),
|
2024-07-08 20:02:12 +08:00
|
|
|
}
|
|
|
|
)
|
|
|
|
|
|
|
|
last_request_times[client_ip] = current_time
|
|
|
|
|
|
|
|
return await call_next(request)
|
|
|
|
|
|
|
|
app.add_middleware(RateLimitMiddleware)
|
|
|
|
|
2024-05-10 12:27:40 +02:00
|
|
|
# CORS configuration
|
|
|
|
origins = ["*"] # Allow all origins
|
|
|
|
app.add_middleware(
|
|
|
|
CORSMiddleware,
|
|
|
|
allow_origins=origins, # List of origins that are allowed to make requests
|
|
|
|
allow_credentials=True,
|
|
|
|
allow_methods=["*"], # Allows all methods
|
|
|
|
allow_headers=["*"], # Allows all headers
|
|
|
|
)
|
|
|
|
|
2024-05-09 19:10:25 +08:00
|
|
|
# Mount the pages directory as a static directory
|
|
|
|
app.mount("/pages", StaticFiles(directory=__location__ + "/pages"), name="pages")
|
2024-06-21 17:56:54 +08:00
|
|
|
app.mount("/mkdocs", StaticFiles(directory="site", html=True), name="mkdocs")
|
2024-06-22 20:41:39 +08:00
|
|
|
site_templates = Jinja2Templates(directory=__location__ + "/site")
|
2024-05-16 17:31:44 +08:00
|
|
|
templates = Jinja2Templates(directory=__location__ + "/pages")
|
2024-06-22 20:54:32 +08:00
|
|
|
|
2024-05-14 21:27:41 +08:00
|
|
|
@lru_cache()
|
|
|
|
def get_crawler():
|
|
|
|
# Initialize and return a WebCrawler instance
|
2024-06-26 13:00:17 +08:00
|
|
|
crawler = WebCrawler(verbose = True)
|
|
|
|
crawler.warmup()
|
|
|
|
return crawler
|
2024-05-09 19:10:25 +08:00
|
|
|
|
2024-05-14 21:27:41 +08:00
|
|
|
class CrawlRequest(BaseModel):
|
2024-05-17 16:50:38 +08:00
|
|
|
urls: List[str]
|
2024-05-09 19:10:25 +08:00
|
|
|
include_raw_html: Optional[bool] = False
|
2024-05-14 21:27:41 +08:00
|
|
|
bypass_cache: bool = False
|
2024-05-09 19:10:25 +08:00
|
|
|
extract_blocks: bool = True
|
|
|
|
word_count_threshold: Optional[int] = 5
|
2024-05-17 16:50:38 +08:00
|
|
|
extraction_strategy: Optional[str] = "NoExtractionStrategy"
|
2024-05-16 17:31:44 +08:00
|
|
|
extraction_strategy_args: Optional[dict] = {}
|
2024-05-14 21:27:41 +08:00
|
|
|
chunking_strategy: Optional[str] = "RegexChunking"
|
2024-05-16 17:31:44 +08:00
|
|
|
chunking_strategy_args: Optional[dict] = {}
|
2024-05-14 21:27:41 +08:00
|
|
|
css_selector: Optional[str] = None
|
2024-06-07 15:23:32 +08:00
|
|
|
screenshot: Optional[bool] = False
|
2024-06-08 17:59:42 +08:00
|
|
|
user_agent: Optional[str] = None
|
2024-05-14 21:27:41 +08:00
|
|
|
verbose: Optional[bool] = True
|
2024-05-09 19:10:25 +08:00
|
|
|
|
2024-06-22 20:54:32 +08:00
|
|
|
@app.get("/")
|
|
|
|
def read_root():
|
|
|
|
return RedirectResponse(url="/mkdocs")
|
2024-06-22 20:36:01 +08:00
|
|
|
|
|
|
|
@app.get("/old", response_class=HTMLResponse)
|
2024-07-08 20:02:12 +08:00
|
|
|
@limiter.limit(get_rate_limit())
|
2024-05-16 17:31:44 +08:00
|
|
|
async def read_index(request: Request):
|
|
|
|
partials_dir = os.path.join(__location__, "pages", "partial")
|
|
|
|
partials = {}
|
|
|
|
|
|
|
|
for filename in os.listdir(partials_dir):
|
|
|
|
if filename.endswith(".html"):
|
2024-05-18 23:31:11 +08:00
|
|
|
with open(os.path.join(partials_dir, filename), "r", encoding="utf8") as file:
|
2024-05-16 17:31:44 +08:00
|
|
|
partials[filename[:-5]] = file.read()
|
|
|
|
|
|
|
|
return templates.TemplateResponse("index.html", {"request": request, **partials})
|
2024-05-09 19:10:25 +08:00
|
|
|
|
|
|
|
@app.get("/total-count")
|
|
|
|
async def get_total_url_count():
|
2024-05-14 21:27:41 +08:00
|
|
|
count = get_total_count()
|
2024-05-09 19:10:25 +08:00
|
|
|
return JSONResponse(content={"count": count})
|
|
|
|
|
2024-05-09 19:42:43 +08:00
|
|
|
@app.get("/clear-db")
|
|
|
|
async def clear_database():
|
2024-05-18 09:16:52 +00:00
|
|
|
# clear_db()
|
2024-05-09 19:42:43 +08:00
|
|
|
return JSONResponse(content={"message": "Database cleared."})
|
|
|
|
|
2024-05-16 17:31:44 +08:00
|
|
|
def import_strategy(module_name: str, class_name: str, *args, **kwargs):
|
2024-05-14 21:27:41 +08:00
|
|
|
try:
|
|
|
|
module = importlib.import_module(module_name)
|
|
|
|
strategy_class = getattr(module, class_name)
|
2024-05-16 17:31:44 +08:00
|
|
|
return strategy_class(*args, **kwargs)
|
2024-05-14 21:27:41 +08:00
|
|
|
except ImportError:
|
2024-05-18 09:16:52 +00:00
|
|
|
print("ImportError: Module not found.")
|
2024-05-14 21:27:41 +08:00
|
|
|
raise HTTPException(status_code=400, detail=f"Module {module_name} not found.")
|
|
|
|
except AttributeError:
|
2024-05-18 09:16:52 +00:00
|
|
|
print("AttributeError: Class not found.")
|
2024-05-14 21:27:41 +08:00
|
|
|
raise HTTPException(status_code=400, detail=f"Class {class_name} not found in {module_name}.")
|
|
|
|
|
2024-05-09 19:10:25 +08:00
|
|
|
@app.post("/crawl")
|
2024-07-08 20:02:12 +08:00
|
|
|
@limiter.limit(get_rate_limit())
|
2024-05-14 21:27:41 +08:00
|
|
|
async def crawl_urls(crawl_request: CrawlRequest, request: Request):
|
2024-05-18 09:16:52 +00:00
|
|
|
logging.debug(f"[LOG] Crawl request for URL: {crawl_request.urls}")
|
2024-05-09 19:10:25 +08:00
|
|
|
global current_requests
|
|
|
|
async with lock:
|
|
|
|
if current_requests >= MAX_CONCURRENT_REQUESTS:
|
|
|
|
raise HTTPException(status_code=429, detail="Too many requests - please try again later.")
|
|
|
|
current_requests += 1
|
|
|
|
|
|
|
|
try:
|
2024-05-18 09:16:52 +00:00
|
|
|
logging.debug("[LOG] Loading extraction and chunking strategies...")
|
2024-05-19 16:18:58 +00:00
|
|
|
crawl_request.extraction_strategy_args['verbose'] = True
|
|
|
|
crawl_request.chunking_strategy_args['verbose'] = True
|
|
|
|
|
2024-05-16 17:31:44 +08:00
|
|
|
extraction_strategy = import_strategy("crawl4ai.extraction_strategy", crawl_request.extraction_strategy, **crawl_request.extraction_strategy_args)
|
|
|
|
chunking_strategy = import_strategy("crawl4ai.chunking_strategy", crawl_request.chunking_strategy, **crawl_request.chunking_strategy_args)
|
2024-05-09 19:10:25 +08:00
|
|
|
|
|
|
|
# Use ThreadPoolExecutor to run the synchronous WebCrawler in async manner
|
2024-05-18 09:16:52 +00:00
|
|
|
logging.debug("[LOG] Running the WebCrawler...")
|
2024-05-09 19:10:25 +08:00
|
|
|
with ThreadPoolExecutor() as executor:
|
|
|
|
loop = asyncio.get_event_loop()
|
|
|
|
futures = [
|
2024-05-14 21:27:41 +08:00
|
|
|
loop.run_in_executor(
|
|
|
|
executor,
|
|
|
|
get_crawler().run,
|
|
|
|
str(url),
|
|
|
|
crawl_request.word_count_threshold,
|
|
|
|
extraction_strategy,
|
|
|
|
chunking_strategy,
|
|
|
|
crawl_request.bypass_cache,
|
|
|
|
crawl_request.css_selector,
|
2024-06-07 15:23:32 +08:00
|
|
|
crawl_request.screenshot,
|
2024-06-08 17:59:42 +08:00
|
|
|
crawl_request.user_agent,
|
2024-05-14 21:27:41 +08:00
|
|
|
crawl_request.verbose
|
|
|
|
)
|
|
|
|
for url in crawl_request.urls
|
2024-05-09 19:10:25 +08:00
|
|
|
]
|
|
|
|
results = await asyncio.gather(*futures)
|
|
|
|
|
|
|
|
# if include_raw_html is False, remove the raw HTML content from the results
|
2024-05-14 21:27:41 +08:00
|
|
|
if not crawl_request.include_raw_html:
|
2024-05-09 19:10:25 +08:00
|
|
|
for result in results:
|
|
|
|
result.html = None
|
|
|
|
|
2024-06-07 15:23:13 +08:00
|
|
|
return {"results": [result.model_dump() for result in results]}
|
2024-05-09 19:10:25 +08:00
|
|
|
finally:
|
|
|
|
async with lock:
|
|
|
|
current_requests -= 1
|
2024-05-14 21:27:41 +08:00
|
|
|
|
|
|
|
@app.get("/strategies/extraction", response_class=JSONResponse)
|
|
|
|
async def get_extraction_strategies():
|
|
|
|
with open(f"{__location__}/docs/extraction_strategies.json", "r") as file:
|
|
|
|
return JSONResponse(content=file.read())
|
|
|
|
|
|
|
|
@app.get("/strategies/chunking", response_class=JSONResponse)
|
|
|
|
async def get_chunking_strategies():
|
|
|
|
with open(f"{__location__}/docs/chunking_strategies.json", "r") as file:
|
|
|
|
return JSONResponse(content=file.read())
|
2024-06-22 20:54:32 +08:00
|
|
|
|
|
|
|
|
2024-05-09 19:10:25 +08:00
|
|
|
if __name__ == "__main__":
|
|
|
|
import uvicorn
|
2024-06-30 00:34:02 +08:00
|
|
|
uvicorn.run(app, host="0.0.0.0", port=8888)
|