Skip to content

Training & Fine-tuning ​

How to train, adapt, and customize large language models for specific tasks

🎯 What is Training vs Fine-tuning? ​

Training: Creating a model from scratch using large amounts of data to learn general language patterns

Fine-tuning: Taking a pre-trained model and adapting it for specific tasks or domains with smaller, targeted datasets

Simple Analogy: Training is like teaching someone to read and write from scratch, while fine-tuning is like teaching a literate person to become a specialized writer (poet, journalist, technical writer).

Pre-training: Building Foundation Models ​

Objectives and Data ​

Language Modeling Objective ​

  • Next Token Prediction: Given a sequence, predict the next word
  • Masked Language Modeling: Fill in randomly masked words (BERT-style)
  • Prefix LM: Predict continuation given a prefix
python
# Example of next token prediction training
def compute_language_modeling_loss(model, input_ids):
    # Shift inputs for next token prediction
    inputs = input_ids[:, :-1]
    targets = input_ids[:, 1:]
    
    # Forward pass
    logits = model(inputs)
    
    # Compute cross-entropy loss
    loss = F.cross_entropy(
        logits.view(-1, logits.size(-1)), 
        targets.view(-1), 
        ignore_index=-100
    )
    
    return loss

# Training loop
for batch in dataloader:
    optimizer.zero_grad()
    loss = compute_language_modeling_loss(model, batch['input_ids'])
    loss.backward()
    optimizer.step()

Training Data Scale ​

  • Diversity: Web pages, books, articles, code, conversations
  • Volume: Hundreds of billions to trillions of tokens
  • Quality: Filtering, deduplication, safety screening
  • Preprocessing: Tokenization, format standardization

Training Infrastructure ​

Distributed Training ​

python
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel

def setup_distributed_training():
    """Setup for multi-GPU training"""
    # Initialize process group
    dist.init_process_group(backend='nccl')
    
    # Set device
    local_rank = int(os.environ['LOCAL_RANK'])
    torch.cuda.set_device(local_rank)
    
    return local_rank

def create_distributed_model(model, local_rank):
    """Wrap model for distributed training"""
    model = model.to(local_rank)
    model = DistributedDataParallel(
        model, 
        device_ids=[local_rank],
        find_unused_parameters=False
    )
    return model

# Example training setup
local_rank = setup_distributed_training()
model = create_distributed_model(transformer_model, local_rank)

Memory Optimization ​

python
class GradientCheckpointing:
    """Trade compute for memory by recomputing activations"""
    
    @staticmethod
    def checkpoint_function(function, *args):
        return torch.utils.checkpoint.checkpoint(function, *args)

class MixedPrecisionTraining:
    """Use FP16 to reduce memory usage"""
    
    def __init__(self, model, optimizer):
        self.model = model
        self.optimizer = optimizer
        self.scaler = torch.cuda.amp.GradScaler()
    
    def training_step(self, batch):
        with torch.cuda.amp.autocast():
            outputs = self.model(**batch)
            loss = outputs.loss
        
        self.scaler.scale(loss).backward()
        self.scaler.step(self.optimizer)
        self.scaler.update()
        self.optimizer.zero_grad()
        
        return loss

# ZeRO (Zero Redundancy Optimizer) for very large models
from deepspeed import initialize

model, optimizer, _, _ = initialize(
    model=transformer_model,
    optimizer=torch.optim.AdamW(transformer_model.parameters()),
    config={
        "zero_optimization": {
            "stage": 3,  # Partition optimizer states, gradients, and parameters
            "offload_optimizer": {"device": "cpu"},
            "offload_param": {"device": "cpu"}
        }
    }
)

Fine-tuning Approaches ​

Full Fine-tuning ​

Complete retraining of all model parameters on task-specific data.

