523 lines
17 KiB
Python
523 lines
17 KiB
Python
import os
|
|
import json
|
|
import asyncio
|
|
from typing import List, Tuple
|
|
from functools import partial
|
|
|
|
import logging
|
|
from typing import Optional, AsyncGenerator
|
|
from urllib.parse import unquote
|
|
from fastapi import HTTPException, Request, status
|
|
from fastapi.background import BackgroundTasks
|
|
from fastapi.responses import JSONResponse
|
|
from redis import asyncio as aioredis
|
|
|
|
from crawl4ai import (
|
|
AsyncWebCrawler,
|
|
CrawlerRunConfig,
|
|
LLMExtractionStrategy,
|
|
CacheMode,
|
|
BrowserConfig,
|
|
MemoryAdaptiveDispatcher,
|
|
RateLimiter,
|
|
LLMConfig
|
|
)
|
|
from crawl4ai.utils import perform_completion_with_backoff
|
|
from crawl4ai.content_filter_strategy import (
|
|
PruningContentFilter,
|
|
BM25ContentFilter,
|
|
LLMContentFilter
|
|
)
|
|
from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator
|
|
from crawl4ai.content_scraping_strategy import LXMLWebScrapingStrategy
|
|
|
|
from utils import (
|
|
TaskStatus,
|
|
FilterType,
|
|
get_base_url,
|
|
is_task_id,
|
|
should_cleanup_task,
|
|
decode_redis_hash
|
|
)
|
|
|
|
import psutil, time
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# --- Helper to get memory ---
|
|
def _get_memory_mb():
|
|
try:
|
|
return psutil.Process().memory_info().rss / (1024 * 1024)
|
|
except Exception as e:
|
|
logger.warning(f"Could not get memory info: {e}")
|
|
return None
|
|
|
|
|
|
async def handle_llm_qa(
|
|
url: str,
|
|
query: str,
|
|
config: dict
|
|
) -> str:
|
|
"""Process QA using LLM with crawled content as context."""
|
|
try:
|
|
if not url.startswith(('http://', 'https://')):
|
|
url = 'https://' + url
|
|
# Extract base URL by finding last '?q=' occurrence
|
|
last_q_index = url.rfind('?q=')
|
|
if last_q_index != -1:
|
|
url = url[:last_q_index]
|
|
|
|
# Get markdown content
|
|
async with AsyncWebCrawler() as crawler:
|
|
result = await crawler.arun(url)
|
|
if not result.success:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=result.error_message
|
|
)
|
|
content = result.markdown.fit_markdown or result.markdown.raw_markdown
|
|
|
|
# Create prompt and get LLM response
|
|
prompt = f"""Use the following content as context to answer the question.
|
|
Content:
|
|
{content}
|
|
|
|
Question: {query}
|
|
|
|
Answer:"""
|
|
|
|
response = perform_completion_with_backoff(
|
|
provider=config["llm"]["provider"],
|
|
prompt_with_variables=prompt,
|
|
api_token=os.environ.get(config["llm"].get("api_key_env", ""))
|
|
)
|
|
|
|
return response.choices[0].message.content
|
|
except Exception as e:
|
|
logger.error(f"QA processing error: {str(e)}", exc_info=True)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=str(e)
|
|
)
|
|
|
|
async def process_llm_extraction(
|
|
redis: aioredis.Redis,
|
|
config: dict,
|
|
task_id: str,
|
|
url: str,
|
|
instruction: str,
|
|
schema: Optional[str] = None,
|
|
cache: str = "0"
|
|
) -> None:
|
|
"""Process LLM extraction in background."""
|
|
try:
|
|
# If config['llm'] has api_key then ignore the api_key_env
|
|
api_key = ""
|
|
if "api_key" in config["llm"]:
|
|
api_key = config["llm"]["api_key"]
|
|
else:
|
|
api_key = os.environ.get(config["llm"].get("api_key_env", None), "")
|
|
llm_strategy = LLMExtractionStrategy(
|
|
llm_config=LLMConfig(
|
|
provider=config["llm"]["provider"],
|
|
api_token=api_key
|
|
),
|
|
instruction=instruction,
|
|
schema=json.loads(schema) if schema else None,
|
|
)
|
|
|
|
cache_mode = CacheMode.ENABLED if cache == "1" else CacheMode.WRITE_ONLY
|
|
|
|
async with AsyncWebCrawler() as crawler:
|
|
result = await crawler.arun(
|
|
url=url,
|
|
config=CrawlerRunConfig(
|
|
extraction_strategy=llm_strategy,
|
|
scraping_strategy=LXMLWebScrapingStrategy(),
|
|
cache_mode=cache_mode
|
|
)
|
|
)
|
|
|
|
if not result.success:
|
|
await redis.hset(f"task:{task_id}", mapping={
|
|
"status": TaskStatus.FAILED,
|
|
"error": result.error_message
|
|
})
|
|
return
|
|
|
|
try:
|
|
content = json.loads(result.extracted_content)
|
|
except json.JSONDecodeError:
|
|
content = result.extracted_content
|
|
await redis.hset(f"task:{task_id}", mapping={
|
|
"status": TaskStatus.COMPLETED,
|
|
"result": json.dumps(content)
|
|
})
|
|
|
|
except Exception as e:
|
|
logger.error(f"LLM extraction error: {str(e)}", exc_info=True)
|
|
await redis.hset(f"task:{task_id}", mapping={
|
|
"status": TaskStatus.FAILED,
|
|
"error": str(e)
|
|
})
|
|
|
|
async def handle_markdown_request(
|
|
url: str,
|
|
filter_type: FilterType,
|
|
query: Optional[str] = None,
|
|
cache: str = "0",
|
|
config: Optional[dict] = None
|
|
) -> str:
|
|
"""Handle markdown generation requests."""
|
|
try:
|
|
decoded_url = unquote(url)
|
|
if not decoded_url.startswith(('http://', 'https://')):
|
|
decoded_url = 'https://' + decoded_url
|
|
|
|
if filter_type == FilterType.RAW:
|
|
md_generator = DefaultMarkdownGenerator()
|
|
else:
|
|
content_filter = {
|
|
FilterType.FIT: PruningContentFilter(),
|
|
FilterType.BM25: BM25ContentFilter(user_query=query or ""),
|
|
FilterType.LLM: LLMContentFilter(
|
|
llm_config=LLMConfig(
|
|
provider=config["llm"]["provider"],
|
|
api_token=os.environ.get(config["llm"].get("api_key_env", None), ""),
|
|
),
|
|
instruction=query or "Extract main content"
|
|
)
|
|
}[filter_type]
|
|
md_generator = DefaultMarkdownGenerator(content_filter=content_filter)
|
|
|
|
cache_mode = CacheMode.ENABLED if cache == "1" else CacheMode.WRITE_ONLY
|
|
|
|
async with AsyncWebCrawler() as crawler:
|
|
result = await crawler.arun(
|
|
url=decoded_url,
|
|
config=CrawlerRunConfig(
|
|
markdown_generator=md_generator,
|
|
scraping_strategy=LXMLWebScrapingStrategy(),
|
|
cache_mode=cache_mode
|
|
)
|
|
)
|
|
|
|
if not result.success:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=result.error_message
|
|
)
|
|
|
|
return (result.markdown.raw_markdown
|
|
if filter_type == FilterType.RAW
|
|
else result.markdown.fit_markdown)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Markdown error: {str(e)}", exc_info=True)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=str(e)
|
|
)
|
|
|
|
async def handle_llm_request(
|
|
redis: aioredis.Redis,
|
|
background_tasks: BackgroundTasks,
|
|
request: Request,
|
|
input_path: str,
|
|
query: Optional[str] = None,
|
|
schema: Optional[str] = None,
|
|
cache: str = "0",
|
|
config: Optional[dict] = None
|
|
) -> JSONResponse:
|
|
"""Handle LLM extraction requests."""
|
|
base_url = get_base_url(request)
|
|
|
|
try:
|
|
if is_task_id(input_path):
|
|
return await handle_task_status(
|
|
redis, input_path, base_url
|
|
)
|
|
|
|
if not query:
|
|
return JSONResponse({
|
|
"message": "Please provide an instruction",
|
|
"_links": {
|
|
"example": {
|
|
"href": f"{base_url}/llm/{input_path}?q=Extract+main+content",
|
|
"title": "Try this example"
|
|
}
|
|
}
|
|
})
|
|
|
|
return await create_new_task(
|
|
redis,
|
|
background_tasks,
|
|
input_path,
|
|
query,
|
|
schema,
|
|
cache,
|
|
base_url,
|
|
config
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"LLM endpoint error: {str(e)}", exc_info=True)
|
|
return JSONResponse({
|
|
"error": str(e),
|
|
"_links": {
|
|
"retry": {"href": str(request.url)}
|
|
}
|
|
}, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
|
|
|
async def handle_task_status(
|
|
redis: aioredis.Redis,
|
|
task_id: str,
|
|
base_url: str
|
|
) -> JSONResponse:
|
|
"""Handle task status check requests."""
|
|
task = await redis.hgetall(f"task:{task_id}")
|
|
if not task:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="Task not found"
|
|
)
|
|
|
|
task = decode_redis_hash(task)
|
|
response = create_task_response(task, task_id, base_url)
|
|
|
|
if task["status"] in [TaskStatus.COMPLETED, TaskStatus.FAILED]:
|
|
if should_cleanup_task(task["created_at"]):
|
|
await redis.delete(f"task:{task_id}")
|
|
|
|
return JSONResponse(response)
|
|
|
|
async def create_new_task(
|
|
redis: aioredis.Redis,
|
|
background_tasks: BackgroundTasks,
|
|
input_path: str,
|
|
query: str,
|
|
schema: Optional[str],
|
|
cache: str,
|
|
base_url: str,
|
|
config: dict
|
|
) -> JSONResponse:
|
|
"""Create and initialize a new task."""
|
|
decoded_url = unquote(input_path)
|
|
if not decoded_url.startswith(('http://', 'https://')):
|
|
decoded_url = 'https://' + decoded_url
|
|
|
|
from datetime import datetime
|
|
task_id = f"llm_{int(datetime.now().timestamp())}_{id(background_tasks)}"
|
|
|
|
await redis.hset(f"task:{task_id}", mapping={
|
|
"status": TaskStatus.PROCESSING,
|
|
"created_at": datetime.now().isoformat(),
|
|
"url": decoded_url
|
|
})
|
|
|
|
background_tasks.add_task(
|
|
process_llm_extraction,
|
|
redis,
|
|
config,
|
|
task_id,
|
|
decoded_url,
|
|
query,
|
|
schema,
|
|
cache
|
|
)
|
|
|
|
return JSONResponse({
|
|
"task_id": task_id,
|
|
"status": TaskStatus.PROCESSING,
|
|
"url": decoded_url,
|
|
"_links": {
|
|
"self": {"href": f"{base_url}/llm/{task_id}"},
|
|
"status": {"href": f"{base_url}/llm/{task_id}"}
|
|
}
|
|
})
|
|
|
|
def create_task_response(task: dict, task_id: str, base_url: str) -> dict:
|
|
"""Create response for task status check."""
|
|
response = {
|
|
"task_id": task_id,
|
|
"status": task["status"],
|
|
"created_at": task["created_at"],
|
|
"url": task["url"],
|
|
"_links": {
|
|
"self": {"href": f"{base_url}/llm/{task_id}"},
|
|
"refresh": {"href": f"{base_url}/llm/{task_id}"}
|
|
}
|
|
}
|
|
|
|
if task["status"] == TaskStatus.COMPLETED:
|
|
response["result"] = json.loads(task["result"])
|
|
elif task["status"] == TaskStatus.FAILED:
|
|
response["error"] = task["error"]
|
|
|
|
return response
|
|
|
|
async def stream_results(crawler: AsyncWebCrawler, results_gen: AsyncGenerator) -> AsyncGenerator[bytes, None]:
|
|
"""Stream results with heartbeats and completion markers."""
|
|
import json
|
|
from utils import datetime_handler
|
|
|
|
try:
|
|
async for result in results_gen:
|
|
try:
|
|
server_memory_mb = _get_memory_mb()
|
|
result_dict = result.model_dump()
|
|
result_dict['server_memory_mb'] = server_memory_mb
|
|
logger.info(f"Streaming result for {result_dict.get('url', 'unknown')}")
|
|
data = json.dumps(result_dict, default=datetime_handler) + "\n"
|
|
yield data.encode('utf-8')
|
|
except Exception as e:
|
|
logger.error(f"Serialization error: {e}")
|
|
error_response = {"error": str(e), "url": getattr(result, 'url', 'unknown')}
|
|
yield (json.dumps(error_response) + "\n").encode('utf-8')
|
|
|
|
yield json.dumps({"status": "completed"}).encode('utf-8')
|
|
|
|
except asyncio.CancelledError:
|
|
logger.warning("Client disconnected during streaming")
|
|
finally:
|
|
# try:
|
|
# await crawler.close()
|
|
# except Exception as e:
|
|
# logger.error(f"Crawler cleanup error: {e}")
|
|
pass
|
|
|
|
async def handle_crawl_request(
|
|
urls: List[str],
|
|
browser_config: dict,
|
|
crawler_config: dict,
|
|
config: dict
|
|
) -> dict:
|
|
"""Handle non-streaming crawl requests."""
|
|
start_mem_mb = _get_memory_mb() # <--- Get memory before
|
|
start_time = time.time()
|
|
mem_delta_mb = None
|
|
peak_mem_mb = start_mem_mb
|
|
|
|
try:
|
|
urls = [('https://' + url) if not url.startswith(('http://', 'https://')) else url for url in urls]
|
|
browser_config = BrowserConfig.load(browser_config)
|
|
crawler_config = CrawlerRunConfig.load(crawler_config)
|
|
|
|
dispatcher = MemoryAdaptiveDispatcher(
|
|
memory_threshold_percent=config["crawler"]["memory_threshold_percent"],
|
|
rate_limiter=RateLimiter(
|
|
base_delay=tuple(config["crawler"]["rate_limiter"]["base_delay"])
|
|
) if config["crawler"]["rate_limiter"]["enabled"] else None
|
|
)
|
|
|
|
from crawler_pool import get_crawler
|
|
crawler = await get_crawler(browser_config)
|
|
|
|
# crawler: AsyncWebCrawler = AsyncWebCrawler(config=browser_config)
|
|
# await crawler.start()
|
|
|
|
base_config = config["crawler"]["base_config"]
|
|
# Iterate on key-value pairs in global_config then use haseattr to set them
|
|
for key, value in base_config.items():
|
|
if hasattr(crawler_config, key):
|
|
setattr(crawler_config, key, value)
|
|
|
|
results = []
|
|
func = getattr(crawler, "arun" if len(urls) == 1 else "arun_many")
|
|
partial_func = partial(func,
|
|
urls[0] if len(urls) == 1 else urls,
|
|
config=crawler_config,
|
|
dispatcher=dispatcher)
|
|
results = await partial_func()
|
|
|
|
# await crawler.close()
|
|
|
|
end_mem_mb = _get_memory_mb() # <--- Get memory after
|
|
end_time = time.time()
|
|
|
|
if start_mem_mb is not None and end_mem_mb is not None:
|
|
mem_delta_mb = end_mem_mb - start_mem_mb # <--- Calculate delta
|
|
peak_mem_mb = max(peak_mem_mb if peak_mem_mb else 0, end_mem_mb) # <--- Get peak memory
|
|
logger.info(f"Memory usage: Start: {start_mem_mb} MB, End: {end_mem_mb} MB, Delta: {mem_delta_mb} MB, Peak: {peak_mem_mb} MB")
|
|
|
|
return {
|
|
"success": True,
|
|
"results": [result.model_dump() for result in results],
|
|
"server_processing_time_s": end_time - start_time,
|
|
"server_memory_delta_mb": mem_delta_mb,
|
|
"server_peak_memory_mb": peak_mem_mb
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Crawl error: {str(e)}", exc_info=True)
|
|
if 'crawler' in locals() and crawler.ready: # Check if crawler was initialized and started
|
|
# try:
|
|
# await crawler.close()
|
|
# except Exception as close_e:
|
|
# logger.error(f"Error closing crawler during exception handling: {close_e}")
|
|
logger.error(f"Error closing crawler during exception handling: {close_e}")
|
|
|
|
# Measure memory even on error if possible
|
|
end_mem_mb_error = _get_memory_mb()
|
|
if start_mem_mb is not None and end_mem_mb_error is not None:
|
|
mem_delta_mb = end_mem_mb_error - start_mem_mb
|
|
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=json.dumps({ # Send structured error
|
|
"error": str(e),
|
|
"server_memory_delta_mb": mem_delta_mb,
|
|
"server_peak_memory_mb": max(peak_mem_mb if peak_mem_mb else 0, end_mem_mb_error or 0)
|
|
})
|
|
)
|
|
|
|
async def handle_stream_crawl_request(
|
|
urls: List[str],
|
|
browser_config: dict,
|
|
crawler_config: dict,
|
|
config: dict
|
|
) -> Tuple[AsyncWebCrawler, AsyncGenerator]:
|
|
"""Handle streaming crawl requests."""
|
|
try:
|
|
browser_config = BrowserConfig.load(browser_config)
|
|
# browser_config.verbose = True # Set to False or remove for production stress testing
|
|
browser_config.verbose = False
|
|
crawler_config = CrawlerRunConfig.load(crawler_config)
|
|
crawler_config.scraping_strategy = LXMLWebScrapingStrategy()
|
|
crawler_config.stream = True
|
|
|
|
dispatcher = MemoryAdaptiveDispatcher(
|
|
memory_threshold_percent=config["crawler"]["memory_threshold_percent"],
|
|
rate_limiter=RateLimiter(
|
|
base_delay=tuple(config["crawler"]["rate_limiter"]["base_delay"])
|
|
)
|
|
)
|
|
|
|
from crawler_pool import get_crawler
|
|
crawler = await get_crawler(browser_config)
|
|
|
|
# crawler = AsyncWebCrawler(config=browser_config)
|
|
# await crawler.start()
|
|
|
|
results_gen = await crawler.arun_many(
|
|
urls=urls,
|
|
config=crawler_config,
|
|
dispatcher=dispatcher
|
|
)
|
|
|
|
return crawler, results_gen
|
|
|
|
except Exception as e:
|
|
# Make sure to close crawler if started during an error here
|
|
if 'crawler' in locals() and crawler.ready:
|
|
# try:
|
|
# await crawler.close()
|
|
# except Exception as close_e:
|
|
# logger.error(f"Error closing crawler during stream setup exception: {close_e}")
|
|
logger.error(f"Error closing crawler during stream setup exception: {close_e}")
|
|
logger.error(f"Stream crawl error: {str(e)}", exc_info=True)
|
|
# Raising HTTPException here will prevent streaming response
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=str(e)
|
|
) |