core.components.train.train_runner#
Copyright (c) Meta Platforms, Inc. and affiliates.
This source code is licensed under the MIT license found in the LICENSE file in the root directory of this source tree.
Classes#
Protocol that Units used by this trainer should implement if they want save and resume functionality |
|
A Callback is an optional extension that can be used to supplement your loop with additional functionality. Good candidates |
|
Represents an abstraction over things that run in a loop and can save/load state. |
Functions#
|
Module Contents#
- class core.components.train.train_runner.Checkpointable#
Bases:
Protocol
Protocol that Units used by this trainer should implement if they want save and resume functionality This is in addition to Pytorch’s Stateful protocol because it allows units implement custom logic that’s required for checkpointing
- save_state(checkpoint_location: str) None #
Save the unit state to a checkpoint path
- Parameters:
checkpoint_location – The checkpoint path to save to
- load_state(checkpoint_location: str | None) None #
Loads the state given a checkpoint path
- Parameters:
checkpoint_location – The checkpoint path to restore from
- core.components.train.train_runner.get_most_recent_viable_checkpoint_path(checkpoint_dir: str | None) str | None #
- class core.components.train.train_runner.TrainCheckpointCallback(checkpoint_every_n_steps: int, max_saved_checkpoints: int = 2)#
Bases:
torchtnt.framework.callback.Callback
A Callback is an optional extension that can be used to supplement your loop with additional functionality. Good candidates for such logic are ones that can be re-used across units. Callbacks are generally not intended for modeling code; this should go in your Unit. To write your own callback, subclass the Callback class and add your own code into the hooks.
Below is an example of a basic callback which prints a message at various points during execution.
from torchtnt.framework.callback import Callback from torchtnt.framework.state import State from torchtnt.framework.unit import TEvalUnit, TPredictUnit, TTrainUnit class PrintingCallback(Callback): def on_train_start(self, state: State, unit: TTrainUnit) -> None: print("Starting training") def on_train_end(self, state: State, unit: TTrainUnit) -> None: print("Ending training") def on_eval_start(self, state: State, unit: TEvalUnit) -> None: print("Starting evaluation") def on_eval_end(self, state: State, unit: TEvalUnit) -> None: print("Ending evaluation") def on_predict_start(self, state: State, unit: TPredictUnit) -> None: print("Starting prediction") def on_predict_end(self, state: State, unit: TPredictUnit) -> None: print("Ending prediction")
To use a callback, instantiate the class and pass it in the
callbacks
parameter to thetrain()
,evaluate()
,predict()
, orfit()
entry point.printing_callback = PrintingCallback() train(train_unit, train_dataloader, callbacks=[printing_callback])
- checkpoint_every_n_steps#
- max_saved_checkpoints#
- save_callback = None#
- load_callback = None#
- checkpoint_dir = None#
- set_runner_callbacks(save_callback: callable, load_callback: callable, checkpoint_dir: str) None #
- on_train_step_start(state: torchtnt.framework.state.State, unit: torchtnt.framework.unit.TTrainUnit) None #
Hook called before a new train step starts.
- on_train_end(state: torchtnt.framework.state.State, unit: torchtnt.framework.unit.TTrainUnit) None #
Hook called after training ends.
- class core.components.train.train_runner.TrainEvalRunner(train_dataloader: torch.utils.data.dataloader, eval_dataloader: torch.utils.data.dataloader, train_eval_unit: torchtnt.framework.TrainUnit | torchtnt.framework.EvalUnit | Checkpointable, callbacks: list[torchtnt.framework.callback.Callback] | None = None, max_epochs: int | None = 1, evaluate_every_n_steps: int | None = None, max_steps: int | None = None)#
Bases:
fairchem.core.components.runner.Runner
Represents an abstraction over things that run in a loop and can save/load state.
ie: Trainers, Validators, Relaxation all fall in this category.
Note
When running with the fairchemv2 cli, the job_config and attribute is set at runtime to those given in the config file.
- job_config#
a managed attribute that gives access to the job config
- Type:
DictConfig
- train_dataloader#
- eval_dataloader#
- train_eval_unit#
- callbacks#
- max_epochs#
- max_steps#
- evaluate_every_n_steps#
- checkpoint_callback#
- run() None #
- save_state(checkpoint_location: str, is_preemption: bool = False) bool #
- load_state(checkpoint_location: str | None) None #