Skip to content

Transformers & Attention ​

The revolutionary architecture that enabled modern AI breakthroughs

🧠 What are Transformers? ​

Definition: A neural network architecture based on self-attention mechanisms that can process sequences in parallel rather than sequentially

Simple Analogy: Think of attention as being able to simultaneously focus on different parts of a sentence to understand its meaning, like how you can instantly understand "The cat that my neighbor owns is sleeping" by connecting "cat" with "sleeping" regardless of the distance between them.

Why Transformers Revolutionized AI ​

The Problem with Previous Approaches ​

RNNs (Recurrent Neural Networks) ​

  • Sequential processing: Had to process text word by word
  • Memory limitations: Struggled with long sequences due to vanishing gradients
  • Slow training: Couldn't parallelize computation effectively
  • Context loss: Earlier information got forgotten in long sequences

CNNs (Convolutional Neural Networks) ​

  • Local patterns only: Could only capture nearby relationships
  • Fixed window size: Limited context window
  • Not designed for sequences: Better suited for images than text

The Transformer Solution ​

Parallel Processing ​

  • All positions at once: Processes entire sequence simultaneously
  • Faster training: Much more efficient use of GPU parallelization
  • Better hardware utilization: Takes full advantage of modern computing

Global Context ​

  • Any-to-any connections: Every word can directly attend to every other word
  • Long-range dependencies: No degradation over distance
  • Rich representations: Captures complex relationships across the sequence

The Attention Mechanism ​

What is Attention? ​

Attention allows the model to focus on different parts of the input when processing each element. It's like having a spotlight that can illuminate relevant information.

Self-Attention Intuition ​

text
Sentence: "The animal didn't cross the street because it was too tired"

When processing "it":
- High attention to "animal" (subject reference)
- Low attention to "street" (less relevant)
- Medium attention to "tired" (descriptive context)

The model learns these relationships automatically!

Mathematical Foundation ​

Query, Key, Value (QKV) ​

Every word gets three representations:

  • Query (Q): "What am I looking for?"
  • Key (K): "What do I represent?"
  • Value (V): "What information do I contain?"
python
import torch
import torch.nn as nn
import math

class SelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(SelfAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads
        
        assert (
            self.head_dim * heads == embed_size
        ), "Embedding size needs to be divisible by heads"
        
        # Linear transformations for Q, K, V
        self.queries = nn.Linear(embed_size, embed_size, bias=False)
        self.keys = nn.Linear(embed_size, embed_size, bias=False)
        self.values = nn.Linear(embed_size, embed_size, bias=False)
        self.fc_out = nn.Linear(embed_size, embed_size)
        
    def forward(self, query, key, value, mask=None):
        N = query.shape[0]  # Batch size
        value_len, key_len, query_len = value.shape[1], key.shape[1], query.shape[1]
        
        # Split embedding into self.heads pieces
        queries = self.queries(query).view(N, query_len, self.heads, self.head_dim)
        keys = self.keys(key).view(N, key_len, self.heads, self.head_dim)
        values = self.values(value).view(N, value_len, self.heads, self.head_dim)
        
        # Calculate attention scores
        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
        # queries shape: (N, query_len, heads, heads_dim)
        # keys shape: (N, key_len, heads, heads_dim)
        # energy shape: (N, heads, query_len, key_len)
        
        # Scale by square root of dimension
        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))
        
        # Apply softmax
        attention = torch.softmax(energy / (self.embed_size ** (1/2)), dim=3)
        
        # Apply attention to values
        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.heads * self.head_dim
        )
        # attention shape: (N, heads, query_len, key_len)
        # values shape: (N, value_len, heads, heads_dim)
        # out shape: (N, query_len, heads*heads_dim)
        
        out = self.fc_out(out)
        return out

Attention Score Calculation ​

