Transformers & Attention β
The revolutionary architecture that enabled modern AI breakthroughs
π§ What are Transformers? β
Definition: A neural network architecture based on self-attention mechanisms that can process sequences in parallel rather than sequentially
Simple Analogy: Think of attention as being able to simultaneously focus on different parts of a sentence to understand its meaning, like how you can instantly understand "The cat that my neighbor owns is sleeping" by connecting "cat" with "sleeping" regardless of the distance between them.
Why Transformers Revolutionized AI β
The Problem with Previous Approaches β
RNNs (Recurrent Neural Networks) β
- Sequential processing: Had to process text word by word
- Memory limitations: Struggled with long sequences due to vanishing gradients
- Slow training: Couldn't parallelize computation effectively
- Context loss: Earlier information got forgotten in long sequences
CNNs (Convolutional Neural Networks) β
- Local patterns only: Could only capture nearby relationships
- Fixed window size: Limited context window
- Not designed for sequences: Better suited for images than text
The Transformer Solution β
Parallel Processing β
- All positions at once: Processes entire sequence simultaneously
- Faster training: Much more efficient use of GPU parallelization
- Better hardware utilization: Takes full advantage of modern computing
Global Context β
- Any-to-any connections: Every word can directly attend to every other word
- Long-range dependencies: No degradation over distance
- Rich representations: Captures complex relationships across the sequence
The Attention Mechanism β
What is Attention? β
Attention allows the model to focus on different parts of the input when processing each element. It's like having a spotlight that can illuminate relevant information.
Self-Attention Intuition β
Sentence: "The animal didn't cross the street because it was too tired"
When processing "it":
- High attention to "animal" (subject reference)
- Low attention to "street" (less relevant)
- Medium attention to "tired" (descriptive context)
The model learns these relationships automatically!Mathematical Foundation β
Query, Key, Value (QKV) β
Every word gets three representations:
- Query (Q): "What am I looking for?"
- Key (K): "What do I represent?"
- Value (V): "What information do I contain?"
import torch
import torch.nn as nn
import math
class SelfAttention(nn.Module):
def __init__(self, embed_size, heads):
super(SelfAttention, self).__init__()
self.embed_size = embed_size
self.heads = heads
self.head_dim = embed_size // heads
assert (
self.head_dim * heads == embed_size
), "Embedding size needs to be divisible by heads"
# Linear transformations for Q, K, V
self.queries = nn.Linear(embed_size, embed_size, bias=False)
self.keys = nn.Linear(embed_size, embed_size, bias=False)
self.values = nn.Linear(embed_size, embed_size, bias=False)
self.fc_out = nn.Linear(embed_size, embed_size)
def forward(self, query, key, value, mask=None):
N = query.shape[0] # Batch size
value_len, key_len, query_len = value.shape[1], key.shape[1], query.shape[1]
# Split embedding into self.heads pieces
queries = self.queries(query).view(N, query_len, self.heads, self.head_dim)
keys = self.keys(key).view(N, key_len, self.heads, self.head_dim)
values = self.values(value).view(N, value_len, self.heads, self.head_dim)
# Calculate attention scores
energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
# queries shape: (N, query_len, heads, heads_dim)
# keys shape: (N, key_len, heads, heads_dim)
# energy shape: (N, heads, query_len, key_len)
# Scale by square root of dimension
if mask is not None:
energy = energy.masked_fill(mask == 0, float("-1e20"))
# Apply softmax
attention = torch.softmax(energy / (self.embed_size ** (1/2)), dim=3)
# Apply attention to values
out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
N, query_len, self.heads * self.head_dim
)
# attention shape: (N, heads, query_len, key_len)
# values shape: (N, value_len, heads, heads_dim)
# out shape: (N, query_len, heads*heads_dim)
out = self.fc_out(out)
return outAttention Score Calculation β
def scaled_dot_product_attention(Q, K, V, mask=None, temperature=1.0):
"""
Scaled Dot-Product Attention
Args:
Q: Query matrix (batch_size, seq_len, d_model)
K: Key matrix (batch_size, seq_len, d_model)
V: Value matrix (batch_size, seq_len, d_model)
mask: Optional mask to ignore certain positions
temperature: Scaling factor (usually sqrt(d_model))
"""
# Calculate attention scores
scores = torch.matmul(Q, K.transpose(-2, -1)) / temperature
# Apply mask if provided
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
# Apply softmax to get attention weights
attention_weights = torch.softmax(scores, dim=-1)
# Apply attention weights to values
output = torch.matmul(attention_weights, V)
return output, attention_weights
# Example usage
batch_size, seq_len, d_model = 2, 10, 512
Q = torch.randn(batch_size, seq_len, d_model)
K = torch.randn(batch_size, seq_len, d_model)
V = torch.randn(batch_size, seq_len, d_model)
output, weights = scaled_dot_product_attention(Q, K, V, temperature=math.sqrt(d_model))
print(f"Output shape: {output.shape}")
print(f"Attention weights shape: {weights.shape}")Multi-Head Attention β
Instead of using one attention mechanism, transformers use multiple "attention heads" that can focus on different types of relationships.
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super(MultiHeadAttention, self).__init__()
assert d_model % num_heads == 0
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
def forward(self, query, key, value, mask=None):
batch_size = query.size(0)
# 1) Linear transformations and split into heads
Q = self.W_q(query).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
K = self.W_k(key).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
V = self.W_v(value).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
# 2) Apply attention on all projected vectors in batch
attention_output, attention_weights = scaled_dot_product_attention(
Q, K, V, mask=mask, temperature=math.sqrt(self.d_k)
)
# 3) Concatenate heads and put through final linear layer
attention_output = attention_output.transpose(1, 2).contiguous().view(
batch_size, -1, self.d_model
)
output = self.W_o(attention_output)
return output, attention_weights
# Example: Different heads focus on different relationships
# Head 1: Subject-verb relationships
# Head 2: Adjective-noun relationships
# Head 3: Long-range dependencies
# Head 4: Positional relationshipsTransformer Architecture β
Complete Transformer Block β
class TransformerBlock(nn.Module):
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super(TransformerBlock, self).__init__()
# Multi-head attention
self.attention = MultiHeadAttention(d_model, num_heads)
# Feed-forward network
self.feed_forward = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.ReLU(),
nn.Linear(d_ff, d_model)
)
# Layer normalization
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
# Dropout
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
# Self-attention with residual connection
attention_output, _ = self.attention(x, x, x, mask)
x = self.norm1(x + self.dropout(attention_output))
# Feed-forward with residual connection
ff_output = self.feed_forward(x)
x = self.norm2(x + self.dropout(ff_output))
return xPositional Encoding β
Since transformers process all positions simultaneously, they need a way to understand the order of words.
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
super(PositionalEncoding, self).__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len).unsqueeze(1).float()
div_term = torch.exp(torch.arange(0, d_model, 2).float() *
-(math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, x):
return x + self.pe[:, :x.size(1)]
# Visualizing positional encodings
import matplotlib.pyplot as plt
import numpy as np
def plot_positional_encoding(d_model=512, max_len=100):
pos_enc = PositionalEncoding(d_model, max_len)
pe = pos_enc.pe.squeeze().numpy()
plt.figure(figsize=(12, 8))
plt.imshow(pe[:max_len, :50].T, cmap='RdYlBu', aspect='auto')
plt.xlabel('Position')
plt.ylabel('Embedding Dimension')
plt.title('Positional Encoding Pattern')
plt.colorbar()
plt.show()
# plot_positional_encoding()Complete Transformer Model β
class Transformer(nn.Module):
def __init__(self, vocab_size, d_model, num_heads, num_layers, d_ff, max_len, dropout=0.1):
super(Transformer, self).__init__()
# Embedding layers
self.token_embedding = nn.Embedding(vocab_size, d_model)
self.position_embedding = PositionalEncoding(d_model, max_len)
# Transformer blocks
self.transformer_blocks = nn.ModuleList([
TransformerBlock(d_model, num_heads, d_ff, dropout)
for _ in range(num_layers)
])
# Output layer
self.ln_f = nn.LayerNorm(d_model)
self.head = nn.Linear(d_model, vocab_size, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
# Token and position embeddings
token_embeddings = self.token_embedding(x)
x = self.dropout(self.position_embedding(token_embeddings))
# Pass through transformer blocks
for transformer in self.transformer_blocks:
x = transformer(x, mask)
# Final layer norm and output projection
x = self.ln_f(x)
logits = self.head(x)
return logits
# Example instantiation
model = Transformer(
vocab_size=50000,
d_model=512,
num_heads=8,
num_layers=6,
d_ff=2048,
max_len=1024,
dropout=0.1
)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")Types of Transformers β
Encoder-Only (BERT-style) β
class TransformerEncoder(nn.Module):
"""BERT-style encoder for understanding tasks"""
def __init__(self, vocab_size, d_model, num_heads, num_layers, d_ff, max_len):
super(TransformerEncoder, self).__init__()
self.token_embedding = nn.Embedding(vocab_size, d_model)
self.position_embedding = PositionalEncoding(d_model, max_len)
self.encoder_blocks = nn.ModuleList([
TransformerBlock(d_model, num_heads, d_ff)
for _ in range(num_layers)
])
self.ln_f = nn.LayerNorm(d_model)
def forward(self, x, mask=None):
x = self.token_embedding(x) + self.position_embedding(x)
for block in self.encoder_blocks:
x = block(x, mask)
return self.ln_f(x)
# Use cases:
# - Text classification
# - Named entity recognition
# - Question answering
# - Sentence similarityDecoder-Only (GPT-style) β
class TransformerDecoder(nn.Module):
"""GPT-style decoder for generation tasks"""
def __init__(self, vocab_size, d_model, num_heads, num_layers, d_ff, max_len):
super(TransformerDecoder, self).__init__()
self.token_embedding = nn.Embedding(vocab_size, d_model)
self.position_embedding = PositionalEncoding(d_model, max_len)
self.decoder_blocks = nn.ModuleList([
TransformerBlock(d_model, num_heads, d_ff)
for _ in range(num_layers)
])
self.ln_f = nn.LayerNorm(d_model)
self.head = nn.Linear(d_model, vocab_size)
def forward(self, x):
seq_len = x.size(1)
# Create causal mask (can't see future tokens)
mask = torch.tril(torch.ones(seq_len, seq_len)).unsqueeze(0).unsqueeze(0)
x = self.token_embedding(x) + self.position_embedding(x)
for block in self.decoder_blocks:
x = block(x, mask)
x = self.ln_f(x)
return self.head(x)
# Use cases:
# - Text generation
# - Language modeling
# - Code generation
# - Creative writingEncoder-Decoder (T5-style) β
class TransformerEncoderDecoder(nn.Module):
"""T5-style encoder-decoder for sequence-to-sequence tasks"""
def __init__(self, vocab_size, d_model, num_heads, num_layers, d_ff, max_len):
super(TransformerEncoderDecoder, self).__init__()
# Encoder
self.encoder = TransformerEncoder(vocab_size, d_model, num_heads, num_layers//2, d_ff, max_len)
# Decoder with cross-attention
self.decoder = TransformerDecoder(vocab_size, d_model, num_heads, num_layers//2, d_ff, max_len)
def forward(self, src, tgt):
# Encode source
encoder_output = self.encoder(src)
# Decode target (with cross-attention to encoder)
decoder_output = self.decoder(tgt, encoder_output)
return decoder_output
# Use cases:
# - Machine translation
# - Text summarization
# - Question answering
# - Data-to-text generationAttention Patterns and Interpretability β
Visualizing Attention β
def visualize_attention(model, tokenizer, text, layer=0, head=0):
"""Visualize attention patterns"""
# Tokenize input
inputs = tokenizer(text, return_tensors='pt')
# Get attention weights
with torch.no_grad():
outputs = model(**inputs, output_attentions=True)
attention = outputs.attentions[layer][0, head] # [seq_len, seq_len]
# Convert to numpy
attention = attention.cpu().numpy()
tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
# Plot heatmap
plt.figure(figsize=(10, 8))
plt.imshow(attention, cmap='Blues')
plt.xticks(range(len(tokens)), tokens, rotation=45)
plt.yticks(range(len(tokens)), tokens)
plt.xlabel('Key (attending to)')
plt.ylabel('Query (attending from)')
plt.title(f'Attention Pattern - Layer {layer}, Head {head}')
plt.colorbar()
plt.tight_layout()
plt.show()
# Example usage with a pre-trained model
from transformers import BertModel, BertTokenizer
model = BertModel.from_pretrained('bert-base-uncased', output_attentions=True)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
text = "The cat sat on the mat because it was comfortable"
# visualize_attention(model, tokenizer, text, layer=5, head=3)Common Attention Patterns β
def analyze_attention_patterns(attention_weights):
"""Analyze different types of attention patterns"""
patterns = {}
# 1. Diagonal patterns (local attention)
diagonal_score = np.trace(attention_weights)
patterns['local_attention'] = diagonal_score
# 2. Broad attention (attending to many positions)
entropy = -np.sum(attention_weights * np.log(attention_weights + 1e-9), axis=1)
patterns['attention_entropy'] = np.mean(entropy)
# 3. Focused attention (attending to few positions)
max_attention = np.max(attention_weights, axis=1)
patterns['max_attention'] = np.mean(max_attention)
# 4. Beginning/end bias
patterns['cls_attention'] = np.mean(attention_weights[:, 0]) # Attention to [CLS]
patterns['sep_attention'] = np.mean(attention_weights[:, -1]) # Attention to [SEP]
return patterns
# Different heads learn different patterns:
# - Some focus on syntax (subject-verb relationships)
# - Some focus on semantics (word meanings)
# - Some focus on position (nearby words)
# - Some focus on specific tokens ([CLS], [SEP])Advanced Transformer Techniques β
Efficient Attention Mechanisms β
Linear Attention β
class LinearAttention(nn.Module):
"""Linear complexity attention mechanism"""
def __init__(self, d_model):
super(LinearAttention, self).__init__()
self.d_model = d_model
def forward(self, Q, K, V):
# Apply feature map (e.g., ELU + 1)
Q = torch.nn.functional.elu(Q) + 1
K = torch.nn.functional.elu(K) + 1
# Compute linear attention
KV = torch.einsum('nld,nlm->ndm', K, V)
Z = torch.einsum('nld->nd', K)
output = torch.einsum('nld,ndm->nlm', Q, KV) / (torch.einsum('nld,nd->nl', Q, Z).unsqueeze(-1) + 1e-6)
return outputSparse Attention β
class SparseAttention(nn.Module):
"""Sparse attention for long sequences"""
def __init__(self, d_model, sparse_pattern='strided'):
super(SparseAttention, self).__init__()
self.d_model = d_model
self.sparse_pattern = sparse_pattern
def create_sparse_mask(self, seq_len):
if self.sparse_pattern == 'strided':
# Only attend to every k-th position
mask = torch.zeros(seq_len, seq_len)
for i in range(seq_len):
mask[i, ::4] = 1 # Stride of 4
mask[i, max(0, i-64):i+1] = 1 # Local window
return mask
def forward(self, Q, K, V):
seq_len = Q.size(1)
mask = self.create_sparse_mask(seq_len)
# Apply sparse attention
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_model)
scores = scores.masked_fill(mask == 0, -1e9)
attention = torch.softmax(scores, dim=-1)
return torch.matmul(attention, V)Training Techniques β
Gradient Accumulation β
def train_transformer_with_accumulation(model, dataloader, optimizer, accumulation_steps=4):
"""Train with gradient accumulation for large effective batch sizes"""
model.train()
total_loss = 0
optimizer.zero_grad()
for step, batch in enumerate(dataloader):
# Forward pass
outputs = model(**batch)
loss = outputs.loss / accumulation_steps # Scale loss
# Backward pass
loss.backward()
# Update weights every accumulation_steps
if (step + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
total_loss += loss.item()
return total_loss / len(dataloader)Learning Rate Scheduling β
class TransformerLRScheduler:
"""Learning rate scheduler for transformers"""
def __init__(self, optimizer, d_model, warmup_steps=4000):
self.optimizer = optimizer
self.d_model = d_model
self.warmup_steps = warmup_steps
self.step_num = 0
def step(self):
self.step_num += 1
lr = self.get_lr()
for param_group in self.optimizer.param_groups:
param_group['lr'] = lr
def get_lr(self):
return (self.d_model ** -0.5) * min(
self.step_num ** -0.5,
self.step_num * (self.warmup_steps ** -1.5)
)Real-World Applications β
Building a Simple Transformer for Text Classification β
class TextClassificationTransformer(nn.Module):
def __init__(self, vocab_size, d_model, num_heads, num_layers, num_classes, max_len):
super(TextClassificationTransformer, self).__init__()
self.transformer = TransformerEncoder(vocab_size, d_model, num_heads, num_layers, d_model*4, max_len)
self.classifier = nn.Linear(d_model, num_classes)
self.dropout = nn.Dropout(0.1)
def forward(self, x, mask=None):
# Get transformer representations
hidden_states = self.transformer(x, mask)
# Use [CLS] token representation for classification
cls_representation = hidden_states[:, 0] # First token
# Apply dropout and classify
cls_representation = self.dropout(cls_representation)
logits = self.classifier(cls_representation)
return logits
# Training loop
def train_classifier(model, train_loader, val_loader, epochs=10):
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()
for epoch in range(epochs):
model.train()
total_loss = 0
for batch in train_loader:
optimizer.zero_grad()
logits = model(batch['input_ids'], batch['attention_mask'])
loss = criterion(logits, batch['labels'])
loss.backward()
optimizer.step()
total_loss += loss.item()
# Validation
model.eval()
correct = 0
total = 0
with torch.no_grad():
for batch in val_loader:
logits = model(batch['input_ids'], batch['attention_mask'])
predictions = torch.argmax(logits, dim=1)
correct += (predictions == batch['labels']).sum().item()
total += batch['labels'].size(0)
accuracy = correct / total
print(f"Epoch {epoch+1}: Loss={total_loss/len(train_loader):.4f}, Acc={accuracy:.4f}")Impact and Future Directions β
Transformer Variants β
- BERT: Bidirectional encoder for understanding
- GPT: Autoregressive decoder for generation
- T5: Text-to-text transfer transformer
- BART: Denoising autoencoder
- RoBERTa: Robustly optimized BERT
- ELECTRA: Efficiently learning encoder
- DeBERTa: Decoding-enhanced BERT
Recent Innovations β
- Mixture of Experts: Scaling model capacity efficiently
- Retrieval-Augmented: Combining parametric and non-parametric knowledge
- Multimodal Transformers: Processing text, images, and audio together
- Long Context: Handling very long sequences efficiently
Future Research Directions β
- Efficiency: Making transformers faster and more memory-efficient
- Interpretability: Better understanding of what transformers learn
- Generalization: Improving robustness and out-of-domain performance
- Alignment: Making models safer and more aligned with human values
Next Steps:
- LLM Fundamentals: Understand how transformers enable large language models
- Training & Fine-tuning: Learn how to train and customize transformer models
- Vector Embeddings: See how transformers create meaningful representations