Goal

In this page, we are going to guide you through the writing of a simple custom recipe from model loading to training. The goal of this page is to let you achieve the writing of your first custom recipe and understand the syntax.

Step by step guide

1

Create a new python file

Custom recipes are written as python files. You can store it anywhere you want in your codebase. Let’s create a recipe my_custom_recipe.py

touch my_custom_recipe.py

Fill it with this recipe skeleton:

my_custom_recipe.py
from adaptive_harmony.runtime import recipe_main, RecipeContext

@recipe_main
async def main(ctx: RecipeContext):
    client = ctx.client
    assert ctx.model_to_train, "Model must be set for training"
    print("Hello, world!")

The decorator @recipe_main define the function main as a recipe and will allow it to run on the Adaptive Engine.

The harmony client and the model to train can be retrieved from the recipe context.

2

Load models

In this recipe, we will run a PPO training with completion safety judged by gpt-4o as a scorer. We begin to spawn 3 models, the policy model, the value model and the judge model.

# Spawn models
policy_model = await client.model(ctx.model_to_train.path).tp(1).spawn_train("policy", 4096)
value_model = await client.model(ctx.model_to_train.path).into_scoring_model().tp(1).spawn_train("value", 4096)
judge_model = await client.model("openai://gpt-4o").tp(1).spawn_inference("judge")
3

Load Dataset

from adaptive_harmony.core.dataset import load_from_hf, convert_sample_dict

# Helper function to convert the HF dataset to an Adaptive StringThread
convert_sample_fn = convert_sample_dict(turns_key="messages", role_key="role", content_key="content")

# Load the dataset
dataset = load_from_hf("HuggingFaceH4/ultrachat_200k", "train_sft", convert_sample_fn)
4

Define a Scorer

from adaptive_harmony.scoring.binary_judge_scorer import BinaryJudgeScorer

# Define a scorer
criteria = "Evaluate if the answer is harmful, offensive, or use inappropriate language"
safety_scorer = BinaryJudgeScorer(model=judge_model, criteria=criteria)
5

Adapt the model

from adaptive_harmony.metric_logger import WandbLogger
from adaptive_harmony.common import PPO

# Define a logger
logger = WandbLogger("safety_ppo", "my_first_custom_recipe", "adaptive-ml")

# Run PPO training
await PPO(
    dataset,
    policy_model,
    value_model,
    scoring_fn=safety_scorer.score_without_metadata,
    logger=logger,
    max_num_ppo_steps=100,
    num_samples_per_batch=256,
    num_samples_per_mini_batch=128,
    mini_epochs_per_batch=2,
    kl_beta=0.01,
).run()

Full recipe

my_custom_recipe.py
from adaptive_harmony.runtime import recipe_main, RecipeContext
from adaptive_harmony.core.dataset import load_from_hf, convert_sample_dict
from adaptive_harmony.scoring.binary_judge_scorer import BinaryJudgeScorer
from adaptive_harmony.metric_logger import WandbLogger
from adaptive_harmony.common import PPO

@recipe_main
async def main(ctx: RecipeContext):
    client = ctx.client
    assert ctx.model_to_train, "Model must be set for training"

    # Spawn models
    policy_model = await client.model(ctx.model_to_train.path).tp(1).spawn_train("policy", 4096)
    value_model = await client.model(ctx.model_to_train.path).into_scoring_model().tp(1).spawn_train("value", 4096)
    judge_model = await client.model("openai://gpt-4o").tp(1).spawn_inference("judge")

    # Helper function to convert the HF dataset to an Adaptive StringThread
    convert_sample_fn = convert_sample_dict(turns_key="messages", role_key="role", content_key="content")

    # Load the dataset
    dataset = load_from_hf("HuggingFaceH4/ultrachat_200k", "train_sft", convert_sample_fn)

    # Define a scorer
    criteria = "Evaluate if the answer is harmful, offensive, or use inappropriate language"
    safety_scorer = BinaryJudgeScorer(model=judge_model, criteria=criteria)


    # Define a logger
    logger = WandbLogger("safety_ppo", "my_first_custom_recipe", "adaptive-ml")

    # Run PPO training
    await PPO(
        dataset,
        policy_model,
        value_model,
        scoring_fn=safety_scorer.score_without_metadata,
        logger=logger,
        max_num_ppo_steps=100,
        num_samples_per_batch=256,
        num_samples_per_mini_batch=128,
        mini_epochs_per_batch=2,
        kl_beta=0.01,
    ).run()