onnxruntime
onnxruntime copied to clipboard
Add Learning Rate Scheduler C API
This pull request adds an API for registering a scheduler with the training session.
// c api
OrtTrainingSession* session;
g_ort_training_api->CreateTrainingSession(env, session_options, checkpoint_state,
training_model.c_str(), eval_model.c_str(),
optimizer_model.c_str(), &session));
// Set the initial learning rate.
g_ort_training_api->SetLearningRate(session, learning_rate);
// Register the linear leaning rate scheduler.
g_ort_training_api->RegisterLinearLRScheduler(session, warm_up_step_count, total_step_count);
// Take a scheduler step.
g_ort_training_api->SchedulerStep(session);