A Review for Transformer Variants

machine-learning, deep-learning, language-modelling, llms, review

In this article, I will explore various alternatives to transformers, considering their architectural improvements, computational efficiency, and performance results across different benchmarks. I intend to continually update this post with new models in the future. If you believe there are any models or important points that should be included or any corrections that need to be made, please feel free to contact me.

Transformer #

Space: O(T^2 + Td) Time: O(T log Td)

Traditional sequential models, like recurrent neural networks (RNNs) and long short-term memory networks (LSTMs), faced challenges in effectively capturing long-range dependencies and parallelizing computations. The Transformer architecture addresses these issues by relying on self-attention mechanisms.

At the core of the Transformer is the self-attention mechanism. Unlike traditional approaches, where each element in a sequence is processed one at a time, self-attention allows the model to weigh the importance of different elements relative to each other. This enables capturing relationships between distant words in a sentence.

Transformer has some limitations and constraints in terms of computation and storage. The Transformer is based on dot-product attention that computes softmax(Q*K.t), which is computationally heavy, and it needs to store a KV cache that is also heavy in memory at inference. This is a limiting factor, especially in problems with extended context sizes. Transformers’ space complexity increases quadratically with the increasing context size.

The Transformer is a key component of the current LLM revolution, and researchers are actively seeking alternatives to address its limitations. While there have been several proposed alternatives, the original model has yet to be as successful as the original model. Nevertheless, considering the scale of the state-of-the-art LLM problem and the high cost of training these models, even a slight improvement can have a significant impact.

RWKV #

Space: O(Td) Time: O(Td)

👩‍💻 Code 📎 Paper

RWKV is a new approach that combines the advantages of RNNs and Transformers while mitigating their known limitations. It introduces several key strategies that allow it to capture local and long-range dependencies. RWKV offers a promising and viable solution for handling tasks involving large-scale models with billions of parameters, exhibiting competitive performance at a fraction of the computational cost. If you’re interested in improving transformers’ memory and computational complexity in natural language processing tasks, RWKV is worth exploring. However, in my experiments, it did not perform as good as transformers in audio tasks.

Image

The aim of RWKV is to harness the advantages of both RNNs and Transformers while addressing their shortcomings. In comparison to RNNs, RWKV provides more efficient parallelizable training and improved performance in capturing long-range dependencies. This is achieved by eliminating the reliance on a single vector to transmit the context between different time steps.

Compared to Transformers, RWKV offers linear attention and constant computational and memory complexity during inference, making it more efficient for large-scale models.

There are two primary components of a RWKV block: time-mixing and channel-mixing. Time-mixing operates by using linear interpolation to blend the current input with the input from the previous time step. This process effectively combines and controls the information in the input channels. The time-mixing block is composed of three equations that compute the values of r, k, and v at each time step, which are then used to calculate the WKV that plays the role of Transformer’s attention. Essentially, as time progresses and t increases, the vector o_t becomes increasingly dependent on a historical record.

Image

Channel mixing aids in capturing local information effectively. It works by computing the values of r, k, and o at each time step, which are then used to calculate the final output vector. The channel-mixing block comprises three equations that compute the values of r, k, and o at each time step. The output vector is calculated by taking the sigmoid of the receptance r and using it as a “forget gate” to eliminate unnecessary historical information. The final output vector is then computed by multiplying the sigmoid of r with the result of a max pooling operation on k, followed by a squared ReLU activation.

Image

RWKV has certain limitations. For instance, it may struggle with tasks that require recalling information over a long context. This is because RWKV relies on a limited window between time steps, whereas Transformers have access to all the information at each step through attention. Another limitation is the prominence of prompt engineering in RWKV. In RWKV, the linear attention mechanism restricts the extent to which prompt information is passed on to the model. Empirical evidence supports this, showing that when prompts were adjusted to be for RWKV, there was a significant increase in performance, with the F1 measure improving from 44.2% to 74.8%.

The results have demonstrated that RWKV has delivered impressive performance and surpassed other models in certain tasks. Nevertheless, when tasks require a greater emphasis on context, RWKV’s performance tends to decline.

RWKV model serves as an outstanding illustration of an open-source project, with the paper mentioning many contributors. It is impressive to observe the significant influence that open-source research has had in advancing innovative AI solutions on a grand scale. Efforts are already underway to address the limitations of RWKV. You can join their discord if you are willing to get involved in the development process.

Hyena #

Time: O(NdT (logT + d)) st. N is the number of projections Space Complexity: O(Td)

👩‍💻 Code 📎 Paper 📎 Blogpost

