How to write your own recipe for model training
spawn_train
method (see details here). This method creates a trainable instance of your model. We set the world_size (the number of GPU’s that the job is running on) as the tp
degree of the model by grabbing it from the recipe context ctx
.
CombinedGrader
class to do this aggregation.
setup
method should always be called on a grader before it is used, it makes sure to spawn your AI judges if you have AI judges graders (this is why you also pass the client
to graders).
If you have a CombinedGrader
class, calling setup
on it on it will make sure to call the setup
method on every child grader.
CombinedGrader
along with other graders you’ve created and registered in the platform.GRPO
class will run the main training loop.
model.save()
method. If you want to guarantee there are no model key collisions (in case you’ve saved a model with the same key in the past), you can use the save_model_safely
method instead.
dataset
: List of StringThread
objects containing training examplesmodel
: Training model instancelr
: Learning ratesamples_per_batch
: Batch sizemax_grad_norm
: Gradient clipping normdataset
: List of StringThread
promptspolicy_model
: Policy model for trainingvalue_model
: Value model for advantage estimationgrader
: Grader that returns reward valueslr_policy
: Policy learning ratelr_value
: Value learning ratekl_beta
: KL divergence penalty coefficientsamples_per_batch
: Number of samples that compose a single PPO stepmax_num_ppo_steps
: Total number of PPO training steps to takesamples_per_batch
is larger than samples_per_mini_batch
, so that there is sample diversity for each gradient update.
dataset
: List of StringThread
promptsmodel
: Training modelgrader
: Grader that returns reward valuescompletions_per_sample
: Number of completions per promptlr
: Learning ratekl_beta
: KL divergence penalty coefficientdataset
: List of tuples containing (preferred_response, dispreferred_response)model
: Training modellr
: Learning ratesamples_per_batch
: Batch sizebeta
: DPO beta parameter