diff --git a/crawl4ai/train.py b/crawl4ai/train.py deleted file mode 100644 index f7e7c1a..0000000 --- a/crawl4ai/train.py +++ /dev/null @@ -1,146 +0,0 @@ -import spacy -from spacy.training import Example -import random -import nltk -from nltk.corpus import reuters -import torch - -def save_spacy_model_as_torch(nlp, model_dir="models/reuters"): - # Extract the TextCategorizer component - textcat = nlp.get_pipe("textcat_multilabel") - - # Convert the weights to a PyTorch state dictionary - state_dict = {name: torch.tensor(param.data) for name, param in textcat.model.named_parameters()} - - # Save the state dictionary - torch.save(state_dict, f"{model_dir}/model_weights.pth") - - # Extract and save the vocabulary - vocab = extract_vocab(nlp) - with open(f"{model_dir}/vocab.txt", "w") as vocab_file: - for word, idx in vocab.items(): - vocab_file.write(f"{word}\t{idx}\n") - - print(f"Model weights and vocabulary saved to: {model_dir}") - -def extract_vocab(nlp): - # Extract vocabulary from the SpaCy model - vocab = {word: i for i, word in enumerate(nlp.vocab.strings)} - return vocab - -nlp = spacy.load("models/reuters") -save_spacy_model_as_torch(nlp, model_dir="models") - -def train_and_save_reuters_model(model_dir="models/reuters"): - # Ensure the Reuters corpus is downloaded - nltk.download('reuters') - nltk.download('punkt') - if not reuters.fileids(): - print("Reuters corpus not found.") - return - - # Load a blank English spaCy model - nlp = spacy.blank("en") - - # Create a TextCategorizer with the ensemble model for multi-label classification - textcat = nlp.add_pipe("textcat_multilabel") - - # Add labels to text classifier - for label in reuters.categories(): - textcat.add_label(label) - - # Prepare training data - train_examples = [] - for fileid in reuters.fileids(): - categories = reuters.categories(fileid) - text = reuters.raw(fileid) - cats = {label: label in categories for label in reuters.categories()} - # Prepare spacy Example objects - doc = nlp.make_doc(text) - example = Example.from_dict(doc, {'cats': cats}) - train_examples.append(example) - - # Initialize the text categorizer with the example objects - nlp.initialize(lambda: train_examples) - - # Train the model - random.seed(1) - spacy.util.fix_random_seed(1) - for i in range(5): # Adjust iterations for better accuracy - random.shuffle(train_examples) - losses = {} - # Create batches of data - batches = spacy.util.minibatch(train_examples, size=8) - for batch in batches: - nlp.update(batch, drop=0.2, losses=losses) - print(f"Losses at iteration {i}: {losses}") - - # Save the trained model - nlp.to_disk(model_dir) - print(f"Model saved to: {model_dir}") - -def train_model(model_dir, additional_epochs=0): - # Load the model if it exists, otherwise start with a blank model - try: - nlp = spacy.load(model_dir) - print("Model loaded from disk.") - except IOError: - print("No existing model found. Starting with a new model.") - nlp = spacy.blank("en") - textcat = nlp.add_pipe("textcat_multilabel") - for label in reuters.categories(): - textcat.add_label(label) - - # Prepare training data - train_examples = [] - for fileid in reuters.fileids(): - categories = reuters.categories(fileid) - text = reuters.raw(fileid) - cats = {label: label in categories for label in reuters.categories()} - doc = nlp.make_doc(text) - example = Example.from_dict(doc, {'cats': cats}) - train_examples.append(example) - - # Initialize the model if it was newly created - if 'textcat_multilabel' not in nlp.pipe_names: - nlp.initialize(lambda: train_examples) - else: - print("Continuing training with existing model.") - - # Train the model - random.seed(1) - spacy.util.fix_random_seed(1) - num_epochs = 5 + additional_epochs - for i in range(num_epochs): - random.shuffle(train_examples) - losses = {} - batches = spacy.util.minibatch(train_examples, size=8) - for batch in batches: - nlp.update(batch, drop=0.2, losses=losses) - print(f"Losses at iteration {i}: {losses}") - - # Save the trained model - nlp.to_disk(model_dir) - print(f"Model saved to: {model_dir}") - -def load_model_and_predict(model_dir, text, tok_k = 3): - # Load the trained model from the specified directory - nlp = spacy.load(model_dir) - - # Process the text with the loaded model - doc = nlp(text) - - # gee top 3 categories - top_categories = sorted(doc.cats.items(), key=lambda x: x[1], reverse=True)[:tok_k] - print(f"Top {tok_k} categories:") - - return top_categories - -if __name__ == "__main__": - train_and_save_reuters_model() - train_model("models/reuters", additional_epochs=5) - model_directory = "reuters_model_10" - print(reuters.categories()) - example_text = "Apple Inc. is reportedly buying a startup for $1 billion" - r =load_model_and_predict(model_directory, example_text) - print(r) \ No newline at end of file diff --git a/crawl4ai/web_crawler.back.py b/crawl4ai/web_crawler.back.py deleted file mode 100644 index af78f12..0000000 --- a/crawl4ai/web_crawler.back.py +++ /dev/null @@ -1,357 +0,0 @@ -import os, time -os.environ["TOKENIZERS_PARALLELISM"] = "false" -from pathlib import Path - -from .models import UrlModel, CrawlResult -from .database import init_db, get_cached_url, cache_url, DB_PATH, flush_db -from .utils import * -from .chunking_strategy import * -from .extraction_strategy import * -from .crawler_strategy import * -from typing import List -from concurrent.futures import ThreadPoolExecutor -from .config import * - - -class WebCrawler: - def __init__( - self, - # db_path: str = None, - crawler_strategy: CrawlerStrategy = None, - always_by_pass_cache: bool = False, - verbose: bool = False, - ): - # self.db_path = db_path - self.crawler_strategy = crawler_strategy or LocalSeleniumCrawlerStrategy(verbose=verbose) - self.always_by_pass_cache = always_by_pass_cache - - # Create the .crawl4ai folder in the user's home directory if it doesn't exist - self.crawl4ai_folder = os.path.join(Path.home(), ".crawl4ai") - os.makedirs(self.crawl4ai_folder, exist_ok=True) - os.makedirs(f"{self.crawl4ai_folder}/cache", exist_ok=True) - - # If db_path is not provided, use the default path - # if not db_path: - # self.db_path = f"{self.crawl4ai_folder}/crawl4ai.db" - - # flush_db() - init_db() - - self.ready = False - - def warmup(self): - print("[LOG] 🌤️ Warming up the WebCrawler") - result = self.run( - url='https://crawl4ai.uccode.io/', - word_count_threshold=5, - extraction_strategy= NoExtractionStrategy(), - bypass_cache=False, - verbose = False - ) - self.ready = True - print("[LOG] 🌞 WebCrawler is ready to crawl") - - def fetch_page( - self, - url_model: UrlModel, - provider: str = DEFAULT_PROVIDER, - api_token: str = None, - extract_blocks_flag: bool = True, - word_count_threshold=MIN_WORD_THRESHOLD, - css_selector: str = None, - screenshot: bool = False, - use_cached_html: bool = False, - extraction_strategy: ExtractionStrategy = None, - chunking_strategy: ChunkingStrategy = RegexChunking(), - **kwargs, - ) -> CrawlResult: - return self.run( - url_model.url, - word_count_threshold, - extraction_strategy or NoExtractionStrategy(), - chunking_strategy, - bypass_cache=url_model.forced, - css_selector=css_selector, - screenshot=screenshot, - **kwargs, - ) - pass - - def run_old( - self, - url: str, - word_count_threshold=MIN_WORD_THRESHOLD, - extraction_strategy: ExtractionStrategy = None, - chunking_strategy: ChunkingStrategy = RegexChunking(), - bypass_cache: bool = False, - css_selector: str = None, - screenshot: bool = False, - user_agent: str = None, - verbose=True, - **kwargs, - ) -> CrawlResult: - if user_agent: - self.crawler_strategy.update_user_agent(user_agent) - extraction_strategy = extraction_strategy or NoExtractionStrategy() - extraction_strategy.verbose = verbose - # Check if extraction strategy is an instance of ExtractionStrategy if not raise an error - if not isinstance(extraction_strategy, ExtractionStrategy): - raise ValueError("Unsupported extraction strategy") - if not isinstance(chunking_strategy, ChunkingStrategy): - raise ValueError("Unsupported chunking strategy") - - # make sure word_count_threshold is not lesser than MIN_WORD_THRESHOLD - if word_count_threshold < MIN_WORD_THRESHOLD: - word_count_threshold = MIN_WORD_THRESHOLD - - # Check cache first - if not bypass_cache and not self.always_by_pass_cache: - cached = get_cached_url(url) - if cached: - return CrawlResult( - **{ - "url": cached[0], - "html": cached[1], - "cleaned_html": cached[2], - "markdown": cached[3], - "extracted_content": cached[4], - "success": cached[5], - "media": json.loads(cached[6] or "{}"), - "links": json.loads(cached[7] or "{}"), - "metadata": json.loads(cached[8] or "{}"), # "metadata": "{} - "screenshot": cached[9], - "error_message": "", - } - ) - - # Initialize WebDriver for crawling - t = time.time() - if kwargs.get("js", None): - self.crawler_strategy.js_code = kwargs.get("js") - html = self.crawler_strategy.crawl(url) - base64_image = None - if screenshot: - base64_image = self.crawler_strategy.take_screenshot() - success = True - error_message = "" - # Extract content from HTML - try: - result = get_content_of_website(url, html, word_count_threshold, css_selector=css_selector) - metadata = extract_metadata(html) - if result is None: - raise ValueError(f"Failed to extract content from the website: {url}") - except InvalidCSSSelectorError as e: - raise ValueError(str(e)) - - cleaned_html = result.get("cleaned_html", "") - markdown = result.get("markdown", "") - media = result.get("media", []) - links = result.get("links", []) - - # Print a profession LOG style message, show time taken and say crawling is done - if verbose: - print( - f"[LOG] 🚀 Crawling done for {url}, success: {success}, time taken: {time.time() - t} seconds" - ) - - extracted_content = [] - if verbose: - print(f"[LOG] 🔥 Extracting semantic blocks for {url}, Strategy: {extraction_strategy.name}") - t = time.time() - # Split markdown into sections - sections = chunking_strategy.chunk(markdown) - # sections = merge_chunks_based_on_token_threshold(sections, CHUNK_TOKEN_THRESHOLD) - - extracted_content = extraction_strategy.run( - url, sections, - ) - extracted_content = json.dumps(extracted_content) - - if verbose: - print( - f"[LOG] 🚀 Extraction done for {url}, time taken: {time.time() - t} seconds." - ) - - # Cache the result - cleaned_html = beautify_html(cleaned_html) - cache_url( - url, - html, - cleaned_html, - markdown, - extracted_content, - success, - json.dumps(media), - json.dumps(links), - json.dumps(metadata), - screenshot=base64_image, - ) - - return CrawlResult( - url=url, - html=html, - cleaned_html=cleaned_html, - markdown=markdown, - media=media, - links=links, - metadata=metadata, - screenshot=base64_image, - extracted_content=extracted_content, - success=success, - error_message=error_message, - ) - - def fetch_pages( - self, - url_models: List[UrlModel], - provider: str = DEFAULT_PROVIDER, - api_token: str = None, - extract_blocks_flag: bool = True, - word_count_threshold=MIN_WORD_THRESHOLD, - use_cached_html: bool = False, - css_selector: str = None, - screenshot: bool = False, - extraction_strategy: ExtractionStrategy = None, - chunking_strategy: ChunkingStrategy = RegexChunking(), - **kwargs, - ) -> List[CrawlResult]: - extraction_strategy = extraction_strategy or NoExtractionStrategy() - def fetch_page_wrapper(url_model, *args, **kwargs): - return self.fetch_page(url_model, *args, **kwargs) - - with ThreadPoolExecutor() as executor: - results = list( - executor.map( - fetch_page_wrapper, - url_models, - [provider] * len(url_models), - [api_token] * len(url_models), - [extract_blocks_flag] * len(url_models), - [word_count_threshold] * len(url_models), - [css_selector] * len(url_models), - [screenshot] * len(url_models), - [use_cached_html] * len(url_models), - [extraction_strategy] * len(url_models), - [chunking_strategy] * len(url_models), - *[kwargs] * len(url_models), - ) - ) - - return results - - def run( - self, - url: str, - word_count_threshold=MIN_WORD_THRESHOLD, - extraction_strategy: ExtractionStrategy = None, - chunking_strategy: ChunkingStrategy = RegexChunking(), - bypass_cache: bool = False, - css_selector: str = None, - screenshot: bool = False, - user_agent: str = None, - verbose=True, - **kwargs, - ) -> CrawlResult: - extraction_strategy = extraction_strategy or NoExtractionStrategy() - extraction_strategy.verbose = verbose - if not isinstance(extraction_strategy, ExtractionStrategy): - raise ValueError("Unsupported extraction strategy") - if not isinstance(chunking_strategy, ChunkingStrategy): - raise ValueError("Unsupported chunking strategy") - - if word_count_threshold < MIN_WORD_THRESHOLD: - word_count_threshold = MIN_WORD_THRESHOLD - - # Check cache first - cached = None - extracted_content = None - if not bypass_cache and not self.always_by_pass_cache: - cached = get_cached_url(url) - - if cached: - html = cached[1] - extracted_content = cached[2] - if screenshot: - screenshot = cached[9] - - else: - if user_agent: - self.crawler_strategy.update_user_agent(user_agent) - html = self.crawler_strategy.crawl(url) - if screenshot: - screenshot = self.crawler_strategy.take_screenshot() - - return self.process_html(url, html, extracted_content, word_count_threshold, extraction_strategy, chunking_strategy, css_selector, screenshot, verbose, bool(cached), **kwargs) - - def process_html( - self, - url: str, - html: str, - extracted_content: str, - word_count_threshold: int, - extraction_strategy: ExtractionStrategy, - chunking_strategy: ChunkingStrategy, - css_selector: str, - screenshot: bool, - verbose: bool, - is_cached: bool, - **kwargs, - ) -> CrawlResult: - t = time.time() - # Extract content from HTML - try: - result = get_content_of_website(url, html, word_count_threshold, css_selector=css_selector) - metadata = extract_metadata(html) - if result is None: - raise ValueError(f"Failed to extract content from the website: {url}") - except InvalidCSSSelectorError as e: - raise ValueError(str(e)) - - cleaned_html = result.get("cleaned_html", "") - markdown = result.get("markdown", "") - media = result.get("media", []) - links = result.get("links", []) - - if verbose: - print(f"[LOG] 🚀 Crawling done for {url}, success: True, time taken: {time.time() - t} seconds") - - if extracted_content is None: - if verbose: - print(f"[LOG] 🔥 Extracting semantic blocks for {url}, Strategy: {extraction_strategy.name}") - - sections = chunking_strategy.chunk(markdown) - extracted_content = extraction_strategy.run(url, sections) - extracted_content = json.dumps(extracted_content) - - if verbose: - print(f"[LOG] 🚀 Extraction done for {url}, time taken: {time.time() - t} seconds.") - - screenshot = None if not screenshot else screenshot - - if not is_cached: - cache_url( - url, - html, - cleaned_html, - markdown, - extracted_content, - True, - json.dumps(media), - json.dumps(links), - json.dumps(metadata), - screenshot=screenshot, - ) - - return CrawlResult( - url=url, - html=html, - cleaned_html=cleaned_html, - markdown=markdown, - media=media, - links=links, - metadata=metadata, - screenshot=screenshot, - extracted_content=extracted_content, - success=True, - error_message="", - ) \ No newline at end of file