linoyts HF Staff commited on
Commit
d3f7631
·
verified ·
1 Parent(s): 988720a

Update optimization_utils.py

Browse files
Files changed (1) hide show
  1. optimization_utils.py +9 -0
optimization_utils.py CHANGED
@@ -96,3 +96,12 @@ def capture_component_call(
96
  except CapturedCallException as e:
97
  captured_call.args = e.args
98
  captured_call.kwargs = e.kwargs
 
 
 
 
 
 
 
 
 
 
96
  except CapturedCallException as e:
97
  captured_call.args = e.args
98
  captured_call.kwargs = e.kwargs
99
+
100
+
101
+ def drain_module_parameters(module: torch.nn.Module):
102
+ state_dict_meta = {name: {'device': tensor.device, 'dtype': tensor.dtype} for name, tensor in module.state_dict().items()}
103
+ state_dict = {name: torch.nn.Parameter(torch.empty_like(tensor, device='cpu')) for name, tensor in module.state_dict().items()}
104
+ module.load_state_dict(state_dict, assign=True)
105
+ for name, param in state_dict.items():
106
+ meta = state_dict_meta[name]
107
+ param.data = torch.Tensor([]).to(**meta)