300 lines
9.6 KiB
Python
300 lines
9.6 KiB
Python
import requests
|
|
import json
|
|
import time
|
|
import sys
|
|
import base64
|
|
import os
|
|
from typing import Dict, Any
|
|
|
|
|
|
class Crawl4AiTester:
|
|
def __init__(self, base_url: str = "http://localhost:11235"):
|
|
self.base_url = base_url
|
|
|
|
def submit_and_wait(
|
|
self, request_data: Dict[str, Any], timeout: int = 300
|
|
) -> Dict[str, Any]:
|
|
# Submit crawl job
|
|
response = requests.post(f"{self.base_url}/crawl", json=request_data)
|
|
task_id = response.json()["task_id"]
|
|
print(f"Task ID: {task_id}")
|
|
|
|
# Poll for result
|
|
start_time = time.time()
|
|
while True:
|
|
if time.time() - start_time > timeout:
|
|
raise TimeoutError(
|
|
f"Task {task_id} did not complete within {timeout} seconds"
|
|
)
|
|
|
|
result = requests.get(f"{self.base_url}/task/{task_id}")
|
|
status = result.json()
|
|
|
|
if status["status"] == "failed":
|
|
print("Task failed:", status.get("error"))
|
|
raise Exception(f"Task failed: {status.get('error')}")
|
|
|
|
if status["status"] == "completed":
|
|
return status
|
|
|
|
time.sleep(2)
|
|
|
|
|
|
def test_docker_deployment(version="basic"):
|
|
tester = Crawl4AiTester()
|
|
print(f"Testing Crawl4AI Docker {version} version")
|
|
|
|
# Health check with timeout and retry
|
|
max_retries = 5
|
|
for i in range(max_retries):
|
|
try:
|
|
health = requests.get(f"{tester.base_url}/health", timeout=10)
|
|
print("Health check:", health.json())
|
|
break
|
|
except requests.exceptions.RequestException:
|
|
if i == max_retries - 1:
|
|
print(f"Failed to connect after {max_retries} attempts")
|
|
sys.exit(1)
|
|
print(f"Waiting for service to start (attempt {i+1}/{max_retries})...")
|
|
time.sleep(5)
|
|
|
|
# Test cases based on version
|
|
test_basic_crawl(tester)
|
|
|
|
# if version in ["full", "transformer"]:
|
|
# test_cosine_extraction(tester)
|
|
|
|
# test_js_execution(tester)
|
|
# test_css_selector(tester)
|
|
# test_structured_extraction(tester)
|
|
# test_llm_extraction(tester)
|
|
# test_llm_with_ollama(tester)
|
|
# test_screenshot(tester)
|
|
|
|
|
|
def test_basic_crawl(tester: Crawl4AiTester):
|
|
print("\n=== Testing Basic Crawl ===")
|
|
request = {"urls": "https://www.nbcnews.com/business", "priority": 10}
|
|
|
|
result = tester.submit_and_wait(request)
|
|
print(f"Basic crawl result length: {len(result['result']['markdown'])}")
|
|
assert result["result"]["success"]
|
|
assert len(result["result"]["markdown"]) > 0
|
|
|
|
|
|
def test_js_execution(tester: Crawl4AiTester):
|
|
print("\n=== Testing JS Execution ===")
|
|
request = {
|
|
"urls": "https://www.nbcnews.com/business",
|
|
"priority": 8,
|
|
"js_code": [
|
|
"const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();"
|
|
],
|
|
"wait_for": "article.tease-card:nth-child(10)",
|
|
"crawler_params": {"headless": True},
|
|
}
|
|
|
|
result = tester.submit_and_wait(request)
|
|
print(f"JS execution result length: {len(result['result']['markdown'])}")
|
|
assert result["result"]["success"]
|
|
|
|
|
|
def test_css_selector(tester: Crawl4AiTester):
|
|
print("\n=== Testing CSS Selector ===")
|
|
request = {
|
|
"urls": "https://www.nbcnews.com/business",
|
|
"priority": 7,
|
|
"css_selector": ".wide-tease-item__description",
|
|
"crawler_params": {"headless": True},
|
|
"extra": {"word_count_threshold": 10},
|
|
}
|
|
|
|
result = tester.submit_and_wait(request)
|
|
print(f"CSS selector result length: {len(result['result']['markdown'])}")
|
|
assert result["result"]["success"]
|
|
|
|
|
|
def test_structured_extraction(tester: Crawl4AiTester):
|
|
print("\n=== Testing Structured Extraction ===")
|
|
schema = {
|
|
"name": "Coinbase Crypto Prices",
|
|
"baseSelector": ".cds-tableRow-t45thuk",
|
|
"fields": [
|
|
{
|
|
"name": "crypto",
|
|
"selector": "td:nth-child(1) h2",
|
|
"type": "text",
|
|
},
|
|
{
|
|
"name": "symbol",
|
|
"selector": "td:nth-child(1) p",
|
|
"type": "text",
|
|
},
|
|
{
|
|
"name": "price",
|
|
"selector": "td:nth-child(2)",
|
|
"type": "text",
|
|
},
|
|
],
|
|
}
|
|
|
|
request = {
|
|
"urls": "https://www.coinbase.com/explore",
|
|
"priority": 9,
|
|
"extraction_config": {"type": "json_css", "params": {"schema": schema}},
|
|
}
|
|
|
|
result = tester.submit_and_wait(request)
|
|
extracted = json.loads(result["result"]["extracted_content"])
|
|
print(f"Extracted {len(extracted)} items")
|
|
print("Sample item:", json.dumps(extracted[0], indent=2))
|
|
assert result["result"]["success"]
|
|
assert len(extracted) > 0
|
|
|
|
|
|
def test_llm_extraction(tester: Crawl4AiTester):
|
|
print("\n=== Testing LLM Extraction ===")
|
|
schema = {
|
|
"type": "object",
|
|
"properties": {
|
|
"model_name": {
|
|
"type": "string",
|
|
"description": "Name of the OpenAI model.",
|
|
},
|
|
"input_fee": {
|
|
"type": "string",
|
|
"description": "Fee for input token for the OpenAI model.",
|
|
},
|
|
"output_fee": {
|
|
"type": "string",
|
|
"description": "Fee for output token for the OpenAI model.",
|
|
},
|
|
},
|
|
"required": ["model_name", "input_fee", "output_fee"],
|
|
}
|
|
|
|
request = {
|
|
"urls": "https://openai.com/api/pricing",
|
|
"priority": 8,
|
|
"extraction_config": {
|
|
"type": "llm",
|
|
"params": {
|
|
"provider": "openai/gpt-4o-mini",
|
|
"api_token": os.getenv("OPENAI_API_KEY"),
|
|
"schema": schema,
|
|
"extraction_type": "schema",
|
|
"instruction": """From the crawled content, extract all mentioned model names along with their fees for input and output tokens.""",
|
|
},
|
|
},
|
|
"crawler_params": {"word_count_threshold": 1},
|
|
}
|
|
|
|
try:
|
|
result = tester.submit_and_wait(request)
|
|
extracted = json.loads(result["result"]["extracted_content"])
|
|
print(f"Extracted {len(extracted)} model pricing entries")
|
|
print("Sample entry:", json.dumps(extracted[0], indent=2))
|
|
assert result["result"]["success"]
|
|
except Exception as e:
|
|
print(f"LLM extraction test failed (might be due to missing API key): {str(e)}")
|
|
|
|
|
|
def test_llm_with_ollama(tester: Crawl4AiTester):
|
|
print("\n=== Testing LLM with Ollama ===")
|
|
schema = {
|
|
"type": "object",
|
|
"properties": {
|
|
"article_title": {
|
|
"type": "string",
|
|
"description": "The main title of the news article",
|
|
},
|
|
"summary": {
|
|
"type": "string",
|
|
"description": "A brief summary of the article content",
|
|
},
|
|
"main_topics": {
|
|
"type": "array",
|
|
"items": {"type": "string"},
|
|
"description": "Main topics or themes discussed in the article",
|
|
},
|
|
},
|
|
}
|
|
|
|
request = {
|
|
"urls": "https://www.nbcnews.com/business",
|
|
"priority": 8,
|
|
"extraction_config": {
|
|
"type": "llm",
|
|
"params": {
|
|
"provider": "ollama/llama2",
|
|
"schema": schema,
|
|
"extraction_type": "schema",
|
|
"instruction": "Extract the main article information including title, summary, and main topics.",
|
|
},
|
|
},
|
|
"extra": {"word_count_threshold": 1},
|
|
"crawler_params": {"verbose": True},
|
|
}
|
|
|
|
try:
|
|
result = tester.submit_and_wait(request)
|
|
extracted = json.loads(result["result"]["extracted_content"])
|
|
print("Extracted content:", json.dumps(extracted, indent=2))
|
|
assert result["result"]["success"]
|
|
except Exception as e:
|
|
print(f"Ollama extraction test failed: {str(e)}")
|
|
|
|
|
|
def test_cosine_extraction(tester: Crawl4AiTester):
|
|
print("\n=== Testing Cosine Extraction ===")
|
|
request = {
|
|
"urls": "https://www.nbcnews.com/business",
|
|
"priority": 8,
|
|
"extraction_config": {
|
|
"type": "cosine",
|
|
"params": {
|
|
"semantic_filter": "business finance economy",
|
|
"word_count_threshold": 10,
|
|
"max_dist": 0.2,
|
|
"top_k": 3,
|
|
},
|
|
},
|
|
}
|
|
|
|
try:
|
|
result = tester.submit_and_wait(request)
|
|
extracted = json.loads(result["result"]["extracted_content"])
|
|
print(f"Extracted {len(extracted)} text clusters")
|
|
print("First cluster tags:", extracted[0]["tags"])
|
|
assert result["result"]["success"]
|
|
except Exception as e:
|
|
print(f"Cosine extraction test failed: {str(e)}")
|
|
|
|
|
|
def test_screenshot(tester: Crawl4AiTester):
|
|
print("\n=== Testing Screenshot ===")
|
|
request = {
|
|
"urls": "https://www.nbcnews.com/business",
|
|
"priority": 5,
|
|
"screenshot": True,
|
|
"crawler_params": {"headless": True},
|
|
}
|
|
|
|
result = tester.submit_and_wait(request)
|
|
print("Screenshot captured:", bool(result["result"]["screenshot"]))
|
|
|
|
if result["result"]["screenshot"]:
|
|
# Save screenshot
|
|
screenshot_data = base64.b64decode(result["result"]["screenshot"])
|
|
with open("test_screenshot.jpg", "wb") as f:
|
|
f.write(screenshot_data)
|
|
print("Screenshot saved as test_screenshot.jpg")
|
|
|
|
assert result["result"]["success"]
|
|
|
|
|
|
if __name__ == "__main__":
|
|
version = sys.argv[1] if len(sys.argv) > 1 else "basic"
|
|
# version = "full"
|
|
test_docker_deployment(version)
|