Overview
Training callbacks allow you to execute custom logic periodically during training runs. They’re useful for monitoring model performance, generating sample outputs, saving checkpoints, and evaluating on validation sets without interrupting the main training loop.
All trainer classes accept a callbacks parameter where you can pass a list of callback instances.
How Callbacks Work
Callbacks are triggered at regular intervals based on training progress (measured as a percentage of the job’s completion). Each callback:
- Runs periodically - You specify how often (e.g., every 10% of training via
frequency=0.1)
- Returns metrics - Callbacks return dictionaries that can be logged to your metric logger
Using Callbacks
Pass callbacks to any training class via the callbacks parameter:
from adaptive_harmony.common.grpo import GRPO
from adaptive_harmony.common.callbacks import GraderEvalCallback
from adaptive_harmony.metric_logger import get_prod_logger
logger = get_prod_logger()
# Create callbacks
grader_callback = GraderEvalCallback(
validation_set=validation_dataset,
model=policy_model,
grader=safety_grader,
frequency=0.1 # Evaluate every 10% of training
)
# Pass callbacks to trainer
await GRPO(
dataset=train_dataset,
model=policy_model,
grader=safety_grader,
logger=logger,
callbacks=[grader_callback],
checkpoint_frequency=0.2 # Save checkpoints every 20% of training
).run()
Built-in Callbacks
GraderEvalCallback
Evaluates your model on a validation set using a grader. This is the most common callback for monitoring training progress on held-out data.
Make sure the threads in validation_dataset do not already contain completions, as the evaluated model will see them when asked to generate a new one.
from adaptive_harmony.common.callbacks import GraderEvalCallback
grader_callback = GraderEvalCallback(
validation_set=validation_dataset, # List of StringThread prompts
model=policy_model, # InferenceModel to evaluate (will generate completions for dataset in callback)
grader=safety_grader, # Grader to score completions
frequency=0.1, # Evaluate every 10% of training
log_key="validation", # Prefix for logged metrics (default: "validation")
clear_grader_logs=True, # Reset grader logs after each eval (default: True)
temperature=0.0 # Generation temperature (default: 0.0)
)
Logged metrics:
validation/rewards/* - All metrics from the grader’s get_logs() method
validation/generation_length_mean - Average generation length
validation/generation_length_std - Standard deviation of generation length
validation/num_samples - Number of samples evaluated
ValidationLossCallback
Computes the negative log-likelihood loss on a validation set. Useful for monitoring overfitting in supervised fine-tuning.
from adaptive_harmony.common.callbacks import ValidationLossCallback
loss_callback = ValidationLossCallback(
validation_set=validation_dataset, # List of StringThread with completions
model=policy_model, # InferenceModel to evaluate
frequency=0.1, # Evaluate every 10% of training
log_key="loss" # Metric key (default: "loss")
)
Logged metrics:
validation/loss - Average negative log-likelihood on validation set
GenerateSamplesCallback
Generates and logs sample completions periodically. Useful for qualitatively inspecting model outputs during training.
from adaptive_harmony.common.callbacks import GenerateSamplesCallback
samples_callback = GenerateSamplesCallback(
thread_set=sample_prompts, # List of StringThread prompts to complete
model=policy_model, # InferenceModel to generate completions
frequency=0.1, # Generate every 10% of training
log_key="samples" # Table name in logs (default: "samples")
)
Logged metrics:
generation/samples - Table with columns: system, prompt, response
generation/generation_length_mean - Average completion length
generation/generation_length_std - Standard deviation of completion length
generation/num_samples - Number of samples generated
Creating Custom Callbacks
To create your own callback, inherit from RecipeCallback and implement the callback method:
from adaptive_harmony.common.callbacks import RecipeCallback
from typing import Any
class MyCustomCallback(RecipeCallback):
def __init__(self, frequency: float, my_param: str):
super().__init__(frequency, log_key_prefix="custom")
self.my_param = my_param
async def callback(self, current_percentage: float) -> dict[str, Any]:
# Your custom logic here
# current_percentage is a float from 0.0 to 1.0
# Return a dictionary of metrics to log
return {
"my_metric": some_value,
"progress": current_percentage
}
Key points:
frequency - How often to trigger (e.g., 0.1 = every 10% of training)
log_key_prefix - Optional prefix for all logged metric keys
- Return a dictionary of metrics to be logged
- Use async/await for any I/O operations
Best Practices
-
Start with low frequency - Callbacks can slow down training if you run them too often. Start with
frequency=0.1 or 0.2 and adjust as needed.
-
Use small validation sets - Keep validation sets small (e.g., 100-500 samples) for faster evaluation.
-
Combine callbacks strategically:
- GraderEvalCallback - Essential for RL training to monitor reward on validation data
- GenerateSamplesCallback - Helpful for debugging and qualitative inspection
- ValidationLossCallback - Useful for SFT to detect overfitting
-
Set appropriate temperatures:
- Use
temperature=0.0 for deterministic evaluation (recommended for validation)
- Use
temperature=1.0 for diverse samples in GenerateSamplesCallback
-
Mind your validation data - Make sure validation prompts don’t overlap with training data to get accurate generalization metrics. Also, make sure you don’t include completions in validation data, otherwise the evaluated model will see them when asked to generate a new one.
Example: Comprehensive Monitoring
Here’s a complete example combining multiple callbacks:
from adaptive_harmony.common.gspo import GSPO
from adaptive_harmony.common.callbacks import (
GraderEvalCallback,
GenerateSamplesCallback,
)
from adaptive_harmony.metric_logger import get_prod_logger
logger = get_prod_logger()
# Split dataset into train/validation
train_dataset = dataset[:900]
validation_dataset = dataset[900:]
# Sample prompts for qualitative inspection
sample_prompts = validation_dataset[:10]
# Create callbacks
grader_callback = GraderEvalCallback(
validation_set=validation_dataset,
model=policy_model,
grader=safety_grader,
frequency=0.1
)
samples_callback = GenerateSamplesCallback(
thread_set=sample_prompts,
model=policy_model,
frequency=0.1
)
# Run training with all callbacks
await GSPO(
dataset=train_dataset,
model=policy_model,
grader=safety_grader,
logger=logger,
callbacks=[grader_callback, samples_callback],
max_num_gspo_steps=100,
samples_per_batch=128,
completions_per_sample=8
).run()
This setup will:
- Evaluate on validation set every 10% of training
- Generate sample completions every 10% for qualitative inspection
- Log all metrics to your configured metric logger (W&B, MLflow, etc.)