linoyts HF Staff commited on
Commit
07bec8d
·
verified ·
1 Parent(s): 861b772

Update optimization_utils.py

Browse files
Files changed (1) hide show
  1. optimization_utils.py +3 -28
optimization_utils.py CHANGED
@@ -98,35 +98,10 @@ def capture_component_call(
98
  captured_call.kwargs = e.kwargs
99
 
100
 
101
- # def drain_module_parameters(module: torch.nn.Module):
102
- # state_dict_meta = {name: {'device': tensor.device, 'dtype': tensor.dtype} for name, tensor in module.state_dict().items()}
103
- # state_dict = {name: torch.nn.Parameter(torch.empty_like(tensor, device='cpu')) for name, tensor in module.state_dict().items()}
104
- # module.load_state_dict(state_dict, assign=True)
105
- # for name, param in state_dict.items():
106
- # meta = state_dict_meta[name]
107
- # param.data = torch.Tensor([]).to(**meta)
108
-
109
  def drain_module_parameters(module: torch.nn.Module):
110
- state_dict_meta = {
111
- name: {'device': tensor.device, 'dtype': tensor.dtype}
112
- for name, tensor in module.state_dict().items()
113
- }
114
-
115
- state_dict = {}
116
- for name, tensor in module.state_dict().items():
117
- try:
118
- param = torch.nn.Parameter(torch.empty_like(tensor, device='cpu'))
119
- except NotImplementedError:
120
- # Fallback: dequantize (or convert) if empty_like isn't implemented
121
- param = torch.nn.Parameter(tensor.dequantize().to('cpu') if hasattr(tensor, 'dequantize') else tensor.to('cpu'))
122
- state_dict[name] = param
123
-
124
  module.load_state_dict(state_dict, assign=True)
125
-
126
  for name, param in state_dict.items():
127
  meta = state_dict_meta[name]
128
- try:
129
- param.data = torch.Tensor([]).to(**meta)
130
- except NotImplementedError:
131
- # Fallback for quantized tensors
132
- param.data = (param.dequantize().to(**meta) if hasattr(param, 'dequantize') else torch.Tensor([]).to(**meta))
 
98
  captured_call.kwargs = e.kwargs
99
 
100
 
 
 
 
 
 
 
 
 
101
  def drain_module_parameters(module: torch.nn.Module):
102
+ state_dict_meta = {name: {'device': tensor.device, 'dtype': tensor.dtype} for name, tensor in module.state_dict().items()}
103
+ state_dict = {name: torch.nn.Parameter(torch.empty_like(tensor, device='cpu')) for name, tensor in module.state_dict().items()}
 
 
 
 
 
 
 
 
 
 
 
 
104
  module.load_state_dict(state_dict, assign=True)
 
105
  for name, param in state_dict.items():
106
  meta = state_dict_meta[name]
107
+ param.data = torch.Tensor([]).to(**meta)