File size: 6,079 Bytes
3c8f058
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bdff6f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3c8f058
 
 
 
bdff6f4
 
 
 
 
 
 
3c8f058
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
"""
Minimal GoogLeNet (Inception V1) in MLX, up to inception4e.
Loads weights from a torchvision-exported npz (see export_googlenet_npz.py).
"""

import mlx.core as mx
import mlx.nn as nn
import numpy as np


def _conv_bn(in_ch, out_ch, kernel_size, stride=1, padding=0):
    return nn.Sequential(
        nn.Conv2d(
            in_ch,
            out_ch,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            bias=False,
        ),
        nn.BatchNorm(out_ch, eps=1e-3, momentum=0.1),
        nn.ReLU(),
    )


class Inception(nn.Module):
    def __init__(self, in_ch, ch1, ch3r, ch3, ch5r, ch5, pool_proj):
        super().__init__()
        self.branch1 = _conv_bn(in_ch, ch1, 1)

        self.branch2_1 = _conv_bn(in_ch, ch3r, 1)
        self.branch2_2 = _conv_bn(ch3r, ch3, 3, padding=1)

        self.branch3_1 = _conv_bn(in_ch, ch5r, 1)
        # The reference torchvision GoogLeNet uses a 3x3 conv here (not 5x5)
        self.branch3_2 = _conv_bn(ch5r, ch5, 3, padding=1)

        self.branch4_pool = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
        self.branch4_2 = _conv_bn(in_ch, pool_proj, 1)

    def __call__(self, x):
        b1 = self.branch1(x)
        b2 = self.branch2_2(self.branch2_1(x))
        b3 = self.branch3_2(self.branch3_1(x))
        b4 = self.branch4_2(self.branch4_pool(x))
        return mx.concatenate([b1, b2, b3, b4], axis=-1)


class GoogLeNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = _conv_bn(3, 64, 7, stride=2, padding=3)
        self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)

        self.conv2 = _conv_bn(64, 64, 1)
        self.conv3 = _conv_bn(64, 192, 3, padding=1)
        self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)

        self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32)
        self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64)
        self.maxpool3 = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)

        self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64)
        self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64)
        self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64)
        self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64)
        self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128)
        self.maxpool4 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128)
        self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128)

    def forward_with_endpoints(self, x):
        endpoints = {}
        x = self.conv1(x)
        x = self.maxpool1(x)

        x = self.conv2(x)
        x = self.conv3(x)
        x = self.maxpool2(x)

        x = self.inception3a(x)
        endpoints["inception3a"] = x
        x = self.inception3b(x)
        endpoints["inception3b"] = x
        x = self.maxpool3(x)

        x = self.inception4a(x)
        endpoints["inception4a"] = x
        x = self.inception4b(x)
        endpoints["inception4b"] = x
        x = self.inception4c(x)
        endpoints["inception4c"] = x
        x = self.inception4d(x)
        endpoints["inception4d"] = x
        x = self.inception4e(x)
        endpoints["inception4e"] = x
        x = self.maxpool4(x)

        x = self.inception5a(x)
        endpoints["inception5a"] = x
        x = self.inception5b(x)
        endpoints["inception5b"] = x
        return x, endpoints

    def __call__(self, x):
        _, endpoints = self.forward_with_endpoints(x)
        return endpoints

    def load_npz(self, path: str):
        data = np.load(path)

        def load_weight(key, target_module, param_name="weight", transpose=False):
            # Check for standard float16/32 key
            if key in data:
                w = data[key]
            # Check for int8 quantized key
            elif f"{key}_int8" in data:
                w_int8 = data[f"{key}_int8"]
                scale = data[f"{key}_scale"]
                # Dequantize
                w = w_int8.astype(scale.dtype) * scale
            else:
                raise ValueError(f"Missing key {key} (or {key}_int8) in npz")

            # Transpose for Conv2d weights if needed (PyTorch [O,I,H,W] -> MLX [O,H,W,I])
            if transpose and w.ndim == 4:
                w = np.transpose(w, (0, 2, 3, 1))
            
            # Assign to module
            target_module[param_name] = mx.array(w)

        def load_conv_bn(prefix, seq_mod: nn.Sequential):
            conv = seq_mod.layers[0]
            bn = seq_mod.layers[1]
            
            load_weight(f"{prefix}.conv.weight", conv, transpose=True)
            
            load_weight(f"{prefix}.bn.weight", bn)
            load_weight(f"{prefix}.bn.bias", bn, param_name="bias")
            load_weight(f"{prefix}.bn.running_mean", bn, param_name="running_mean")
            load_weight(f"{prefix}.bn.running_var", bn, param_name="running_var")

        load_conv_bn("conv1", self.conv1)
        load_conv_bn("conv2", self.conv2)
        load_conv_bn("conv3", self.conv3)

        def load_inception(prefix, module: Inception):
            load_conv_bn(f"{prefix}.branch1", module.branch1)
            load_conv_bn(f"{prefix}.branch2.0", module.branch2_1)
            load_conv_bn(f"{prefix}.branch2.1", module.branch2_2)
            load_conv_bn(f"{prefix}.branch3.0", module.branch3_1)
            load_conv_bn(f"{prefix}.branch3.1", module.branch3_2)
            load_conv_bn(f"{prefix}.branch4.1", module.branch4_2)

        load_inception("inception3a", self.inception3a)
        load_inception("inception3b", self.inception3b)
        load_inception("inception4a", self.inception4a)
        load_inception("inception4b", self.inception4b)
        load_inception("inception4c", self.inception4c)
        load_inception("inception4d", self.inception4d)
        load_inception("inception4e", self.inception4e)
        load_inception("inception5a", self.inception5a)
        load_inception("inception5b", self.inception5b)