hf-issue-36653 / modules.py
yairschiff's picture
Initial commit
e07ddf4 verified
raw
history blame contribute delete
356 Bytes
import torch
class ResBlock(torch.nn.Module):
def __init__(self, dim: int):
super().__init__()
self.fc = torch.nn.Linear(dim, dim)
self.relu = torch.nn.ReLU()
def forward(self, x):
# Apply the first linear layer
z = self.fc(x)
# Apply ReLU activation
z = self.relu(z)
return x + z