cb1cyf's picture
fix: omnigen2
cf4796c
raw
history blame contribute delete
110 Bytes
import torch.nn.functional as F
def swiglu(x, y):
return F.silu(x.float(), inplace=False).to(x.dtype) * y