|
import torch |
|
import torch.nn as nn |
|
from huggingface_hub import PyTorchModelHubMixin |
|
from transformers import AutoModel, AutoTokenizer |
|
|
|
|
|
from spa import SPALogitsProcessor, spa_tokenize, preprocess_anchors, create_default_attention_mask |
|
|
|
class SPAModel(nn.Module, PyTorchModelHubMixin): |
|
""" |
|
Selective Prompt Anchoring (SPA) model with Hugging Face Hub integration. |
|
|
|
This model wraps a base LLM and provides the SPA functionality with |
|
the ability to be shared and downloaded from the Hugging Face Hub. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
base_model_name="Qwen/Qwen3-0.6B", |
|
anchoring_strength=2, |
|
modulated_by_prob=True, |
|
use_attention_mask=True, |
|
device_map="auto", |
|
**kwargs |
|
): |
|
super().__init__() |
|
|
|
|
|
self.base_model_name = base_model_name |
|
self.anchoring_strength = anchoring_strength |
|
self.modulated_by_prob = modulated_by_prob |
|
self.use_attention_mask = use_attention_mask |
|
self.device_map = device_map |
|
|
|
|
|
self.model = AutoModel.from_pretrained(base_model_name, device_map=device_map, **kwargs) |
|
self.tokenizer = AutoTokenizer.from_pretrained(base_model_name) |
|
|
|
|
|
if self.tokenizer.pad_token is None: |
|
self.tokenizer.pad_token = self.tokenizer.eos_token |
|
if hasattr(self.model, "config"): |
|
self.model.config.pad_token_id = self.model.config.eos_token_id |
|
|
|
|
|
if hasattr(self.model, "device"): |
|
self.device = self.model.device |
|
else: |
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
def forward(self, input_ids, attention_mask=None, **kwargs): |
|
"""Pass through to the base model's forward method""" |
|
return self.model(input_ids=input_ids, attention_mask=attention_mask, **kwargs) |
|
|
|
def generate_with_spa( |
|
self, |
|
prompt, |
|
anchors=None, |
|
anchoring_strength=None, |
|
modulated_by_prob=None, |
|
use_attention_mask=None, |
|
max_new_tokens=100, |
|
min_new_tokens=1, |
|
do_sample=True, |
|
temperature=0.7, |
|
top_p=0.95, |
|
top_k=50, |
|
stream=False, |
|
**kwargs |
|
): |
|
""" |
|
Generate text using Selective Prompt Anchoring. |
|
|
|
Args: |
|
prompt: Text or messages to generate from |
|
anchors: List of anchor strings to influence generation |
|
anchoring_strength: How much to weight the anchored version |
|
modulated_by_prob: Whether to modulate strength by token probability |
|
use_attention_mask: Whether to use attention masking for anchor tokens |
|
max_new_tokens: Maximum number of tokens to generate |
|
min_new_tokens: Minimum number of tokens to generate |
|
do_sample: Whether to use sampling for generation |
|
temperature: Sampling temperature |
|
top_p: Top-p sampling parameter |
|
top_k: Top-k sampling parameter |
|
stream: Whether to stream the output |
|
|
|
Returns: |
|
Generated text (or streamer if stream=True) |
|
""" |
|
|
|
anchoring_strength = anchoring_strength or self.anchoring_strength |
|
modulated_by_prob = modulated_by_prob if modulated_by_prob is not None else self.modulated_by_prob |
|
use_attention_mask = use_attention_mask if use_attention_mask is not None else self.use_attention_mask |
|
|
|
|
|
if anchors is None: |
|
anchors = [] |
|
|
|
|
|
anchors = preprocess_anchors(anchors) |
|
|
|
|
|
main_inputs, aux_inputs, mask_token = spa_tokenize( |
|
prompt_with_anchors=prompt, |
|
global_anchors=anchors, |
|
tokenizer=self.tokenizer, |
|
device=self.device |
|
) |
|
|
|
|
|
spa_processor = SPALogitsProcessor( |
|
aux_model=self.model, |
|
aux_input_ids=aux_inputs, |
|
strength=anchoring_strength, |
|
modulated_by_prob=modulated_by_prob, |
|
use_attention_mask=use_attention_mask, |
|
mask_token=mask_token, |
|
tokenizer=self.tokenizer |
|
) |
|
|
|
|
|
attention_mask = create_default_attention_mask(main_inputs, device=self.device) |
|
|
|
|
|
generation_kwargs = { |
|
"input_ids": main_inputs, |
|
"attention_mask": attention_mask, |
|
"logits_processor": [spa_processor], |
|
"min_new_tokens": min_new_tokens, |
|
"max_new_tokens": max_new_tokens, |
|
"do_sample": do_sample, |
|
"temperature": temperature, |
|
"top_p": top_p, |
|
"top_k": top_k, |
|
**kwargs |
|
} |
|
|
|
if stream: |
|
from transformers import TextIteratorStreamer |
|
import threading |
|
|
|
|
|
streamer = TextIteratorStreamer( |
|
self.tokenizer, |
|
skip_special_tokens=True, |
|
skip_prompt=True |
|
) |
|
generation_kwargs["streamer"] = streamer |
|
|
|
|
|
generation_thread = threading.Thread( |
|
target=self.model.generate, |
|
kwargs=generation_kwargs |
|
) |
|
generation_thread.start() |
|
|
|
|
|
return streamer |
|
else: |
|
|
|
output_sequences = self.model.generate(**generation_kwargs) |
|
|
|
|
|
input_length = main_inputs.shape[1] |
|
new_tokens = output_sequences[0][input_length:] |
|
generated_text = self.tokenizer.decode(new_tokens, skip_special_tokens=True) |
|
|
|
return generated_text |
|
|
|
|
|
def load_spa_model( |
|
model_name="magic-yuantian/selective-prompt-anchoring", |
|
base_model_name="meta-llama/Llama-3.1-8B-Instruct", |
|
**kwargs |
|
): |
|
""" |
|
Load a SPAModel from the Hugging Face Hub or create a new one. |
|
|
|
Args: |
|
model_name: Name or path of the SPA model in the Hub |
|
base_model_name: The base model to use (if creating a new model) |
|
**kwargs: Additional arguments to pass to from_pretrained or __init__ |
|
|
|
Returns: |
|
A SPAModel instance |
|
""" |
|
try: |
|
|
|
model = SPAModel.from_pretrained(model_name, **kwargs) |
|
return model |
|
except Exception as e: |
|
print(f"Error loading model from hub: {e}") |
|
print(f"Creating a new SPAModel with base model {base_model_name}") |
|
|
|
model = SPAModel(base_model_name=base_model_name, **kwargs) |
|
return model |