Fine-tuning#

This repo provides a number of scripts to quickly fine-tune a model using a custom ASE LMDB dataset. These scripts are merely for convenience and finetuning uses the exact same tooling and infra as our standard training (See Training section). Training in the fairchem repo uses the fairchem cli tool and configs are in Hydra yaml format. Training dataset must be in the ASE-lmdb format. For UMA models, we provide a simple script to help generate ASE-lmdb datasets from a variety of input formats as such (cifs, traj, extxyz etc) as well as a finetuning yaml config that can be directly used for finetuning.

Generating training/fine-tuning datasets#

First we need to generate a dataset in the aselmdb format for finetuning. The only requirement is you need to have input files that can be read as ASE atoms object by the ase.io.read routine and that they contain energy (forces, stress) in the correct format. For concrete examples refer to this to the test at tests/core/scripts/test_create_finetune_dataset.py.

First you should checkout the fairchem repo and install it to access the scripts

git clone git@github.com:facebookresearch/fairchem.git

pip install -e fairchem/src/packages/fairchem-core[dev]

Run this script to create the aselmdbs as well as a set of templated yamls for finetuning, we will use a few dummy structures for demonstration purposes

import os
os.chdir('../../../../fairchem')
! python src/fairchem/core/scripts/create_uma_finetune_dataset.py --train-dir docs/core/common_tasks/finetune_assets/train/ --val-dir docs/core/common_tasks/finetune_assets/val --output-dir /tmp/bulk --uma-task=omat --regression-task e
  0%|                                                     | 0/1 [00:00<?, ?it/s]



0it [00:00, ?it/s]
0it [00:00, ?it/s]



0it [00:00, ?it/s]

0it [00:00, ?it/s]

  0%|                                                     | 0/1 [00:00<?, ?it/s]






0it [00:00, ?it/s]
0it [00:00, ?it/s]






0it [00:00, ?it/s]
0it [00:00, ?it/s]








0it [00:00, ?it/s]
0it [00:00, ?it/s]





0it [00:00, ?it/s]
0it [00:00, ?it/s]

100%|████████████████████████████████████████████| 1/1 [00:00<00:00, 115.17it/s]

100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 88.32it/s]
Computing normalizer values.:   0%|                       | 0/2 [00:00<?, ?it/s]
Computing normalizer values.: 100%|██████████████| 2/2 [00:00<00:00, 912.40it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]



0it [00:00, ?it/s]





0it [00:00, ?it/s]
0it [00:00, ?it/s]








0it [00:00, ?it/s]
0it [00:00, ?it/s]

0it [00:00, ?it/s]







0it [00:00, ?it/s]




0it [00:00, ?it/s]
0it [00:00, ?it/s]



0it [00:00, ?it/s]
0it [00:00, ?it/s]

0it [00:00, ?it/s]

100%|████████████████████████████████████████████| 1/1 [00:00<00:00, 109.87it/s]

100%|████████████████████████████████████████████| 1/1 [00:00<00:00, 115.10it/s]
INFO:root:Generated dataset and data config yaml in /tmp/bulk
INFO:root:To run finetuning, run fairchem -c /tmp/bulk/uma_sm_finetune_template.yaml
  • The uma-task can be one of the uma tasks: ie: omol, odac, oc20, omat, omc. While UMA was trained in the multi-task fashion, we ONLY support finetuning on a single UMA task at a time. Multi-task training can become very complicated! Feel free to contact us on github if you have a special use-case for multi-task finetuning or refer to the training configs in /training_release to mimic the original UMA training configs.

  • The regression-task can be one of e, ef, efs (energy, energy+force, energy+force+stress), depending on the data you have available in the ASE db. For example, some aperiodic DFT codes only support energy/forces and not gradients, and some very fancy codes like QMC only produce energies. Note that even if you train on just energy or energy/forces, all gradients (forces/stresses) will be computable via the model gradients.

This will generate a folder of lmdbs and the a uma_sm_finetune_template.yaml that you can run directly with the fairchem cli to start training.

If you want to only create the aselmdbs, you can use src/fairchem/core/scripts/create_finetune_dataset.py which is called by create_uma_finetune_dataset.py.

Model fine-tuning (default settings)#

