snowclipsed commited on
Commit
e15c30f
·
1 Parent(s): 8f87bef

remove is_quantized completely

Browse files
Files changed (1) hide show
  1. weights.py +17 -41
weights.py CHANGED
@@ -6,9 +6,6 @@ import re
6
  from contextlib import contextmanager
7
  from typing import Callable, List
8
 
9
- from .text import build_text_model
10
- from .config import TextConfig
11
-
12
 
13
  # Our custom linear has an module named linear, so we add linear to the name
14
  def add_linear_to_key(k: str) -> str:
@@ -46,7 +43,6 @@ def safetensors_open(safetensors_file: str):
46
  def _load_weights(
47
  get_tensor: Callable[[str], torch.Tensor],
48
  model: nn.Module,
49
- is_quantized: bool = False,
50
  ) -> None:
51
  """Internal function to load weights using a tensor getter function."""
52
  model = model.to(dtype=torch.float16)
@@ -111,42 +107,23 @@ def _load_weights(
111
  }
112
  )
113
 
114
- if not is_quantized:
115
- for i in range(len(model.text["blocks"])):
116
- prefix = f"text_model.transformer.h.{i}"
117
- blk = model.text["blocks"][i]
118
- weight_map.update(
119
- {
120
- f"{prefix}.ln.weight": blk["ln"].weight,
121
- f"{prefix}.ln.bias": blk["ln"].bias,
122
- f"{prefix}.mixer.Wqkv.weight": blk["attn"]["qkv"].weight,
123
- f"{prefix}.mixer.Wqkv.bias": blk["attn"]["qkv"].bias,
124
- f"{prefix}.mixer.out_proj.weight": blk["attn"]["proj"].weight,
125
- f"{prefix}.mixer.out_proj.bias": blk["attn"]["proj"].bias,
126
- f"{prefix}.mlp.fc1.weight": blk["mlp"]["fc1"].weight,
127
- f"{prefix}.mlp.fc1.bias": blk["mlp"]["fc1"].bias,
128
- f"{prefix}.mlp.fc2.weight": blk["mlp"]["fc2"].weight,
129
- f"{prefix}.mlp.fc2.bias": blk["mlp"]["fc2"].bias,
130
- }
131
- )
132
- else: # add special quantized path. this is specific to how bitblas expects weights to be loaded (.qweight)
133
- for i in range(len(model.text["blocks"])):
134
- prefix = f"text_model.transformer.h.{i}"
135
- blk = model.text["blocks"][i]
136
- weight_map.update(
137
- {
138
- f"{prefix}.ln.qweight": blk["ln"].weight,
139
- f"{prefix}.ln.bias": blk["ln"].bias,
140
- f"{prefix}.mixer.Wqkv.qweight": blk["attn"]["qkv"].weight,
141
- f"{prefix}.mixer.Wqkv.bias": blk["attn"]["qkv"].bias,
142
- f"{prefix}.mixer.out_proj.qweight": blk["attn"]["proj"].weight,
143
- f"{prefix}.mixer.out_proj.bias": blk["attn"]["proj"].bias,
144
- f"{prefix}.mlp.fc1.qweight": blk["mlp"]["fc1"].weight,
145
- f"{prefix}.mlp.fc1.bias": blk["mlp"]["fc1"].bias,
146
- f"{prefix}.mlp.fc2.qweight": blk["mlp"]["fc2"].weight,
147
- f"{prefix}.mlp.fc2.bias": blk["mlp"]["fc2"].bias,
148
- }
149
- )
150
 
151
  for key, tensor in weight_map.items():
152
  tensor.data.copy_(get_tensor(key))
@@ -175,7 +152,6 @@ def load_weights_from_safetensors(weights_file: str, model: nn.Module) -> None:
175
  _load_weights(
176
  lambda x: get_tensor(name_map[x]).to(dtype=torch.float16),
177
  model,
178
- # is_quantized,
179
  )
180
 
181
 
 
6
  from contextlib import contextmanager
7
  from typing import Callable, List
8
 
 
 
 
9
 
10
  # Our custom linear has an module named linear, so we add linear to the name
11
  def add_linear_to_key(k: str) -> str:
 
43
  def _load_weights(
44
  get_tensor: Callable[[str], torch.Tensor],
45
  model: nn.Module,
 
46
  ) -> None:
47
  """Internal function to load weights using a tensor getter function."""
48
  model = model.to(dtype=torch.float16)
 
107
  }
108
  )
109
 
110
+ for i in range(len(model.text["blocks"])):
111
+ prefix = f"text_model.transformer.h.{i}"
112
+ blk = model.text["blocks"][i]
113
+ weight_map.update(
114
+ {
115
+ f"{prefix}.ln.weight": blk["ln"].weight,
116
+ f"{prefix}.ln.bias": blk["ln"].bias,
117
+ f"{prefix}.mixer.Wqkv.weight": blk["attn"]["qkv"].weight,
118
+ f"{prefix}.mixer.Wqkv.bias": blk["attn"]["qkv"].bias,
119
+ f"{prefix}.mixer.out_proj.weight": blk["attn"]["proj"].weight,
120
+ f"{prefix}.mixer.out_proj.bias": blk["attn"]["proj"].bias,
121
+ f"{prefix}.mlp.fc1.weight": blk["mlp"]["fc1"].weight,
122
+ f"{prefix}.mlp.fc1.bias": blk["mlp"]["fc1"].bias,
123
+ f"{prefix}.mlp.fc2.weight": blk["mlp"]["fc2"].weight,
124
+ f"{prefix}.mlp.fc2.bias": blk["mlp"]["fc2"].bias,
125
+ }
126
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
  for key, tensor in weight_map.items():
129
  tensor.data.copy_(get_tensor(key))
 
152
  _load_weights(
153
  lambda x: get_tensor(name_map[x]).to(dtype=torch.float16),
154
  model,
 
155
  )
156
 
157