python
class FullFineTuning:
    def __init__(self, model, learning_rate=1e-5):
        self.model = model
        self.optimizer = torch.optim.AdamW(
            model.parameters(), 
            lr=learning_rate,
            weight_decay=0.01
        )
    
    def fine_tune(self, train_loader, val_loader, epochs=3):
        """Full fine-tuning on task data"""
        
        for epoch in range(epochs):
            # Training
            self.model.train()
            total_loss = 0
            
            for batch in train_loader:
                self.optimizer.zero_grad()
                
                outputs = self.model(**batch)
                loss = outputs.loss
                
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
                self.optimizer.step()
                
                total_loss += loss.item()
            
            # Validation
            val_loss = self.evaluate(val_loader)
            print(f"Epoch {epoch+1}: Train Loss={total_loss/len(train_loader):.4f}, "
                  f"Val Loss={val_loss:.4f}")
    
    def evaluate(self, val_loader):
        self.model.eval()
        total_loss = 0
        
        with torch.no_grad():
            for batch in val_loader:
                outputs = self.model(**batch)
                total_loss += outputs.loss.item()
        
        return total_loss / len(val_loader)

Parameter-Efficient Fine-tuning (PEFT) ​

LoRA (Low-Rank Adaptation) ​

python
import torch.nn as nn

class LoRALayer(nn.Module):
    """Low-Rank Adaptation layer"""
    
    def __init__(self, original_layer, rank=16, alpha=32):
        super(LoRALayer, self).__init__()
        self.original_layer = original_layer
        self.rank = rank
        self.alpha = alpha
        
        # Freeze original parameters
        for param in self.original_layer.parameters():
            param.requires_grad = False
        
        # Add low-rank adaptation matrices
        in_features = original_layer.in_features
        out_features = original_layer.out_features
        
        self.lora_A = nn.Parameter(torch.randn(rank, in_features) * 0.01)
        self.lora_B = nn.Parameter(torch.zeros(out_features, rank))
        
    def forward(self, x):
        # Original computation
        original_output = self.original_layer(x)
        
        # LoRA adaptation
        lora_output = (x @ self.lora_A.T @ self.lora_B.T) * (self.alpha / self.rank)
        
        return original_output + lora_output

def apply_lora_to_model(model, target_modules=['q_proj', 'v_proj'], rank=16):
    """Apply LoRA to specific modules in the model"""
    
    for name, module in model.named_modules():
        if any(target in name for target in target_modules):
            if isinstance(module, nn.Linear):
                # Replace with LoRA layer
                parent = model
                for attr in name.split('.')[:-1]:
                    parent = getattr(parent, attr)
                
                lora_layer = LoRALayer(module, rank=rank)
                setattr(parent, name.split('.')[-1], lora_layer)
    
    return model

# Example usage
lora_model = apply_lora_to_model(pretrained_model, rank=16)

# Only LoRA parameters will be trained
trainable_params = sum(p.numel() for p in lora_model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in lora_model.parameters())
print(f"Trainable parameters: {trainable_params:,} ({100*trainable_params/total_params:.2f}%)")

Adapters ​

python
class AdapterLayer(nn.Module):
    """Adapter layer for parameter-efficient fine-tuning"""
    
    def __init__(self, hidden_size, adapter_size=64):
        super(AdapterLayer, self).__init__()
        
        self.adapter_down = nn.Linear(hidden_size, adapter_size)
        self.adapter_up = nn.Linear(adapter_size, hidden_size)
        self.activation = nn.ReLU()
        self.dropout = nn.Dropout(0.1)
        
        # Initialize with small weights
        nn.init.normal_(self.adapter_down.weight, std=0.01)
        nn.init.normal_(self.adapter_up.weight, std=0.01)
        nn.init.zeros_(self.adapter_down.bias)
        nn.init.zeros_(self.adapter_up.bias)
    
    def forward(self, x):
        # Adapter computation
        adapter_output = self.adapter_down(x)
        adapter_output = self.activation(adapter_output)
        adapter_output = self.dropout(adapter_output)
        adapter_output = self.adapter_up(adapter_output)
        
        # Residual connection
        return x + adapter_output

