Spaces:
Running
Running
Upload 7 files
Browse files- README.md +10 -6
- all_config.yaml +35 -0
- app.py +107 -0
- hrm_act_v1.py +288 -0
- losses.py +101 -0
- pytorch_model.bin +3 -0
- requirements.txt +5 -0
README.md
CHANGED
@@ -1,12 +1,16 @@
|
|
1 |
---
|
2 |
-
title: HRM
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version:
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
11 |
|
12 |
-
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
title: HRM Grant Abstract Optimizer
|
3 |
+
emoji: 🎯
|
4 |
+
colorFrom: blue
|
5 |
+
colorTo: green
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 4.0.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
11 |
|
12 |
+
# HRM Grant Abstract Optimizer
|
13 |
+
|
14 |
+
A Hierarchical Reasoning Model fine-tuned for optimizing scientific grant abstracts.
|
15 |
+
|
16 |
+
This Space demonstrates the model loading and provides a testing interface for the 27M-parameter HRM trained on grant abstract optimization.
|
all_config.yaml
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
arch:
|
2 |
+
H_cycles: 2
|
3 |
+
H_layers: 4
|
4 |
+
L_cycles: 2
|
5 |
+
L_layers: 4
|
6 |
+
expansion: 4
|
7 |
+
halt_exploration_prob: 0.1
|
8 |
+
halt_max_steps: 16
|
9 |
+
hidden_size: 512
|
10 |
+
loss:
|
11 |
+
loss_type: stablemax_cross_entropy
|
12 |
+
name: losses@ACTLossHead
|
13 |
+
name: hrm.hrm_act_v1@HierarchicalReasoningModel_ACTV1
|
14 |
+
num_heads: 8
|
15 |
+
pos_encodings: rope
|
16 |
+
puzzle_emb_ndim: 128
|
17 |
+
beta1: 0.9
|
18 |
+
beta2: 0.95
|
19 |
+
checkpoint_every_eval: true
|
20 |
+
checkpoint_path: checkpoints/Abstract_optimizer_processed ACT-torch/HierarchicalReasoningModel_ACTV1
|
21 |
+
ambrosial-orca
|
22 |
+
data_path: data/abstract_optimizer_processed
|
23 |
+
epochs: 20000
|
24 |
+
eval_interval: 1000
|
25 |
+
eval_save_outputs: []
|
26 |
+
global_batch_size: 16
|
27 |
+
lr: 0.0001
|
28 |
+
lr_min_ratio: 1.0
|
29 |
+
lr_warmup_steps: 2000
|
30 |
+
project_name: Abstract_optimizer_processed ACT-torch
|
31 |
+
puzzle_emb_lr: 0.01
|
32 |
+
puzzle_emb_weight_decay: 0.1
|
33 |
+
run_name: HierarchicalReasoningModel_ACTV1 ambrosial-orca
|
34 |
+
seed: 0
|
35 |
+
weight_decay: 0.1
|
app.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
import yaml
|
4 |
+
import os
|
5 |
+
|
6 |
+
def load_model():
|
7 |
+
"""Load the HRM model and config"""
|
8 |
+
try:
|
9 |
+
# Load config
|
10 |
+
with open('all_config.yaml', 'r') as f:
|
11 |
+
config = yaml.safe_load(f)
|
12 |
+
|
13 |
+
# Load checkpoint
|
14 |
+
checkpoint = torch.load('pytorch_model.bin', map_location='cpu')
|
15 |
+
|
16 |
+
return config, checkpoint, "✅ Model loaded successfully!"
|
17 |
+
except Exception as e:
|
18 |
+
return None, None, f"❌ Error loading model: {str(e)}"
|
19 |
+
|
20 |
+
def test_model_info(config, checkpoint):
|
21 |
+
"""Display model information"""
|
22 |
+
if config is None or checkpoint is None:
|
23 |
+
return "Model not loaded"
|
24 |
+
|
25 |
+
info = f"""
|
26 |
+
**Model Architecture**: {config['arch']['name']}
|
27 |
+
**Hidden Size**: {config['arch']['hidden_size']}
|
28 |
+
**H Layers**: {config['arch']['H_layers']}
|
29 |
+
**L Layers**: {config['arch']['L_layers']}
|
30 |
+
**Parameters in Checkpoint**: {len(checkpoint)}
|
31 |
+
**Model Purpose**: Grant Abstract Optimization
|
32 |
+
|
33 |
+
**Training Details**:
|
34 |
+
- Steps: 492,500 (final checkpoint)
|
35 |
+
- Batch Size: {config['global_batch_size']}
|
36 |
+
- Learning Rate: {config['lr']}
|
37 |
+
"""
|
38 |
+
return info
|
39 |
+
|
40 |
+
def placeholder_inference(draft_abstract, grant_type):
|
41 |
+
"""Placeholder for actual inference (requires full training pipeline)"""
|
42 |
+
return f"""
|
43 |
+
**Input Abstract**: {draft_abstract[:100]}...
|
44 |
+
|
45 |
+
**Grant Type**: {grant_type}
|
46 |
+
|
47 |
+
**Status**: Model checkpoint loaded successfully!
|
48 |
+
|
49 |
+
⚠️ **Note**: Full inference requires the original training pipeline with tokenizer and preprocessing code.
|
50 |
+
This demo shows that the model weights are accessible and the architecture is properly configured.
|
51 |
+
|
52 |
+
**Next Steps**:
|
53 |
+
1. Integrate with original training codebase
|
54 |
+
2. Load tokenizer and preprocessing pipeline
|
55 |
+
3. Implement full inference function
|
56 |
+
"""
|
57 |
+
|
58 |
+
# Load model on startup
|
59 |
+
config, checkpoint, load_status = load_model()
|
60 |
+
|
61 |
+
# Create Gradio interface
|
62 |
+
with gr.Blocks(title="HRM Grant Abstract Optimizer") as demo:
|
63 |
+
gr.Markdown("# 🎯 Hierarchical Reasoning Model for Grant Abstract Optimization")
|
64 |
+
gr.Markdown("A specialized 27M-parameter model for transforming draft grant abstracts into funding-worthy versions.")
|
65 |
+
|
66 |
+
with gr.Tab("Model Info"):
|
67 |
+
gr.Markdown("## Model Status")
|
68 |
+
gr.Markdown(load_status)
|
69 |
+
|
70 |
+
if config is not None:
|
71 |
+
model_info = test_model_info(config, checkpoint)
|
72 |
+
gr.Markdown(model_info)
|
73 |
+
|
74 |
+
with gr.Tab("Test Interface"):
|
75 |
+
gr.Markdown("## Abstract Optimization Demo")
|
76 |
+
gr.Markdown("*Note: This is a demonstration interface. Full inference requires integration with the training pipeline.*")
|
77 |
+
|
78 |
+
with gr.Row():
|
79 |
+
with gr.Column():
|
80 |
+
draft_input = gr.Textbox(
|
81 |
+
label="Draft Abstract",
|
82 |
+
placeholder="Enter your sub-optimal grant abstract here...",
|
83 |
+
lines=8,
|
84 |
+
value="Our study will investigate protein interactions in cancer cells. We believe this research could be important for understanding disease mechanisms."
|
85 |
+
)
|
86 |
+
grant_type = gr.Dropdown(
|
87 |
+
choices=["R01", "F32", "K99", "R21", "R15"],
|
88 |
+
label="Grant Type",
|
89 |
+
value="R01"
|
90 |
+
)
|
91 |
+
optimize_btn = gr.Button("Optimize Abstract", variant="primary")
|
92 |
+
|
93 |
+
with gr.Column():
|
94 |
+
output = gr.Textbox(
|
95 |
+
label="Optimized Abstract",
|
96 |
+
lines=10,
|
97 |
+
interactive=False
|
98 |
+
)
|
99 |
+
|
100 |
+
optimize_btn.click(
|
101 |
+
fn=placeholder_inference,
|
102 |
+
inputs=[draft_input, grant_type],
|
103 |
+
outputs=output
|
104 |
+
)
|
105 |
+
|
106 |
+
if __name__ == "__main__":
|
107 |
+
demo.launch()
|
hrm_act_v1.py
ADDED
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Tuple, List, Dict, Optional
|
2 |
+
from dataclasses import dataclass
|
3 |
+
import math
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from torch import nn
|
8 |
+
from pydantic import BaseModel
|
9 |
+
|
10 |
+
from models.common import trunc_normal_init_
|
11 |
+
from models.layers import rms_norm, SwiGLU, Attention, RotaryEmbedding, CosSin, CastedEmbedding, CastedLinear
|
12 |
+
from models.sparse_embedding import CastedSparseEmbedding
|
13 |
+
|
14 |
+
|
15 |
+
@dataclass
|
16 |
+
class HierarchicalReasoningModel_ACTV1InnerCarry:
|
17 |
+
z_H: torch.Tensor
|
18 |
+
z_L: torch.Tensor
|
19 |
+
|
20 |
+
|
21 |
+
@dataclass
|
22 |
+
class HierarchicalReasoningModel_ACTV1Carry:
|
23 |
+
inner_carry: HierarchicalReasoningModel_ACTV1InnerCarry
|
24 |
+
|
25 |
+
steps: torch.Tensor
|
26 |
+
halted: torch.Tensor
|
27 |
+
|
28 |
+
current_data: Dict[str, torch.Tensor]
|
29 |
+
|
30 |
+
|
31 |
+
class HierarchicalReasoningModel_ACTV1Config(BaseModel):
|
32 |
+
batch_size: int
|
33 |
+
seq_len: int
|
34 |
+
puzzle_emb_ndim: int = 0
|
35 |
+
num_puzzle_identifiers: int
|
36 |
+
vocab_size: int
|
37 |
+
|
38 |
+
H_cycles: int
|
39 |
+
L_cycles: int
|
40 |
+
|
41 |
+
H_layers: int
|
42 |
+
L_layers: int
|
43 |
+
|
44 |
+
# Transformer config
|
45 |
+
hidden_size: int
|
46 |
+
expansion: float
|
47 |
+
num_heads: int
|
48 |
+
pos_encodings: str
|
49 |
+
|
50 |
+
rms_norm_eps: float = 1e-5
|
51 |
+
rope_theta: float = 10000.0
|
52 |
+
|
53 |
+
# Halting Q-learning config
|
54 |
+
halt_max_steps: int
|
55 |
+
halt_exploration_prob: float
|
56 |
+
|
57 |
+
forward_dtype: str = "bfloat16"
|
58 |
+
|
59 |
+
|
60 |
+
class HierarchicalReasoningModel_ACTV1Block(nn.Module):
|
61 |
+
def __init__(self, config: HierarchicalReasoningModel_ACTV1Config) -> None:
|
62 |
+
super().__init__()
|
63 |
+
|
64 |
+
self.self_attn = Attention(
|
65 |
+
hidden_size=config.hidden_size,
|
66 |
+
head_dim=config.hidden_size // config.num_heads,
|
67 |
+
num_heads=config.num_heads,
|
68 |
+
num_key_value_heads=config.num_heads,
|
69 |
+
causal=False
|
70 |
+
)
|
71 |
+
self.mlp = SwiGLU(
|
72 |
+
hidden_size=config.hidden_size,
|
73 |
+
expansion=config.expansion,
|
74 |
+
)
|
75 |
+
self.norm_eps = config.rms_norm_eps
|
76 |
+
|
77 |
+
def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
|
78 |
+
# Post Norm
|
79 |
+
# Self Attention
|
80 |
+
hidden_states = rms_norm(hidden_states + self.self_attn(cos_sin=cos_sin, hidden_states=hidden_states), variance_epsilon=self.norm_eps)
|
81 |
+
# Fully Connected
|
82 |
+
hidden_states = rms_norm(hidden_states + self.mlp(hidden_states), variance_epsilon=self.norm_eps)
|
83 |
+
return hidden_states
|
84 |
+
|
85 |
+
|
86 |
+
class HierarchicalReasoningModel_ACTV1ReasoningModule(nn.Module):
|
87 |
+
def __init__(self, layers: List[HierarchicalReasoningModel_ACTV1Block]):
|
88 |
+
super().__init__()
|
89 |
+
|
90 |
+
self.layers = torch.nn.ModuleList(layers)
|
91 |
+
|
92 |
+
def forward(self, hidden_states: torch.Tensor, input_injection: torch.Tensor, **kwargs) -> torch.Tensor:
|
93 |
+
# Input injection (add)
|
94 |
+
hidden_states = hidden_states + input_injection
|
95 |
+
# Layers
|
96 |
+
for layer in self.layers:
|
97 |
+
hidden_states = layer(hidden_states=hidden_states, **kwargs)
|
98 |
+
|
99 |
+
return hidden_states
|
100 |
+
|
101 |
+
|
102 |
+
class HierarchicalReasoningModel_ACTV1_Inner(nn.Module):
|
103 |
+
def __init__(self, config: HierarchicalReasoningModel_ACTV1Config) -> None:
|
104 |
+
super().__init__()
|
105 |
+
self.config = config
|
106 |
+
self.forward_dtype = getattr(torch, self.config.forward_dtype)
|
107 |
+
|
108 |
+
# I/O
|
109 |
+
self.embed_scale = math.sqrt(self.config.hidden_size)
|
110 |
+
embed_init_std = 1.0 / self.embed_scale
|
111 |
+
|
112 |
+
self.embed_tokens = CastedEmbedding(self.config.vocab_size, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
|
113 |
+
self.lm_head = CastedLinear(self.config.hidden_size, self.config.vocab_size, bias=False)
|
114 |
+
self.q_head = CastedLinear(self.config.hidden_size, 2, bias=True)
|
115 |
+
|
116 |
+
self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) # ceil div
|
117 |
+
if self.config.puzzle_emb_ndim > 0:
|
118 |
+
# Zero init puzzle embeddings
|
119 |
+
self.puzzle_emb = CastedSparseEmbedding(self.config.num_puzzle_identifiers, self.config.puzzle_emb_ndim,
|
120 |
+
batch_size=self.config.batch_size, init_std=0, cast_to=self.forward_dtype)
|
121 |
+
|
122 |
+
# LM Blocks
|
123 |
+
if self.config.pos_encodings == "rope":
|
124 |
+
self.rotary_emb = RotaryEmbedding(dim=self.config.hidden_size // self.config.num_heads,
|
125 |
+
max_position_embeddings=self.config.seq_len + self.puzzle_emb_len,
|
126 |
+
base=self.config.rope_theta)
|
127 |
+
elif self.config.pos_encodings == "learned":
|
128 |
+
self.embed_pos = CastedEmbedding(self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
|
129 |
+
else:
|
130 |
+
raise NotImplementedError()
|
131 |
+
|
132 |
+
# Reasoning Layers
|
133 |
+
self.H_level = HierarchicalReasoningModel_ACTV1ReasoningModule(layers=[HierarchicalReasoningModel_ACTV1Block(self.config) for _i in range(self.config.H_layers)])
|
134 |
+
self.L_level = HierarchicalReasoningModel_ACTV1ReasoningModule(layers=[HierarchicalReasoningModel_ACTV1Block(self.config) for _i in range(self.config.L_layers)])
|
135 |
+
|
136 |
+
# --- CORRECTED CODE BLOCK ---
|
137 |
+
# Initial states
|
138 |
+
h_init_tensor = trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1)
|
139 |
+
self.register_buffer('H_init', h_init_tensor)
|
140 |
+
|
141 |
+
l_init_tensor = trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1)
|
142 |
+
self.register_buffer('L_init', l_init_tensor)
|
143 |
+
# --- END OF CORRECTION ---
|
144 |
+
|
145 |
+
# Q head special init
|
146 |
+
# Init Q to (almost) zero for faster learning during bootstrapping
|
147 |
+
with torch.no_grad():
|
148 |
+
self.q_head.weight.zero_()
|
149 |
+
self.q_head.bias.fill_(-5) # type: ignore
|
150 |
+
|
151 |
+
def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: torch.Tensor):
|
152 |
+
# Token embedding
|
153 |
+
embedding = self.embed_tokens(input.to(torch.int32))
|
154 |
+
|
155 |
+
# Puzzle embeddings
|
156 |
+
if self.config.puzzle_emb_ndim > 0:
|
157 |
+
puzzle_embedding = self.puzzle_emb(puzzle_identifiers)
|
158 |
+
|
159 |
+
pad_count = self.puzzle_emb_len * self.config.hidden_size - puzzle_embedding.shape[-1]
|
160 |
+
if pad_count > 0:
|
161 |
+
puzzle_embedding = F.pad(puzzle_embedding, (0, pad_count))
|
162 |
+
|
163 |
+
embedding = torch.cat((puzzle_embedding.view(-1, self.puzzle_emb_len, self.config.hidden_size), embedding), dim=-2)
|
164 |
+
|
165 |
+
# Position embeddings
|
166 |
+
if self.config.pos_encodings == "learned":
|
167 |
+
# scale by 1/sqrt(2) to maintain forward variance
|
168 |
+
embedding = 0.707106781 * (embedding + self.embed_pos.embedding_weight.to(self.forward_dtype))
|
169 |
+
|
170 |
+
# Scale
|
171 |
+
return self.embed_scale * embedding
|
172 |
+
|
173 |
+
def empty_carry(self, batch_size: int):
|
174 |
+
return HierarchicalReasoningModel_ACTV1InnerCarry(
|
175 |
+
z_H=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
|
176 |
+
z_L=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
|
177 |
+
)
|
178 |
+
|
179 |
+
def reset_carry(self, reset_flag: torch.Tensor, carry: HierarchicalReasoningModel_ACTV1InnerCarry):
|
180 |
+
return HierarchicalReasoningModel_ACTV1InnerCarry(
|
181 |
+
z_H=torch.where(reset_flag.view(-1, 1, 1), self.H_init, carry.z_H),
|
182 |
+
z_L=torch.where(reset_flag.view(-1, 1, 1), self.L_init, carry.z_L),
|
183 |
+
)
|
184 |
+
|
185 |
+
def forward(self, carry: HierarchicalReasoningModel_ACTV1InnerCarry, batch: Dict[str, torch.Tensor]) -> Tuple[HierarchicalReasoningModel_ACTV1InnerCarry, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
186 |
+
seq_info = dict(
|
187 |
+
cos_sin=self.rotary_emb() if hasattr(self, "rotary_emb") else None,
|
188 |
+
)
|
189 |
+
|
190 |
+
# Input encoding
|
191 |
+
input_embeddings = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"])
|
192 |
+
|
193 |
+
# Forward iterations
|
194 |
+
with torch.no_grad():
|
195 |
+
z_H, z_L = carry.z_H, carry.z_L
|
196 |
+
|
197 |
+
for _H_step in range(self.config.H_cycles):
|
198 |
+
for _L_step in range(self.config.L_cycles):
|
199 |
+
if not ((_H_step == self.config.H_cycles - 1) and (_L_step == self.config.L_cycles - 1)):
|
200 |
+
z_L = self.L_level(z_L, z_H + input_embeddings, **seq_info)
|
201 |
+
|
202 |
+
if not (_H_step == self.config.H_cycles - 1):
|
203 |
+
z_H = self.H_level(z_H, z_L, **seq_info)
|
204 |
+
|
205 |
+
assert not z_H.requires_grad and not z_L.requires_grad
|
206 |
+
|
207 |
+
# 1-step grad
|
208 |
+
z_L = self.L_level(z_L, z_H + input_embeddings, **seq_info)
|
209 |
+
z_H = self.H_level(z_H, z_L, **seq_info)
|
210 |
+
|
211 |
+
# LM Outputs
|
212 |
+
new_carry = HierarchicalReasoningModel_ACTV1InnerCarry(z_H=z_H.detach(), z_L=z_L.detach()) # New carry no grad
|
213 |
+
output = self.lm_head(z_H)[:, self.puzzle_emb_len:]
|
214 |
+
|
215 |
+
# Q head
|
216 |
+
q_logits = self.q_head(z_H[:, 0]).to(torch.float32)
|
217 |
+
|
218 |
+
return new_carry, output, (q_logits[..., 0], q_logits[..., 1])
|
219 |
+
|
220 |
+
|
221 |
+
class HierarchicalReasoningModel_ACTV1(nn.Module):
|
222 |
+
"""ACT wrapper."""
|
223 |
+
|
224 |
+
def __init__(self, config_dict: dict):
|
225 |
+
super().__init__()
|
226 |
+
self.config = HierarchicalReasoningModel_ACTV1Config(**config_dict)
|
227 |
+
self.inner = HierarchicalReasoningModel_ACTV1_Inner(self.config)
|
228 |
+
|
229 |
+
@property
|
230 |
+
def puzzle_emb(self):
|
231 |
+
return self.inner.puzzle_emb
|
232 |
+
|
233 |
+
def initial_carry(self, batch: Dict[str, torch.Tensor]):
|
234 |
+
batch_size = batch["inputs"].shape[0]
|
235 |
+
|
236 |
+
return HierarchicalReasoningModel_ACTV1Carry(
|
237 |
+
inner_carry=self.inner.empty_carry(batch_size), # Empty is expected, it will be reseted in first pass as all sequences are halted.
|
238 |
+
|
239 |
+
steps=torch.zeros((batch_size, ), dtype=torch.int32),
|
240 |
+
halted=torch.ones((batch_size, ), dtype=torch.bool), # Default to halted
|
241 |
+
|
242 |
+
current_data={k: torch.empty_like(v) for k, v in batch.items()}
|
243 |
+
)
|
244 |
+
|
245 |
+
def forward(self, carry: HierarchicalReasoningModel_ACTV1Carry, batch: Dict[str, torch.Tensor]) -> Tuple[HierarchicalReasoningModel_ACTV1Carry, Dict[str, torch.Tensor]]:
|
246 |
+
# Update data, carry (removing halted sequences)
|
247 |
+
new_inner_carry = self.inner.reset_carry(carry.halted, carry.inner_carry)
|
248 |
+
|
249 |
+
new_steps = torch.where(carry.halted, 0, carry.steps)
|
250 |
+
|
251 |
+
new_current_data = {k: torch.where(carry.halted.view((-1, ) + (1, ) * (batch[k].ndim - 1)), batch[k], v) for k, v in carry.current_data.items()}
|
252 |
+
|
253 |
+
# Forward inner model
|
254 |
+
new_inner_carry, logits, (q_halt_logits, q_continue_logits) = self.inner(new_inner_carry, new_current_data)
|
255 |
+
|
256 |
+
outputs = {
|
257 |
+
"logits": logits,
|
258 |
+
"q_halt_logits": q_halt_logits,
|
259 |
+
"q_continue_logits": q_continue_logits
|
260 |
+
}
|
261 |
+
|
262 |
+
with torch.no_grad():
|
263 |
+
# Step
|
264 |
+
new_steps = new_steps + 1
|
265 |
+
is_last_step = new_steps >= self.config.halt_max_steps
|
266 |
+
|
267 |
+
halted = is_last_step
|
268 |
+
|
269 |
+
# if training, and ACT is enabled
|
270 |
+
if self.training and (self.config.halt_max_steps > 1):
|
271 |
+
# Halt signal
|
272 |
+
# NOTE: During evaluation, always use max steps, this is to guarantee the same halting steps inside a batch for batching purposes
|
273 |
+
halted = halted | (q_halt_logits > q_continue_logits)
|
274 |
+
|
275 |
+
# Exploration
|
276 |
+
min_halt_steps = (torch.rand_like(q_halt_logits) < self.config.halt_exploration_prob) * torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1)
|
277 |
+
|
278 |
+
halted = halted & (new_steps >= min_halt_steps)
|
279 |
+
|
280 |
+
# Compute target Q
|
281 |
+
# NOTE: No replay buffer and target networks for computing target Q-value.
|
282 |
+
# As batch_size is large, there're many parallel envs.
|
283 |
+
# Similar concept as PQN https://arxiv.org/abs/2407.04811
|
284 |
+
next_q_halt_logits, next_q_continue_logits = self.inner(new_inner_carry, new_current_data)[-1]
|
285 |
+
|
286 |
+
outputs["target_q_continue"] = torch.sigmoid(torch.where(is_last_step, next_q_halt_logits, torch.maximum(next_q_halt_logits, next_q_continue_logits)))
|
287 |
+
|
288 |
+
return HierarchicalReasoningModel_ACTV1Carry(new_inner_carry, new_steps, halted, new_current_data), outputs
|
losses.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Tuple, Dict, Sequence, Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torch import nn
|
6 |
+
|
7 |
+
|
8 |
+
IGNORE_LABEL_ID = -100
|
9 |
+
|
10 |
+
|
11 |
+
def s(x, epsilon=1e-30):
|
12 |
+
return torch.where(
|
13 |
+
x<0,
|
14 |
+
1/(1-x+ epsilon),
|
15 |
+
x + 1
|
16 |
+
)
|
17 |
+
|
18 |
+
|
19 |
+
def log_stablemax(x, dim=-1):
|
20 |
+
s_x = s(x)
|
21 |
+
return torch.log(s_x/torch.sum(s_x, dim=dim, keepdim=True))
|
22 |
+
|
23 |
+
|
24 |
+
def stablemax_cross_entropy(logits, labels, ignore_index: int = -100):
|
25 |
+
logprobs = log_stablemax(logits.to(torch.float64), dim=-1)
|
26 |
+
|
27 |
+
valid_mask = labels != ignore_index
|
28 |
+
transformed_labels = torch.where(valid_mask, labels, 0)
|
29 |
+
prediction_logprobs = torch.gather(logprobs, index=transformed_labels.to(torch.long).unsqueeze(-1), dim=-1).squeeze(-1)
|
30 |
+
|
31 |
+
return -torch.where(valid_mask, prediction_logprobs, 0)
|
32 |
+
|
33 |
+
|
34 |
+
def softmax_cross_entropy(logits, labels, ignore_index: int = -100):
|
35 |
+
# Cast logits to f32
|
36 |
+
# Flatten logits
|
37 |
+
return F.cross_entropy(logits.to(torch.float32).view(-1, logits.shape[-1]), labels.to(torch.long).view(-1), ignore_index=ignore_index, reduction="none").view(labels.shape)
|
38 |
+
|
39 |
+
|
40 |
+
class ACTLossHead(nn.Module):
|
41 |
+
def __init__(self, model: nn.Module, loss_type: str):
|
42 |
+
super().__init__()
|
43 |
+
self.model = model
|
44 |
+
self.loss_fn = globals()[loss_type]
|
45 |
+
|
46 |
+
def initial_carry(self, *args, **kwargs):
|
47 |
+
return self.model.initial_carry(*args, **kwargs) # type: ignore
|
48 |
+
|
49 |
+
def forward(
|
50 |
+
self,
|
51 |
+
return_keys: Sequence[str],
|
52 |
+
# Model args
|
53 |
+
**model_kwargs,
|
54 |
+
) -> Tuple[Any, torch.Tensor, Dict[str, torch.Tensor], Optional[Dict[str, torch.Tensor]], torch.Tensor]:
|
55 |
+
# Model logits
|
56 |
+
# B x SeqLen x D
|
57 |
+
new_carry, outputs = self.model(**model_kwargs)
|
58 |
+
labels = new_carry.current_data["labels"]
|
59 |
+
|
60 |
+
# Correctness
|
61 |
+
with torch.no_grad():
|
62 |
+
mask = labels != IGNORE_LABEL_ID
|
63 |
+
loss_counts = mask.sum(-1)
|
64 |
+
loss_divisor = loss_counts.clamp_min(1).unsqueeze(-1) # Avoid NaNs in division
|
65 |
+
|
66 |
+
is_correct = mask & (torch.argmax(outputs["logits"], dim=-1) == labels)
|
67 |
+
seq_is_correct = is_correct.sum(-1) == loss_counts
|
68 |
+
|
69 |
+
# Metrics (halted)
|
70 |
+
valid_metrics = new_carry.halted & (loss_counts > 0)
|
71 |
+
metrics = {
|
72 |
+
"count": valid_metrics.sum(),
|
73 |
+
|
74 |
+
"accuracy": torch.where(valid_metrics, (is_correct.to(torch.float32) / loss_divisor).sum(-1), 0).sum(),
|
75 |
+
"exact_accuracy": (valid_metrics & seq_is_correct).sum(),
|
76 |
+
|
77 |
+
"q_halt_accuracy": (valid_metrics & ((outputs["q_halt_logits"] >= 0) == seq_is_correct)).sum(),
|
78 |
+
"steps": torch.where(valid_metrics, new_carry.steps, 0).sum(),
|
79 |
+
}
|
80 |
+
|
81 |
+
# Losses
|
82 |
+
# FIXME: Assuming the batch is always full
|
83 |
+
lm_loss = (self.loss_fn(outputs["logits"], labels, ignore_index=IGNORE_LABEL_ID) / loss_divisor).sum()
|
84 |
+
q_halt_loss = F.binary_cross_entropy_with_logits(outputs["q_halt_logits"], seq_is_correct.to(outputs["q_halt_logits"].dtype), reduction="sum")
|
85 |
+
|
86 |
+
metrics.update({
|
87 |
+
"lm_loss": lm_loss.detach(),
|
88 |
+
"q_halt_loss": q_halt_loss.detach(),
|
89 |
+
})
|
90 |
+
|
91 |
+
# Q continue (bootstrapping target loss)
|
92 |
+
q_continue_loss = 0
|
93 |
+
if "target_q_continue" in outputs:
|
94 |
+
q_continue_loss = F.binary_cross_entropy_with_logits(outputs["q_continue_logits"], outputs["target_q_continue"], reduction="sum")
|
95 |
+
|
96 |
+
metrics["q_continue_loss"] = q_continue_loss.detach()
|
97 |
+
|
98 |
+
# Filter outputs for return
|
99 |
+
detached_outputs = {k: outputs[k].detach() for k in return_keys if k in outputs}
|
100 |
+
|
101 |
+
return new_carry, lm_loss + 0.5 * (q_halt_loss + q_continue_loss), metrics, detached_outputs, new_carry.halted.all()
|
pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7739c5146e1cf12a55fb2ce44c69ea099307ea6db383f0935c3966eedf0e203f
|
3 |
+
size 174644621
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
gradio
|
3 |
+
pydantic
|
4 |
+
PyYAML
|
5 |
+
transformers
|