The previous step should have generated some yaml files to get you started on finetuning. You can simply run this with the fairchem cli. The default is configured to run locally on a 1 GPU.

! fairchem -c /tmp/bulk/uma_sm_finetune_template.yaml

Advanced configuration#

The scripts provide a simple way to get started on finetuning, but likely for your own use cases you will need to modify the parameters. The configuration uses hydra-style yamls.

To modify the generated yamls, you can either edit the files directly or use hydra override notation. For example, changing a few parameters is very simple to do on the command line

! fairchem -c /tmp/bulk/uma_sm_finetune_template.yaml epochs=2 lr=2e-4 job.run_dir=/tmp/finetune_dir +job.timestamp_id=some_id
INFO:root:saved canonical config to /tmp/finetune_dir/some_id/canonical_config.yaml
INFO:root:Running fairchemv2 cli with {'job': {'run_name': 'uma_finetune', 'timestamp_id': 'some_id', 'run_dir': '/tmp/finetune_dir', 'device_type': <DeviceType.CUDA: 'cuda'>, 'debug': True, 'scheduler': {'mode': <SchedulerType.LOCAL: 'local'>, 'distributed_init_method': <DistributedInitMethod.TCP: 'tcp'>, 'ranks_per_node': 1, 'num_nodes': 1, 'num_array_jobs': 1, 'slurm': {'mem_gb': 80, 'timeout_hr': 168, 'cpus_per_task': 8, 'partition': None, 'qos': None, 'account': None}}, 'logger': {'_target_': 'fairchem.core.common.logger.WandBSingletonLogger.init_wandb', '_partial_': True, 'entity': 'example', 'project': 'uma_finetune'}, 'seed': 0, 'deterministic': False, 'runner_state_path': None, 'metadata': {'commit': 'core:a596c961,experimental:NA', 'log_dir': '/tmp/finetune_dir/some_id/logs', 'checkpoint_dir': '/tmp/finetune_dir/some_id/checkpoints', 'results_dir': '/tmp/finetune_dir/some_id/results', 'config_path': '/tmp/finetune_dir/some_id/canonical_config.yaml', 'preemption_checkpoint_dir': '/tmp/finetune_dir/some_id/checkpoints/preemption_state', 'cluster_name': 'local-node', 'array_job_num': 0, 'slurm_env': {'job_id': None, 'raw_job_id': None, 'array_job_id': None, 'array_task_id': None, 'restart_count': None}}, 'graph_parallel_group_size': None}, 'runner': {'_target_': 'fairchem.core.components.train.train_runner.TrainEvalRunner', 'train_dataloader': {'_target_': 'fairchem.core.components.common.dataloader_builder.get_dataloader', 'dataset': {'_target_': 'fairchem.core.datasets.mt_concat_dataset.create_concat_dataset', 'dataset_configs': {'omat': {'splits': {'train': {'src': '/tmp/bulk/train'}}, 'format': 'ase_db', 'transforms': {'common_transform': {'dataset_name': 'omat'}, 'stress_reshape_transform': {'dataset_name': 'omat'}}}}, 'combined_dataset_config': {'sampling': {'type': 'temperature', 'temperature': 1.0}}}, 'batch_sampler_fn': {'_target_': 'fairchem.core.common.data_parallel.BalancedBatchSampler', '_partial_': True, 'batch_size': 2, 'shuffle': True, 'seed': 0}, 'num_workers': 0, 'collate_fn': {'_target_': 'fairchem.core.units.mlip_unit.mlip_unit.mt_collater_adapter', 'tasks': [{'_target_': 'fairchem.core.units.mlip_unit.mlip_unit.Task', 'name': 'energy', 'level': 'system', 'property': 'energy', 'loss_fn': {'_target_': 'fairchem.core.modules.loss.DDPMTLoss', 'loss_fn': {'_target_': 'fairchem.core.modules.loss.PerAtomMAELoss'}, 'coefficient': 20}, 'out_spec': {'dim': [1], 'dtype': 'float32'}, 'normalizer': {'_target_': 'fairchem.core.modules.normalization.normalizer.Normalizer', 'mean': 0.0, 'rmsd': 1.0}, 'element_references': {'_target_': 'fairchem.core.modules.normalization.element_references.ElementReferences', 'element_references': {'_target_': 'torch.DoubleTensor', '_args_': [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -2.082805460035017, 0.0, 0.0, 0.0, -3.256404920200298, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]}}, 'datasets': ['omat'], 'metrics': ['mae', 'per_atom_mae']}, {'_target_': 'fairchem.core.units.mlip_unit.mlip_unit.Task', 'name': 'forces', 'level': 'atom', 'property': 'forces', 'out_spec': {'dim': [3], 'dtype': 'float32'}, 'normalizer': {'_target_': 'fairchem.core.modules.normalization.normalizer.Normalizer', 'mean': 0.0, 'rmsd': 1.0}, 'datasets': ['omat'], 'inference_only': True}, {'_target_': 'fairchem.core.units.mlip_unit.mlip_unit.Task', 'name': 'stress', 'level': 'system', 'property': 'stress', 'out_spec': {'dim': [1, 9], 'dtype': 'float32'}, 'normalizer': {'_target_': 'fairchem.core.modules.normalization.normalizer.Normalizer', 'mean': 0.0, 'rmsd': 1.0}, 'datasets': ['omat'], 'inference_only': True}]}}, 'eval_dataloader': {'_target_': 'fairchem.core.components.common.dataloader_builder.get_dataloader', 'dataset': {'_target_': 'fairchem.core.datasets.mt_concat_dataset.create_concat_dataset', 'dataset_configs': {'omat': {'splits': {'val': {'src': '/tmp/bulk/val'}}, 'format': 'ase_db', 'transforms': {'common_transform': {'dataset_name': 'omat'}, 'stress_reshape_transform': {'dataset_name': 'omat'}}}}, 'combined_dataset_config': {'sampling': {'type': 'temperature', 'temperature': 1.0}}}, 'batch_sampler_fn': {'_target_': 'fairchem.core.common.data_parallel.BalancedBatchSampler', '_partial_': True, 'batch_size': 2, 'shuffle': False, 'seed': 0}, 'num_workers': 0, 'collate_fn': {'_target_': 'fairchem.core.units.mlip_unit.mlip_unit.mt_collater_adapter', 'tasks': [{'_target_': 'fairchem.core.units.mlip_unit.mlip_unit.Task', 'name': 'energy', 'level': 'system', 'property': 'energy', 'loss_fn': {'_target_': 'fairchem.core.modules.loss.DDPMTLoss', 'loss_fn': {'_target_': 'fairchem.core.modules.loss.PerAtomMAELoss'}, 'coefficient': 20}, 'out_spec': {'dim': [1], 'dtype': 'float32'}, 'normalizer': {'_target_': 'fairchem.core.modules.normalization.normalizer.Normalizer', 'mean': 0.0, 'rmsd': 1.0}, 'element_references': {'_target_': 'fairchem.core.modules.normalization.element_references.ElementReferences', 'element_references': {'_target_': 'torch.DoubleTensor', '_args_': [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -2.082805460035017, 0.0, 0.0, 0.0, -3.256404920200298, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]}}, 'datasets': ['omat'], 'metrics': ['mae', 'per_atom_mae']}, {'_target_': 'fairchem.core.units.mlip_unit.mlip_unit.Task', 'name': 'forces', 'level': 'atom', 'property': 'forces', 'out_spec': {'dim': [3], 'dtype': 'float32'}, 'normalizer': {'_target_': 'fairchem.core.modules.normalization.normalizer.Normalizer', 'mean': 0.0, 'rmsd': 1.0}, 'datasets': ['omat'], 'inference_only': True}, {'_target_': 'fairchem.core.units.mlip_unit.mlip_unit.Task', 'name': 'stress', 'level': 'system', 'property': 'stress', 'out_spec': {'dim': [1, 9], 'dtype': 'float32'}, 'normalizer': {'_target_': 'fairchem.core.modules.normalization.normalizer.Normalizer', 'mean': 0.0, 'rmsd': 1.0}, 'datasets': ['omat'], 'inference_only': True}]}}, 'train_eval_unit': {'_target_': 'fairchem.core.units.mlip_unit.mlip_unit.MLIPTrainEvalUnit', 'job_config': {'run_name': 'uma_finetune', 'timestamp_id': 'some_id', 'run_dir': '/tmp/finetune_dir', 'device_type': <DeviceType.CUDA: 'cuda'>, 'debug': True, 'scheduler': {'mode': <SchedulerType.LOCAL: 'local'>, 'distributed_init_method': <DistributedInitMethod.TCP: 'tcp'>, 'ranks_per_node': 1, 'num_nodes': 1, 'num_array_jobs': 1, 'slurm': {'mem_gb': 80, 'timeout_hr': 168, 'cpus_per_task': 8, 'partition': None, 'qos': None, 'account': None}}, 'logger': {'_target_': 'fairchem.core.common.logger.WandBSingletonLogger.init_wandb', '_partial_': True, 'entity': 'example', 'project': 'uma_finetune'}, 'seed': 0, 'deterministic': False, 'runner_state_path': None, 'metadata': {'commit': 'core:a596c961,experimental:NA', 'log_dir': '/tmp/finetune_dir/some_id/logs', 'checkpoint_dir': '/tmp/finetune_dir/some_id/checkpoints', 'results_dir': '/tmp/finetune_dir/some_id/results', 'config_path': '/tmp/finetune_dir/some_id/canonical_config.yaml', 'preemption_checkpoint_dir': '/tmp/finetune_dir/some_id/checkpoints/preemption_state', 'cluster_name': 'local-node', 'array_job_num': 0, 'slurm_env': {'job_id': None, 'raw_job_id': None, 'array_job_id': None, 'array_task_id': None, 'restart_count': None}}, 'graph_parallel_group_size': None}, 'tasks': [{'_target_': 'fairchem.core.units.mlip_unit.mlip_unit.Task', 'name': 'energy', 'level': 'system', 'property': 'energy', 'loss_fn': {'_target_': 'fairchem.core.modules.loss.DDPMTLoss', 'loss_fn': {'_target_': 'fairchem.core.modules.loss.PerAtomMAELoss'}, 'coefficient': 20}, 'out_spec': {'dim': [1], 'dtype': 'float32'}, 'normalizer': {'_target_': 'fairchem.core.modules.normalization.normalizer.Normalizer', 'mean': 0.0, 'rmsd': 1.0}, 'element_references': {'_target_': 'fairchem.core.modules.normalization.element_references.ElementReferences', 'element_references': {'_target_': 'torch.DoubleTensor', '_args_': [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -2.082805460035017, 0.0, 0.0, 0.0, -3.256404920200298, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]}}, 'datasets': ['omat'], 'metrics': ['mae', 'per_atom_mae']}, {'_target_': 'fairchem.core.units.mlip_unit.mlip_unit.Task', 'name': 'forces', 'level': 'atom', 'property': 'forces', 'out_spec': {'dim': [3], 'dtype': 'float32'}, 'normalizer': {'_target_': 'fairchem.core.modules.normalization.normalizer.Normalizer', 'mean': 0.0, 'rmsd': 1.0}, 'datasets': ['omat'], 'inference_only': True}, {'_target_': 'fairchem.core.units.mlip_unit.mlip_unit.Task', 'name': 'stress', 'level': 'system', 'property': 'stress', 'out_spec': {'dim': [1, 9], 'dtype': 'float32'}, 'normalizer': {'_target_': 'fairchem.core.modules.normalization.normalizer.Normalizer', 'mean': 0.0, 'rmsd': 1.0}, 'datasets': ['omat'], 'inference_only': True}], 'model': {'_target_': 'fairchem.core.units.mlip_unit.mlip_unit.initialize_finetuning_model', 'checkpoint_location': {'_target_': 'fairchem.core.calculate.pretrained_mlip.pretrained_checkpoint_path_from_name', 'model_name': 'uma-s-1'}, 'overrides': {'backbone': {'otf_graph': True, 'max_neighbors': 300, 'regress_stress': True, 'always_use_pbc': False}, 'pass_through_head_outputs': True}, 'heads': {'efs': {'module': 'fairchem.core.models.uma.escn_md.MLP_EFS_Head'}}}, 'optimizer_fn': {'_target_': 'torch.optim.AdamW', '_partial_': True, 'lr': 0.0002, 'weight_decay': 0.001}, 'cosine_lr_scheduler_fn': {'_target_': 'fairchem.core.units.mlip_unit.mlip_unit._get_consine_lr_scheduler', '_partial_': True, 'warmup_factor': 0.2, 'warmup_epochs': 0.01, 'lr_min_factor': 0.01, 'epochs': 2, 'steps': None}, 'print_every': 10, 'clip_grad_norm': 100}, 'max_epochs': 2, 'max_steps': None, 'evaluate_every_n_steps': 100, 'callbacks': [{'_target_': 'fairchem.core.components.train.train_runner.TrainCheckpointCallback', 'checkpoint_every_n_steps': 1000, 'max_saved_checkpoints': 5}, {'_target_': 'torchtnt.framework.callbacks.TQDMProgressBar'}]}}
INFO:root:Running in local mode without elastic launch
INFO:root:Setting env PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
INFO:root:Setting up distributed backend...
INFO:root:Calling runner.run() ...
INFO:root:get_dataloader::Calling batch_sampler_fn=functools.partial(<class 'fairchem.core.common.data_parallel.BalancedBatchSampler'>, batch_size=2, shuffle=True, seed=0)...
WARNING:root:Disabled BalancedBatchSampler because num_replicas=1.
INFO:root:rank: 0: Sampler created...
INFO:root:Created BalancedBatchSampler with sampler=<fairchem.core.common.data_parallel.StatefulDistributedSampler object at 0x72a5d6fa8620>, batch_size=2, drop_last=False
INFO:root:get_dataloader::Calling Dataloader...
INFO:root:get_dataloader::Done!
INFO:root:get_dataloader::Calling batch_sampler_fn=functools.partial(<class 'fairchem.core.common.data_parallel.BalancedBatchSampler'>, batch_size=2, shuffle=False, seed=0)...
WARNING:root:Disabled BalancedBatchSampler because num_replicas=1.
INFO:root:rank: 0: Sampler created...
INFO:root:Created BalancedBatchSampler with sampler=<fairchem.core.common.data_parallel.StatefulDistributedSampler object at 0x72a5d7207020>, batch_size=2, drop_last=False
INFO:root:get_dataloader::Calling Dataloader...
INFO:root:get_dataloader::Done!
checkpoints/uma-s-1.pt:   0%|                       | 0.00/1.17G [00:00<?, ?B/s]
checkpoints/uma-s-1.pt:   0%|                | 615k/1.17G [00:00<26:23, 741kB/s]
checkpoints/uma-s-1.pt:   6%|▊             | 67.7M/1.17G [00:01<00:19, 55.3MB/s]
checkpoints/uma-s-1.pt:  17%|██▊             | 202M/1.17G [00:01<00:05, 175MB/s]
checkpoints/uma-s-1.pt:  31%|█████           | 370M/1.17G [00:01<00:02, 339MB/s]
checkpoints/uma-s-1.pt:  43%|██████▊         | 503M/1.17G [00:01<00:01, 458MB/s]
checkpoints/uma-s-1.pt:  54%|████████▋       | 637M/1.17G [00:02<00:00, 591MB/s]
checkpoints/uma-s-1.pt:  66%|██████████▌     | 771M/1.17G [00:02<00:00, 718MB/s]
checkpoints/uma-s-1.pt:  77%|████████████▎   | 905M/1.17G [00:02<00:00, 828MB/s]
checkpoints/uma-s-1.pt:  89%|█████████████▎ | 1.04G/1.17G [00:02<00:00, 937MB/s]
checkpoints/uma-s-1.pt: 100%|███████████████| 1.17G/1.17G [00:02<00:00, 477MB/s]
WARNING:root:initialize_finetuning_model starting from checkpoint_location: /home/runner/.cache/fairchem/models--facebook--UMA/snapshots/abaa274e3612b2cfcc5be2d900ffa2a03cb42ee7/checkpoints/uma-s-1.pt
INFO:root:Train Dataloader size 1
INFO:root:Eval Dataloader size 1
INFO:root:No existing checkpoints found, starting from scratch
INFO:torchtnt.framework.fit:Started fit with max_epochs=2 max_steps=None max_train_steps_per_epoch=None max_eval_steps_per_epoch=None evaluate_every_n_steps=100 evaluate_every_n_epochs=1 
INFO:torchtnt.framework.train:Started train with max_epochs=2, max_steps=None, max_steps_per_epoch=None
INFO:root:on_train_start: setting sampler state to 0, 0
INFO:root:at beginning of epoch 0, setting sampler start step to 0
INFO:torchtnt.framework.train:Started train epoch
INFO:root:at beginning of epoch 0, setting sampler start step to 0