def add_adapters_to_transformer(model, adapter_size=64):
    """Add adapter layers to transformer blocks"""
    
    for layer in model.transformer.layers:
        # Add adapter after attention
        if hasattr(layer, 'attention'):
            original_forward = layer.attention.forward
            adapter = AdapterLayer(model.config.hidden_size, adapter_size)
            
            def new_forward(x, *args, **kwargs):
                attn_output = original_forward(x, *args, **kwargs)
                return adapter(attn_output)
            
            layer.attention.forward = new_forward
    
    return model

Prefix Tuning ​

python
class PrefixTuning(nn.Module):
    """Prefix tuning for conditional generation"""
    
    def __init__(self, config, prefix_length=20):
        super(PrefixTuning, self).__init__()
        self.prefix_length = prefix_length
        self.num_layers = config.num_layers
        self.num_heads = config.num_attention_heads
        self.head_dim = config.hidden_size // config.num_attention_heads
        
        # Learnable prefix parameters
        self.prefix_embeddings = nn.Parameter(
            torch.randn(prefix_length, config.hidden_size) * 0.01
        )
        
        # Project to key-value pairs for each layer
        self.prefix_projections = nn.ModuleList([
            nn.Linear(config.hidden_size, 2 * config.hidden_size)
            for _ in range(self.num_layers)
        ])
    
    def get_prefix_states(self, batch_size):
        """Get prefix key-value states for all layers"""
        prefix_states = []
        
        for projection in self.prefix_projections:
            # Project prefix embeddings
            projected = projection(self.prefix_embeddings)  # [prefix_length, 2*hidden_size]
            
            # Split into key and value
            key, value = projected.chunk(2, dim=-1)
            
            # Reshape for multi-head attention
            key = key.view(self.prefix_length, self.num_heads, self.head_dim)
            value = value.view(self.prefix_length, self.num_heads, self.head_dim)
            
            # Expand for batch
            key = key.unsqueeze(0).expand(batch_size, -1, -1, -1)
            value = value.unsqueeze(0).expand(batch_size, -1, -1, -1)
            
            prefix_states.append((key, value))
        
        return prefix_states

Task-Specific Fine-tuning ​

Text Classification ​

python
class TextClassificationFineTuner:
    def __init__(self, model_name, num_classes, max_length=512):
        from transformers import AutoModel, AutoTokenizer
        
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.encoder = AutoModel.from_pretrained(model_name)
        self.classifier = nn.Linear(self.encoder.config.hidden_size, num_classes)
        self.max_length = max_length
        
    def prepare_data(self, texts, labels):
        """Tokenize and prepare data"""
        encodings = self.tokenizer(
            texts,
            truncation=True,
            padding=True,
            max_length=self.max_length,
            return_tensors='pt'
        )
        
        dataset = torch.utils.data.TensorDataset(
            encodings['input_ids'],
            encodings['attention_mask'],
            torch.tensor(labels)
        )
        
        return dataset
    
    def forward(self, input_ids, attention_mask):
        outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output  # [CLS] token representation
        logits = self.classifier(pooled_output)
        return logits
    
    def fine_tune(self, train_texts, train_labels, val_texts, val_labels, epochs=3):
        # Prepare datasets
        train_dataset = self.prepare_data(train_texts, train_labels)
        val_dataset = self.prepare_data(val_texts, val_labels)
        
        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=16, shuffle=True)
        val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=16)
        
        # Setup optimizer
        optimizer = torch.optim.AdamW([
            {'params': self.encoder.parameters(), 'lr': 2e-5},
            {'params': self.classifier.parameters(), 'lr': 1e-4}
        ])
        
        criterion = nn.CrossEntropyLoss()
        
        for epoch in range(epochs):
            # Training
            self.encoder.train()
            self.classifier.train()
            
            for batch in train_loader:
                input_ids, attention_mask, labels = batch
                
                optimizer.zero_grad()
                logits = self.forward(input_ids, attention_mask)
                loss = criterion(logits, labels)
                loss.backward()
                optimizer.step()
            
            # Validation
            accuracy = self.evaluate(val_loader)
            print(f"Epoch {epoch+1}: Validation Accuracy = {accuracy:.4f}")
    
    def evaluate(self, val_loader):
        self.encoder.eval()
        self.classifier.eval()
        
        correct = 0
        total = 0
        
        with torch.no_grad():
            for batch in val_loader:
                input_ids, attention_mask, labels = batch
                logits = self.forward(input_ids, attention_mask)
                predictions = torch.argmax(logits, dim=1)
                
                correct += (predictions == labels).sum().item()
                total += labels.size(0)
        
        return correct / total

