Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

Fine-tuning

This repo provides a number of scripts to quickly fine-tune a model using a custom ASE LMDB dataset. These scripts are merely for convenience and fine-tuning uses the exact same tooling and infrastructure as our standard training (see Training section). Training in the fairchem repo uses the fairchem CLI tool and configs are in Hydra yaml format.

Generating Training/Fine-tuning Datasets

First we need to generate a dataset in the aselmdb format for fine-tuning.

First you should checkout the fairchem repo and install it to access the scripts

git clone git@github.com:facebookresearch/fairchem.git

pip install -e fairchem/src/packages/fairchem-core[dev]

Run this script to create the aselmdbs as well as a set of templated yamls for finetuning, we will use a few dummy structures for demonstration purposes

import os
os.chdir('../../../../fairchem')
! python src/fairchem/core/scripts/create_uma_finetune_dataset.py --train-dir docs/core/common_tasks/finetune_assets/train/ --val-dir docs/core/common_tasks/finetune_assets/val --output-dir /tmp/bulk --uma-task=omat --regression-task e
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]

0it [00:00, ?it/s]



0it [00:00, ?it/s]




0it [00:00, ?it/s]






0it [00:00, ?it/s]





0it [00:00, ?it/s]







0it [00:00, ?it/s]
0it [00:00, ?it/s]
100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 26.64it/s]
100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 22.38it/s]
Computing normalizer values.: 100%|██████████████| 2/2 [00:00<00:00, 343.57it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]


0it [00:00, ?it/s]

0it [00:00, ?it/s]




0it [00:00, ?it/s]






0it [00:00, ?it/s]







0it [00:00, ?it/s]
0it [00:00, ?it/s]





0it [00:00, ?it/s]
100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 32.46it/s]
100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 25.70it/s]
INFO:root:Generated dataset and data config yaml in /tmp/bulk
INFO:root:To run finetuning, run fairchem -c /tmp/bulk/uma_sm_finetune_template.yaml

This will generate a folder of LMDBs and a uma_sm_finetune_template.yaml that you can run directly with the fairchem CLI to start training.

Model Fine-tuning (Default Settings)

The previous step should have generated some YAML files to get you started on fine-tuning. You can simply run this with the fairchem CLI. The default is configured to run locally on 1 GPU.

! fairchem -c /tmp/bulk/uma_sm_finetune_template.yaml

Advanced Configuration

The scripts provide a simple way to get started on fine-tuning, but likely for your own use cases you will need to modify the parameters. The configuration uses Hydra-style YAMLs.

