DeepDream-MLX / mlx_googlenet.py
NickMystic's picture
Upload folder using huggingface_hub
bdff6f4 verified
"""
Minimal GoogLeNet (Inception V1) in MLX, up to inception4e.
Loads weights from a torchvision-exported npz (see export_googlenet_npz.py).
"""
import mlx.core as mx
import mlx.nn as nn
import numpy as np
def _conv_bn(in_ch, out_ch, kernel_size, stride=1, padding=0):
return nn.Sequential(
nn.Conv2d(
in_ch,
out_ch,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=False,
),
nn.BatchNorm(out_ch, eps=1e-3, momentum=0.1),
nn.ReLU(),
)
class Inception(nn.Module):
def __init__(self, in_ch, ch1, ch3r, ch3, ch5r, ch5, pool_proj):
super().__init__()
self.branch1 = _conv_bn(in_ch, ch1, 1)
self.branch2_1 = _conv_bn(in_ch, ch3r, 1)
self.branch2_2 = _conv_bn(ch3r, ch3, 3, padding=1)
self.branch3_1 = _conv_bn(in_ch, ch5r, 1)
# The reference torchvision GoogLeNet uses a 3x3 conv here (not 5x5)
self.branch3_2 = _conv_bn(ch5r, ch5, 3, padding=1)
self.branch4_pool = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
self.branch4_2 = _conv_bn(in_ch, pool_proj, 1)
def __call__(self, x):
b1 = self.branch1(x)
b2 = self.branch2_2(self.branch2_1(x))
b3 = self.branch3_2(self.branch3_1(x))
b4 = self.branch4_2(self.branch4_pool(x))
return mx.concatenate([b1, b2, b3, b4], axis=-1)
class GoogLeNet(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = _conv_bn(3, 64, 7, stride=2, padding=3)
self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)
self.conv2 = _conv_bn(64, 64, 1)
self.conv3 = _conv_bn(64, 192, 3, padding=1)
self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)
self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32)
self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64)
self.maxpool3 = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)
self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64)
self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64)
self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64)
self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64)
self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128)
self.maxpool4 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128)
self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128)
def forward_with_endpoints(self, x):
endpoints = {}
x = self.conv1(x)
x = self.maxpool1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.maxpool2(x)
x = self.inception3a(x)
endpoints["inception3a"] = x
x = self.inception3b(x)
endpoints["inception3b"] = x
x = self.maxpool3(x)
x = self.inception4a(x)
endpoints["inception4a"] = x
x = self.inception4b(x)
endpoints["inception4b"] = x
x = self.inception4c(x)
endpoints["inception4c"] = x
x = self.inception4d(x)
endpoints["inception4d"] = x
x = self.inception4e(x)
endpoints["inception4e"] = x
x = self.maxpool4(x)
x = self.inception5a(x)
endpoints["inception5a"] = x
x = self.inception5b(x)
endpoints["inception5b"] = x
return x, endpoints
def __call__(self, x):
_, endpoints = self.forward_with_endpoints(x)
return endpoints
def load_npz(self, path: str):
data = np.load(path)
def load_weight(key, target_module, param_name="weight", transpose=False):
# Check for standard float16/32 key
if key in data:
w = data[key]
# Check for int8 quantized key
elif f"{key}_int8" in data:
w_int8 = data[f"{key}_int8"]
scale = data[f"{key}_scale"]
# Dequantize
w = w_int8.astype(scale.dtype) * scale
else:
raise ValueError(f"Missing key {key} (or {key}_int8) in npz")
# Transpose for Conv2d weights if needed (PyTorch [O,I,H,W] -> MLX [O,H,W,I])
if transpose and w.ndim == 4:
w = np.transpose(w, (0, 2, 3, 1))
# Assign to module
target_module[param_name] = mx.array(w)
def load_conv_bn(prefix, seq_mod: nn.Sequential):
conv = seq_mod.layers[0]
bn = seq_mod.layers[1]
load_weight(f"{prefix}.conv.weight", conv, transpose=True)
load_weight(f"{prefix}.bn.weight", bn)
load_weight(f"{prefix}.bn.bias", bn, param_name="bias")
load_weight(f"{prefix}.bn.running_mean", bn, param_name="running_mean")
load_weight(f"{prefix}.bn.running_var", bn, param_name="running_var")
load_conv_bn("conv1", self.conv1)
load_conv_bn("conv2", self.conv2)
load_conv_bn("conv3", self.conv3)
def load_inception(prefix, module: Inception):
load_conv_bn(f"{prefix}.branch1", module.branch1)
load_conv_bn(f"{prefix}.branch2.0", module.branch2_1)
load_conv_bn(f"{prefix}.branch2.1", module.branch2_2)
load_conv_bn(f"{prefix}.branch3.0", module.branch3_1)
load_conv_bn(f"{prefix}.branch3.1", module.branch3_2)
load_conv_bn(f"{prefix}.branch4.1", module.branch4_2)
load_inception("inception3a", self.inception3a)
load_inception("inception3b", self.inception3b)
load_inception("inception4a", self.inception4a)
load_inception("inception4b", self.inception4b)
load_inception("inception4c", self.inception4c)
load_inception("inception4d", self.inception4d)
load_inception("inception4e", self.inception4e)
load_inception("inception5a", self.inception5a)
load_inception("inception5b", self.inception5b)