Question Answering ​

python
class QAFineTuner:
    def __init__(self, model_name):
        from transformers import AutoModel, AutoTokenizer
        
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.encoder = AutoModel.from_pretrained(model_name)
        
        # QA heads for start and end positions
        self.qa_outputs = nn.Linear(self.encoder.config.hidden_size, 2)
        
    def prepare_qa_data(self, contexts, questions, answers):
        """Prepare question-answering data"""
        encodings = self.tokenizer(
            questions,
            contexts,
            truncation=True,
            padding=True,
            max_length=512,
            return_tensors='pt'
        )
        
        # Find answer positions
        start_positions = []
        end_positions = []
        
        for i, (context, answer) in enumerate(zip(contexts, answers)):
            # Find answer span in tokenized context
            answer_start = context.find(answer['text'])
            if answer_start != -1:
                answer_end = answer_start + len(answer['text'])
                
                # Convert character positions to token positions
                start_token = len(self.tokenizer(context[:answer_start])['input_ids']) - 1
                end_token = len(self.tokenizer(context[:answer_end])['input_ids']) - 2
                
                start_positions.append(start_token)
                end_positions.append(end_token)
            else:
                # Answer not found, mark as unanswerable
                start_positions.append(0)
                end_positions.append(0)
        
        return encodings, torch.tensor(start_positions), torch.tensor(end_positions)
    
    def forward(self, input_ids, attention_mask):
        outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        sequence_output = outputs.last_hidden_state
        
        logits = self.qa_outputs(sequence_output)
        start_logits, end_logits = logits.split(1, dim=-1)
        
        start_logits = start_logits.squeeze(-1)
        end_logits = end_logits.squeeze(-1)
        
        return start_logits, end_logits

Instruction Tuning ​

python
class InstructionTuner:
    """Fine-tune models to follow instructions"""
    
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer
        
    def format_instruction(self, instruction, input_text="", output_text=""):
        """Format training examples in instruction format"""
        if input_text:
            prompt = f"### Instruction:\n{instruction}\n\n### Input:\n{input_text}\n\n### Response:\n"
        else:
            prompt = f"### Instruction:\n{instruction}\n\n### Response:\n"
        
        if output_text:
            return prompt + output_text
        else:
            return prompt
    
    def prepare_instruction_data(self, examples):
        """Prepare instruction-following dataset"""
        formatted_examples = []
        
        for example in examples:
            formatted = self.format_instruction(
                example['instruction'],
                example.get('input', ''),
                example['output']
            )
            formatted_examples.append(formatted)
        
        # Tokenize
        encodings = self.tokenizer(
            formatted_examples,
            truncation=True,
            padding=True,
            max_length=1024,
            return_tensors='pt'
        )
        
        return encodings
    
    def compute_instruction_loss(self, input_ids, attention_mask):
        """Compute loss only on response tokens"""
        # Find response start positions
        response_token = self.tokenizer.encode("### Response:")[0]
        response_starts = []
        
        for seq in input_ids:
            try:
                response_start = (seq == response_token).nonzero(as_tuple=True)[0][-1].item() + 1
                response_starts.append(response_start)
            except:
                response_starts.append(0)
        
        # Create labels (only compute loss on response)
        labels = input_ids.clone()
        for i, start in enumerate(response_starts):
            labels[i, :start] = -100  # Ignore instruction tokens
        
        # Forward pass
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        return outputs.loss

Reinforcement Learning from Human Feedback (RLHF) ​