Train Epoch 0:   0%|                                     | 0/1 [00:00<?, ?it/s]
INFO:root:Saved dcp checkpoint to /tmp/finetune_dir/some_id/checkpoints/step_0
INFO:root:0: Expert variance: 3.60e-07,1.19e-02,7.17e-07,3.88e-07,4.74e-04,7.71e-07,4.15e-07,1.53e-02,8.34e-07,2.89e-05,1.65e-02,7.07e-07,7.77e-07,1.01e-06,2.53e-07,2.40e-07,3.66e-07,5.95e-04,1.57e-05,9.03e-04,1.30e-06,1.34e-04,4.34e-07,5.28e-07,9.35e-07,3.97e-07,9.43e-07,5.13e-05,4.00e-03,6.28e-07,6.77e-06,4.84e-05
INFO:root:0: Expert mean: 6.88e-03,1.02e-01,6.92e-03,6.72e-03,2.17e-02,6.83e-03,6.69e-03,1.28e-01,7.05e-03,1.22e-02,9.99e-02,7.17e-03,6.82e-03,7.00e-03,6.73e-03,6.83e-03,7.01e-03,3.28e-01,1.25e-02,5.07e-02,6.99e-03,1.45e-02,6.66e-03,7.03e-03,6.87e-03,7.00e-03,7.11e-03,1.17e-02,2.01e-01,6.74e-03,8.92e-03,3.75e-02
/home/runner/work/_tool/Python/3.12.11/x64/lib/python3.12/site-packages/torch/optim/lr_scheduler.py:384: UserWarning: To get the last learning rate computed by the scheduler, please use `get_last_lr()`.
  _warn_get_lr_called_within_step(self)