Hyena addresses the Transformer’s limitations with their attention operator, which becomes computationally expensive with longer sequences and cannot access a significant amount of context. Hyena offers a subquadratic alternative to attention by combining long convolutions with data-controlled gating. In various tasks involving recall and reasoning with sequences containing thousands to hundreds of thousands of tokens, Hyena has demonstrated significant improvements in accuracy. It achieves Transformer-level quality while reducing required training computed by 20% at a sequence length of 2K. Notably, Hyena operators are also faster, offering twice the efficiency of highly optimized attention operators.

Image description

Hyena first projects the input into a set of vectors v, x_1, ..., x_n and v acts like the value vector as in the attention. Then it projects v, x_1, ..., x_n with learnable filters h_1, ..., h_n. Hyena applies a multiplicative gating interaction to the projected vectors, similar to LSTMs. This gating is used to control the information flow through the recurrence.

Image

Hyena uses an implicit long convolution to the gated input, using a set of Hyena filters that are parametrized by a feedforward network. This convolution is used to capture long-range dependencies in the input.

Image
Image description
Image

Below is the overall Hyena operator in Python as described in the blog post:

def hyena_filter(t):
    return window(t) * ffn(t) * poitional_encoding(t)

x, v = input_projections(u)
for o in range(hyena_orders):
    h = hyena_filter(L)  # long conv filter parameterized via an MLP
    v = x[o] * fftconv(h, v)  # elem-wise mult & fftconv
)

Regarding language modeling, Hyena is compared to GPTNeo and RWKV. Hyena outperforms in few-shot learning, but RWKV is better in zero-shot accuracy on SuperGLUE tasks. Moreover, Hyena performs on par with a Transformer on language modeling with WikiText103 dataset.

Regarding runtime, the cross-over point between Hyena and attention occurs at 2048, and Hyena and flash attention range from 4086 to 8196.

Image

My 2 cents: Hyena is an interesting approach for extending input length through scalable computing. Nonetheless, further investigations on a larger scale are necessary to confirm its efficacy as a viable alternative to the Transformer model. For now, the RWKV model offers better value in terms of both complexity and performance. However, if the goal is to tackle lengthy context problems, Hyena could be a promising choice.

Attention Free Transformer #

Time: AFT-simple O(Td), AFT-full O(T^2d) Space: O(Td)

📎 Paper 👩‍💻 Code (unofficial)

Attention Free Transformer (AFT) eliminates the need for dot product self-attention, making it scalable with long inputs and large model sizes. AFT takes advantage of locality and spatial weight sharing while maintaining global connectivity, resulting in excellent efficiency. The paper presents experiments on autoregressive modeling tasks and image recognition, demonstrating competitive performance compared to other models.

AFT is a weighted average over values combined by the queries with element-wise multiplication instead of a heavy attention matrix. In an Attention-based Feedforward Transformer (AFT) layer, the learned position biases are added to the key values. Then, the values are combined with the key using element-wise multiplication. Finally, the resulting values are multiplied with the query element-wise. Thus, it avoids the computationally heavy softmax(Q*K.t) operation of a Transformer. “AFT can be interpreted as performing implicit attention with as many heads as feature dimensions, where the attention matrices take a factorized form.”

Image

There are four different versions of AFT. The first version is AFT-simple, which does not utilize position encoding. The second version is AFT-full, which includes regular position encoding. The third version is AFT-local, incorporating a learned set of relative position biases within a specified window. The fourth version is AFT-conv, which utilizes depth-wise separable convolution and is proposed especially for image tasks.

Screenshot
AFT-conv formulation.

In terms of results, the paper shows that AFT achieves comparable or better accuracy than traditional Transformers on various autoregressive modeling tasks and image recognition tasks while using much smaller memory footprints. AFT also outperforms other efficient Transformer variants such as Linformer and Performer. The paper also demonstrates the effectiveness of AFT on variable-length inputs and shows that it is well-suited for pretraining and finetuning workflows in vision tasks.

In general, AFT shows potential as a substitute for conventional Transformers. It substantially reduces computational requirements and memory usage, all while maintaining high performance. Moreover, AFT serves as the foundation for the development of both Hyena and RWKV.

Retentive Network #

Time: O(Td(b + h)) s.t. b chunk size and h is head dimension Space: O(T)

📎 Paper 👩‍💻 Official Code 👩‍💻 Code 1 👩‍💻 Code 2

RetNet borrows recurrent inference from RNN and parallel-training from Transformer, combining them to achieve an efficient model. Recurrent models facilitate O(1) inference as they do not require modeling the relationship between each input and every other input in the sequence. RetNet applies chunk-wise recurrence to alleviate the representational bottleneck of RNNs and effectively model longer context.

Caption
Difference between Transformer and RetNet

