from adaptive_harmony.runtime import InputConfig, recipe_main
from adaptive_harmony.runtime.context import RecipeContext
from adaptive_harmony.common.sft import SFT
from adaptive_harmony.common.gspo import GSPO
from adaptive_harmony.parameters import Model, Dataset, dataset_kinds
from adaptive_harmony.graders import BinaryJudgeGrader
from adaptive_harmony.graders.combined_grader import CombinedGrader
from adaptive_harmony.metric_logger import get_prod_logger
class SFTtoGSPOConfig(InputConfig):
"""Configuration for SFT warm-up followed by GSPO training."""
sft_dataset: Dataset[dataset_kinds.Completion] # Demonstrations with expected outputs
gspo_dataset: Dataset[dataset_kinds.Prompt] # Prompts for RL (no ground truth needed)
model_to_train: Model
# SFT hyperparameters
sft_learning_rate: float = 1e-5
sft_epochs: int = 1
sft_batch_size: int = 512
# GSPO hyperparameters
gspo_learning_rate: float = 5e-7
gspo_max_steps: int = 1000
gspo_batch_size: int = 256
@recipe_main
async def sft_to_gspo(config: SFTtoGSPOConfig, ctx: RecipeContext):
"""Run SFT warm-up, then continue with GSPO for quality improvement."""
# --- Model setup ---
# Both stages train the same model sequentially: SFT initializes it,
# then GSPO continues from where SFT left off. We use spawn_train()
# because we need gradient updates (spawn_inference is read-only).
model_builder = await config.model_to_train.to_builder(ctx, tp=1, kv_cache_len=131072)
model = await model_builder.spawn_train(
"sft_then_gspo",
max_batch_size=10_000,
)
logger = get_prod_logger(ctx=ctx)
# === STAGE 1: SFT warm-up ===
# SFT uses a Completion dataset — each sample contains a full conversation
# with the expected assistant response. The model learns to reproduce these
# demonstrations via supervised cross-entropy loss.
sft_dataset = await config.sft_dataset.load(ctx)
sft_trainer = SFT(
dataset=sft_dataset,
model=model,
logger=logger,
lr=config.sft_learning_rate,
epochs=config.sft_epochs, # Usually 1-3; more risks overfitting
samples_per_batch=config.sft_batch_size,
max_grad_norm=1.0,
weight_decay=0.01,
)
await sft_trainer.run()
# At this point, `model` has been updated in-place with SFT weights.
# We continue training the same model instance in the next stage.
# === STAGE 2: GSPO for quality improvement ===
# GSPO uses a Prompt dataset — each sample is a prompt only, no expected
# output. The algorithm generates completions, grades them, and learns
# from the reward signal.
gspo_dataset = await config.gspo_dataset.load(ctx)
# Spawn a separate model for judging. This uses spawn_inference() (not
# spawn_train) because we only need forward passes — no gradients.
judge_model_builder = await config.model_to_train.to_builder(ctx, tp=1, kv_cache_len=131072)
judge_model = await judge_model_builder.spawn_inference("judge")
# Set up a grader to score completions. BinaryJudgeGrader sends each
# completion to the judge LLM, which scores it PASS (1.0) or FAIL (0.0).
grader = BinaryJudgeGrader(
grader_key="output-quality",
criteria="Is the output high-quality, well-formed, and correct?",
model=judge_model,
)
# CombinedGrader aggregates one or more graders into a single score.
# Even with one grader, the training algorithms expect a CombinedGrader.
combined_grader = CombinedGrader(
grader_key="combined",
graders=[grader],
)
gspo_trainer = GSPO(
dataset=gspo_dataset,
model=model, # Same model instance, now SFT-warm-started
grader=combined_grader,
logger=logger,
lr=config.gspo_learning_rate, # Typically lower than SFT LR
samples_per_batch=config.gspo_batch_size,
samples_per_mini_batch=64,
max_grad_norm=1.0,
clip_range=0.2,
kl_beta=0.1,
max_num_gspo_steps=config.gspo_max_steps,
completions_per_sample=4, # Generate 4 completions per prompt for ranking
)
await gspo_trainer.run()
# Save the final model — includes both SFT and GSPO training
await model.save(model_name="sft-then-gspo-final", ctx=ctx)