core.trainers.ocp_trainer#

Copyright (c) Meta, Inc. and its affiliates.

This source code is licensed under the MIT license found in the LICENSE file in the root directory of this source tree.

Classes#

OCPTrainer

Trainer class for the Structure to Energy & Force (S2EF) and Initial State to

Module Contents#

class core.trainers.ocp_trainer.OCPTrainer(task: dict[str, str | Any], model: dict[str, Any], outputs: dict[str, str | int], dataset: dict[str, str | float], optimizer: dict[str, str | float], loss_functions: dict[str, str | float], evaluation_metrics: dict[str, str], identifier: str, local_rank: int, timestamp_id: str | None = None, run_dir: str | None = None, is_debug: bool = False, print_every: int = 100, seed: int | None = None, logger: str = 'wandb', amp: bool = False, cpu: bool = False, name: str = 'ocp', slurm: dict | None = None, gp_gpus: int | None = None, inference_only: bool = False)#

Bases: fairchem.core.trainers.base_trainer.BaseTrainer

Trainer class for the Structure to Energy & Force (S2EF) and Initial State to Relaxed State (IS2RS) tasks.

Note

Examples of configurations for task, model, dataset and optimizer can be found in configs/ocp_s2ef and configs/ocp_is2rs.

Parameters:
  • task (dict) – Task configuration.

  • model (dict) – Model configuration.

  • outputs (dict) – Dictionary of model output configuration.

  • dataset (dict) – Dataset configuration. The dataset needs to be a SinglePointLMDB dataset.

  • optimizer (dict) – Optimizer configuration.

  • loss_functions (dict) – Loss function configuration.

  • evaluation_metrics (dict) – Evaluation metrics configuration.

  • identifier (str) – Experiment identifier that is appended to log directory.

  • run_dir (str, optional) – Path to the run directory where logs are to be saved. (default: None)

  • timestamp_id (str, optional) – timestamp identifier.

  • run_dir – Run directory used to save checkpoints and results.

  • is_debug (bool, optional) – Run in debug mode. (default: False)

  • print_every (int, optional) – Frequency of printing logs. (default: 100)

  • seed (int, optional) – Random number seed. (default: None)

  • logger (str, optional) – Type of logger to be used. (default: wandb)

  • local_rank (int, optional) – Local rank of the process, only applicable for distributed training. (default: 0)

  • amp (bool, optional) – Run using automatic mixed precision. (default: False)

  • cpu (bool) – If True will run on CPU. Default is False, will attempt to use cuda.

  • name (str) – Trainer name.

  • slurm (dict) – Slurm configuration. Currently just for keeping track. (default: {})

  • gp_gpus (int, optional) – Number of graph parallel GPUs.

  • inference_only (bool) – If true trainer will be loaded for inference only. (ie datasets, optimizer, schedular, etc, will not be instantiated)

train(disable_eval_tqdm: bool = False) None#

Run model training iterations.

_denorm_preds(target_key: str, prediction: torch.Tensor, batch: torch_geometric.data.Batch)#

Convert model output from a batch into raw prediction by denormalizing and adding references

_forward(batch)#
_compute_loss(out, batch) torch.Tensor#
_compute_metrics(out, batch, evaluator, metrics=None)#
predict(data_loader, per_image: bool = True, results_file: str | None = None, disable_tqdm: bool = False)#
run_relaxations()#