RetNet introduces a novel approach to replace the softmax operation utilized in self-attention with a Hadamard product. By leveraging a newly introduced D-matrix and incorporating a GroupNorm operation, the relative attention weights assigned to each token in the input sequence are determined. Traditionally, softmax plays a crucial role in capturing long-term dependencies and contributes to the remarkable performance of Transformers.

In RetNet, training and inference use of different flows that result in the same computation. In the training phase, a parallel formulation is utilized, while in the inference phase, a recurrent formulation.I suggest you check this post by Shantanu Chandra who did a better job than the paper explaining how things work.

Image
Training and inference computation.

When we compare RetNet to attention-free transformers and RWKV, it retains the element-wise interactions in the sequence with the retention operation. It keeps the high-dimensional state of the encoded sequence information, which they claim to contribute to the model performance.

Results show that after ~2.7B parameters, RetNet achieves lower perplexity and outperforms Transformer. Most of the results are reported based on the 6.7B model. RetNet is significantly better than Transformer at this scale in zero-shot, few-shot learning.

RetNet replaces the KV cache of Transformers with recurrence and saves memory. Also, chunk-wise retention makes inference significantly scalable with increasing batch size and input length.

They also show that RetNet is computationally way more efficient than Transformer and almost on par with Transformer + Flash Attention 1 (needs to compare Flash Attention2). Results show that it uses 3.4x lower memory, 8.4x higher throughput, and 15.6x lower latency concerning a Transformer model.

When compared to the other Transformer alternatives, RetNet outperforms all the different models by a big margin on language modeling.

Image
Comparison with the other models.


Longnet #

Time: O(Td) Space: O(T/r log T/r d) s.t. r is the attention dilation rate

📎 Paper 👩‍💻 Code

LONGNET is designed to tackle longer sequence lengths. It can handle sequences with over 1 billion tokens while maintaining good performance on shorter sequences. This is accomplished through dilated attention, which enhances the model’s ability to attend to distant tokens. LONGNET has advantages such as linear computation complexity, the capability to serve as a distributed trainer for long sequences. Experiments confirm its effectiveness.

Image

To simplify the self-attention layers, LONGNET utilizes dilated attention. This approach involves dividing the input sequence into segments and dilating each segment at a specific rate. By doing so, the model is able to leverage different segment lengths and dilation rates to improve its modeling abilities. The outputs of each segment size and dilation rate pairs are then combined through a weighted sum. These weights are determined based on the softmax denominators of each output. Combination of using segments and dilated attention strikes a balance between considering the global context and maintaining efficiency, as dilation serves as an efficient approximation of the attention matrix.

Image Description

Two more tricks LONGNET employs for better modeling. It incorporates varying dilation rates in each attention head for more diversity. It also gradually increases the segment lengths and dilation rates in successive layers, allowing for the processing extremely long input sequences with a larger increasing receptive window in later layers.

To train LONGNET on 1 billion tokens, distributed training is necessary. LONGNET divides the inputs into segments, which are then distributed across different GPUs. These segments are processed simultaneously, with a constant communication overhead.

They used the Stack dataset to test the model, a source code collection with over 300 programming languages. They showed that LONGNET outperforms a vanilla Transformer model by a large margin in perplexity and computation. They were able to train LONGNET with 32k context size whereas the Transformer only 16k.

Image

My 2 cents: Consider using LONGNET when processing a long context or stream outputs.


MegaByte #

Time: O(T ^ (4/3) d) Space: O(T log Td)

📎 Paper 👩‍💻 Code

image

MEGABYTE is a “multiscale decoder architecture that enables end-to-end differentiable modeling of sequences of over one million bytes”. MEGABYTE utilizes byte values to directly model data, requiring the ability to effectively capture a lengthy context. To achieve this, it divides sequences into patches and employs a local model for each patch, while also incorporating a global model between patches. By doing so, MEGABYTE enables sub-quadratic self-attention, facilitates larger feedforward layers without incurring additional computational cost, and enhances parallelism during decoding. As a result, MEGABYTE delivers improved performance for training and generation efficiently.

MEGABYTES offers several advantages, including sub-quadratic self-attention, per patch feedforward layers, and parallel decoding. The sub-quadratic self-attention is achieved by dividing the input into smaller “patches,”. This reduces the self-attention cost to O(T^(4/3) d).

They note that in a Transformer, the feedforward layers consume about 98% of the FLOPs. MEGABYTES addresses this by replacing multiple passes of these layers with a single pass, utilizing a larger linear layer.

Furthermore, the use of patches also introduces a level of parallelism. As a result, they found that their 1.5B parameter model is 40% faster than a 350M Transformer model.

