Tarive commited on
Commit
b829e8f
·
verified ·
1 Parent(s): 12758be

Upload 7 files

Browse files
Files changed (7) hide show
  1. README.md +10 -6
  2. all_config.yaml +35 -0
  3. app.py +107 -0
  4. hrm_act_v1.py +288 -0
  5. losses.py +101 -0
  6. pytorch_model.bin +3 -0
  7. requirements.txt +5 -0
README.md CHANGED
@@ -1,12 +1,16 @@
1
  ---
2
- title: HRM Anchoring Bias Model
3
- emoji: 🏢
4
- colorFrom: pink
5
- colorTo: pink
6
  sdk: gradio
7
- sdk_version: 5.39.0
8
  app_file: app.py
9
  pinned: false
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
1
  ---
2
+ title: HRM Grant Abstract Optimizer
3
+ emoji: 🎯
4
+ colorFrom: blue
5
+ colorTo: green
6
  sdk: gradio
7
+ sdk_version: 4.0.0
8
  app_file: app.py
9
  pinned: false
10
  ---
11
 
12
+ # HRM Grant Abstract Optimizer
13
+
14
+ A Hierarchical Reasoning Model fine-tuned for optimizing scientific grant abstracts.
15
+
16
+ This Space demonstrates the model loading and provides a testing interface for the 27M-parameter HRM trained on grant abstract optimization.
all_config.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ arch:
2
+ H_cycles: 2
3
+ H_layers: 4
4
+ L_cycles: 2
5
+ L_layers: 4
6
+ expansion: 4
7
+ halt_exploration_prob: 0.1
8
+ halt_max_steps: 16
9
+ hidden_size: 512
10
+ loss:
11
+ loss_type: stablemax_cross_entropy
12
+ name: losses@ACTLossHead
13
+ name: hrm.hrm_act_v1@HierarchicalReasoningModel_ACTV1
14
+ num_heads: 8
15
+ pos_encodings: rope
16
+ puzzle_emb_ndim: 128
17
+ beta1: 0.9
18
+ beta2: 0.95
19
+ checkpoint_every_eval: true
20
+ checkpoint_path: checkpoints/Abstract_optimizer_processed ACT-torch/HierarchicalReasoningModel_ACTV1
21
+ ambrosial-orca
22
+ data_path: data/abstract_optimizer_processed
23
+ epochs: 20000
24
+ eval_interval: 1000
25
+ eval_save_outputs: []
26
+ global_batch_size: 16
27
+ lr: 0.0001
28
+ lr_min_ratio: 1.0
29
+ lr_warmup_steps: 2000
30
+ project_name: Abstract_optimizer_processed ACT-torch
31
+ puzzle_emb_lr: 0.01
32
+ puzzle_emb_weight_decay: 0.1
33
+ run_name: HierarchicalReasoningModel_ACTV1 ambrosial-orca
34
+ seed: 0
35
+ weight_decay: 0.1
app.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import yaml
4
+ import os
5
+
6
+ def load_model():
7
+ """Load the HRM model and config"""
8
+ try:
9
+ # Load config
10
+ with open('all_config.yaml', 'r') as f:
11
+ config = yaml.safe_load(f)
12
+
13
+ # Load checkpoint
14
+ checkpoint = torch.load('pytorch_model.bin', map_location='cpu')
15
+
16
+ return config, checkpoint, "✅ Model loaded successfully!"
17
+ except Exception as e:
18
+ return None, None, f"❌ Error loading model: {str(e)}"
19
+
20
+ def test_model_info(config, checkpoint):
21
+ """Display model information"""
22
+ if config is None or checkpoint is None:
23
+ return "Model not loaded"
24
+
25
+ info = f"""
26
+ **Model Architecture**: {config['arch']['name']}
27
+ **Hidden Size**: {config['arch']['hidden_size']}
28
+ **H Layers**: {config['arch']['H_layers']}
29
+ **L Layers**: {config['arch']['L_layers']}
30
+ **Parameters in Checkpoint**: {len(checkpoint)}
31
+ **Model Purpose**: Grant Abstract Optimization
32
+
33
+ **Training Details**:
34
+ - Steps: 492,500 (final checkpoint)
35
+ - Batch Size: {config['global_batch_size']}
36
+ - Learning Rate: {config['lr']}
37
+ """
38
+ return info
39
+
40
+ def placeholder_inference(draft_abstract, grant_type):
41
+ """Placeholder for actual inference (requires full training pipeline)"""
42
+ return f"""
43
+ **Input Abstract**: {draft_abstract[:100]}...
44
+
45
+ **Grant Type**: {grant_type}
46
+
47
+ **Status**: Model checkpoint loaded successfully!
48
+
49
+ ⚠️ **Note**: Full inference requires the original training pipeline with tokenizer and preprocessing code.
50
+ This demo shows that the model weights are accessible and the architecture is properly configured.
51
+
52
+ **Next Steps**:
53
+ 1. Integrate with original training codebase
54
+ 2. Load tokenizer and preprocessing pipeline
55
+ 3. Implement full inference function
56
+ """
57
+
58
+ # Load model on startup
59
+ config, checkpoint, load_status = load_model()
60
+
61
+ # Create Gradio interface
62
+ with gr.Blocks(title="HRM Grant Abstract Optimizer") as demo:
63
+ gr.Markdown("# 🎯 Hierarchical Reasoning Model for Grant Abstract Optimization")
64
+ gr.Markdown("A specialized 27M-parameter model for transforming draft grant abstracts into funding-worthy versions.")
65
+
66
+ with gr.Tab("Model Info"):
67
+ gr.Markdown("## Model Status")
68
+ gr.Markdown(load_status)
69
+
70
+ if config is not None:
71
+ model_info = test_model_info(config, checkpoint)
72
+ gr.Markdown(model_info)
73
+
74
+ with gr.Tab("Test Interface"):
75
+ gr.Markdown("## Abstract Optimization Demo")
76
+ gr.Markdown("*Note: This is a demonstration interface. Full inference requires integration with the training pipeline.*")
77
+
78
+ with gr.Row():
79
+ with gr.Column():
80
+ draft_input = gr.Textbox(
81
+ label="Draft Abstract",
82
+ placeholder="Enter your sub-optimal grant abstract here...",
83
+ lines=8,
84
+ value="Our study will investigate protein interactions in cancer cells. We believe this research could be important for understanding disease mechanisms."
85
+ )
86
+ grant_type = gr.Dropdown(
87
+ choices=["R01", "F32", "K99", "R21", "R15"],
88
+ label="Grant Type",
89
+ value="R01"
90
+ )
91
+ optimize_btn = gr.Button("Optimize Abstract", variant="primary")
92
+
93
+ with gr.Column():
94
+ output = gr.Textbox(
95
+ label="Optimized Abstract",
96
+ lines=10,
97
+ interactive=False
98
+ )
99
+
100
+ optimize_btn.click(
101
+ fn=placeholder_inference,
102
+ inputs=[draft_input, grant_type],
103
+ outputs=output
104
+ )
105
+
106
+ if __name__ == "__main__":
107
+ demo.launch()
hrm_act_v1.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, List, Dict, Optional
2
+ from dataclasses import dataclass
3
+ import math
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+ from pydantic import BaseModel
9
+
10
+ from models.common import trunc_normal_init_
11
+ from models.layers import rms_norm, SwiGLU, Attention, RotaryEmbedding, CosSin, CastedEmbedding, CastedLinear
12
+ from models.sparse_embedding import CastedSparseEmbedding
13
+
14
+
15
+ @dataclass
16
+ class HierarchicalReasoningModel_ACTV1InnerCarry:
17
+ z_H: torch.Tensor
18
+ z_L: torch.Tensor
19
+
20
+
21
+ @dataclass
22
+ class HierarchicalReasoningModel_ACTV1Carry:
23
+ inner_carry: HierarchicalReasoningModel_ACTV1InnerCarry
24
+
25
+ steps: torch.Tensor
26
+ halted: torch.Tensor
27
+
28
+ current_data: Dict[str, torch.Tensor]
29
+
30
+
31
+ class HierarchicalReasoningModel_ACTV1Config(BaseModel):
32
+ batch_size: int
33
+ seq_len: int
34
+ puzzle_emb_ndim: int = 0
35
+ num_puzzle_identifiers: int
36
+ vocab_size: int
37
+
38
+ H_cycles: int
39
+ L_cycles: int
40
+
41
+ H_layers: int
42
+ L_layers: int
43
+
44
+ # Transformer config
45
+ hidden_size: int
46
+ expansion: float
47
+ num_heads: int
48
+ pos_encodings: str
49
+
50
+ rms_norm_eps: float = 1e-5
51
+ rope_theta: float = 10000.0
52
+
53
+ # Halting Q-learning config
54
+ halt_max_steps: int
55
+ halt_exploration_prob: float
56
+
57
+ forward_dtype: str = "bfloat16"
58
+
59
+
60
+ class HierarchicalReasoningModel_ACTV1Block(nn.Module):
61
+ def __init__(self, config: HierarchicalReasoningModel_ACTV1Config) -> None:
62
+ super().__init__()
63
+
64
+ self.self_attn = Attention(
65
+ hidden_size=config.hidden_size,
66
+ head_dim=config.hidden_size // config.num_heads,
67
+ num_heads=config.num_heads,
68
+ num_key_value_heads=config.num_heads,
69
+ causal=False
70
+ )
71
+ self.mlp = SwiGLU(
72
+ hidden_size=config.hidden_size,
73
+ expansion=config.expansion,
74
+ )
75
+ self.norm_eps = config.rms_norm_eps
76
+
77
+ def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
78
+ # Post Norm
79
+ # Self Attention
80
+ hidden_states = rms_norm(hidden_states + self.self_attn(cos_sin=cos_sin, hidden_states=hidden_states), variance_epsilon=self.norm_eps)
81
+ # Fully Connected
82
+ hidden_states = rms_norm(hidden_states + self.mlp(hidden_states), variance_epsilon=self.norm_eps)
83
+ return hidden_states
84
+
85
+
86
+ class HierarchicalReasoningModel_ACTV1ReasoningModule(nn.Module):
87
+ def __init__(self, layers: List[HierarchicalReasoningModel_ACTV1Block]):
88
+ super().__init__()
89
+
90
+ self.layers = torch.nn.ModuleList(layers)
91
+
92
+ def forward(self, hidden_states: torch.Tensor, input_injection: torch.Tensor, **kwargs) -> torch.Tensor:
93
+ # Input injection (add)
94
+ hidden_states = hidden_states + input_injection
95
+ # Layers
96
+ for layer in self.layers:
97
+ hidden_states = layer(hidden_states=hidden_states, **kwargs)
98
+
99
+ return hidden_states
100
+
101
+
102
+ class HierarchicalReasoningModel_ACTV1_Inner(nn.Module):
103
+ def __init__(self, config: HierarchicalReasoningModel_ACTV1Config) -> None:
104
+ super().__init__()
105
+ self.config = config
106
+ self.forward_dtype = getattr(torch, self.config.forward_dtype)
107
+
108
+ # I/O
109
+ self.embed_scale = math.sqrt(self.config.hidden_size)
110
+ embed_init_std = 1.0 / self.embed_scale
111
+
112
+ self.embed_tokens = CastedEmbedding(self.config.vocab_size, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
113
+ self.lm_head = CastedLinear(self.config.hidden_size, self.config.vocab_size, bias=False)
114
+ self.q_head = CastedLinear(self.config.hidden_size, 2, bias=True)
115
+
116
+ self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) # ceil div
117
+ if self.config.puzzle_emb_ndim > 0:
118
+ # Zero init puzzle embeddings
119
+ self.puzzle_emb = CastedSparseEmbedding(self.config.num_puzzle_identifiers, self.config.puzzle_emb_ndim,
120
+ batch_size=self.config.batch_size, init_std=0, cast_to=self.forward_dtype)
121
+
122
+ # LM Blocks
123
+ if self.config.pos_encodings == "rope":
124
+ self.rotary_emb = RotaryEmbedding(dim=self.config.hidden_size // self.config.num_heads,
125
+ max_position_embeddings=self.config.seq_len + self.puzzle_emb_len,
126
+ base=self.config.rope_theta)
127
+ elif self.config.pos_encodings == "learned":
128
+ self.embed_pos = CastedEmbedding(self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
129
+ else:
130
+ raise NotImplementedError()
131
+
132
+ # Reasoning Layers
133
+ self.H_level = HierarchicalReasoningModel_ACTV1ReasoningModule(layers=[HierarchicalReasoningModel_ACTV1Block(self.config) for _i in range(self.config.H_layers)])
134
+ self.L_level = HierarchicalReasoningModel_ACTV1ReasoningModule(layers=[HierarchicalReasoningModel_ACTV1Block(self.config) for _i in range(self.config.L_layers)])
135
+
136
+ # --- CORRECTED CODE BLOCK ---
137
+ # Initial states
138
+ h_init_tensor = trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1)
139
+ self.register_buffer('H_init', h_init_tensor)
140
+
141
+ l_init_tensor = trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1)
142
+ self.register_buffer('L_init', l_init_tensor)
143
+ # --- END OF CORRECTION ---
144
+
145
+ # Q head special init
146
+ # Init Q to (almost) zero for faster learning during bootstrapping
147
+ with torch.no_grad():
148
+ self.q_head.weight.zero_()
149
+ self.q_head.bias.fill_(-5) # type: ignore
150
+
151
+ def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: torch.Tensor):
152
+ # Token embedding
153
+ embedding = self.embed_tokens(input.to(torch.int32))
154
+
155
+ # Puzzle embeddings
156
+ if self.config.puzzle_emb_ndim > 0:
157
+ puzzle_embedding = self.puzzle_emb(puzzle_identifiers)
158
+
159
+ pad_count = self.puzzle_emb_len * self.config.hidden_size - puzzle_embedding.shape[-1]
160
+ if pad_count > 0:
161
+ puzzle_embedding = F.pad(puzzle_embedding, (0, pad_count))
162
+
163
+ embedding = torch.cat((puzzle_embedding.view(-1, self.puzzle_emb_len, self.config.hidden_size), embedding), dim=-2)
164
+
165
+ # Position embeddings
166
+ if self.config.pos_encodings == "learned":
167
+ # scale by 1/sqrt(2) to maintain forward variance
168
+ embedding = 0.707106781 * (embedding + self.embed_pos.embedding_weight.to(self.forward_dtype))
169
+
170
+ # Scale
171
+ return self.embed_scale * embedding
172
+
173
+ def empty_carry(self, batch_size: int):
174
+ return HierarchicalReasoningModel_ACTV1InnerCarry(
175
+ z_H=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
176
+ z_L=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
177
+ )
178
+
179
+ def reset_carry(self, reset_flag: torch.Tensor, carry: HierarchicalReasoningModel_ACTV1InnerCarry):
180
+ return HierarchicalReasoningModel_ACTV1InnerCarry(
181
+ z_H=torch.where(reset_flag.view(-1, 1, 1), self.H_init, carry.z_H),
182
+ z_L=torch.where(reset_flag.view(-1, 1, 1), self.L_init, carry.z_L),
183
+ )
184
+
185
+ def forward(self, carry: HierarchicalReasoningModel_ACTV1InnerCarry, batch: Dict[str, torch.Tensor]) -> Tuple[HierarchicalReasoningModel_ACTV1InnerCarry, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
186
+ seq_info = dict(
187
+ cos_sin=self.rotary_emb() if hasattr(self, "rotary_emb") else None,
188
+ )
189
+
190
+ # Input encoding
191
+ input_embeddings = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"])
192
+
193
+ # Forward iterations
194
+ with torch.no_grad():
195
+ z_H, z_L = carry.z_H, carry.z_L
196
+
197
+ for _H_step in range(self.config.H_cycles):
198
+ for _L_step in range(self.config.L_cycles):
199
+ if not ((_H_step == self.config.H_cycles - 1) and (_L_step == self.config.L_cycles - 1)):
200
+ z_L = self.L_level(z_L, z_H + input_embeddings, **seq_info)
201
+
202
+ if not (_H_step == self.config.H_cycles - 1):
203
+ z_H = self.H_level(z_H, z_L, **seq_info)
204
+
205
+ assert not z_H.requires_grad and not z_L.requires_grad
206
+
207
+ # 1-step grad
208
+ z_L = self.L_level(z_L, z_H + input_embeddings, **seq_info)
209
+ z_H = self.H_level(z_H, z_L, **seq_info)
210
+
211
+ # LM Outputs
212
+ new_carry = HierarchicalReasoningModel_ACTV1InnerCarry(z_H=z_H.detach(), z_L=z_L.detach()) # New carry no grad
213
+ output = self.lm_head(z_H)[:, self.puzzle_emb_len:]
214
+
215
+ # Q head
216
+ q_logits = self.q_head(z_H[:, 0]).to(torch.float32)
217
+
218
+ return new_carry, output, (q_logits[..., 0], q_logits[..., 1])
219
+
220
+
221
+ class HierarchicalReasoningModel_ACTV1(nn.Module):
222
+ """ACT wrapper."""
223
+
224
+ def __init__(self, config_dict: dict):
225
+ super().__init__()
226
+ self.config = HierarchicalReasoningModel_ACTV1Config(**config_dict)
227
+ self.inner = HierarchicalReasoningModel_ACTV1_Inner(self.config)
228
+
229
+ @property
230
+ def puzzle_emb(self):
231
+ return self.inner.puzzle_emb
232
+
233
+ def initial_carry(self, batch: Dict[str, torch.Tensor]):
234
+ batch_size = batch["inputs"].shape[0]
235
+
236
+ return HierarchicalReasoningModel_ACTV1Carry(
237
+ inner_carry=self.inner.empty_carry(batch_size), # Empty is expected, it will be reseted in first pass as all sequences are halted.
238
+
239
+ steps=torch.zeros((batch_size, ), dtype=torch.int32),
240
+ halted=torch.ones((batch_size, ), dtype=torch.bool), # Default to halted
241
+
242
+ current_data={k: torch.empty_like(v) for k, v in batch.items()}
243
+ )
244
+
245
+ def forward(self, carry: HierarchicalReasoningModel_ACTV1Carry, batch: Dict[str, torch.Tensor]) -> Tuple[HierarchicalReasoningModel_ACTV1Carry, Dict[str, torch.Tensor]]:
246
+ # Update data, carry (removing halted sequences)
247
+ new_inner_carry = self.inner.reset_carry(carry.halted, carry.inner_carry)
248
+
249
+ new_steps = torch.where(carry.halted, 0, carry.steps)
250
+
251
+ new_current_data = {k: torch.where(carry.halted.view((-1, ) + (1, ) * (batch[k].ndim - 1)), batch[k], v) for k, v in carry.current_data.items()}
252
+
253
+ # Forward inner model
254
+ new_inner_carry, logits, (q_halt_logits, q_continue_logits) = self.inner(new_inner_carry, new_current_data)
255
+
256
+ outputs = {
257
+ "logits": logits,
258
+ "q_halt_logits": q_halt_logits,
259
+ "q_continue_logits": q_continue_logits
260
+ }
261
+
262
+ with torch.no_grad():
263
+ # Step
264
+ new_steps = new_steps + 1
265
+ is_last_step = new_steps >= self.config.halt_max_steps
266
+
267
+ halted = is_last_step
268
+
269
+ # if training, and ACT is enabled
270
+ if self.training and (self.config.halt_max_steps > 1):
271
+ # Halt signal
272
+ # NOTE: During evaluation, always use max steps, this is to guarantee the same halting steps inside a batch for batching purposes
273
+ halted = halted | (q_halt_logits > q_continue_logits)
274
+
275
+ # Exploration
276
+ min_halt_steps = (torch.rand_like(q_halt_logits) < self.config.halt_exploration_prob) * torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1)
277
+
278
+ halted = halted & (new_steps >= min_halt_steps)
279
+
280
+ # Compute target Q
281
+ # NOTE: No replay buffer and target networks for computing target Q-value.
282
+ # As batch_size is large, there're many parallel envs.
283
+ # Similar concept as PQN https://arxiv.org/abs/2407.04811
284
+ next_q_halt_logits, next_q_continue_logits = self.inner(new_inner_carry, new_current_data)[-1]
285
+
286
+ outputs["target_q_continue"] = torch.sigmoid(torch.where(is_last_step, next_q_halt_logits, torch.maximum(next_q_halt_logits, next_q_continue_logits)))
287
+
288
+ return HierarchicalReasoningModel_ACTV1Carry(new_inner_carry, new_steps, halted, new_current_data), outputs
losses.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Tuple, Dict, Sequence, Optional
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn
6
+
7
+
8
+ IGNORE_LABEL_ID = -100
9
+
10
+
11
+ def s(x, epsilon=1e-30):
12
+ return torch.where(
13
+ x<0,
14
+ 1/(1-x+ epsilon),
15
+ x + 1
16
+ )
17
+
18
+
19
+ def log_stablemax(x, dim=-1):
20
+ s_x = s(x)
21
+ return torch.log(s_x/torch.sum(s_x, dim=dim, keepdim=True))
22
+
23
+
24
+ def stablemax_cross_entropy(logits, labels, ignore_index: int = -100):
25
+ logprobs = log_stablemax(logits.to(torch.float64), dim=-1)
26
+
27
+ valid_mask = labels != ignore_index
28
+ transformed_labels = torch.where(valid_mask, labels, 0)
29
+ prediction_logprobs = torch.gather(logprobs, index=transformed_labels.to(torch.long).unsqueeze(-1), dim=-1).squeeze(-1)
30
+
31
+ return -torch.where(valid_mask, prediction_logprobs, 0)
32
+
33
+
34
+ def softmax_cross_entropy(logits, labels, ignore_index: int = -100):
35
+ # Cast logits to f32
36
+ # Flatten logits
37
+ return F.cross_entropy(logits.to(torch.float32).view(-1, logits.shape[-1]), labels.to(torch.long).view(-1), ignore_index=ignore_index, reduction="none").view(labels.shape)
38
+
39
+
40
+ class ACTLossHead(nn.Module):
41
+ def __init__(self, model: nn.Module, loss_type: str):
42
+ super().__init__()
43
+ self.model = model
44
+ self.loss_fn = globals()[loss_type]
45
+
46
+ def initial_carry(self, *args, **kwargs):
47
+ return self.model.initial_carry(*args, **kwargs) # type: ignore
48
+
49
+ def forward(
50
+ self,
51
+ return_keys: Sequence[str],
52
+ # Model args
53
+ **model_kwargs,
54
+ ) -> Tuple[Any, torch.Tensor, Dict[str, torch.Tensor], Optional[Dict[str, torch.Tensor]], torch.Tensor]:
55
+ # Model logits
56
+ # B x SeqLen x D
57
+ new_carry, outputs = self.model(**model_kwargs)
58
+ labels = new_carry.current_data["labels"]
59
+
60
+ # Correctness
61
+ with torch.no_grad():
62
+ mask = labels != IGNORE_LABEL_ID
63
+ loss_counts = mask.sum(-1)
64
+ loss_divisor = loss_counts.clamp_min(1).unsqueeze(-1) # Avoid NaNs in division
65
+
66
+ is_correct = mask & (torch.argmax(outputs["logits"], dim=-1) == labels)
67
+ seq_is_correct = is_correct.sum(-1) == loss_counts
68
+
69
+ # Metrics (halted)
70
+ valid_metrics = new_carry.halted & (loss_counts > 0)
71
+ metrics = {
72
+ "count": valid_metrics.sum(),
73
+
74
+ "accuracy": torch.where(valid_metrics, (is_correct.to(torch.float32) / loss_divisor).sum(-1), 0).sum(),
75
+ "exact_accuracy": (valid_metrics & seq_is_correct).sum(),
76
+
77
+ "q_halt_accuracy": (valid_metrics & ((outputs["q_halt_logits"] >= 0) == seq_is_correct)).sum(),
78
+ "steps": torch.where(valid_metrics, new_carry.steps, 0).sum(),
79
+ }
80
+
81
+ # Losses
82
+ # FIXME: Assuming the batch is always full
83
+ lm_loss = (self.loss_fn(outputs["logits"], labels, ignore_index=IGNORE_LABEL_ID) / loss_divisor).sum()
84
+ q_halt_loss = F.binary_cross_entropy_with_logits(outputs["q_halt_logits"], seq_is_correct.to(outputs["q_halt_logits"].dtype), reduction="sum")
85
+
86
+ metrics.update({
87
+ "lm_loss": lm_loss.detach(),
88
+ "q_halt_loss": q_halt_loss.detach(),
89
+ })
90
+
91
+ # Q continue (bootstrapping target loss)
92
+ q_continue_loss = 0
93
+ if "target_q_continue" in outputs:
94
+ q_continue_loss = F.binary_cross_entropy_with_logits(outputs["q_continue_logits"], outputs["target_q_continue"], reduction="sum")
95
+
96
+ metrics["q_continue_loss"] = q_continue_loss.detach()
97
+
98
+ # Filter outputs for return
99
+ detached_outputs = {k: outputs[k].detach() for k in return_keys if k in outputs}
100
+
101
+ return new_carry, lm_loss + 0.5 * (q_halt_loss + q_continue_loss), metrics, detached_outputs, new_carry.halted.all()
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7739c5146e1cf12a55fb2ce44c69ea099307ea6db383f0935c3966eedf0e203f
3
+ size 174644621
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ gradio
3
+ pydantic
4
+ PyYAML
5
+ transformers