linoyts HF Staff commited on
Commit
861b772
·
verified ·
1 Parent(s): 5376c20

Update optimization_utils.py

Browse files
Files changed (1) hide show
  1. optimization_utils.py +28 -3
optimization_utils.py CHANGED
@@ -98,10 +98,35 @@ def capture_component_call(
98
  captured_call.kwargs = e.kwargs
99
 
100
 
 
 
 
 
 
 
 
 
101
  def drain_module_parameters(module: torch.nn.Module):
102
- state_dict_meta = {name: {'device': tensor.device, 'dtype': tensor.dtype} for name, tensor in module.state_dict().items()}
103
- state_dict = {name: torch.nn.Parameter(torch.empty_like(tensor, device='cpu')) for name, tensor in module.state_dict().items()}
 
 
 
 
 
 
 
 
 
 
 
 
104
  module.load_state_dict(state_dict, assign=True)
 
105
  for name, param in state_dict.items():
106
  meta = state_dict_meta[name]
107
- param.data = torch.Tensor([]).to(**meta)
 
 
 
 
 
98
  captured_call.kwargs = e.kwargs
99
 
100
 
101
+ # def drain_module_parameters(module: torch.nn.Module):
102
+ # state_dict_meta = {name: {'device': tensor.device, 'dtype': tensor.dtype} for name, tensor in module.state_dict().items()}
103
+ # state_dict = {name: torch.nn.Parameter(torch.empty_like(tensor, device='cpu')) for name, tensor in module.state_dict().items()}
104
+ # module.load_state_dict(state_dict, assign=True)
105
+ # for name, param in state_dict.items():
106
+ # meta = state_dict_meta[name]
107
+ # param.data = torch.Tensor([]).to(**meta)
108
+
109
  def drain_module_parameters(module: torch.nn.Module):
110
+ state_dict_meta = {
111
+ name: {'device': tensor.device, 'dtype': tensor.dtype}
112
+ for name, tensor in module.state_dict().items()
113
+ }
114
+
115
+ state_dict = {}
116
+ for name, tensor in module.state_dict().items():
117
+ try:
118
+ param = torch.nn.Parameter(torch.empty_like(tensor, device='cpu'))
119
+ except NotImplementedError:
120
+ # Fallback: dequantize (or convert) if empty_like isn't implemented
121
+ param = torch.nn.Parameter(tensor.dequantize().to('cpu') if hasattr(tensor, 'dequantize') else tensor.to('cpu'))
122
+ state_dict[name] = param
123
+
124
  module.load_state_dict(state_dict, assign=True)
125
+
126
  for name, param in state_dict.items():
127
  meta = state_dict_meta[name]
128
+ try:
129
+ param.data = torch.Tensor([]).to(**meta)
130
+ except NotImplementedError:
131
+ # Fallback for quantized tensors
132
+ param.data = (param.dequantize().to(**meta) if hasattr(param, 'dequantize') else torch.Tensor([]).to(**meta))