python
class RLHFTrainer:
    """Reinforcement Learning from Human Feedback"""
    
    def __init__(self, policy_model, reward_model, ref_model):
        self.policy_model = policy_model
        self.reward_model = reward_model
        self.ref_model = ref_model
        
        # Freeze reference model
        for param in self.ref_model.parameters():
            param.requires_grad = False
    
    def compute_rewards(self, queries, responses):
        """Compute rewards using reward model"""
        with torch.no_grad():
            # Combine queries and responses
            full_sequences = []
            for query, response in zip(queries, responses):
                full_sequences.append(query + response)
            
            # Get reward scores
            rewards = self.reward_model(full_sequences)
            return rewards
    
    def compute_kl_penalty(self, query_response_ids, attention_mask):
        """Compute KL divergence penalty with reference model"""
        with torch.no_grad():
            ref_logits = self.ref_model(query_response_ids, attention_mask=attention_mask).logits
            ref_log_probs = F.log_softmax(ref_logits, dim=-1)
        
        policy_logits = self.policy_model(query_response_ids, attention_mask=attention_mask).logits
        policy_log_probs = F.log_softmax(policy_logits, dim=-1)
        
        kl_div = F.kl_div(policy_log_probs, ref_log_probs, reduction='none', log_target=True)
        return kl_div.sum(dim=-1)
    
    def ppo_step(self, queries, responses, rewards, kl_coeff=0.1):
        """Proximal Policy Optimization step"""
        
        # Prepare input
        query_response_ids = []
        attention_masks = []
        
        for query, response in zip(queries, responses):
            seq = torch.cat([query, response], dim=0)
            query_response_ids.append(seq)
            attention_masks.append(torch.ones_like(seq))
        
        query_response_ids = torch.stack(query_response_ids)
        attention_masks = torch.stack(attention_masks)
        
        # Compute KL penalty
        kl_penalty = self.compute_kl_penalty(query_response_ids, attention_masks)
        
        # Total reward = base reward - KL penalty
        total_rewards = rewards - kl_coeff * kl_penalty
        
        # Compute policy loss (simplified PPO)
        policy_logits = self.policy_model(query_response_ids, attention_mask=attention_masks).logits
        policy_loss = -total_rewards.mean()  # Simplified - actual PPO is more complex
        
        return policy_loss
    
    def train_step(self, batch):
        """Single RLHF training step"""
        queries = batch['queries']
        
        # Generate responses
        with torch.no_grad():
            response_ids = self.policy_model.generate(
                queries,
                max_length=512,
                do_sample=True,
                temperature=0.7,
                pad_token_id=self.tokenizer.eos_token_id
            )
        
        # Extract only the generated part
        responses = response_ids[:, queries.size(1):]
        
        # Compute rewards
        rewards = self.compute_rewards(queries, responses)
        
        # PPO update
        loss = self.ppo_step(queries, responses, rewards)
        
        return loss

Advanced Training Techniques ​

Curriculum Learning ​

python
class CurriculumLearning:
    """Gradually increase task difficulty during training"""
    
    def __init__(self, easy_data, medium_data, hard_data):
        self.data_stages = [easy_data, medium_data, hard_data]
        self.current_stage = 0
        
    def get_current_data(self, epoch):
        # Switch to next stage every few epochs
        if epoch > 0 and epoch % 5 == 0:
            self.current_stage = min(self.current_stage + 1, len(self.data_stages) - 1)
        
        return self.data_stages[self.current_stage]
    
    def should_advance(self, validation_accuracy):
        # Advance when model performs well on current stage
        return validation_accuracy > 0.8

Multi-task Learning ​

python
class MultiTaskTrainer:
    """Train on multiple tasks simultaneously"""
    
    def __init__(self, shared_encoder, task_heads):
        self.shared_encoder = shared_encoder
        self.task_heads = task_heads
        
    def compute_multi_task_loss(self, batch, task_weights=None):
        if task_weights is None:
            task_weights = {task: 1.0 for task in self.task_heads.keys()}
        
        # Shared encoding
        shared_features = self.shared_encoder(batch['input_ids'], batch['attention_mask'])
        
        total_loss = 0
        for task_name, head in self.task_heads.items():
            if task_name in batch:
                task_logits = head(shared_features)
                task_loss = F.cross_entropy(task_logits, batch[task_name])
                total_loss += task_weights[task_name] * task_loss
        
        return total_loss

