import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer from huggingface_hub import ModelCard, DatasetCard, model_info, dataset_info import logging from typing import Tuple, Literal import functools import spaces from cachetools import TTLCache from cachetools.func import ttl_cache import time import os import json os.environ['HF_TRANSFER'] = "1" # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Global variables MODEL_NAME = "davanstrien/Smol-Hub-tldr" model = None tokenizer = None device = None CACHE_TTL = 6 * 60 * 60 # 6 hours in seconds CACHE_MAXSIZE = 100 def load_model(): global model, tokenizer, device logger.info("Loading model and tokenizer...") try: device = "cuda" if torch.cuda.is_available() else "cpu" tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True) model = AutoModelForCausalLM.from_pretrained(MODEL_NAME) model = model.to(device) model.eval() return True except Exception as e: logger.error(f"Failed to load model: {e}") return False def get_card_info(hub_id: str, repo_type: str = "auto") -> Tuple[str, str]: """Get card information from a Hugging Face hub_id.""" model_exists = False dataset_exists = False model_text = None dataset_text = None # Handle based on repo type if repo_type == "auto": # Try getting model card try: info = model_info(hub_id) card = ModelCard.load(hub_id) model_exists = True model_text = card.text except Exception as e: logger.debug(f"No model card found for {hub_id}: {e}") # Try getting dataset card try: info = dataset_info(hub_id) card = DatasetCard.load(hub_id) dataset_exists = True dataset_text = card.text except Exception as e: logger.debug(f"No dataset card found for {hub_id}: {e}") elif repo_type == "model": try: info = model_info(hub_id) card = ModelCard.load(hub_id) model_exists = True model_text = card.text except Exception as e: logger.error(f"Failed to get model card for {hub_id}: {e}") raise ValueError(f"Could not find model with id {hub_id}") elif repo_type == "dataset": try: info = dataset_info(hub_id) card = DatasetCard.load(hub_id) dataset_exists = True dataset_text = card.text except Exception as e: logger.error(f"Failed to get dataset card for {hub_id}: {e}") raise ValueError(f"Could not find dataset with id {hub_id}") else: raise ValueError(f"Invalid repo_type: {repo_type}. Must be 'auto', 'model', or 'dataset'") # Handle different cases if model_exists and dataset_exists: return "both", (model_text, dataset_text) elif model_exists: return "model", model_text elif dataset_exists: return "dataset", dataset_text else: raise ValueError(f"Could not find model or dataset with id {hub_id}") @spaces.GPU def _generate_summary_gpu(card_text: str, card_type: str) -> str: """Internal function that runs on GPU.""" # Determine prefix based on card type prefix = "" if card_type == "model" else "" # Format input according to the chat template messages = [{"role": "user", "content": f"{prefix}{card_text[:5000]}"}] inputs = tokenizer.apply_chat_template( messages, add_generation_prompt=True, return_tensors="pt" ) inputs = inputs.to(device) # Generate with optimized settings with torch.no_grad(): outputs = model.generate( inputs, max_new_tokens=60, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, temperature=0.4, do_sample=True, use_cache=True, ) # Extract and clean up the summary input_length = inputs.shape[1] response = tokenizer.decode(outputs[0][input_length:], skip_special_tokens=False) # Extract just the summary part try: summary = response.split("")[-1].split("")[0].strip() except IndexError: summary = response.strip() return summary @ttl_cache(maxsize=CACHE_MAXSIZE, ttl=CACHE_TTL) def generate_summary(card_text: str, card_type: str) -> str: """Cached wrapper for generate_summary with TTL.""" return _generate_summary_gpu(card_text, card_type) def summarize(hub_id: str = "", repo_type: str = "auto") -> str: """Interface function for Gradio. Returns JSON format.""" try: if hub_id: # Fetch card information with specified repo_type card_type, card_text = get_card_info(hub_id, repo_type) if card_type == "both": model_text, dataset_text = card_text model_summary = generate_summary(model_text, "model") dataset_summary = generate_summary(dataset_text, "dataset") return json.dumps({ "type": "both", "hub_id": hub_id, "model_summary": model_summary, "dataset_summary": dataset_summary }) else: summary = generate_summary(card_text, card_type) return json.dumps({ "summary": summary, "type": card_type, "hub_id": hub_id }) else: return json.dumps({"error": "Hub ID must be provided"}) except Exception as e: return json.dumps({"error": str(e)}) def create_interface(): interface = gr.Interface( fn=summarize, inputs=[ gr.Textbox(label="Hub ID", placeholder="e.g., huggingface/llama-7b"), gr.Radio( choices=["auto", "model", "dataset"], value="auto", label="Repository Type", info="Choose 'auto' to detect automatically, or specify the repository type" ) ], outputs=gr.JSON(label="Output"), title="Hugging Face Hub TLDR Generator", description="Generate concise summaries of model and dataset cards from the Hugging Face Hub.", ) return interface if __name__ == "__main__": if load_model(): interface = create_interface() interface.launch() else: print("Failed to load model. Please check the logs for details.")