File size: 7,594 Bytes
1c87faa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9308dfe
1c87faa
9308dfe
1c87faa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9308dfe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1c87faa
 
 
 
 
 
 
 
 
 
 
 
 
 
9308dfe
 
1c87faa
 
9308dfe
 
 
1c87faa
 
 
9308dfe
1c87faa
9308dfe
1c87faa
 
 
 
 
9308dfe
 
 
 
 
 
1c87faa
 
9308dfe
1c87faa
 
9308dfe
1c87faa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import safetensors
import torch
import torch.nn as nn

from contextlib import contextmanager
from typing import Callable, List


@contextmanager
def safetensors_open(safetensors_file: str):
    """
    Simplify interfacing with safetensors files. Eliminates the need to ignore
    type errors when using the `safe_open` function.
    """
    with safetensors.safe_open(
        safetensors_file, framework="pt"
    ) as st:  # pyright: ignore

        def get_tensor(name: str) -> torch.Tensor:
            return st.get_tensor(name)

        def get_keys() -> List[str]:
            return st.keys()

        get_tensor.keys = get_keys

        yield get_tensor


def _load_weights(get_tensor: Callable[[str], torch.Tensor], model: nn.Module) -> None:
    """Internal function to load weights using a tensor getter function."""
    model = model.to(dtype=torch.bfloat16)

    vision = model.vision
    region = model.region
    weight_map = {
        "vision_encoder.encoder.model.visual.patch_embed.linear.weight": vision[
            "patch_emb"
        ].weight,
        "vision_encoder.encoder.model.visual.patch_embed.linear.bias": vision[
            "patch_emb"
        ].bias,
        "vision_encoder.encoder.model.visual.pos_embed": vision.pos_emb,
        "vision_encoder.encoder.model.visual.norm.weight": vision["post_ln"].weight,
        "vision_encoder.encoder.model.visual.norm.bias": vision["post_ln"].bias,
        "vision_encoder.projection.mlp.fc1.weight": vision["proj_mlp"]["fc1"].weight,
        "vision_encoder.projection.mlp.fc1.bias": vision["proj_mlp"]["fc1"].bias,
        "vision_encoder.projection.mlp.fc2.weight": vision["proj_mlp"]["fc2"].weight,
        "vision_encoder.projection.mlp.fc2.bias": vision["proj_mlp"]["fc2"].bias,
        "text_model.transformer.embd.wte.weight": model.text.wte,
        "text_model.lm_head.ln.weight": model.text["post_ln"].weight,
        "text_model.lm_head.ln.bias": model.text["post_ln"].bias,
        "text_model.lm_head.linear.weight": model.text["lm_head"].weight,
        "text_model.lm_head.linear.bias": model.text["lm_head"].bias,
        "region_model.coordinate_encoder.weight": region["coord_encoder"].weight,
        "region_model.coordinate_encoder.bias": region["coord_encoder"].bias,
        "region_model.coordinate_decoder.fc1.weight": region["coord_decoder"][
            "fc1"
        ].weight,
        "region_model.coordinate_decoder.fc1.bias": region["coord_decoder"]["fc1"].bias,
        "region_model.coordinate_decoder.fc2.weight": region["coord_decoder"][
            "fc2"
        ].weight,
        "region_model.coordinate_decoder.fc2.bias": region["coord_decoder"]["fc2"].bias,
        "region_model.size_encoder.weight": region["size_encoder"].weight,
        "region_model.size_encoder.bias": region["size_encoder"].bias,
        "region_model.size_decoder.fc1.weight": region["size_decoder"]["fc1"].weight,
        "region_model.size_decoder.fc1.bias": region["size_decoder"]["fc1"].bias,
        "region_model.size_decoder.fc2.weight": region["size_decoder"]["fc2"].weight,
        "region_model.size_decoder.fc2.bias": region["size_decoder"]["fc2"].bias,
    }

    for i in range(len(model.vision["blocks"])):
        prefix = f"vision_encoder.encoder.model.visual.blocks.{i}"
        blk = model.vision["blocks"][i]
        weight_map.update(
            {
                f"{prefix}.norm1.weight": blk["ln1"].weight,
                f"{prefix}.norm1.bias": blk["ln1"].bias,
                f"{prefix}.norm2.weight": blk["ln2"].weight,
                f"{prefix}.norm2.bias": blk["ln2"].bias,
                f"{prefix}.attn.qkv.weight": blk["attn"]["qkv"].weight,
                f"{prefix}.attn.qkv.bias": blk["attn"]["qkv"].bias,
                f"{prefix}.attn.proj.weight": blk["attn"]["proj"].weight,
                f"{prefix}.attn.proj.bias": blk["attn"]["proj"].bias,
                f"{prefix}.mlp.fc1.weight": blk["mlp"]["fc1"].weight,
                f"{prefix}.mlp.fc1.bias": blk["mlp"]["fc1"].bias,
                f"{prefix}.mlp.fc2.weight": blk["mlp"]["fc2"].weight,
                f"{prefix}.mlp.fc2.bias": blk["mlp"]["fc2"].bias,
            }
        )

    for i in range(len(model.text["blocks"])):
        prefix = f"text_model.transformer.h.{i}"
        blk = model.text["blocks"][i]
        weight_map.update(
            {
                f"{prefix}.ln.weight": blk["ln"].weight,
                f"{prefix}.ln.bias": blk["ln"].bias,
                f"{prefix}.mixer.Wqkv.weight": blk["attn"]["qkv"].weight,
                f"{prefix}.mixer.Wqkv.bias": blk["attn"]["qkv"].bias,
                f"{prefix}.mixer.out_proj.weight": blk["attn"]["proj"].weight,
                f"{prefix}.mixer.out_proj.bias": blk["attn"]["proj"].bias,
                f"{prefix}.mlp.fc1.weight": blk["mlp"]["fc1"].weight,
                f"{prefix}.mlp.fc1.bias": blk["mlp"]["fc1"].bias,
                f"{prefix}.mlp.fc2.weight": blk["mlp"]["fc2"].weight,
                f"{prefix}.mlp.fc2.bias": blk["mlp"]["fc2"].bias,
            }
        )

    for key, tensor in weight_map.items():
        tensor.data.copy_(get_tensor(key))

    region.coord_features.data.copy_(
        get_tensor("region_model.coordinate_features.weight").T
    )
    region.size_features.data.copy_(get_tensor("region_model.size_features.weight").T)


