2024-05-09 19:10:25 +08:00
|
|
|
from fastapi import FastAPI, HTTPException, Request
|
|
|
|
from fastapi.responses import HTMLResponse
|
|
|
|
from fastapi.staticfiles import StaticFiles
|
|
|
|
from fastapi.responses import JSONResponse
|
|
|
|
from pydantic import BaseModel, HttpUrl
|
|
|
|
from typing import List, Optional
|
2024-05-09 22:16:28 +08:00
|
|
|
from crawl4ai.web_crawler import WebCrawler
|
|
|
|
from crawl4ai.models import UrlModel
|
2024-05-09 19:10:25 +08:00
|
|
|
import asyncio
|
|
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
|
|
import chromedriver_autoinstaller
|
|
|
|
from functools import lru_cache
|
2024-05-09 22:16:28 +08:00
|
|
|
from crawl4ai.database import get_total_count, clear_db
|
2024-05-09 19:10:25 +08:00
|
|
|
import os
|
|
|
|
import uuid
|
2024-05-10 12:27:40 +02:00
|
|
|
# Import the CORS middleware
|
|
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
|
2024-05-09 19:10:25 +08:00
|
|
|
|
|
|
|
# Task management
|
|
|
|
tasks = {}
|
|
|
|
|
|
|
|
# Configuration
|
|
|
|
__location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
|
|
|
|
MAX_CONCURRENT_REQUESTS = 10 # Adjust this to change the maximum concurrent requests
|
|
|
|
current_requests = 0
|
|
|
|
lock = asyncio.Lock()
|
|
|
|
|
|
|
|
app = FastAPI()
|
|
|
|
|
2024-05-10 12:27:40 +02:00
|
|
|
# CORS configuration
|
|
|
|
origins = ["*"] # Allow all origins
|
|
|
|
app.add_middleware(
|
|
|
|
CORSMiddleware,
|
|
|
|
allow_origins=origins, # List of origins that are allowed to make requests
|
|
|
|
allow_credentials=True,
|
|
|
|
allow_methods=["*"], # Allows all methods
|
|
|
|
allow_headers=["*"], # Allows all headers
|
|
|
|
)
|
|
|
|
|
2024-05-09 19:10:25 +08:00
|
|
|
# Mount the pages directory as a static directory
|
|
|
|
app.mount("/pages", StaticFiles(directory=__location__ + "/pages"), name="pages")
|
|
|
|
|
|
|
|
|
|
|
|
chromedriver_autoinstaller.install() # Ensure chromedriver is installed
|
|
|
|
|
|
|
|
class UrlsInput(BaseModel):
|
|
|
|
urls: List[HttpUrl]
|
|
|
|
provider_model: str
|
|
|
|
api_token: str
|
|
|
|
include_raw_html: Optional[bool] = False
|
|
|
|
forced: bool = False
|
|
|
|
extract_blocks: bool = True
|
|
|
|
word_count_threshold: Optional[int] = 5
|
|
|
|
|
|
|
|
@lru_cache()
|
|
|
|
def get_crawler():
|
|
|
|
# Initialize and return a WebCrawler instance
|
|
|
|
return WebCrawler(db_path='crawler_data.db')
|
|
|
|
|
|
|
|
@app.get("/", response_class=HTMLResponse)
|
|
|
|
async def read_index():
|
|
|
|
with open(f"{__location__}/pages/index.html", "r") as file:
|
|
|
|
html_content = file.read()
|
|
|
|
return HTMLResponse(content=html_content, status_code=200)
|
|
|
|
|
|
|
|
@app.get("/total-count")
|
|
|
|
async def get_total_url_count():
|
|
|
|
count = get_total_count(db_path='crawler_data.db')
|
|
|
|
return JSONResponse(content={"count": count})
|
|
|
|
|
2024-05-09 19:42:43 +08:00
|
|
|
# Add endpoit to clear db
|
|
|
|
@app.get("/clear-db")
|
|
|
|
async def clear_database():
|
|
|
|
clear_db(db_path='crawler_data.db')
|
|
|
|
return JSONResponse(content={"message": "Database cleared."})
|
|
|
|
|
2024-05-09 19:10:25 +08:00
|
|
|
@app.post("/crawl")
|
|
|
|
async def crawl_urls(urls_input: UrlsInput, request: Request):
|
|
|
|
global current_requests
|
|
|
|
# Raise error if api_token is not provided
|
|
|
|
if not urls_input.api_token:
|
|
|
|
raise HTTPException(status_code=401, detail="API token is required.")
|
|
|
|
async with lock:
|
|
|
|
if current_requests >= MAX_CONCURRENT_REQUESTS:
|
|
|
|
raise HTTPException(status_code=429, detail="Too many requests - please try again later.")
|
|
|
|
current_requests += 1
|
|
|
|
|
|
|
|
try:
|
|
|
|
# Prepare URL models for crawling
|
|
|
|
url_models = [UrlModel(url=url, forced=urls_input.forced) for url in urls_input.urls]
|
|
|
|
|
|
|
|
# Use ThreadPoolExecutor to run the synchronous WebCrawler in async manner
|
|
|
|
with ThreadPoolExecutor() as executor:
|
|
|
|
loop = asyncio.get_event_loop()
|
|
|
|
futures = [
|
|
|
|
loop.run_in_executor(executor, get_crawler().fetch_page, url_model, urls_input.provider_model, urls_input.api_token, urls_input.extract_blocks, urls_input.word_count_threshold)
|
|
|
|
for url_model in url_models
|
|
|
|
]
|
|
|
|
results = await asyncio.gather(*futures)
|
|
|
|
|
|
|
|
# if include_raw_html is False, remove the raw HTML content from the results
|
|
|
|
if not urls_input.include_raw_html:
|
|
|
|
for result in results:
|
|
|
|
result.html = None
|
|
|
|
|
|
|
|
return {"results": [result.dict() for result in results]}
|
|
|
|
finally:
|
|
|
|
async with lock:
|
|
|
|
current_requests -= 1
|
|
|
|
|
|
|
|
@app.post("/crawl_async")
|
|
|
|
async def crawl_urls(urls_input: UrlsInput, request: Request):
|
|
|
|
global current_requests
|
|
|
|
if not urls_input.api_token:
|
|
|
|
raise HTTPException(status_code=401, detail="API token is required.")
|
|
|
|
|
|
|
|
async with lock:
|
|
|
|
if current_requests >= MAX_CONCURRENT_REQUESTS:
|
|
|
|
raise HTTPException(status_code=429, detail="Too many requests - please try again later.")
|
|
|
|
current_requests += 1
|
|
|
|
|
|
|
|
task_id = str(uuid.uuid4())
|
|
|
|
tasks[task_id] = {"status": "pending", "results": None}
|
|
|
|
|
|
|
|
try:
|
|
|
|
url_models = [UrlModel(url=url, forced=urls_input.forced) for url in urls_input.urls]
|
|
|
|
|
|
|
|
loop = asyncio.get_running_loop()
|
|
|
|
loop.create_task(
|
|
|
|
process_crawl_task(url_models, urls_input.provider_model, urls_input.api_token, task_id, urls_input.extract_blocks)
|
|
|
|
)
|
|
|
|
return {"task_id": task_id}
|
|
|
|
finally:
|
|
|
|
async with lock:
|
|
|
|
current_requests -= 1
|
|
|
|
|
|
|
|
async def process_crawl_task(url_models, provider, api_token, task_id, extract_blocks_flag):
|
|
|
|
try:
|
|
|
|
with ThreadPoolExecutor() as executor:
|
|
|
|
loop = asyncio.get_running_loop()
|
|
|
|
futures = [
|
|
|
|
loop.run_in_executor(executor, get_crawler().fetch_page, url_model, provider, api_token, extract_blocks_flag)
|
|
|
|
for url_model in url_models
|
|
|
|
]
|
|
|
|
results = await asyncio.gather(*futures)
|
|
|
|
|
|
|
|
tasks[task_id] = {"status": "done", "results": results}
|
|
|
|
except Exception as e:
|
|
|
|
tasks[task_id] = {"status": "failed", "error": str(e)}
|
|
|
|
|
|
|
|
@app.get("/task/{task_id}")
|
|
|
|
async def get_task_status(task_id: str):
|
|
|
|
task = tasks.get(task_id)
|
|
|
|
if not task:
|
|
|
|
raise HTTPException(status_code=404, detail="Task not found")
|
|
|
|
|
|
|
|
if task['status'] == 'done':
|
|
|
|
return {
|
|
|
|
"status": task['status'],
|
|
|
|
"results": [result.dict() for result in task['results']]
|
|
|
|
}
|
|
|
|
elif task['status'] == 'failed':
|
|
|
|
return {
|
|
|
|
"status": task['status'],
|
|
|
|
"error": task['error']
|
|
|
|
}
|
|
|
|
else:
|
|
|
|
return {"status": task['status']}
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
import uvicorn
|
|
|
|
uvicorn.run(app, host="0.0.0.0", port=8000)
|