DoctorChaos commited on
Commit
a2a2f0e
·
verified ·
1 Parent(s): d82fdf0

Upload spa_hf.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. spa_hf.py +197 -0
spa_hf.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from huggingface_hub import PyTorchModelHubMixin
4
+ from transformers import AutoModel, AutoTokenizer
5
+
6
+ # Import core SPA functionality
7
+ from spa import SPALogitsProcessor, spa_tokenize, preprocess_anchors, create_default_attention_mask
8
+
9
+ class SPAModel(nn.Module, PyTorchModelHubMixin):
10
+ """
11
+ Selective Prompt Anchoring (SPA) model with Hugging Face Hub integration.
12
+
13
+ This model wraps a base LLM and provides the SPA functionality with
14
+ the ability to be shared and downloaded from the Hugging Face Hub.
15
+ """
16
+
17
+ def __init__(
18
+ self,
19
+ base_model_name="Qwen/Qwen3-0.6B",
20
+ anchoring_strength=2,
21
+ modulated_by_prob=True,
22
+ use_attention_mask=True,
23
+ device_map="auto",
24
+ **kwargs
25
+ ):
26
+ super().__init__()
27
+
28
+ # Store configuration parameters
29
+ self.base_model_name = base_model_name
30
+ self.anchoring_strength = anchoring_strength
31
+ self.modulated_by_prob = modulated_by_prob
32
+ self.use_attention_mask = use_attention_mask
33
+ self.device_map = device_map
34
+
35
+ # Load the base model and tokenizer - using AutoModel to handle any model type
36
+ self.model = AutoModel.from_pretrained(base_model_name, device_map=device_map, **kwargs)
37
+ self.tokenizer = AutoTokenizer.from_pretrained(base_model_name)
38
+
39
+ # Set default pad token if needed
40
+ if self.tokenizer.pad_token is None:
41
+ self.tokenizer.pad_token = self.tokenizer.eos_token
42
+ if hasattr(self.model, "config"):
43
+ self.model.config.pad_token_id = self.model.config.eos_token_id
44
+
45
+ # Determine device
46
+ if hasattr(self.model, "device"):
47
+ self.device = self.model.device
48
+ else:
49
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
50
+
51
+ def forward(self, input_ids, attention_mask=None, **kwargs):
52
+ """Pass through to the base model's forward method"""
53
+ return self.model(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
54
+
55
+ def generate_with_spa(
56
+ self,
57
+ prompt,
58
+ anchors=None,
59
+ anchoring_strength=None,
60
+ modulated_by_prob=None,
61
+ use_attention_mask=None,
62
+ max_new_tokens=100,
63
+ min_new_tokens=1,
64
+ do_sample=True,
65
+ temperature=0.7,
66
+ top_p=0.95,
67
+ top_k=50,
68
+ stream=False,
69
+ **kwargs
70
+ ):
71
+ """
72
+ Generate text using Selective Prompt Anchoring.
73
+
74
+ Args:
75
+ prompt: Text or messages to generate from
76
+ anchors: List of anchor strings to influence generation
77
+ anchoring_strength: How much to weight the anchored version
78
+ modulated_by_prob: Whether to modulate strength by token probability
79
+ use_attention_mask: Whether to use attention masking for anchor tokens
80
+ max_new_tokens: Maximum number of tokens to generate
81
+ min_new_tokens: Minimum number of tokens to generate
82
+ do_sample: Whether to use sampling for generation
83
+ temperature: Sampling temperature
84
+ top_p: Top-p sampling parameter
85
+ top_k: Top-k sampling parameter
86
+ stream: Whether to stream the output
87
+
88
+ Returns:
89
+ Generated text (or streamer if stream=True)
90
+ """
91
+ # Use instance defaults if parameters are not provided
92
+ anchoring_strength = anchoring_strength or self.anchoring_strength
93
+ modulated_by_prob = modulated_by_prob if modulated_by_prob is not None else self.modulated_by_prob
94
+ use_attention_mask = use_attention_mask if use_attention_mask is not None else self.use_attention_mask
95
+
96
+ # Default to empty list if anchors not provided
97
+ if anchors is None:
98
+ anchors = []
99
+
100
+ # Preprocess anchors
101
+ anchors = preprocess_anchors(anchors)
102
+
103
+ # Tokenize with SPA
104
+ main_inputs, aux_inputs, mask_token = spa_tokenize(
105
+ prompt_with_anchors=prompt,
106
+ global_anchors=anchors,
107
+ tokenizer=self.tokenizer,
108
+ device=self.device
109
+ )
110
+
111
+ # Create SPA logits processor
112
+ spa_processor = SPALogitsProcessor(
113
+ aux_model=self.model,
114
+ aux_input_ids=aux_inputs,
115
+ strength=anchoring_strength,
116
+ modulated_by_prob=modulated_by_prob,
117
+ use_attention_mask=use_attention_mask,
118
+ mask_token=mask_token,
119
+ tokenizer=self.tokenizer
120
+ )
121
+
122
+ # Get attention mask
123
+ attention_mask = create_default_attention_mask(main_inputs, device=self.device)
124
+
125
+ # Set up generation kwargs
126
+ generation_kwargs = {
127
+ "input_ids": main_inputs,
128
+ "attention_mask": attention_mask,
129
+ "logits_processor": [spa_processor],
130
+ "min_new_tokens": min_new_tokens,
131
+ "max_new_tokens": max_new_tokens,
132
+ "do_sample": do_sample,
133
+ "temperature": temperature,
134
+ "top_p": top_p,
135
+ "top_k": top_k,
136
+ **kwargs
137
+ }
138
+
139
+ if stream:
140
+ from transformers import TextIteratorStreamer
141
+ import threading
142
+
143
+ # Set up streamer
144
+ streamer = TextIteratorStreamer(
145
+ self.tokenizer,
146
+ skip_special_tokens=True,
147
+ skip_prompt=True
148
+ )
149
+ generation_kwargs["streamer"] = streamer
150
+
151
+ # Start generation in a separate thread
152
+ generation_thread = threading.Thread(
153
+ target=self.model.generate,
154
+ kwargs=generation_kwargs
155
+ )
156
+ generation_thread.start()
157
+
158
+ # Return streamer for token-by-token output
159
+ return streamer
160
+ else:
161
+ # Normal generation (non-streaming)
162
+ output_sequences = self.model.generate(**generation_kwargs)
163
+
164
+ # Decode the output
165
+ input_length = main_inputs.shape[1]
166
+ new_tokens = output_sequences[0][input_length:]
167
+ generated_text = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
168
+
169
+ return generated_text
170
+
171
+ # Create a helper function to load models directly from hub
172
+ def load_spa_model(
173
+ model_name="magic-yuantian/selective-prompt-anchoring",
174
+ base_model_name="meta-llama/Llama-3.1-8B-Instruct",
175
+ **kwargs
176
+ ):
177
+ """
178
+ Load a SPAModel from the Hugging Face Hub or create a new one.
179
+
180
+ Args:
181
+ model_name: Name or path of the SPA model in the Hub
182
+ base_model_name: The base model to use (if creating a new model)
183
+ **kwargs: Additional arguments to pass to from_pretrained or __init__
184
+
185
+ Returns:
186
+ A SPAModel instance
187
+ """
188
+ try:
189
+ # Try to load from hub
190
+ model = SPAModel.from_pretrained(model_name, **kwargs)
191
+ return model
192
+ except Exception as e:
193
+ print(f"Error loading model from hub: {e}")
194
+ print(f"Creating a new SPAModel with base model {base_model_name}")
195
+ # Create a new model
196
+ model = SPAModel(base_model_name=base_model_name, **kwargs)
197
+ return model