! fairchem -c /tmp/bulk/uma_sm_finetune_template.yaml epochs=2 lr=2e-4 job.run_dir=/tmp/finetune_dir +job.timestamp_id=some_id
INFO:root:saved canonical config to /tmp/finetune_dir/some_id/canonical_config.yaml
INFO:root:Running fairchemv2 cli with {'job': {'run_name': 'uma_finetune', 'timestamp_id': 'some_id', 'run_dir': '/tmp/finetune_dir', 'device_type': <DeviceType.CUDA: 'cuda'>, 'debug': True, 'scheduler': {'mode': <SchedulerType.LOCAL: 'local'>, 'distributed_init_method': <DistributedInitMethod.TCP: 'tcp'>, 'ranks_per_node': 1, 'num_nodes': 1, 'num_array_jobs': 1, 'slurm': {'mem_gb': 80, 'timeout_hr': 168, 'cpus_per_task': 8, 'partition': None, 'qos': None, 'account': None, 'additional_parameters': None}, 'use_ray': False, 'ray_cluster': {'head_gpus': 0}}, 'logger': {'_target_': 'fairchem.core.common.logger.WandBSingletonLogger.init_wandb', '_partial_': True, 'entity': 'example', 'project': 'uma_finetune'}, 'seed': 0, 'deterministic': False, 'runner_state_path': None, 'metadata': {'commit': 'core:None,experimental:NA', 'log_dir': '/tmp/finetune_dir/some_id/logs', 'checkpoint_dir': '/tmp/finetune_dir/some_id/checkpoints', 'results_dir': '/tmp/finetune_dir/some_id/results', 'config_path': '/tmp/finetune_dir/some_id/canonical_config.yaml', 'preemption_checkpoint_dir': '/tmp/finetune_dir/some_id/checkpoints/preemption_state', 'cluster_name': 'github', 'array_job_num': 0, 'slurm_env': {'job_id': None, 'raw_job_id': None, 'array_job_id': None, 'array_task_id': None, 'restart_count': None}}, 'graph_parallel_group_size': None, 'recursive_instantiate_runner': True}, 'runner': {'_target_': 'fairchem.core.components.train.train_runner.TrainEvalRunner', 'train_dataloader': {'_target_': 'fairchem.core.components.common.dataloader_builder.get_dataloader', 'dataset': {'_target_': 'fairchem.core.datasets.mt_concat_dataset.create_concat_dataset', 'dataset_configs': {'omat': {'splits': {'train': {'src': '/tmp/bulk/train'}}, 'format': 'ase_db', 'transforms': {'common_transform': {'dataset_name': 'omat'}, 'stress_reshape_transform': {'dataset_name': 'omat'}}}}, 'combined_dataset_config': {'sampling': {'type': 'temperature', 'temperature': 1.0}}}, 'batch_sampler_fn': {'_target_': 'fairchem.core.common.data_parallel.BalancedBatchSampler', '_partial_': True, 'batch_size': 2, 'shuffle': True, 'seed': 0}, 'num_workers': 0, 'collate_fn': {'_target_': 'fairchem.core.units.mlip_unit.mlip_unit.mt_collater_adapter', 'tasks': [{'_target_': 'fairchem.core.units.mlip_unit.mlip_unit.Task', 'name': 'energy', 'level': 'system', 'property': 'energy', 'loss_fn': {'_target_': 'fairchem.core.modules.loss.DDPMTLoss', 'loss_fn': {'_target_': 'fairchem.core.modules.loss.PerAtomMAELoss'}, 'coefficient': 20}, 'out_spec': {'dim': [1], 'dtype': 'float32'}, 'normalizer': {'_target_': 'fairchem.core.modules.normalization.normalizer.Normalizer', 'mean': 0.0, 'rmsd': 1.0}, 'element_references': {'_target_': 'fairchem.core.modules.normalization.element_references.ElementReferences', 'element_references': {'_target_': 'torch.DoubleTensor', '_args_': [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -2.082805460035017, 0.0, 0.0, 0.0, -3.256404920200298, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]}}, 'datasets': ['omat'], 'metrics': ['mae', 'per_atom_mae']}, {'_target_': 'fairchem.core.units.mlip_unit.mlip_unit.Task', 'name': 'forces', 'level': 'atom', 'property': 'forces', 'out_spec': {'dim': [3], 'dtype': 'float32'}, 'normalizer': {'_target_': 'fairchem.core.modules.normalization.normalizer.Normalizer', 'mean': 0.0, 'rmsd': 1.0}, 'datasets': ['omat'], 'inference_only': True}, {'_target_': 'fairchem.core.units.mlip_unit.mlip_unit.Task', 'name': 'stress', 'level': 'system', 'property': 'stress', 'out_spec': {'dim': [1, 9], 'dtype': 'float32'}, 'normalizer': {'_target_': 'fairchem.core.modules.normalization.normalizer.Normalizer', 'mean': 0.0, 'rmsd': 1.0}, 'datasets': ['omat'], 'inference_only': True}]}}, 'eval_dataloader': {'_target_': 'fairchem.core.components.common.dataloader_builder.get_dataloader', 'dataset': {'_target_': 'fairchem.core.datasets.mt_concat_dataset.create_concat_dataset', 'dataset_configs': {'omat': {'splits': {'val': {'src': '/tmp/bulk/val'}}, 'format': 'ase_db', 'transforms': {'common_transform': {'dataset_name': 'omat'}, 'stress_reshape_transform': {'dataset_name': 'omat'}}}}, 'combined_dataset_config': {'sampling': {'type': 'temperature', 'temperature': 1.0}}}, 'batch_sampler_fn': {'_target_': 'fairchem.core.common.data_parallel.BalancedBatchSampler', '_partial_': True, 'batch_size': 2, 'shuffle': False, 'seed': 0}, 'num_workers': 0, 'collate_fn': {'_target_': 'fairchem.core.units.mlip_unit.mlip_unit.mt_collater_adapter', 'tasks': [{'_target_': 'fairchem.core.units.mlip_unit.mlip_unit.Task', 'name': 'energy', 'level': 'system', 'property': 'energy', 'loss_fn': {'_target_': 'fairchem.core.modules.loss.DDPMTLoss', 'loss_fn': {'_target_': 'fairchem.core.modules.loss.PerAtomMAELoss'}, 'coefficient': 20}, 'out_spec': {'dim': [1], 'dtype': 'float32'}, 'normalizer': {'_target_': 'fairchem.core.modules.normalization.normalizer.Normalizer', 'mean': 0.0, 'rmsd': 1.0}, 'element_references': {'_target_': 'fairchem.core.modules.normalization.element_references.ElementReferences', 'element_references': {'_target_': 'torch.DoubleTensor', '_args_': [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -2.082805460035017, 0.0, 0.0, 0.0, -3.256404920200298, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]}}, 'datasets': ['omat'], 'metrics': ['mae', 'per_atom_mae']}, {'_target_': 'fairchem.core.units.mlip_unit.mlip_unit.Task', 'name': 'forces', 'level': 'atom', 'property': 'forces', 'out_spec': {'dim': [3], 'dtype': 'float32'}, 'normalizer': {'_target_': 'fairchem.core.modules.normalization.normalizer.Normalizer', 'mean': 0.0, 'rmsd': 1.0}, 'datasets': ['omat'], 'inference_only': True}, {'_target_': 'fairchem.core.units.mlip_unit.mlip_unit.Task', 'name': 'stress', 'level': 'system', 'property': 'stress', 'out_spec': {'dim': [1, 9], 'dtype': 'float32'}, 'normalizer': {'_target_': 'fairchem.core.modules.normalization.normalizer.Normalizer', 'mean': 0.0, 'rmsd': 1.0}, 'datasets': ['omat'], 'inference_only': True}]}}, 'train_eval_unit': {'_target_': 'fairchem.core.units.mlip_unit.mlip_unit.MLIPTrainEvalUnit', 'job_config': {'run_name': 'uma_finetune', 'timestamp_id': 'some_id', 'run_dir': '/tmp/finetune_dir', 'device_type': <DeviceType.CUDA: 'cuda'>, 'debug': True, 'scheduler': {'mode': <SchedulerType.LOCAL: 'local'>, 'distributed_init_method': <DistributedInitMethod.TCP: 'tcp'>, 'ranks_per_node': 1, 'num_nodes': 1, 'num_array_jobs': 1, 'slurm': {'mem_gb': 80, 'timeout_hr': 168, 'cpus_per_task': 8, 'partition': None, 'qos': None, 'account': None, 'additional_parameters': None}, 'use_ray': False, 'ray_cluster': {'head_gpus': 0}}, 'logger': {'_target_': 'fairchem.core.common.logger.WandBSingletonLogger.init_wandb', '_partial_': True, 'entity': 'example', 'project': 'uma_finetune'}, 'seed': 0, 'deterministic': False, 'runner_state_path': None, 'metadata': {'commit': 'core:None,experimental:NA', 'log_dir': '/tmp/finetune_dir/some_id/logs', 'checkpoint_dir': '/tmp/finetune_dir/some_id/checkpoints', 'results_dir': '/tmp/finetune_dir/some_id/results', 'config_path': '/tmp/finetune_dir/some_id/canonical_config.yaml', 'preemption_checkpoint_dir': '/tmp/finetune_dir/some_id/checkpoints/preemption_state', 'cluster_name': 'github', 'array_job_num': 0, 'slurm_env': {'job_id': None, 'raw_job_id': None, 'array_job_id': None, 'array_task_id': None, 'restart_count': None}}, 'graph_parallel_group_size': None, 'recursive_instantiate_runner': True}, 'tasks': [{'_target_': 'fairchem.core.units.mlip_unit.mlip_unit.Task', 'name': 'energy', 'level': 'system', 'property': 'energy', 'loss_fn': {'_target_': 'fairchem.core.modules.loss.DDPMTLoss', 'loss_fn': {'_target_': 'fairchem.core.modules.loss.PerAtomMAELoss'}, 'coefficient': 20}, 'out_spec': {'dim': [1], 'dtype': 'float32'}, 'normalizer': {'_target_': 'fairchem.core.modules.normalization.normalizer.Normalizer', 'mean': 0.0, 'rmsd': 1.0}, 'element_references': {'_target_': 'fairchem.core.modules.normalization.element_references.ElementReferences', 'element_references': {'_target_': 'torch.DoubleTensor', '_args_': [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -2.082805460035017, 0.0, 0.0, 0.0, -3.256404920200298, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]}}, 'datasets': ['omat'], 'metrics': ['mae', 'per_atom_mae']}, {'_target_': 'fairchem.core.units.mlip_unit.mlip_unit.Task', 'name': 'forces', 'level': 'atom', 'property': 'forces', 'out_spec': {'dim': [3], 'dtype': 'float32'}, 'normalizer': {'_target_': 'fairchem.core.modules.normalization.normalizer.Normalizer', 'mean': 0.0, 'rmsd': 1.0}, 'datasets': ['omat'], 'inference_only': True}, {'_target_': 'fairchem.core.units.mlip_unit.mlip_unit.Task', 'name': 'stress', 'level': 'system', 'property': 'stress', 'out_spec': {'dim': [1, 9], 'dtype': 'float32'}, 'normalizer': {'_target_': 'fairchem.core.modules.normalization.normalizer.Normalizer', 'mean': 0.0, 'rmsd': 1.0}, 'datasets': ['omat'], 'inference_only': True}], 'model': {'_target_': 'fairchem.core.units.mlip_unit.mlip_unit.initialize_finetuning_model', 'checkpoint_location': {'_target_': 'fairchem.core.calculate.pretrained_mlip.pretrained_checkpoint_path_from_name', 'model_name': 'uma-s-1'}, 'overrides': {'backbone': {'otf_graph': True, 'max_neighbors': 300, 'regress_stress': True, 'always_use_pbc': False}, 'pass_through_head_outputs': True}, 'heads': {'efs': {'module': 'fairchem.core.models.uma.escn_md.MLP_EFS_Head'}}}, 'optimizer_fn': {'_target_': 'torch.optim.AdamW', '_partial_': True, 'lr': 0.0002, 'weight_decay': 0.001}, 'cosine_lr_scheduler_fn': {'_target_': 'fairchem.core.units.mlip_unit.mlip_unit._get_consine_lr_scheduler', '_partial_': True, 'warmup_factor': 0.2, 'warmup_epochs': 0.01, 'lr_min_factor': 0.01, 'epochs': 2, 'steps': None}, 'print_every': 10, 'clip_grad_norm': 100}, 'max_epochs': 2, 'max_steps': None, 'evaluate_every_n_steps': 100, 'callbacks': [{'_target_': 'fairchem.core.components.train.train_runner.TrainCheckpointCallback', 'checkpoint_every_n_steps': 1000, 'max_saved_checkpoints': 5}, {'_target_': 'torchtnt.framework.callbacks.TQDMProgressBar'}]}}
INFO:root:Running in local mode with 1 ranks using device_type:cuda
INFO:root:Running in local mode without elastic launch
INFO:root:Setting env PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
INFO:root:Setting up distributed backend...
INFO:root:Calling runner.run() ...
INFO:root:get_dataloader::Calling batch_sampler_fn=functools.partial(<class 'fairchem.core.common.data_parallel.BalancedBatchSampler'>, batch_size=2, shuffle=True, seed=0)...
WARNING:root:Disabled BalancedBatchSampler because num_replicas=1.
INFO:root:rank: 0: Sampler created...
INFO:root:Created BalancedBatchSampler with sampler=<fairchem.core.common.data_parallel.StatefulDistributedSampler object at 0x7814f6a24680>, batch_size=2, drop_last=False
INFO:root:get_dataloader::Calling Dataloader...
INFO:root:get_dataloader::Done!
INFO:root:get_dataloader::Calling batch_sampler_fn=functools.partial(<class 'fairchem.core.common.data_parallel.BalancedBatchSampler'>, batch_size=2, shuffle=False, seed=0)...
WARNING:root:Disabled BalancedBatchSampler because num_replicas=1.
INFO:root:rank: 0: Sampler created...
INFO:root:Created BalancedBatchSampler with sampler=<fairchem.core.common.data_parallel.StatefulDistributedSampler object at 0x7814f70112b0>, batch_size=2, drop_last=False
INFO:root:get_dataloader::Calling Dataloader...
INFO:root:get_dataloader::Done!
INFO:httpx:HTTP Request: HEAD https://huggingface.co/facebook/UMA/resolve/main/checkpoints/uma-s-1.pt "HTTP/1.1 302 Found"
INFO:httpx:HTTP Request: GET https://huggingface.co/api/models/facebook/UMA/xet-read-token/38529caa2c51a9a8a0d71f0b56b79ac33bc9eceb "HTTP/1.1 200 OK"
checkpoints/uma-s-1.pt:   0%|                       | 0.00/1.17G [00:00<?, ?B/s]
checkpoints/uma-s-1.pt:   0%|                | 615k/1.17G [00:00<30:58, 631kB/s]
checkpoints/uma-s-1.pt:   6%|▊             | 67.7M/1.17G [00:01<00:14, 78.6MB/s]
checkpoints/uma-s-1.pt:  11%|█▊              | 135M/1.17G [00:01<00:09, 107MB/s]
checkpoints/uma-s-1.pt:  20%|███▏            | 235M/1.17G [00:01<00:04, 196MB/s]
checkpoints/uma-s-1.pt:  31%|█████           | 370M/1.17G [00:01<00:02, 319MB/s]
checkpoints/uma-s-1.pt:  37%|█████▉          | 437M/1.17G [00:02<00:01, 370MB/s]
checkpoints/uma-s-1.pt:  43%|██████▊         | 504M/1.17G [00:02<00:01, 405MB/s]
checkpoints/uma-s-1.pt:  54%|████████▋       | 638M/1.17G [00:02<00:00, 572MB/s]
checkpoints/uma-s-1.pt:  66%|██████████▌     | 772M/1.17G [00:02<00:00, 720MB/s]
checkpoints/uma-s-1.pt:  77%|████████████▎   | 906M/1.17G [00:02<00:00, 850MB/s]
checkpoints/uma-s-1.pt:  89%|█████████████▎ | 1.04G/1.17G [00:02<00:00, 949MB/s]
checkpoints/uma-s-1.pt: 100%|██████████████| 1.17G/1.17G [00:02<00:00, 1.05GB/s]
checkpoints/uma-s-1.pt: 100%|███████████████| 1.17G/1.17G [00:02<00:00, 430MB/s]
WARNING:root:initialize_finetuning_model starting from checkpoint_location: /home/runner/.cache/fairchem/models--facebook--UMA/snapshots/38529caa2c51a9a8a0d71f0b56b79ac33bc9eceb/checkpoints/uma-s-1.pt
INFO:root:Train Dataloader size 1
INFO:root:Eval Dataloader size 1
INFO:root:No existing checkpoints found, starting from scratch
INFO:torchtnt.framework.fit:Started fit with max_epochs=2 max_steps=None max_train_steps_per_epoch=None max_eval_steps_per_epoch=None evaluate_every_n_steps=100 evaluate_every_n_epochs=1 
INFO:torchtnt.framework.train:Started train with max_epochs=2, max_steps=None, max_steps_per_epoch=None
INFO:root:on_train_start: setting sampler state to 0, 0
INFO:root:at beginning of epoch 0, setting sampler start step to 0
INFO:torchtnt.framework.train:Started train epoch
INFO:root:at beginning of epoch 0, setting sampler start step to 0
Train Epoch 0:   0%|                                     | 0/1 [00:00<?, ?it/s]
INFO:root:Saved dcp checkpoint to /tmp/finetune_dir/some_id/checkpoints/step_0
INFO:root:0: Expert variance: 3.60e-07,1.19e-02,7.17e-07,3.88e-07,4.74e-04,7.71e-07,4.15e-07,1.53e-02,8.34e-07,2.89e-05,1.65e-02,7.07e-07,7.77e-07,1.01e-06,2.53e-07,2.40e-07,3.66e-07,5.95e-04,1.57e-05,9.03e-04,1.30e-06,1.34e-04,4.34e-07,5.28e-07,9.35e-07,3.97e-07,9.43e-07,5.13e-05,4.00e-03,6.28e-07,6.77e-06,4.84e-05
INFO:root:0: Expert mean: 6.88e-03,1.02e-01,6.92e-03,6.72e-03,2.17e-02,6.83e-03,6.69e-03,1.28e-01,7.05e-03,1.22e-02,9.99e-02,7.17e-03,6.82e-03,7.00e-03,6.73e-03,6.83e-03,7.01e-03,3.28e-01,1.25e-02,5.07e-02,6.99e-03,1.45e-02,6.66e-03,7.03e-03,6.87e-03,7.00e-03,7.11e-03,1.17e-02,2.01e-01,6.74e-03,8.92e-03,3.75e-02
/home/runner/work/_tool/Python/3.12.12/x64/lib/python3.12/site-packages/torch/optim/lr_scheduler.py:367: UserWarning: To get the last learning rate computed by the scheduler, please use `get_last_lr()`.
  _warn_get_lr_called_within_step(self)
