import torch import torch.nn as nn from huggingface_hub import PyTorchModelHubMixin from transformers import AutoModel, AutoTokenizer # Import core SPA functionality from spa import SPALogitsProcessor, spa_tokenize, preprocess_anchors, create_default_attention_mask class SPAModel(nn.Module, PyTorchModelHubMixin): """ Selective Prompt Anchoring (SPA) model with Hugging Face Hub integration. This model wraps a base LLM and provides the SPA functionality with the ability to be shared and downloaded from the Hugging Face Hub. """ def __init__( self, base_model_name="Qwen/Qwen3-0.6B", anchoring_strength=2, modulated_by_prob=True, use_attention_mask=True, device_map="auto", **kwargs ): super().__init__() # Store configuration parameters self.base_model_name = base_model_name self.anchoring_strength = anchoring_strength self.modulated_by_prob = modulated_by_prob self.use_attention_mask = use_attention_mask self.device_map = device_map # Load the base model and tokenizer - using AutoModel to handle any model type self.model = AutoModel.from_pretrained(base_model_name, device_map=device_map, **kwargs) self.tokenizer = AutoTokenizer.from_pretrained(base_model_name) # Set default pad token if needed if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token if hasattr(self.model, "config"): self.model.config.pad_token_id = self.model.config.eos_token_id # Determine device if hasattr(self.model, "device"): self.device = self.model.device else: self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def forward(self, input_ids, attention_mask=None, **kwargs): """Pass through to the base model's forward method""" return self.model(input_ids=input_ids, attention_mask=attention_mask, **kwargs) def generate_with_spa( self, prompt, anchors=None, anchoring_strength=None, modulated_by_prob=None, use_attention_mask=None, max_new_tokens=100, min_new_tokens=1, do_sample=True, temperature=0.7, top_p=0.95, top_k=50, stream=False, **kwargs ): """ Generate text using Selective Prompt Anchoring. Args: prompt: Text or messages to generate from anchors: List of anchor strings to influence generation anchoring_strength: How much to weight the anchored version modulated_by_prob: Whether to modulate strength by token probability use_attention_mask: Whether to use attention masking for anchor tokens max_new_tokens: Maximum number of tokens to generate min_new_tokens: Minimum number of tokens to generate do_sample: Whether to use sampling for generation temperature: Sampling temperature top_p: Top-p sampling parameter top_k: Top-k sampling parameter stream: Whether to stream the output Returns: Generated text (or streamer if stream=True) """ # Use instance defaults if parameters are not provided anchoring_strength = anchoring_strength or self.anchoring_strength modulated_by_prob = modulated_by_prob if modulated_by_prob is not None else self.modulated_by_prob use_attention_mask = use_attention_mask if use_attention_mask is not None else self.use_attention_mask # Default to empty list if anchors not provided if anchors is None: anchors = [] # Preprocess anchors anchors = preprocess_anchors(anchors) # Tokenize with SPA main_inputs, aux_inputs, mask_token = spa_tokenize( prompt_with_anchors=prompt, global_anchors=anchors, tokenizer=self.tokenizer, device=self.device ) # Create SPA logits processor spa_processor = SPALogitsProcessor( aux_model=self.model, aux_input_ids=aux_inputs, strength=anchoring_strength, modulated_by_prob=modulated_by_prob, use_attention_mask=use_attention_mask, mask_token=mask_token, tokenizer=self.tokenizer ) # Get attention mask attention_mask = create_default_attention_mask(main_inputs, device=self.device) # Set up generation kwargs generation_kwargs = { "input_ids": main_inputs, "attention_mask": attention_mask, "logits_processor": [spa_processor], "min_new_tokens": min_new_tokens, "max_new_tokens": max_new_tokens, "do_sample": do_sample, "temperature": temperature, "top_p": top_p, "top_k": top_k, **kwargs } if stream: from transformers import TextIteratorStreamer import threading # Set up streamer streamer = TextIteratorStreamer( self.tokenizer, skip_special_tokens=True, skip_prompt=True ) generation_kwargs["streamer"] = streamer # Start generation in a separate thread generation_thread = threading.Thread( target=self.model.generate, kwargs=generation_kwargs ) generation_thread.start() # Return streamer for token-by-token output return streamer else: # Normal generation (non-streaming) output_sequences = self.model.generate(**generation_kwargs) # Decode the output input_length = main_inputs.shape[1] new_tokens = output_sequences[0][input_length:] generated_text = self.tokenizer.decode(new_tokens, skip_special_tokens=True) return generated_text # Create a helper function to load models directly from hub def load_spa_model( model_name="magic-yuantian/selective-prompt-anchoring", base_model_name="meta-llama/Llama-3.1-8B-Instruct", **kwargs ): """ Load a SPAModel from the Hugging Face Hub or create a new one. Args: model_name: Name or path of the SPA model in the Hub base_model_name: The base model to use (if creating a new model) **kwargs: Additional arguments to pass to from_pretrained or __init__ Returns: A SPAModel instance """ try: # Try to load from hub model = SPAModel.from_pretrained(model_name, **kwargs) return model except Exception as e: print(f"Error loading model from hub: {e}") print(f"Creating a new SPAModel with base model {base_model_name}") # Create a new model model = SPAModel(base_model_name=base_model_name, **kwargs) return model