core.trainers.base_trainer#

Copyright (c) Meta, Inc. and its affiliates.

This source code is licensed under the MIT license found in the LICENSE file in the root directory of this source tree.

Classes#

BaseTrainer

Helper class that provides a standard way to create an ABC using

Module Contents#

class core.trainers.base_trainer.BaseTrainer(task: dict[str, str | Any], model: dict[str, Any], outputs: dict[str, str | int], dataset: dict[str, str | float], optimizer: dict[str, str | float], loss_functions: dict[str, str | float], evaluation_metrics: dict[str, str], identifier: str, local_rank: int, timestamp_id: str | None = None, run_dir: str | None = None, is_debug: bool = False, print_every: int = 100, seed: int | None = None, logger: str = 'wandb', amp: bool = False, cpu: bool = False, name: str = 'ocp', slurm=None, gp_gpus: int | None = None, inference_only: bool = False)#

Bases: abc.ABC

Helper class that provides a standard way to create an ABC using inheritance.

name#
is_debug#
cpu#
epoch = 0#
step = 0#
ema = None#
timestamp_id: str#
config#
scaler#
elementrefs#
normalizers#
train_dataset = None#
val_dataset = None#
test_dataset = None#
best_val_metric = None#
primary_metric = None#
abstract train(disable_eval_tqdm: bool = False) None#

Run model training iterations.

static _get_timestamp(device: torch.device, suffix: str | None) str#
load(inference_only: bool) None#
static set_seed(seed) None#
load_seed_from_config() None#
load_logger() None#
get_sampler(dataset, batch_size: int, shuffle: bool) fairchem.core.common.data_parallel.BalancedBatchSampler#
get_dataloader(dataset, sampler, workers=None) torch.utils.data.DataLoader#
load_datasets() None#
load_references_and_normalizers()#

Load or create element references and normalizers from config

load_task()#
load_model() None#
property _unwrapped_model#
load_checkpoint(checkpoint_path: str, checkpoint: dict | None = None, inference_only: bool = False) None#
load_loss() None#
load_optimizer() None#
load_extras() None#
save(metrics=None, checkpoint_file: str = 'checkpoint.pt', training_state: bool = True) str | None#
update_best(primary_metric, val_metrics, disable_eval_tqdm: bool = True) None#
_aggregate_metrics(metrics)#
validate(split: str = 'val', disable_tqdm: bool = False)#
_backward(loss) None#
save_results(predictions: dict[str, numpy.typing.NDArray], results_file: str | None, keys: collections.abc.Sequence[str] | None = None) None#