INFO:root:{'train/loss': 1.2937601246535735, 'train/lr': 4e-05, 'train/step': 0, 'train/epoch': 0.0, 'train/samples_per_second(approx)': 0.06165631355940083, 'train/atoms_per_second(approx)': 1.47975152542562, 'train/num_atoms_on_rank': 48, 'train/num_samples_on_rank': 2}

Train Epoch 0: 100%|█████████████████████████████| 1/1 [00:32<00:00, 32.48s/it]
INFO:torchtnt.framework.train:Reached end of train dataloader

Train Epoch 0: 100%|█████████████████████████████| 1/1 [00:32<00:00, 32.48s/it]

INFO:torchtnt.framework.evaluate:Started evaluate with max_steps_per_epoch=None

Eval Epoch 0:   0%|                                      | 0/1 [00:00<?, ?it/s]
Eval Epoch 0: 100%|██████████████████████████████| 1/1 [00:00<00:00,  8.88it/s]
INFO:root:Done eval epoch, aggregating metrics
INFO:root:Finished aggregating metrics: 
  val/atoms_per_second: 611.0012
  val/epoch: 0.0000
  val/loss: 61.0236
  val/omat.val,energy,mae: 97.6377
  val/omat.val,energy,per_atom_mae: 3.0512


Eval Epoch 0: 100%|██████████████████████████████| 1/1 [00:00<00:00,  8.82it/s]