# Example with classification and NER
multi_task_model = MultiTaskTrainer(
    shared_encoder=transformer_encoder,
    task_heads={
        'classification': nn.Linear(768, num_classes),
        'ner': nn.Linear(768, num_ner_tags)
    }
)

Continual Learning ​

python
class ContinualLearner:
    """Learn new tasks without forgetting old ones"""
    
    def __init__(self, model):
        self.model = model
        self.task_memories = {}
        self.ewc_lambda = 10000  # Elastic Weight Consolidation parameter
        
    def compute_fisher_information(self, dataloader):
        """Compute Fisher Information Matrix for EWC"""
        fisher = {}
        for name, param in self.model.named_parameters():
            fisher[name] = torch.zeros_like(param)
        
        self.model.eval()
        for batch in dataloader:
            self.model.zero_grad()
            loss = self.model(**batch).loss
            loss.backward()
            
            for name, param in self.model.named_parameters():
                if param.grad is not None:
                    fisher[name] += param.grad.data ** 2
        
        # Normalize
        for name in fisher:
            fisher[name] /= len(dataloader)
        
        return fisher
    
    def ewc_loss(self, current_params):
        """Elastic Weight Consolidation loss"""
        ewc_loss = 0
        for task_id, memory in self.task_memories.items():
            for name, param in current_params.items():
                if name in memory['params']:
                    ewc_loss += (memory['fisher'][name] * 
                               (param - memory['params'][name]) ** 2).sum()
        
        return self.ewc_lambda * ewc_loss
    
    def learn_new_task(self, task_id, train_loader, val_loader):
        """Learn new task while preserving old knowledge"""
        
        # Standard training with EWC regularization
        optimizer = torch.optim.AdamW(self.model.parameters(), lr=1e-5)
        
        for epoch in range(3):
            for batch in train_loader:
                optimizer.zero_grad()
                
                # Task loss
                task_loss = self.model(**batch).loss
                
                # EWC regularization
                current_params = {name: param for name, param in self.model.named_parameters()}
                ewc_loss = self.ewc_loss(current_params)
                
                total_loss = task_loss + ewc_loss
                total_loss.backward()
                optimizer.step()
        
        # Store task memory
        fisher = self.compute_fisher_information(val_loader)
        params = {name: param.clone() for name, param in self.model.named_parameters()}
        
        self.task_memories[task_id] = {
            'fisher': fisher,
            'params': params
        }

Evaluation and Monitoring ​

Training Metrics ​

python
class TrainingMonitor:
    """Monitor training progress and model performance"""
    
    def __init__(self):
        self.metrics = defaultdict(list)
        
    def log_metrics(self, epoch, **kwargs):
        for key, value in kwargs.items():
            self.metrics[key].append((epoch, value))
    
    def plot_training_curves(self):
        import matplotlib.pyplot as plt
        
        fig, axes = plt.subplots(2, 2, figsize=(12, 8))
        
        # Loss curves
        if 'train_loss' in self.metrics:
            epochs, losses = zip(*self.metrics['train_loss'])
            axes[0, 0].plot(epochs, losses, label='Train Loss')
        if 'val_loss' in self.metrics:
            epochs, losses = zip(*self.metrics['val_loss'])
            axes[0, 0].plot(epochs, losses, label='Val Loss')
        axes[0, 0].set_title('Loss Curves')
        axes[0, 0].legend()
        
        # Learning rate
        if 'learning_rate' in self.metrics:
            epochs, lrs = zip(*self.metrics['learning_rate'])
            axes[0, 1].plot(epochs, lrs)
            axes[0, 1].set_title('Learning Rate')
        
        # Gradient norms
        if 'grad_norm' in self.metrics:
            epochs, norms = zip(*self.metrics['grad_norm'])
            axes[1, 0].plot(epochs, norms)
            axes[1, 0].set_title('Gradient Norm')
        
        # Accuracy
        if 'accuracy' in self.metrics:
            epochs, accs = zip(*self.metrics['accuracy'])
            axes[1, 1].plot(epochs, accs)
            axes[1, 1].set_title('Accuracy')
        
        plt.tight_layout()
        plt.show()
    
    def early_stopping_check(self, patience=5):
        """Check if training should stop early"""
        if 'val_loss' not in self.metrics or len(self.metrics['val_loss']) < patience:
            return False
        
        recent_losses = [loss for _, loss in self.metrics['val_loss'][-patience:]]
        return all(recent_losses[i] >= recent_losses[i+1] for i in range(len(recent_losses)-1))

