Spaces:
Runtime error
Runtime error
| from functools import lru_cache | |
| from pathlib import Path | |
| import subprocess, os | |
| import shutil | |
| import tarfile | |
| from .model_loader import * | |
| import argparse | |
| import urllib.request | |
| from crawl4ai.config import MODEL_REPO_BRANCH | |
| __location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__))) | |
| def get_available_memory(device): | |
| import torch | |
| if device.type == 'cuda': | |
| return torch.cuda.get_device_properties(device).total_memory | |
| elif device.type == 'mps': | |
| return 48 * 1024 ** 3 # Assuming 8GB for MPS, as a conservative estimate | |
| else: | |
| return 0 | |
| def calculate_batch_size(device): | |
| available_memory = get_available_memory(device) | |
| if device.type == 'cpu': | |
| return 16 | |
| elif device.type in ['cuda', 'mps']: | |
| # Adjust these thresholds based on your model size and available memory | |
| if available_memory >= 31 * 1024 ** 3: # > 32GB | |
| return 256 | |
| elif available_memory >= 15 * 1024 ** 3: # > 16GB to 32GB | |
| return 128 | |
| elif available_memory >= 8 * 1024 ** 3: # 8GB to 16GB | |
| return 64 | |
| else: | |
| return 32 | |
| else: | |
| return 16 # Default batch size | |
| def get_device(): | |
| import torch | |
| if torch.cuda.is_available(): | |
| device = torch.device('cuda') | |
| elif torch.backends.mps.is_available(): | |
| device = torch.device('mps') | |
| else: | |
| device = torch.device('cpu') | |
| return device | |
| def set_model_device(model): | |
| device = get_device() | |
| model.to(device) | |
| return model, device | |
| def get_home_folder(): | |
| home_folder = os.path.join(os.getenv("CRAWL4_AI_BASE_DIRECTORY", Path.home()), ".crawl4ai") | |
| os.makedirs(home_folder, exist_ok=True) | |
| os.makedirs(f"{home_folder}/cache", exist_ok=True) | |
| os.makedirs(f"{home_folder}/models", exist_ok=True) | |
| return home_folder | |
| def load_bert_base_uncased(): | |
| from transformers import BertTokenizer, BertModel, AutoTokenizer, AutoModel | |
| tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', resume_download=None) | |
| model = BertModel.from_pretrained('bert-base-uncased', resume_download=None) | |
| model.eval() | |
| model, device = set_model_device(model) | |
| return tokenizer, model | |
| def load_HF_embedding_model(model_name="BAAI/bge-small-en-v1.5") -> tuple: | |
| """Load the Hugging Face model for embedding. | |
| Args: | |
| model_name (str, optional): The model name to load. Defaults to "BAAI/bge-small-en-v1.5". | |
| Returns: | |
| tuple: The tokenizer and model. | |
| """ | |
| from transformers import BertTokenizer, BertModel, AutoTokenizer, AutoModel | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, resume_download=None) | |
| model = AutoModel.from_pretrained(model_name, resume_download=None) | |
| model.eval() | |
| model, device = set_model_device(model) | |
| return tokenizer, model | |
| def load_text_classifier(): | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| from transformers import pipeline | |
| import torch | |
| tokenizer = AutoTokenizer.from_pretrained("dstefa/roberta-base_topic_classification_nyt_news") | |
| model = AutoModelForSequenceClassification.from_pretrained("dstefa/roberta-base_topic_classification_nyt_news") | |
| model.eval() | |
| model, device = set_model_device(model) | |
| pipe = pipeline("text-classification", model=model, tokenizer=tokenizer) | |
| return pipe | |
| def load_text_multilabel_classifier(): | |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
| import numpy as np | |
| from scipy.special import expit | |
| import torch | |
| # # Check for available device: CUDA, MPS (for Apple Silicon), or CPU | |
| # if torch.cuda.is_available(): | |
| # device = torch.device("cuda") | |
| # elif torch.backends.mps.is_available(): | |
| # device = torch.device("mps") | |
| # else: | |
| # device = torch.device("cpu") | |
| # # return load_spacy_model(), torch.device("cpu") | |
| MODEL = "cardiffnlp/tweet-topic-21-multi" | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL, resume_download=None) | |
| model = AutoModelForSequenceClassification.from_pretrained(MODEL, resume_download=None) | |
| model.eval() | |
| model, device = set_model_device(model) | |
| class_mapping = model.config.id2label | |
| def _classifier(texts, threshold=0.5, max_length=64): | |
| tokens = tokenizer(texts, return_tensors='pt', padding=True, truncation=True, max_length=max_length) | |
| tokens = {key: val.to(device) for key, val in tokens.items()} # Move tokens to the selected device | |
| with torch.no_grad(): | |
| output = model(**tokens) | |
| scores = output.logits.detach().cpu().numpy() | |
| scores = expit(scores) | |
| predictions = (scores >= threshold) * 1 | |
| batch_labels = [] | |
| for prediction in predictions: | |
| labels = [class_mapping[i] for i, value in enumerate(prediction) if value == 1] | |
| batch_labels.append(labels) | |
| return batch_labels | |
| return _classifier, device | |
| def load_nltk_punkt(): | |
| import nltk | |
| try: | |
| nltk.data.find('tokenizers/punkt') | |
| except LookupError: | |
| nltk.download('punkt') | |
| return nltk.data.find('tokenizers/punkt') | |
| def load_spacy_model(): | |
| import spacy | |
| name = "models/reuters" | |
| home_folder = get_home_folder() | |
| model_folder = Path(home_folder) / name | |
| # Check if the model directory already exists | |
| if not (model_folder.exists() and any(model_folder.iterdir())): | |
| repo_url = "https://github.com/unclecode/crawl4ai.git" | |
| branch = MODEL_REPO_BRANCH | |
| repo_folder = Path(home_folder) / "crawl4ai" | |
| print("[LOG] ⏬ Downloading Spacy model for the first time...") | |
| # Remove existing repo folder if it exists | |
| if repo_folder.exists(): | |
| try: | |
| shutil.rmtree(repo_folder) | |
| if model_folder.exists(): | |
| shutil.rmtree(model_folder) | |
| except PermissionError: | |
| print("[WARNING] Unable to remove existing folders. Please manually delete the following folders and try again:") | |
| print(f"- {repo_folder}") | |
| print(f"- {model_folder}") | |
| return None | |
| try: | |
| # Clone the repository | |
| subprocess.run( | |
| ["git", "clone", "-b", branch, repo_url, str(repo_folder)], | |
| stdout=subprocess.DEVNULL, | |
| stderr=subprocess.DEVNULL, | |
| check=True | |
| ) | |
| # Create the models directory if it doesn't exist | |
| models_folder = Path(home_folder) / "models" | |
| models_folder.mkdir(parents=True, exist_ok=True) | |
| # Copy the reuters model folder to the models directory | |
| source_folder = repo_folder / "models" / "reuters" | |
| shutil.copytree(source_folder, model_folder) | |
| # Remove the cloned repository | |
| shutil.rmtree(repo_folder) | |
| print("[LOG] ✅ Spacy Model downloaded successfully") | |
| except subprocess.CalledProcessError as e: | |
| print(f"An error occurred while cloning the repository: {e}") | |
| return None | |
| except Exception as e: | |
| print(f"An error occurred: {e}") | |
| return None | |
| try: | |
| return spacy.load(str(model_folder)) | |
| except Exception as e: | |
| print(f"Error loading spacy model: {e}") | |
| return None | |
| def download_all_models(remove_existing=False): | |
| """Download all models required for Crawl4AI.""" | |
| if remove_existing: | |
| print("[LOG] Removing existing models...") | |
| home_folder = get_home_folder() | |
| model_folders = [ | |
| os.path.join(home_folder, "models/reuters"), | |
| os.path.join(home_folder, "models"), | |
| ] | |
| for folder in model_folders: | |
| if Path(folder).exists(): | |
| shutil.rmtree(folder) | |
| print("[LOG] Existing models removed.") | |
| # Load each model to trigger download | |
| # print("[LOG] Downloading BERT Base Uncased...") | |
| # load_bert_base_uncased() | |
| # print("[LOG] Downloading BGE Small EN v1.5...") | |
| # load_bge_small_en_v1_5() | |
| # print("[LOG] Downloading ONNX model...") | |
| # load_onnx_all_MiniLM_l6_v2() | |
| print("[LOG] Downloading text classifier...") | |
| _, device = load_text_multilabel_classifier() | |
| print(f"[LOG] Text classifier loaded on {device}") | |
| print("[LOG] Downloading custom NLTK Punkt model...") | |
| load_nltk_punkt() | |
| print("[LOG] ✅ All models downloaded successfully.") | |
| def main(): | |
| print("[LOG] Welcome to the Crawl4AI Model Downloader!") | |
| print("[LOG] This script will download all the models required for Crawl4AI.") | |
| parser = argparse.ArgumentParser(description="Crawl4AI Model Downloader") | |
| parser.add_argument('--remove-existing', action='store_true', help="Remove existing models before downloading") | |
| args = parser.parse_args() | |
| download_all_models(remove_existing=args.remove_existing) | |
| if __name__ == "__main__": | |
| main() | |