INFO:torchtnt.framework.train:Ended train epoch
INFO:torchtnt.framework.train:After train epoch, train progress: num_epochs_completed = 1, num_steps_completed = 1
INFO:torchtnt.framework.train:Started train epoch
INFO:root:at beginning of epoch 1, setting sampler start step to 0

Train Epoch 1:   0%|                                     | 0/1 [00:00<?, ?it/s]
Train Epoch 1: 100%|█████████████████████████████| 1/1 [00:00<00:00,  4.23it/s]
INFO:torchtnt.framework.train:Reached end of train dataloader

Train Epoch 1: 100%|█████████████████████████████| 1/1 [00:00<00:00,  4.22it/s]

INFO:torchtnt.framework.evaluate:Started evaluate with max_steps_per_epoch=None

Eval Epoch 1:   0%|                                      | 0/1 [00:00<?, ?it/s]
Eval Epoch 1: 100%|██████████████████████████████| 1/1 [00:00<00:00,  9.23it/s]
INFO:root:Done eval epoch, aggregating metrics
INFO:root:Finished aggregating metrics: 
  val/atoms_per_second: 632.5744
  val/epoch: 0.0000
  val/loss: 61.0235
  val/omat.val,energy,mae: 97.6376
  val/omat.val,energy,per_atom_mae: 3.0512


