Exploring Transformers

AI
Machine Learning

Exploring Transformer Architectures

Since their introduction in the “Attention Is All You Need” paper, transformer architectures have revolutionized natural language processing and beyond. In this post, I’ll explore how transformers work and implement a simple version from scratch.

The Core Idea: Attention

The key innovation of transformers is the self-attention mechanism, which allows the model to focus on different parts of the input sequence when producing each element of the output sequence.

Unlike RNNs or LSTMs, transformers process the entire sequence in parallel, which enables much more efficient training on modern hardware.

Building a Simple Transformer

Let’s implement a basic transformer encoder block:

import torch
import torch.nn as nn

class SelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(SelfAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads
        
        assert (self.head_dim * heads == embed_size), "Embed size needs to be divisible by heads"
        
        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)
        
    def forward(self, values, keys, queries, mask=None):
        # Get batch size
        N = queries.shape[0]
        value_len, key_len, query_len = values.shape[1], keys.shape[1], queries.shape[1]
        
        # Split embedding into self.heads pieces
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = queries.reshape(N, query_len, self.heads, self.head_dim)
        
        # Scaled dot-product attention
        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
        
        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))
        
        attention = torch.softmax(energy / (self.head_dim ** (1/2)), dim=3)
        
        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.heads * self.head_dim
        )
        
        out = self.fc_out(out)
        return out

Scaling Challenges

While the concept is elegant, training large transformer models presents significant challenges:

  • Memory requirements grow quadratically with sequence length
  • Numerical stability issues require careful implementation
  • Pre-training requires massive datasets

Recent Innovations

Some exciting developments in transformer research include:

  • Sparse Attention - Reducing complexity from O(n²) to O(n log n)
  • Parameter Sharing - Techniques like ALBERT that reduce model size
  • Efficient Fine-tuning - Methods like LoRA and adapter tuning

Beyond NLP

Transformers aren’t just for text anymore. They’re now being applied to:

  • Computer vision (ViT)
  • Audio processing
  • Protein folding
  • Multi-modal learning

Practical Tips

If you’re working with transformers, here are some practical tips:

  1. Start with pre-trained models rather than training from scratch
  2. Use gradient accumulation for larger batch sizes
  3. Mixed precision training significantly improves performance
  4. Carefully monitor training stability, especially for larger models

I’m fascinated by how transformers continue to evolve and will be sharing more experiments as I explore this technology further.