The MEGABYTE system is composed of three main components:

MEGABYTE is applied to language modeling, image modeling, and audio modeling. The cool thing is that it is trained by the raw byte values (hence the name). It is compared to PerceiverAR and a Transformer baseline. In all tasks, it outperforms both and is competitive with models that use tokenizers to discretize the input.

The ablation analysis shows the importance of having both local and global models. If one of these components is absent, there is a notable decline in performance.

My 2 cents: I find learning from raw bytes and utilizing multi-stage transformers intriguing. This approach can potentially revolutionize language model systems (LLMs). By eliminating tokenization models, we can bridge the gap between computers and models, paving the way for developing new generation LLM-based operating systems.

In addition, I’d like to try MEGABYTE for text-to-speech. I believe it is well-suited to learn local and global relations better than Transfomers for TTS.

Edit: Looks like UniAudio did it.

Noteworthy Mentions #

Here are a few other noteworthy models that I won’t delve into further since they have yet to gain much traction in the community or are simple tricks that don’t require much explanation.

Multi-Query Attention #

📎Paper 👩‍💻Code

Using shared key and value vectors among attention heads reduces the memory overhead at inference by reducing the size of the KV cache.

Linformer #

📎 Paper 👩‍💻 Code

A linear time self-attention is achieved by breaking down the scaled dot-product attention into multiple smaller attentions using linear projections. Together, these operations create a low-rank factorization of the original attention mechanism.

Roformer #

📎 Paper 👩‍💻 Code

“Rotary Position Embedding, or RoPE, is a type of position embedding which encodes absolute positional information with rotation matrix and naturally incorporates explicit relative position dependency in self-attention formulation.”

One Wide Feedforward is All You Need #

📎 Paper

It is suggested that the Feedforward Network (FFN) are unnecessary and redundant in Transformers. As a result, the FFN is removed from the Transformer decoder, shared in the encoder. Even though there is a small decrease in accuracy as a result of this change, when the model is scaled back to its original size, it leads to enhanced accuracy and decreased latency. They report 18.5% speed-up using this technique.

Performer #

Time: O(Td^2 log d) Space: O(Td log d + d^2 lod d)

📎 Paper 👩‍💻 Code

Performer can “estimate” regular dot-product attention using an approach called “Fast attention via positive orthogonal random features” FAVOR+. FAVOR+ combines low-rank approximation, matrix factorization, and matrix decomposition; then the space and time complexity becomes much more linear.

Reformer #

Time: O(T log Td) Space: O(T log T + Td)

📎Paper 👩‍💻Code (unofficial)

Reformer model incorporates three techniques to improve efficiency. First, it uses “reversible residuals” to reduce memory consumption by storing only one copy of the intermediate activation that can be used to reproduce the activations of the earlier layers by the model parameters. This helps minimize the memory overhead. Second, it splits values into chunks, saving memory in FFT layers and make the inference more efficient. Lastly, Reformer uses locality-sensitive hashing to approximate the attention matrix for a more efficient runtime.

Monarch Mixer #

👩‍💻Blog 👩‍💻Code

“Monarch Mixer uses monarch matrices for a sub-quadratic model in sequence length and model dimension. The idea is to replace the major elements of a Transformer with Monarch matrices — which are a class of structured matrices that “generalize the FFT and are sub-quadratic, hardware-efficient, and expressive.” In Monarch Mixer, they use layers built up from Monarch matrices to mix across the sequence (replacing the Attention operation) and across the model dimension (replacing the dense MLP).

Conformers #

📎 Paper 👩‍💻 Code (unofficial)

The Conformer is a variant designed for speech recognition. While the Transformer excels at capturing global relationships, it is less effective than convolutional layers in capturing local information. To address this, the Conformer augments the Transformer model by adding convolutional layers between the attention module and the final feedforward layer. As a result, the Conformer achieves significantly better performance than previous Transformer and CNN-based models, setting new state-of-the-art on ASR.

Efficient Streaming LMs with Attention Sinks #

📎 Paper

This looks similar to Longnet, but they keep a set of learnable tokens - sinks - at the beginning of the generated sequence, observing that it improves stability and performance even if you window the attention computation.

Simplifying Transformer Blocks #

📎 Paper 👨‍💻 Code

image

Proposed model removes attention sub-block skip connection, value and projection parameters, MLP skip connection and normalization layers. They report 15% faster training and 15% fewer parameters with no loss in performance. One caveat is that they only experimented with relatively smaller models (200-300m parameters) and it is not clear if this approach scales to larger models. Also, experiments are done in a limited number of tasks. However, it is easy to implement and worth trying at least for the model scales and tasks mentioned in the paper.