python
def scaled_dot_product_attention(Q, K, V, mask=None, temperature=1.0):
    """
    Scaled Dot-Product Attention
    
    Args:
        Q: Query matrix (batch_size, seq_len, d_model)
        K: Key matrix (batch_size, seq_len, d_model)
        V: Value matrix (batch_size, seq_len, d_model)
        mask: Optional mask to ignore certain positions
        temperature: Scaling factor (usually sqrt(d_model))
    """
    # Calculate attention scores
    scores = torch.matmul(Q, K.transpose(-2, -1)) / temperature
    
    # Apply mask if provided
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    
    # Apply softmax to get attention weights
    attention_weights = torch.softmax(scores, dim=-1)
    
    # Apply attention weights to values
    output = torch.matmul(attention_weights, V)
    
    return output, attention_weights

# Example usage
batch_size, seq_len, d_model = 2, 10, 512
Q = torch.randn(batch_size, seq_len, d_model)
K = torch.randn(batch_size, seq_len, d_model)
V = torch.randn(batch_size, seq_len, d_model)

output, weights = scaled_dot_product_attention(Q, K, V, temperature=math.sqrt(d_model))
print(f"Output shape: {output.shape}")
print(f"Attention weights shape: {weights.shape}")

Multi-Head Attention ​

Instead of using one attention mechanism, transformers use multiple "attention heads" that can focus on different types of relationships.

python
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)
        
        # 1) Linear transformations and split into heads
        Q = self.W_q(query).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_k(key).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        V = self.W_v(value).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        
        # 2) Apply attention on all projected vectors in batch
        attention_output, attention_weights = scaled_dot_product_attention(
            Q, K, V, mask=mask, temperature=math.sqrt(self.d_k)
        )
        
        # 3) Concatenate heads and put through final linear layer
        attention_output = attention_output.transpose(1, 2).contiguous().view(
            batch_size, -1, self.d_model
        )
        
        output = self.W_o(attention_output)
        
        return output, attention_weights

# Example: Different heads focus on different relationships
# Head 1: Subject-verb relationships
# Head 2: Adjective-noun relationships  
# Head 3: Long-range dependencies
# Head 4: Positional relationships

Transformer Architecture ​

Complete Transformer Block ​

python
class TransformerBlock(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super(TransformerBlock, self).__init__()
        
        # Multi-head attention
        self.attention = MultiHeadAttention(d_model, num_heads)
        
        # Feed-forward network
        self.feed_forward = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model)
        )
        
        # Layer normalization
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        
        # Dropout
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask=None):
        # Self-attention with residual connection
        attention_output, _ = self.attention(x, x, x, mask)
        x = self.norm1(x + self.dropout(attention_output))
        
        # Feed-forward with residual connection
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))
        
        return x

Positional Encoding ​

Since transformers process all positions simultaneously, they need a way to understand the order of words.