INFO:root:{'train/loss': 1.2937601246535735, 'train/lr': 4e-05, 'train/step': 0, 'train/epoch': 0.0, 'train/samples_per_second(approx)': 0.06049049141882921, 'train/atoms_per_second(approx)': 1.4517717940519008, 'train/num_atoms_on_rank': 48, 'train/num_samples_on_rank': 2}
Train Epoch 0: 100%|█████████████████████████████| 1/1 [00:33<00:00, 33.14s/it]
INFO:torchtnt.framework.train:Reached end of train dataloader
Train Epoch 0: 100%|█████████████████████████████| 1/1 [00:33<00:00, 33.14s/it]

INFO:torchtnt.framework.evaluate:Started evaluate with max_steps_per_epoch=None
Eval Epoch 0:   0%|                                      | 0/1 [00:00<?, ?it/s]
Eval Epoch 0: 100%|██████████████████████████████| 1/1 [00:00<00:00,  4.46it/s]
INFO:root:Done eval epoch, aggregating metrics
INFO:root:Finished aggregating metrics: 
  val/atoms_per_second: 368.0586
  val/epoch: 0.0000
  val/loss: 61.0236
  val/omat.val,energy,mae: 97.6377
  val/omat.val,energy,per_atom_mae: 3.0512

Eval Epoch 0: 100%|██████████████████████████████| 1/1 [00:00<00:00,  4.45it/s]