def load_weights_from_safetensors(weights_file: str, model: nn.Module) -> None:
    """Load weights from a safetensors file into a MoondreamModel instance."""
    with safetensors_open(weights_file) as get_tensor:
        if (
            "vision.blocks.0.attn.proj.bias" in get_tensor.keys()
            or "model.vision.blocks.0.attn.proj.bias" in get_tensor.keys()
        ):
            with safetensors_open(weights_file) as get_tensor:
                tensors = {
                    k.replace("model.", ""): get_tensor(k) for k in get_tensor.keys()
                }
                model.load_state_dict(tensors, strict=False)
        else:
            # Wrap the get_tensor function to handle key normalization
            name_map = {k.replace("._orig_mod", ""): k for k in get_tensor.keys()}
            _load_weights(
                lambda x: get_tensor(name_map[x]).to(dtype=torch.bfloat16), model
            )


def load_weights_from_pt(weights_file: str, model: nn.Module) -> None:
    """Load weights from a PyTorch file into a MoondreamModel instance."""
    device = str(torch.empty(0).device)
    tensors = torch.load(weights_file, map_location=device, weights_only=True)
    if "vision.blocks.0.attn.proj.bias" in tensors.keys():
        missing_keys, unexpected_keys = model.load_state_dict(tensors, strict=False)
        print("Missing keys:", missing_keys)
        print("Unexpected keys:", unexpected_keys)
    else:
        tensors = {
            k.replace("._orig_mod", ""): v.to(dtype=torch.bfloat16)
            for k, v in tensors.items()
        }
        _load_weights(lambda x: tensors[x], model)


def load_weights_into_model(weights_file: str, model: nn.Module) -> None:
    """
    Load weights from either a safetensors or PyTorch file directly into a MoondreamModel instance.

    Args:
        weights_file: Path to weights file (either .safetensors or .pt)
        model: MoondreamModel instance to load weights into
    """
    if weights_file.endswith(".safetensors"):
        load_weights_from_safetensors(weights_file, model)
    else:
        load_weights_from_pt(weights_file, model)

    # Make all parameters contiguous
    for param in model.parameters():
        param.data = param.data.contiguous()