FlashKAT: Understanding and Addressing Performance Bottlenecks in the Kolmogorov-Arnold Transformer
Abstract
FlashKAT improves the training speed and reduces rounding errors in the backward pass of the Kolmogorov-Arnold Transformer by minimizing gradient accumulation.
The Kolmogorov-Arnold Network (KAN) has been gaining popularity as an alternative to the multi-layer perceptron (MLP) with its increased expressiveness and interpretability. However, the KAN can be orders of magnitude slower due to its increased computational cost and training instability, limiting its applicability to larger-scale tasks. Recently, the Kolmogorov-Arnold Transformer (KAT) has been proposed, which can achieve FLOPs similar to the traditional Transformer with MLPs by leveraging Group-Rational KAN (GR-KAN). Unfortunately, despite the comparable FLOPs, our characterizations reveal that the KAT is still 123x slower in training speeds, indicating that there are other performance bottlenecks beyond FLOPs. In this paper, we conduct a series of experiments to understand the root cause of the slowdown in KAT. We uncover that the slowdown can be isolated to memory stalls and, more specifically, in the backward pass of GR-KAN caused by inefficient gradient accumulation. To address this memory bottleneck, we propose FlashKAT, which builds on our restructured kernel that minimizes gradient accumulation with atomic adds and accesses to slow memory. Evaluations demonstrate that FlashKAT can achieve a training speedup of 86.5x compared with the state-of-the-art KAT, while reducing rounding errors in the coefficient gradients. Our code is available at https://github.com/OSU-STARLAB/FlashKAT.
Models citing this paper 0
No model linking this paper
Datasets citing this paper 0
No dataset linking this paper
Spaces citing this paper 0
No Space linking this paper
Collections including this paper 0
No Collection including this paper