crawl4ai/main.py

172 lines
6.4 KiB
Python
Raw Normal View History

import os
import importlib
import asyncio
from functools import lru_cache
import logging
logging.basicConfig(level=logging.DEBUG)
2024-05-09 19:10:25 +08:00
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import HTMLResponse, JSONResponse
2024-05-09 19:10:25 +08:00
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
from fastapi.templating import Jinja2Templates
2024-06-22 20:54:32 +08:00
from fastapi.exceptions import RequestValidationError
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import FileResponse
from fastapi.responses import RedirectResponse
2024-05-09 19:10:25 +08:00
from pydantic import BaseModel, HttpUrl
from concurrent.futures import ThreadPoolExecutor, as_completed
2024-05-09 19:10:25 +08:00
from typing import List, Optional
from crawl4ai.web_crawler import WebCrawler
from crawl4ai.database import get_total_count, clear_db
2024-05-09 19:10:25 +08:00
# Configuration
__location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
MAX_CONCURRENT_REQUESTS = 10 # Adjust this to change the maximum concurrent requests
current_requests = 0
lock = asyncio.Lock()
app = FastAPI()
# CORS configuration
origins = ["*"] # Allow all origins
app.add_middleware(
CORSMiddleware,
allow_origins=origins, # List of origins that are allowed to make requests
allow_credentials=True,
allow_methods=["*"], # Allows all methods
allow_headers=["*"], # Allows all headers
)
2024-05-09 19:10:25 +08:00
# Mount the pages directory as a static directory
app.mount("/pages", StaticFiles(directory=__location__ + "/pages"), name="pages")
2024-06-21 17:56:54 +08:00
app.mount("/mkdocs", StaticFiles(directory="site", html=True), name="mkdocs")
site_templates = Jinja2Templates(directory=__location__ + "/site")
templates = Jinja2Templates(directory=__location__ + "/pages")
2024-06-22 20:54:32 +08:00
@lru_cache()
def get_crawler():
# Initialize and return a WebCrawler instance
crawler = WebCrawler(verbose = True)
crawler.warmup()
return crawler
2024-05-09 19:10:25 +08:00
class CrawlRequest(BaseModel):
urls: List[str]
2024-05-09 19:10:25 +08:00
include_raw_html: Optional[bool] = False
bypass_cache: bool = False
2024-05-09 19:10:25 +08:00
extract_blocks: bool = True
word_count_threshold: Optional[int] = 5
extraction_strategy: Optional[str] = "NoExtractionStrategy"
extraction_strategy_args: Optional[dict] = {}
chunking_strategy: Optional[str] = "RegexChunking"
chunking_strategy_args: Optional[dict] = {}
css_selector: Optional[str] = None
screenshot: Optional[bool] = False
user_agent: Optional[str] = None
verbose: Optional[bool] = True
2024-05-09 19:10:25 +08:00
2024-06-22 20:54:32 +08:00
@app.get("/")
def read_root():
return RedirectResponse(url="/mkdocs")
@app.get("/old", response_class=HTMLResponse)
async def read_index(request: Request):
partials_dir = os.path.join(__location__, "pages", "partial")
partials = {}
for filename in os.listdir(partials_dir):
if filename.endswith(".html"):
with open(os.path.join(partials_dir, filename), "r", encoding="utf8") as file:
partials[filename[:-5]] = file.read()
return templates.TemplateResponse("index.html", {"request": request, **partials})
2024-05-09 19:10:25 +08:00
@app.get("/total-count")
async def get_total_url_count():
count = get_total_count()
2024-05-09 19:10:25 +08:00
return JSONResponse(content={"count": count})
@app.get("/clear-db")
async def clear_database():
# clear_db()
return JSONResponse(content={"message": "Database cleared."})
def import_strategy(module_name: str, class_name: str, *args, **kwargs):
try:
module = importlib.import_module(module_name)
strategy_class = getattr(module, class_name)
return strategy_class(*args, **kwargs)
except ImportError:
print("ImportError: Module not found.")
raise HTTPException(status_code=400, detail=f"Module {module_name} not found.")
except AttributeError:
print("AttributeError: Class not found.")
raise HTTPException(status_code=400, detail=f"Class {class_name} not found in {module_name}.")
2024-05-09 19:10:25 +08:00
@app.post("/crawl")
async def crawl_urls(crawl_request: CrawlRequest, request: Request):
logging.debug(f"[LOG] Crawl request for URL: {crawl_request.urls}")
2024-05-09 19:10:25 +08:00
global current_requests
async with lock:
if current_requests >= MAX_CONCURRENT_REQUESTS:
raise HTTPException(status_code=429, detail="Too many requests - please try again later.")
current_requests += 1
try:
logging.debug("[LOG] Loading extraction and chunking strategies...")
crawl_request.extraction_strategy_args['verbose'] = True
crawl_request.chunking_strategy_args['verbose'] = True
extraction_strategy = import_strategy("crawl4ai.extraction_strategy", crawl_request.extraction_strategy, **crawl_request.extraction_strategy_args)
chunking_strategy = import_strategy("crawl4ai.chunking_strategy", crawl_request.chunking_strategy, **crawl_request.chunking_strategy_args)
2024-05-09 19:10:25 +08:00
# Use ThreadPoolExecutor to run the synchronous WebCrawler in async manner
logging.debug("[LOG] Running the WebCrawler...")
2024-05-09 19:10:25 +08:00
with ThreadPoolExecutor() as executor:
loop = asyncio.get_event_loop()
futures = [
loop.run_in_executor(
executor,
get_crawler().run,
str(url),
crawl_request.word_count_threshold,
extraction_strategy,
chunking_strategy,
crawl_request.bypass_cache,
crawl_request.css_selector,
crawl_request.screenshot,
crawl_request.user_agent,
crawl_request.verbose
)
for url in crawl_request.urls
2024-05-09 19:10:25 +08:00
]
results = await asyncio.gather(*futures)
# if include_raw_html is False, remove the raw HTML content from the results
if not crawl_request.include_raw_html:
2024-05-09 19:10:25 +08:00
for result in results:
result.html = None
return {"results": [result.model_dump() for result in results]}
2024-05-09 19:10:25 +08:00
finally:
async with lock:
current_requests -= 1
@app.get("/strategies/extraction", response_class=JSONResponse)
async def get_extraction_strategies():
with open(f"{__location__}/docs/extraction_strategies.json", "r") as file:
return JSONResponse(content=file.read())
@app.get("/strategies/chunking", response_class=JSONResponse)
async def get_chunking_strategies():
with open(f"{__location__}/docs/chunking_strategies.json", "r") as file:
return JSONResponse(content=file.read())
2024-06-22 20:54:32 +08:00
2024-05-09 19:10:25 +08:00
if __name__ == "__main__":
import uvicorn
2024-06-30 00:34:02 +08:00
uvicorn.run(app, host="0.0.0.0", port=8888)