Skip to main content

Overview

Training callbacks allow you to execute custom logic periodically during training runs. They’re useful for monitoring model performance, generating sample outputs, saving checkpoints, and evaluating on validation sets without interrupting the main training loop. All trainer classes accept a callbacks parameter where you can pass a list of callback instances.

How Callbacks Work

Callbacks are triggered at regular intervals based on training progress (measured as a percentage of the job’s completion). Each callback:
  1. Runs periodically - You specify how often (e.g., every 10% of training via frequency=0.1)
  2. Returns metrics - Callbacks return dictionaries that can be logged to your metric logger

Using Callbacks

Pass callbacks to any training class via the callbacks parameter:
from adaptive_harmony.common.grpo import GRPO
from adaptive_harmony.common.callbacks import GraderEvalCallback
from adaptive_harmony.metric_logger import get_prod_logger

logger = get_prod_logger()

# Create callbacks
grader_callback = GraderEvalCallback(
    validation_set=validation_dataset,
    model=policy_model,
    grader=safety_grader,
    frequency=0.1  # Evaluate every 10% of training
)

# Pass callbacks to trainer
await GRPO(
    dataset=train_dataset,
    model=policy_model,
    grader=safety_grader,
    logger=logger,
    callbacks=[grader_callback],
    checkpoint_frequency=0.2  # Save checkpoints every 20% of training
).run()

Built-in Callbacks

GraderEvalCallback

Evaluates your model on a validation set using a grader. This is the most common callback for monitoring training progress on held-out data.
Make sure the threads in validation_dataset do not already contain completions, as the evaluated model will see them when asked to generate a new one.
from adaptive_harmony.common.callbacks import GraderEvalCallback

grader_callback = GraderEvalCallback(
    validation_set=validation_dataset,  # List of StringThread prompts
    model=policy_model,                 # InferenceModel to evaluate (will generate completions for dataset in callback)
    grader=safety_grader,               # Grader to score completions
    frequency=0.1,                      # Evaluate every 10% of training
    log_key="validation",               # Prefix for logged metrics (default: "validation")
    clear_grader_logs=True,             # Reset grader logs after each eval (default: True)
    temperature=0.0                     # Generation temperature (default: 0.0)
)
Logged metrics:
  • validation/rewards/* - All metrics from the grader’s get_logs() method
  • validation/generation_length_mean - Average generation length
  • validation/generation_length_std - Standard deviation of generation length
  • validation/num_samples - Number of samples evaluated

ValidationLossCallback

Computes the negative log-likelihood loss on a validation set. Useful for monitoring overfitting in supervised fine-tuning.
from adaptive_harmony.common.callbacks import ValidationLossCallback

loss_callback = ValidationLossCallback(
    validation_set=validation_dataset,  # List of StringThread with completions
    model=policy_model,                 # InferenceModel to evaluate
    frequency=0.1,                      # Evaluate every 10% of training
    log_key="loss"                      # Metric key (default: "loss")
)
Logged metrics:
  • validation/loss - Average negative log-likelihood on validation set

GenerateSamplesCallback

Generates and logs sample completions periodically. Useful for qualitatively inspecting model outputs during training.
from adaptive_harmony.common.callbacks import GenerateSamplesCallback

samples_callback = GenerateSamplesCallback(
    thread_set=sample_prompts,          # List of StringThread prompts to complete
    model=policy_model,                 # InferenceModel to generate completions
    frequency=0.1,                      # Generate every 10% of training
    log_key="samples"                   # Table name in logs (default: "samples")
)
Logged metrics:
  • generation/samples - Table with columns: system, prompt, response
  • generation/generation_length_mean - Average completion length
  • generation/generation_length_std - Standard deviation of completion length
  • generation/num_samples - Number of samples generated

Creating Custom Callbacks

To create your own callback, inherit from RecipeCallback and implement the callback method:
from adaptive_harmony.common.callbacks import RecipeCallback
from typing import Any

class MyCustomCallback(RecipeCallback):
    def __init__(self, frequency: float, my_param: str):
        super().__init__(frequency, log_key_prefix="custom")
        self.my_param = my_param

    async def callback(self, current_percentage: float) -> dict[str, Any]:
        # Your custom logic here
        # current_percentage is a float from 0.0 to 1.0

        # Return a dictionary of metrics to log
        return {
            "my_metric": some_value,
            "progress": current_percentage
        }
Key points:
  • frequency - How often to trigger (e.g., 0.1 = every 10% of training)
  • log_key_prefix - Optional prefix for all logged metric keys
  • Return a dictionary of metrics to be logged
  • Use async/await for any I/O operations

Best Practices

  1. Start with low frequency - Callbacks can slow down training if you run them too often. Start with frequency=0.1 or 0.2 and adjust as needed.
  2. Use small validation sets - Keep validation sets small (e.g., 100-500 samples) for faster evaluation.
  3. Combine callbacks strategically:
    • GraderEvalCallback - Essential for RL training to monitor reward on validation data
    • GenerateSamplesCallback - Helpful for debugging and qualitative inspection
    • ValidationLossCallback - Useful for SFT to detect overfitting
  4. Set appropriate temperatures:
    • Use temperature=0.0 for deterministic evaluation (recommended for validation)
    • Use temperature=1.0 for diverse samples in GenerateSamplesCallback
  5. Mind your validation data - Make sure validation prompts don’t overlap with training data to get accurate generalization metrics. Also, make sure you don’t include completions in validation data, otherwise the evaluated model will see them when asked to generate a new one.

Example: Comprehensive Monitoring

Here’s a complete example combining multiple callbacks:
from adaptive_harmony.common.gspo import GSPO
from adaptive_harmony.common.callbacks import (
    GraderEvalCallback,
    GenerateSamplesCallback,
)
from adaptive_harmony.metric_logger import get_prod_logger

logger = get_prod_logger()

# Split dataset into train/validation
train_dataset = dataset[:900]
validation_dataset = dataset[900:]

# Sample prompts for qualitative inspection
sample_prompts = validation_dataset[:10]

# Create callbacks
grader_callback = GraderEvalCallback(
    validation_set=validation_dataset,
    model=policy_model,
    grader=safety_grader,
    frequency=0.1
)

samples_callback = GenerateSamplesCallback(
    thread_set=sample_prompts,
    model=policy_model,
    frequency=0.1
)

# Run training with all callbacks
await GSPO(
    dataset=train_dataset,
    model=policy_model,
    grader=safety_grader,
    logger=logger,
    callbacks=[grader_callback, samples_callback],
    max_num_gspo_steps=100,
    samples_per_batch=128,
    completions_per_sample=8
).run()
This setup will:
  • Evaluate on validation set every 10% of training
  • Generate sample completions every 10% for qualitative inspection
  • Log all metrics to your configured metric logger (W&B, MLflow, etc.)