442 lines
14 KiB
Python
Raw Normal View History

import os
import json
import asyncio
from typing import List, Tuple
import logging
from typing import Optional, AsyncGenerator
from urllib.parse import unquote
from fastapi import HTTPException, Request, status
from fastapi.background import BackgroundTasks
from fastapi.responses import JSONResponse
from redis import asyncio as aioredis
from crawl4ai import (
AsyncWebCrawler,
CrawlerRunConfig,
LLMExtractionStrategy,
CacheMode,
BrowserConfig,
MemoryAdaptiveDispatcher,
RateLimiter
)
from crawl4ai.utils import perform_completion_with_backoff
from crawl4ai.content_filter_strategy import (
PruningContentFilter,
BM25ContentFilter,
LLMContentFilter
)
from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator
from crawl4ai.content_scraping_strategy import LXMLWebScrapingStrategy
from utils import (
TaskStatus,
FilterType,
get_base_url,
is_task_id,
should_cleanup_task,
decode_redis_hash
)
logger = logging.getLogger(__name__)
async def handle_llm_qa(
url: str,
query: str,
config: dict
) -> str:
"""Process QA using LLM with crawled content as context."""
try:
# Extract base URL by finding last '?q=' occurrence
last_q_index = url.rfind('?q=')
if last_q_index != -1:
url = url[:last_q_index]
# Get markdown content
async with AsyncWebCrawler() as crawler:
result = await crawler.arun(url)
if not result.success:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=result.error_message
)
content = result.markdown_v2.fit_markdown
# Create prompt and get LLM response
prompt = f"""Use the following content as context to answer the question.
Content:
{content}
Question: {query}
Answer:"""
response = perform_completion_with_backoff(
provider=config["llm"]["provider"],
prompt_with_variables=prompt,
api_token=os.environ.get(config["llm"].get("api_key_env", ""))
)
return response.choices[0].message.content
except Exception as e:
logger.error(f"QA processing error: {str(e)}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=str(e)
)
async def process_llm_extraction(
redis: aioredis.Redis,
config: dict,
task_id: str,
url: str,
instruction: str,
schema: Optional[str] = None,
cache: str = "0"
) -> None:
"""Process LLM extraction in background."""
try:
# If config['llm'] has api_key then ignore the api_key_env
api_key = ""
if "api_key" in config["llm"]:
api_key = config["llm"]["api_key"]
else:
api_key = os.environ.get(config["llm"].get("api_key_env", None), "")
llm_strategy = LLMExtractionStrategy(
provider=config["llm"]["provider"],
api_token=api_key,
instruction=instruction,
schema=json.loads(schema) if schema else None,
)
cache_mode = CacheMode.ENABLED if cache == "1" else CacheMode.WRITE_ONLY
async with AsyncWebCrawler() as crawler:
result = await crawler.arun(
url=url,
config=CrawlerRunConfig(
extraction_strategy=llm_strategy,
scraping_strategy=LXMLWebScrapingStrategy(),
cache_mode=cache_mode
)
)
if not result.success:
await redis.hset(f"task:{task_id}", mapping={
"status": TaskStatus.FAILED,
"error": result.error_message
})
return
try:
content = json.loads(result.extracted_content)
except json.JSONDecodeError:
content = result.extracted_content
await redis.hset(f"task:{task_id}", mapping={
"status": TaskStatus.COMPLETED,
"result": json.dumps(content)
})
except Exception as e:
logger.error(f"LLM extraction error: {str(e)}", exc_info=True)
await redis.hset(f"task:{task_id}", mapping={
"status": TaskStatus.FAILED,
"error": str(e)
})
async def handle_markdown_request(
url: str,
filter_type: FilterType,
query: Optional[str] = None,
cache: str = "0",
config: Optional[dict] = None
) -> str:
"""Handle markdown generation requests."""
try:
decoded_url = unquote(url)
if not decoded_url.startswith(('http://', 'https://')):
decoded_url = 'https://' + decoded_url
if filter_type == FilterType.RAW:
md_generator = DefaultMarkdownGenerator()
else:
content_filter = {
FilterType.FIT: PruningContentFilter(),
FilterType.BM25: BM25ContentFilter(user_query=query or ""),
FilterType.LLM: LLMContentFilter(
provider=config["llm"]["provider"],
api_token=os.environ.get(config["llm"].get("api_key_env", None), ""),
instruction=query or "Extract main content"
)
}[filter_type]
md_generator = DefaultMarkdownGenerator(content_filter=content_filter)
cache_mode = CacheMode.ENABLED if cache == "1" else CacheMode.WRITE_ONLY
async with AsyncWebCrawler() as crawler:
result = await crawler.arun(
url=decoded_url,
config=CrawlerRunConfig(
markdown_generator=md_generator,
scraping_strategy=LXMLWebScrapingStrategy(),
cache_mode=cache_mode
)
)
if not result.success:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=result.error_message
)
return (result.markdown_v2.raw_markdown
if filter_type == FilterType.RAW
else result.markdown_v2.fit_markdown)
except Exception as e:
logger.error(f"Markdown error: {str(e)}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=str(e)
)
async def handle_llm_request(
redis: aioredis.Redis,
background_tasks: BackgroundTasks,
request: Request,
input_path: str,
query: Optional[str] = None,
schema: Optional[str] = None,
cache: str = "0",
config: Optional[dict] = None
) -> JSONResponse:
"""Handle LLM extraction requests."""
base_url = get_base_url(request)
try:
if is_task_id(input_path):
return await handle_task_status(
redis, input_path, base_url
)
if not query:
return JSONResponse({
"message": "Please provide an instruction",
"_links": {
"example": {
"href": f"{base_url}/llm/{input_path}?q=Extract+main+content",
"title": "Try this example"
}
}
})
return await create_new_task(
redis,
background_tasks,
input_path,
query,
schema,
cache,
base_url,
config
)
except Exception as e:
logger.error(f"LLM endpoint error: {str(e)}", exc_info=True)
return JSONResponse({
"error": str(e),
"_links": {
"retry": {"href": str(request.url)}
}
}, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
async def handle_task_status(
redis: aioredis.Redis,
task_id: str,
base_url: str
) -> JSONResponse:
"""Handle task status check requests."""
task = await redis.hgetall(f"task:{task_id}")
if not task:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Task not found"
)
task = decode_redis_hash(task)
response = create_task_response(task, task_id, base_url)
if task["status"] in [TaskStatus.COMPLETED, TaskStatus.FAILED]:
if should_cleanup_task(task["created_at"]):
await redis.delete(f"task:{task_id}")
return JSONResponse(response)
async def create_new_task(
redis: aioredis.Redis,
background_tasks: BackgroundTasks,
input_path: str,
query: str,
schema: Optional[str],
cache: str,
base_url: str,
config: dict
) -> JSONResponse:
"""Create and initialize a new task."""
decoded_url = unquote(input_path)
if not decoded_url.startswith(('http://', 'https://')):
decoded_url = 'https://' + decoded_url
from datetime import datetime
task_id = f"llm_{int(datetime.now().timestamp())}_{id(background_tasks)}"
await redis.hset(f"task:{task_id}", mapping={
"status": TaskStatus.PROCESSING,
"created_at": datetime.now().isoformat(),
"url": decoded_url
})
background_tasks.add_task(
process_llm_extraction,
redis,
config,
task_id,
decoded_url,
query,
schema,
cache
)
return JSONResponse({
"task_id": task_id,
"status": TaskStatus.PROCESSING,
"url": decoded_url,
"_links": {
"self": {"href": f"{base_url}/llm/{task_id}"},
"status": {"href": f"{base_url}/llm/{task_id}"}
}
})
def create_task_response(task: dict, task_id: str, base_url: str) -> dict:
"""Create response for task status check."""
response = {
"task_id": task_id,
"status": task["status"],
"created_at": task["created_at"],
"url": task["url"],
"_links": {
"self": {"href": f"{base_url}/llm/{task_id}"},
"refresh": {"href": f"{base_url}/llm/{task_id}"}
}
}
if task["status"] == TaskStatus.COMPLETED:
response["result"] = json.loads(task["result"])
elif task["status"] == TaskStatus.FAILED:
response["error"] = task["error"]
return response
async def stream_results(crawler: AsyncWebCrawler, results_gen: AsyncGenerator) -> AsyncGenerator[bytes, None]:
"""Stream results with heartbeats and completion markers."""
import json
from utils import datetime_handler
try:
async for result in results_gen:
try:
result_dict = result.model_dump()
logger.info(f"Streaming result for {result_dict.get('url', 'unknown')}")
data = json.dumps(result_dict, default=datetime_handler) + "\n"
yield data.encode('utf-8')
except Exception as e:
logger.error(f"Serialization error: {e}")
error_response = {"error": str(e), "url": getattr(result, 'url', 'unknown')}
yield (json.dumps(error_response) + "\n").encode('utf-8')
yield json.dumps({"status": "completed"}).encode('utf-8')
except asyncio.CancelledError:
logger.warning("Client disconnected during streaming")
finally:
try:
await crawler.close()
except Exception as e:
logger.error(f"Crawler cleanup error: {e}")
async def handle_crawl_request(
urls: List[str],
browser_config: dict,
crawler_config: dict,
config: dict
) -> dict:
"""Handle non-streaming crawl requests."""
try:
browser_config = BrowserConfig.load(browser_config)
crawler_config = CrawlerRunConfig.load(crawler_config)
dispatcher = MemoryAdaptiveDispatcher(
memory_threshold_percent=config["crawler"]["memory_threshold_percent"],
rate_limiter=RateLimiter(
base_delay=tuple(config["crawler"]["rate_limiter"]["base_delay"])
)
)
async with AsyncWebCrawler(config=browser_config) as crawler:
results = await crawler.arun_many(
urls=urls,
config=crawler_config,
dispatcher=dispatcher
)
return {
"success": True,
"results": [result.model_dump() for result in results]
}
except Exception as e:
logger.error(f"Crawl error: {str(e)}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=str(e)
)
async def handle_stream_crawl_request(
urls: List[str],
browser_config: dict,
crawler_config: dict,
config: dict
) -> Tuple[AsyncWebCrawler, AsyncGenerator]:
"""Handle streaming crawl requests."""
try:
browser_config = BrowserConfig.load(browser_config)
browser_config.verbose = True
crawler_config = CrawlerRunConfig.load(crawler_config)
crawler_config.scraping_strategy = LXMLWebScrapingStrategy()
dispatcher = MemoryAdaptiveDispatcher(
memory_threshold_percent=config["crawler"]["memory_threshold_percent"],
rate_limiter=RateLimiter(
base_delay=tuple(config["crawler"]["rate_limiter"]["base_delay"])
)
)
crawler = AsyncWebCrawler(config=browser_config)
await crawler.start()
results_gen = await crawler.arun_many(
urls=urls,
config=crawler_config,
dispatcher=dispatcher
)
return crawler, results_gen
except Exception as e:
if 'crawler' in locals():
await crawler.close()
logger.error(f"Stream crawl error: {str(e)}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=str(e)
)