DoctorChaos's picture
Upload spa_hf.py with huggingface_hub
a2a2f0e verified
import torch
import torch.nn as nn
from huggingface_hub import PyTorchModelHubMixin
from transformers import AutoModel, AutoTokenizer
# Import core SPA functionality
from spa import SPALogitsProcessor, spa_tokenize, preprocess_anchors, create_default_attention_mask
class SPAModel(nn.Module, PyTorchModelHubMixin):
"""
Selective Prompt Anchoring (SPA) model with Hugging Face Hub integration.
This model wraps a base LLM and provides the SPA functionality with
the ability to be shared and downloaded from the Hugging Face Hub.
"""
def __init__(
self,
base_model_name="Qwen/Qwen3-0.6B",
anchoring_strength=2,
modulated_by_prob=True,
use_attention_mask=True,
device_map="auto",
**kwargs
):
super().__init__()
# Store configuration parameters
self.base_model_name = base_model_name
self.anchoring_strength = anchoring_strength
self.modulated_by_prob = modulated_by_prob
self.use_attention_mask = use_attention_mask
self.device_map = device_map
# Load the base model and tokenizer - using AutoModel to handle any model type
self.model = AutoModel.from_pretrained(base_model_name, device_map=device_map, **kwargs)
self.tokenizer = AutoTokenizer.from_pretrained(base_model_name)
# Set default pad token if needed
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
if hasattr(self.model, "config"):
self.model.config.pad_token_id = self.model.config.eos_token_id
# Determine device
if hasattr(self.model, "device"):
self.device = self.model.device
else:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def forward(self, input_ids, attention_mask=None, **kwargs):
"""Pass through to the base model's forward method"""
return self.model(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
def generate_with_spa(
self,
prompt,
anchors=None,
anchoring_strength=None,
modulated_by_prob=None,
use_attention_mask=None,
max_new_tokens=100,
min_new_tokens=1,
do_sample=True,
temperature=0.7,
top_p=0.95,
top_k=50,
stream=False,
**kwargs
):
"""
Generate text using Selective Prompt Anchoring.
Args:
prompt: Text or messages to generate from
anchors: List of anchor strings to influence generation
anchoring_strength: How much to weight the anchored version
modulated_by_prob: Whether to modulate strength by token probability
use_attention_mask: Whether to use attention masking for anchor tokens
max_new_tokens: Maximum number of tokens to generate
min_new_tokens: Minimum number of tokens to generate
do_sample: Whether to use sampling for generation
temperature: Sampling temperature
top_p: Top-p sampling parameter
top_k: Top-k sampling parameter
stream: Whether to stream the output
Returns:
Generated text (or streamer if stream=True)
"""
# Use instance defaults if parameters are not provided
anchoring_strength = anchoring_strength or self.anchoring_strength
modulated_by_prob = modulated_by_prob if modulated_by_prob is not None else self.modulated_by_prob
use_attention_mask = use_attention_mask if use_attention_mask is not None else self.use_attention_mask
# Default to empty list if anchors not provided
if anchors is None:
anchors = []
# Preprocess anchors
anchors = preprocess_anchors(anchors)
# Tokenize with SPA
main_inputs, aux_inputs, mask_token = spa_tokenize(
prompt_with_anchors=prompt,
global_anchors=anchors,
tokenizer=self.tokenizer,
device=self.device
)
# Create SPA logits processor
spa_processor = SPALogitsProcessor(
aux_model=self.model,
aux_input_ids=aux_inputs,
strength=anchoring_strength,
modulated_by_prob=modulated_by_prob,
use_attention_mask=use_attention_mask,
mask_token=mask_token,
tokenizer=self.tokenizer
)
# Get attention mask
attention_mask = create_default_attention_mask(main_inputs, device=self.device)
# Set up generation kwargs
generation_kwargs = {
"input_ids": main_inputs,
"attention_mask": attention_mask,
"logits_processor": [spa_processor],
"min_new_tokens": min_new_tokens,
"max_new_tokens": max_new_tokens,
"do_sample": do_sample,
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
**kwargs
}
if stream:
from transformers import TextIteratorStreamer
import threading
# Set up streamer
streamer = TextIteratorStreamer(
self.tokenizer,
skip_special_tokens=True,
skip_prompt=True
)
generation_kwargs["streamer"] = streamer
# Start generation in a separate thread
generation_thread = threading.Thread(
target=self.model.generate,
kwargs=generation_kwargs
)
generation_thread.start()
# Return streamer for token-by-token output
return streamer
else:
# Normal generation (non-streaming)
output_sequences = self.model.generate(**generation_kwargs)
# Decode the output
input_length = main_inputs.shape[1]
new_tokens = output_sequences[0][input_length:]
generated_text = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
return generated_text
# Create a helper function to load models directly from hub
def load_spa_model(
model_name="magic-yuantian/selective-prompt-anchoring",
base_model_name="meta-llama/Llama-3.1-8B-Instruct",
**kwargs
):
"""
Load a SPAModel from the Hugging Face Hub or create a new one.
Args:
model_name: Name or path of the SPA model in the Hub
base_model_name: The base model to use (if creating a new model)
**kwargs: Additional arguments to pass to from_pretrained or __init__
Returns:
A SPAModel instance
"""
try:
# Try to load from hub
model = SPAModel.from_pretrained(model_name, **kwargs)
return model
except Exception as e:
print(f"Error loading model from hub: {e}")
print(f"Creating a new SPAModel with base model {base_model_name}")
# Create a new model
model = SPAModel(base_model_name=base_model_name, **kwargs)
return model