Tarive commited on
Commit
cb71d2d
·
verified ·
1 Parent(s): fdeb6a5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -2
app.py CHANGED
@@ -3,6 +3,7 @@ import torch
3
  import yaml
4
  import json
5
  from tokenizers import Tokenizer
 
6
 
7
  # --- 1. Load Custom Model Code ---
8
  # This import now works because we have the correct models/hrm/ structure
@@ -26,8 +27,24 @@ model_config.update({
26
  'vocab_size': tokenizer.get_vocab_size()
27
  })
28
  model = HierarchicalReasoningModel_ACTV1(config_dict=model_config)
29
- model.load_state_dict(torch.load('pytorch_model.bin', map_location='cpu'))
30
- model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  print("Model loaded successfully!")
32
 
33
  # --- 4. Define the Inference Function ---
 
3
  import yaml
4
  import json
5
  from tokenizers import Tokenizer
6
+ from collections import OrderedDict
7
 
8
  # --- 1. Load Custom Model Code ---
9
  # This import now works because we have the correct models/hrm/ structure
 
27
  'vocab_size': tokenizer.get_vocab_size()
28
  })
29
  model = HierarchicalReasoningModel_ACTV1(config_dict=model_config)
30
+
31
+ # --- MODIFICATION: Clean the state dict keys before loading ---
32
+ # Load the original state dict
33
+ original_state_dict = torch.load('pytorch_model.bin', map_location='cpu')
34
+ # Create a new state dict with the corrected keys
35
+ new_state_dict = OrderedDict()
36
+ for k, v in original_state_dict.items():
37
+ if k.startswith('_orig_mod.model.'):
38
+ name = k[len('_orig_mod.model.'):] # remove the prefix
39
+ new_state_dict[name] = v
40
+ else:
41
+ new_state_dict[k] = v
42
+
43
+ # Load the cleaned state dict
44
+ model.load_state_dict(new_state_dict)
45
+ # --- END MODIFICATION ---
46
+
47
+ model.eval() # Set the model to evaluation mode
48
  print("Model loaded successfully!")
49
 
50
  # --- 4. Define the Inference Function ---