python
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        
        div_term = torch.exp(torch.arange(0, d_model, 2).float() *
                           -(math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

# Visualizing positional encodings
import matplotlib.pyplot as plt
import numpy as np

def plot_positional_encoding(d_model=512, max_len=100):
    pos_enc = PositionalEncoding(d_model, max_len)
    pe = pos_enc.pe.squeeze().numpy()
    
    plt.figure(figsize=(12, 8))
    plt.imshow(pe[:max_len, :50].T, cmap='RdYlBu', aspect='auto')
    plt.xlabel('Position')
    plt.ylabel('Embedding Dimension')
    plt.title('Positional Encoding Pattern')
    plt.colorbar()
    plt.show()

# plot_positional_encoding()

Complete Transformer Model ​

python
class Transformer(nn.Module):
    def __init__(self, vocab_size, d_model, num_heads, num_layers, d_ff, max_len, dropout=0.1):
        super(Transformer, self).__init__()
        
        # Embedding layers
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        self.position_embedding = PositionalEncoding(d_model, max_len)
        
        # Transformer blocks
        self.transformer_blocks = nn.ModuleList([
            TransformerBlock(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])
        
        # Output layer
        self.ln_f = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, vocab_size, bias=False)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask=None):
        # Token and position embeddings
        token_embeddings = self.token_embedding(x)
        x = self.dropout(self.position_embedding(token_embeddings))
        
        # Pass through transformer blocks
        for transformer in self.transformer_blocks:
            x = transformer(x, mask)
        
        # Final layer norm and output projection
        x = self.ln_f(x)
        logits = self.head(x)
        
        return logits

# Example instantiation
model = Transformer(
    vocab_size=50000,
    d_model=512,
    num_heads=8,
    num_layers=6,
    d_ff=2048,
    max_len=1024,
    dropout=0.1
)

print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

Types of Transformers ​

Encoder-Only (BERT-style) ​

python
class TransformerEncoder(nn.Module):
    """BERT-style encoder for understanding tasks"""
    def __init__(self, vocab_size, d_model, num_heads, num_layers, d_ff, max_len):
        super(TransformerEncoder, self).__init__()
        
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        self.position_embedding = PositionalEncoding(d_model, max_len)
        
        self.encoder_blocks = nn.ModuleList([
            TransformerBlock(d_model, num_heads, d_ff)
            for _ in range(num_layers)
        ])
        
        self.ln_f = nn.LayerNorm(d_model)
        
    def forward(self, x, mask=None):
        x = self.token_embedding(x) + self.position_embedding(x)
        
        for block in self.encoder_blocks:
            x = block(x, mask)
        
        return self.ln_f(x)

# Use cases:
# - Text classification
# - Named entity recognition
# - Question answering
# - Sentence similarity

Decoder-Only (GPT-style) ​

python
class TransformerDecoder(nn.Module):
    """GPT-style decoder for generation tasks"""
    def __init__(self, vocab_size, d_model, num_heads, num_layers, d_ff, max_len):
        super(TransformerDecoder, self).__init__()
        
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        self.position_embedding = PositionalEncoding(d_model, max_len)
        
        self.decoder_blocks = nn.ModuleList([
            TransformerBlock(d_model, num_heads, d_ff)
            for _ in range(num_layers)
        ])
        
        self.ln_f = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, vocab_size)
        
    def forward(self, x):
        seq_len = x.size(1)
        
        # Create causal mask (can't see future tokens)
        mask = torch.tril(torch.ones(seq_len, seq_len)).unsqueeze(0).unsqueeze(0)
        
        x = self.token_embedding(x) + self.position_embedding(x)
        
        for block in self.decoder_blocks:
            x = block(x, mask)
        
        x = self.ln_f(x)
        return self.head(x)

# Use cases:
# - Text generation
# - Language modeling
# - Code generation
# - Creative writing

Encoder-Decoder (T5-style) ​