Eval Epoch 1: 100%|██████████████████████████████| 1/1 [00:00<00:00,  9.17it/s]

INFO:torchtnt.framework.train:Ended train epoch
INFO:torchtnt.framework.train:After train epoch, train progress: num_epochs_completed = 2, num_steps_completed = 2
INFO:root:Training Completed 2 steps
INFO:root:Saved dcp checkpoint to /tmp/finetune_dir/some_id/checkpoints/final
INFO:torchtnt.framework.fit:Finished fit

The basic yaml configuration looks like the following:

job:
  device_type: CUDA
  scheduler:
    mode: LOCAL
    ranks_per_node: 1
    num_nodes: 1
  debug: True
  run_dir: /tmp/uma_finetune_runs/
  run_name: uma_finetune
  logger:
    _target_: fairchem.core.common.logger.WandBSingletonLogger.init_wandb
    _partial_: true
    entity: example
    project: uma_finetune


base_model_name: uma-s-1p1
max_neighbors: 300
epochs: 1
steps: null
batch_size: 2
lr: 4e-4

train_dataloader ...
eval_dataloader ...
runner ...
  • base_model_name: refers to a model name that can be retrieved from huggingface. If you want to use your custom uma checkpoint. You need to provide the path directly in the runner:

    model:
      _target_: fairchem.core.units.mlip_unit.mlip_unit.initialize_finetuning_model
      checkpoint_location: /path/to/your/checkpoint.pt
  • max_neighbors: the number of neighbors used for the equivariant SO2 convolutions. 300 is the default used in uma training but if you don’t have alot of memory, 100 is usually fine to ensure smoothness of the potential (see the ESEN paper).

  • epochs, steps: choose to either run for integer number of epochs or steps, only 1 can be specified, the other must be null

  • batch_size: in this configuration we use the batch sampler, you can start with choosing the largest batch size that can fit on your system without running out of memory. However, you don’t want to use a batch size so large such that you complete training in very few training steps. The optimal batch size is usually the one that minimizes the final validation loss for a fixed compute budget.

  • lr, weight_decay: these are standard learning parameters, the recommended values we use are the defaults