Model Evaluation ​

python
def comprehensive_evaluation(model, tokenizer, test_datasets):
    """Comprehensive evaluation across multiple metrics"""
    
    results = {}
    
    for dataset_name, dataset in test_datasets.items():
        print(f"Evaluating on {dataset_name}...")
        
        # Task-specific evaluation
        if 'classification' in dataset_name:
            accuracy, f1 = evaluate_classification(model, dataset)
            results[dataset_name] = {'accuracy': accuracy, 'f1': f1}
            
        elif 'generation' in dataset_name:
            bleu, rouge = evaluate_generation(model, tokenizer, dataset)
            results[dataset_name] = {'bleu': bleu, 'rouge': rouge}
            
        elif 'qa' in dataset_name:
            exact_match, f1 = evaluate_qa(model, tokenizer, dataset)
            results[dataset_name] = {'exact_match': exact_match, 'f1': f1}
    
    return results

def evaluate_classification(model, dataset):
    """Evaluate classification performance"""
    model.eval()
    predictions = []
    labels = []
    
    with torch.no_grad():
        for batch in DataLoader(dataset, batch_size=32):
            outputs = model(**batch)
            preds = torch.argmax(outputs.logits, dim=1)
            
            predictions.extend(preds.cpu().numpy())
            labels.extend(batch['labels'].cpu().numpy())
    
    accuracy = accuracy_score(labels, predictions)
    f1 = f1_score(labels, predictions, average='weighted')
    
    return accuracy, f1

Best Practices ​

Hyperparameter Selection ​

python
# Learning rates by model size
LEARNING_RATES = {
    'small': 5e-4,    # < 100M parameters
    'base': 1e-4,     # 100M - 1B parameters  
    'large': 5e-5,    # 1B - 10B parameters
    'xl': 1e-5,       # > 10B parameters
}

# Batch sizes for different hardware
BATCH_SIZES = {
    'single_gpu': 8,
    'multi_gpu': 32,
    'tpu': 128
}

# Fine-tuning epochs by task
EPOCHS = {
    'classification': 3,
    'qa': 2,
    'generation': 1,
    'instruction_following': 3
}

Common Pitfalls and Solutions ​

python
class FineTuningBestPractices:
    """Common pitfalls and how to avoid them"""
    
    @staticmethod
    def prevent_catastrophic_forgetting():
        """Techniques to preserve pre-trained knowledge"""
        return [
            "Use lower learning rates for pre-trained layers",
            "Apply dropout and weight decay",
            "Use gradual unfreezing",
            "Implement EWC or similar regularization",
            "Keep some general domain data in training"
        ]
    
    @staticmethod
    def handle_data_imbalance():
        """Deal with imbalanced datasets"""
        return [
            "Use weighted loss functions",
            "Apply data augmentation",
            "Use focal loss for extreme imbalance", 
            "Implement class-balanced sampling",
            "Consider cost-sensitive learning"
        ]
    
    @staticmethod
    def optimize_for_deployment():
        """Prepare models for production"""
        return [
            "Apply model quantization",
            "Use knowledge distillation",
            "Implement model pruning",
            "Optimize inference with ONNX",
            "Consider edge deployment constraints"
        ]

Next Steps:

Released under the MIT License.