python
class TransformerEncoderDecoder(nn.Module):
    """T5-style encoder-decoder for sequence-to-sequence tasks"""
    def __init__(self, vocab_size, d_model, num_heads, num_layers, d_ff, max_len):
        super(TransformerEncoderDecoder, self).__init__()
        
        # Encoder
        self.encoder = TransformerEncoder(vocab_size, d_model, num_heads, num_layers//2, d_ff, max_len)
        
        # Decoder with cross-attention
        self.decoder = TransformerDecoder(vocab_size, d_model, num_heads, num_layers//2, d_ff, max_len)
        
    def forward(self, src, tgt):
        # Encode source
        encoder_output = self.encoder(src)
        
        # Decode target (with cross-attention to encoder)
        decoder_output = self.decoder(tgt, encoder_output)
        
        return decoder_output

# Use cases:
# - Machine translation
# - Text summarization
# - Question answering
# - Data-to-text generation

Attention Patterns and Interpretability ​

Visualizing Attention ​

python
def visualize_attention(model, tokenizer, text, layer=0, head=0):
    """Visualize attention patterns"""
    
    # Tokenize input
    inputs = tokenizer(text, return_tensors='pt')
    
    # Get attention weights
    with torch.no_grad():
        outputs = model(**inputs, output_attentions=True)
        attention = outputs.attentions[layer][0, head]  # [seq_len, seq_len]
    
    # Convert to numpy
    attention = attention.cpu().numpy()
    tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
    
    # Plot heatmap
    plt.figure(figsize=(10, 8))
    plt.imshow(attention, cmap='Blues')
    plt.xticks(range(len(tokens)), tokens, rotation=45)
    plt.yticks(range(len(tokens)), tokens)
    plt.xlabel('Key (attending to)')
    plt.ylabel('Query (attending from)')
    plt.title(f'Attention Pattern - Layer {layer}, Head {head}')
    plt.colorbar()
    plt.tight_layout()
    plt.show()

# Example usage with a pre-trained model
from transformers import BertModel, BertTokenizer

model = BertModel.from_pretrained('bert-base-uncased', output_attentions=True)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

text = "The cat sat on the mat because it was comfortable"
# visualize_attention(model, tokenizer, text, layer=5, head=3)

Common Attention Patterns ​

python
def analyze_attention_patterns(attention_weights):
    """Analyze different types of attention patterns"""
    
    patterns = {}
    
    # 1. Diagonal patterns (local attention)
    diagonal_score = np.trace(attention_weights)
    patterns['local_attention'] = diagonal_score
    
    # 2. Broad attention (attending to many positions)
    entropy = -np.sum(attention_weights * np.log(attention_weights + 1e-9), axis=1)
    patterns['attention_entropy'] = np.mean(entropy)
    
    # 3. Focused attention (attending to few positions)
    max_attention = np.max(attention_weights, axis=1)
    patterns['max_attention'] = np.mean(max_attention)
    
    # 4. Beginning/end bias
    patterns['cls_attention'] = np.mean(attention_weights[:, 0])  # Attention to [CLS]
    patterns['sep_attention'] = np.mean(attention_weights[:, -1])  # Attention to [SEP]
    
    return patterns

# Different heads learn different patterns:
# - Some focus on syntax (subject-verb relationships)
# - Some focus on semantics (word meanings)
# - Some focus on position (nearby words)
# - Some focus on specific tokens ([CLS], [SEP])

Advanced Transformer Techniques ​

Efficient Attention Mechanisms ​

Linear Attention ​

python
class LinearAttention(nn.Module):
    """Linear complexity attention mechanism"""
    def __init__(self, d_model):
        super(LinearAttention, self).__init__()
        self.d_model = d_model
        
    def forward(self, Q, K, V):
        # Apply feature map (e.g., ELU + 1)
        Q = torch.nn.functional.elu(Q) + 1
        K = torch.nn.functional.elu(K) + 1
        
        # Compute linear attention
        KV = torch.einsum('nld,nlm->ndm', K, V)
        Z = torch.einsum('nld->nd', K)
        
        output = torch.einsum('nld,ndm->nlm', Q, KV) / (torch.einsum('nld,nd->nl', Q, Z).unsqueeze(-1) + 1e-6)
        
        return output

Sparse Attention ​

python
class SparseAttention(nn.Module):
    """Sparse attention for long sequences"""
    def __init__(self, d_model, sparse_pattern='strided'):
        super(SparseAttention, self).__init__()
        self.d_model = d_model
        self.sparse_pattern = sparse_pattern
        
    def create_sparse_mask(self, seq_len):
        if self.sparse_pattern == 'strided':
            # Only attend to every k-th position
            mask = torch.zeros(seq_len, seq_len)
            for i in range(seq_len):
                mask[i, ::4] = 1  # Stride of 4
                mask[i, max(0, i-64):i+1] = 1  # Local window
        
        return mask
        
    def forward(self, Q, K, V):
        seq_len = Q.size(1)
        mask = self.create_sparse_mask(seq_len)
        
        # Apply sparse attention
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_model)
        scores = scores.masked_fill(mask == 0, -1e9)
        attention = torch.softmax(scores, dim=-1)
        
        return torch.matmul(attention, V)

Training Techniques ​

Gradient Accumulation ​

python
def train_transformer_with_accumulation(model, dataloader, optimizer, accumulation_steps=4):
    """Train with gradient accumulation for large effective batch sizes"""
    
    model.train()
    total_loss = 0
    optimizer.zero_grad()
    
    for step, batch in enumerate(dataloader):
        # Forward pass
        outputs = model(**batch)
        loss = outputs.loss / accumulation_steps  # Scale loss
        
        # Backward pass
        loss.backward()
        
        # Update weights every accumulation_steps
        if (step + 1) % accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
        
        total_loss += loss.item()
    
    return total_loss / len(dataloader)

Learning Rate Scheduling ​

python
class TransformerLRScheduler:
    """Learning rate scheduler for transformers"""
    def __init__(self, optimizer, d_model, warmup_steps=4000):
        self.optimizer = optimizer
        self.d_model = d_model
        self.warmup_steps = warmup_steps
        self.step_num = 0
        
    def step(self):
        self.step_num += 1
        lr = self.get_lr()
        
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr
    
    def get_lr(self):
        return (self.d_model ** -0.5) * min(
            self.step_num ** -0.5,
            self.step_num * (self.warmup_steps ** -1.5)
        )

Real-World Applications ​

Building a Simple Transformer for Text Classification ​

python
class TextClassificationTransformer(nn.Module):
    def __init__(self, vocab_size, d_model, num_heads, num_layers, num_classes, max_len):
        super(TextClassificationTransformer, self).__init__()
        
        self.transformer = TransformerEncoder(vocab_size, d_model, num_heads, num_layers, d_model*4, max_len)
        self.classifier = nn.Linear(d_model, num_classes)
        self.dropout = nn.Dropout(0.1)
        
    def forward(self, x, mask=None):
        # Get transformer representations
        hidden_states = self.transformer(x, mask)
        
        # Use [CLS] token representation for classification
        cls_representation = hidden_states[:, 0]  # First token
        
        # Apply dropout and classify
        cls_representation = self.dropout(cls_representation)
        logits = self.classifier(cls_representation)
        
        return logits

# Training loop
def train_classifier(model, train_loader, val_loader, epochs=10):
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
    criterion = nn.CrossEntropyLoss()
    
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        
        for batch in train_loader:
            optimizer.zero_grad()
            
            logits = model(batch['input_ids'], batch['attention_mask'])
            loss = criterion(logits, batch['labels'])
            
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        # Validation
        model.eval()
        correct = 0
        total = 0
        
        with torch.no_grad():
            for batch in val_loader:
                logits = model(batch['input_ids'], batch['attention_mask'])
                predictions = torch.argmax(logits, dim=1)
                
                correct += (predictions == batch['labels']).sum().item()
                total += batch['labels'].size(0)
        
        accuracy = correct / total
        print(f"Epoch {epoch+1}: Loss={total_loss/len(train_loader):.4f}, Acc={accuracy:.4f}")

Impact and Future Directions ​

Transformer Variants ​

  • BERT: Bidirectional encoder for understanding
  • GPT: Autoregressive decoder for generation
  • T5: Text-to-text transfer transformer
  • BART: Denoising autoencoder
  • RoBERTa: Robustly optimized BERT
  • ELECTRA: Efficiently learning encoder
  • DeBERTa: Decoding-enhanced BERT

Recent Innovations ​

  • Mixture of Experts: Scaling model capacity efficiently
  • Retrieval-Augmented: Combining parametric and non-parametric knowledge
  • Multimodal Transformers: Processing text, images, and audio together
  • Long Context: Handling very long sequences efficiently

Future Research Directions ​

  • Efficiency: Making transformers faster and more memory-efficient
  • Interpretability: Better understanding of what transformers learn
  • Generalization: Improving robustness and out-of-domain performance
  • Alignment: Making models safer and more aligned with human values

Next Steps:

Released under the MIT License.