Logging and Artifacts#

For logging and checkpoints, all artifacts are stored in the location specified in job.run_dir. The visual logger we support is Weights and Biases. Tensorboard is no longer supported. You must set up your W&B account separately and job.debug must be set to False for W&B logging to work.

Distributed training#

We support multi-gpu distributed training without additional infra and multi-node distributed training on SLURM only.

To train with multi-gpu locally, simply set job.scheduler.ranks_per_node=N where N is the number of GPUs you like to train on.

To train with multi-node on an SLURM cluster, you need to change job.scheduler.mode=SLURM and set both job.scheduler.ranks_per_node and job.scheduler.num_nodes to the desired values. The run_dir must be in a shared network accessible mount for this to work.

Resuming runs#

To resume from a checkpoint in the middle of a run, find the checkpoint folder at the step you want and use the same fairchem command, eg:

! fairchem -c /tmp/finetune_dir/some_id/checkpoints/final/resume.yaml

Running inference on the finetuned model#

Inference is run in the same way as the UMA models, except you need to load the checkpoint from a local path. You must also use the same task that you used for finetuning:

from fairchem.core.units.mlip_unit import load_predict_unit
from fairchem.core import FAIRChemCalculator

predictor = load_predict_unit("/tmp/finetune_dir/some_id/checkpoints/final/inference_ckpt.pt")
calc = FAIRChemCalculator(predictor, task_name="omat")
WARNING:root:device was not explicitly set, using device='cuda'.