INFO:torchtnt.framework.train:Ended train epoch
INFO:torchtnt.framework.train:After train epoch, train progress: num_epochs_completed = 1, num_steps_completed = 1
INFO:torchtnt.framework.train:Started train epoch
INFO:root:at beginning of epoch 1, setting sampler start step to 0
Train Epoch 1:   0%|                                     | 0/1 [00:00<?, ?it/s]
Train Epoch 1: 100%|█████████████████████████████| 1/1 [00:00<00:00,  2.26it/s]
INFO:torchtnt.framework.train:Reached end of train dataloader
Train Epoch 1: 100%|█████████████████████████████| 1/1 [00:00<00:00,  2.26it/s]

INFO:torchtnt.framework.evaluate:Started evaluate with max_steps_per_epoch=None
Eval Epoch 1:   0%|                                      | 0/1 [00:00<?, ?it/s]
Eval Epoch 1: 100%|██████████████████████████████| 1/1 [00:00<00:00,  4.54it/s]
INFO:root:Done eval epoch, aggregating metrics
INFO:root:Finished aggregating metrics: 
  val/atoms_per_second: 372.8167
  val/epoch: 0.0000
  val/loss: 61.0235
  val/omat.val,energy,mae: 97.6376
  val/omat.val,energy,per_atom_mae: 3.0512

