core.components.train.train_runner#

Copyright (c) Meta Platforms, Inc. and affiliates.

This source code is licensed under the MIT license found in the LICENSE file in the root directory of this source tree.

Classes#

TrainCheckpointCallback

A Callback is an optional extension that can be used to supplement your loop with additional functionality. Good candidates

TrainEvalRunner

Represents an abstraction over things that run in a loop and can save/load state.

Functions#

Module Contents#

core.components.train.train_runner.get_most_recent_viable_checkpoint_path(checkpoint_dir: str | None) str | None#
class core.components.train.train_runner.TrainCheckpointCallback(checkpoint_every_n_steps: int, max_saved_checkpoints: int = 2)#

Bases: torchtnt.framework.callback.Callback

A Callback is an optional extension that can be used to supplement your loop with additional functionality. Good candidates for such logic are ones that can be re-used across units. Callbacks are generally not intended for modeling code; this should go in your Unit. To write your own callback, subclass the Callback class and add your own code into the hooks.

Below is an example of a basic callback which prints a message at various points during execution.

from torchtnt.framework.callback import Callback
from torchtnt.framework.state import State
from torchtnt.framework.unit import TEvalUnit, TPredictUnit, TTrainUnit

class PrintingCallback(Callback):
    def on_train_start(self, state: State, unit: TTrainUnit) -> None:
        print("Starting training")

    def on_train_end(self, state: State, unit: TTrainUnit) -> None:
        print("Ending training")

    def on_eval_start(self, state: State, unit: TEvalUnit) -> None:
        print("Starting evaluation")

    def on_eval_end(self, state: State, unit: TEvalUnit) -> None:
        print("Ending evaluation")

    def on_predict_start(self, state: State, unit: TPredictUnit) -> None:
        print("Starting prediction")

    def on_predict_end(self, state: State, unit: TPredictUnit) -> None:
        print("Ending prediction")

To use a callback, instantiate the class and pass it in the callbacks parameter to the train(), evaluate(), predict(), or fit() entry point.

printing_callback = PrintingCallback()
train(train_unit, train_dataloader, callbacks=[printing_callback])
checkpoint_every_n_steps#
max_saved_checkpoints#
save_callback = None#
load_callback = None#
checkpoint_dir = None#
set_runner_callbacks(save_callback: callable, load_callback: callable, checkpoint_dir: str) None#
on_train_step_start(state: torchtnt.framework.state.State, unit: torchtnt.framework.unit.TTrainUnit) None#

Hook called before a new train step starts.

on_train_end(state: torchtnt.framework.state.State, unit: torchtnt.framework.unit.TTrainUnit) None#

Hook called after training ends.

class core.components.train.train_runner.TrainEvalRunner(train_dataloader: torch.utils.data.dataloader, eval_dataloader: torch.utils.data.dataloader, train_eval_unit: torchtnt.framework.TrainUnit | torchtnt.framework.EvalUnit | torch.distributed.checkpoint.stateful.Stateful, callbacks: list[torchtnt.framework.callback.Callback] | None = None, max_epochs: int | None = 1, evaluate_every_n_steps: int | None = None, max_steps: int | None = None, save_inference_ckpt: bool = True)#

Bases: fairchem.core.components.runner.Runner

Represents an abstraction over things that run in a loop and can save/load state.

ie: Trainers, Validators, Relaxation all fall in this category.

Note

When running with the fairchemv2 cli, the job_config and attribute is set at runtime to those given in the config file.

job_config#

a managed attribute that gives access to the job config

Type:

DictConfig

train_dataloader#
eval_dataloader#
train_eval_unit#
callbacks#
max_epochs#
max_steps#
evaluate_every_n_steps#
save_inference_ckpt#
checkpoint_callback#
run() None#
save_state(checkpoint_location: str, is_preemption: bool = False) bool#
load_state(checkpoint_location: str | None) None#