Eval Epoch 1: 100%|██████████████████████████████| 1/1 [00:00<00:00,  4.52it/s]

INFO:torchtnt.framework.train:Ended train epoch
INFO:torchtnt.framework.train:After train epoch, train progress: num_epochs_completed = 2, num_steps_completed = 2
INFO:root:Training Completed 2 steps
INFO:root:Saved dcp checkpoint to /tmp/finetune_dir/some_id/checkpoints/final
INFO:torchtnt.framework.fit:Finished fit

The basic YAML configuration looks like the following:

job:
  device_type: CUDA
  scheduler:
    mode: LOCAL
    ranks_per_node: 1
    num_nodes: 1
  debug: True
  run_dir: /tmp/uma_finetune_runs/
  run_name: uma_finetune
  logger:
    _target_: fairchem.core.common.logger.WandBSingletonLogger.init_wandb
    _partial_: true
    entity: example
    project: uma_finetune


base_model_name: uma-s-1p1
max_neighbors: 300
epochs: 1
steps: null
batch_size: 2
lr: 4e-4

train_dataloader ...
eval_dataloader ...
runner ...

Logging and Artifacts

For logging and checkpoints, all artifacts are stored in the location specified in job.run_dir. The visual logger we support is Weights and Biases.

Distributed Training

We support multi-GPU distributed training without additional infrastructure and multi-node distributed training on SLURM only.

Multi-GPU locally: Simply set job.scheduler.ranks_per_node=N where N is the number of GPUs you want to train on.

Multi-node on SLURM: Change job.scheduler.mode=SLURM and set both job.scheduler.ranks_per_node and job.scheduler.num_nodes to the desired values.

Resuming Runs

To resume from a checkpoint in the middle of a run, find the checkpoint folder at the step you want and use the same fairchem command:

! fairchem -c /tmp/finetune_dir/some_id/checkpoints/final/resume.yaml

Running Inference on the Fine-tuned Model

Inference is run in the same way as the UMA models, except you need to load the checkpoint from a local path.

from fairchem.core.units.mlip_unit import load_predict_unit
from fairchem.core import FAIRChemCalculator

predictor = load_predict_unit("/tmp/finetune_dir/some_id/checkpoints/final/inference_ckpt.pt")
calc = FAIRChemCalculator(predictor, task_name="omat")
WARNING:root:device was not explicitly set, using device='cuda'.