This repo provides a number of scripts to quickly fine-tune a model using a custom ASE LMDB dataset. These scripts are merely for convenience and fine-tuning uses the exact same tooling and infrastructure as our standard training (see Training section). Training in the fairchem repo uses the fairchem CLI tool and configs are in Hydra yaml format.
Generating Training/Fine-tuning Datasets¶
First we need to generate a dataset in the aselmdb format for fine-tuning.
First you should checkout the fairchem repo and install it to access the scripts
git clone git@github.com:facebookresearch/fairchem.git
pip install -e fairchem/src/packages/fairchem-core[dev]Run this script to create the aselmdbs as well as a set of templated yamls for finetuning, we will use a few dummy structures for demonstration purposes
import os
os.chdir('../../../../fairchem')
! python src/fairchem/core/scripts/create_uma_finetune_dataset.py --train-dir docs/core/common_tasks/finetune_assets/train/ --val-dir docs/core/common_tasks/finetune_assets/val --output-dir /tmp/bulk --uma-task=omat --regression-task e 0%| | 0/1 [00:00<?, ?it/s]
0%| | 0/1 [00:00<?, ?it/s]
0it [00:00, ?it/s]
0it [00:00, ?it/s]
0it [00:00, ?it/s]
0it [00:00, ?it/s]
0it [00:00, ?it/s]
0it [00:00, ?it/s]0it [00:00, ?it/s]
100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 26.64it/s]
100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 22.38it/s]
Computing normalizer values.: 100%|██████████████| 2/2 [00:00<00:00, 343.57it/s]
0%| | 0/1 [00:00<?, ?it/s]
0%| | 0/1 [00:00<?, ?it/s]
0it [00:00, ?it/s]
0it [00:00, ?it/s]
0it [00:00, ?it/s]
0it [00:00, ?it/s]
0it [00:00, ?it/s]
0it [00:00, ?it/s]
0it [00:00, ?it/s]
100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 32.46it/s]
100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 25.70it/s]
INFO:root:Generated dataset and data config yaml in /tmp/bulk
INFO:root:To run finetuning, run fairchem -c /tmp/bulk/uma_sm_finetune_template.yaml
Regression Task Options
The regression-task can be one of:
e: Energy only
ef: Energy + forces
efs: Energy + forces + stress
Choose based on the data you have available in the ASE db. For example, some aperiodic DFT codes only support energy/forces and not gradients, and some very fancy codes like QMC only produce energies.
Note: Even if you train on just energy or energy/forces, all gradients (forces/stresses) will be computable via the model gradients.
This will generate a folder of LMDBs and a uma_sm_finetune_template.yaml that you can run directly with the fairchem CLI to start training.
Model Fine-tuning (Default Settings)¶
The previous step should have generated some YAML files to get you started on fine-tuning. You can simply run this with the fairchem CLI. The default is configured to run locally on 1 GPU.
! fairchem -c /tmp/bulk/uma_sm_finetune_template.yamlAdvanced Configuration¶
The scripts provide a simple way to get started on fine-tuning, but likely for your own use cases you will need to modify the parameters. The configuration uses Hydra-style YAMLs.
! fairchem -c /tmp/bulk/uma_sm_finetune_template.yaml epochs=2 lr=2e-4 job.run_dir=/tmp/finetune_dir +job.timestamp_id=some_idINFO:root:saved canonical config to /tmp/finetune_dir/some_id/canonical_config.yaml
INFO:root:Running fairchemv2 cli with {'job': {'run_name': 'uma_finetune', 'timestamp_id': 'some_id', 'run_dir': '/tmp/finetune_dir', 'device_type': <DeviceType.CUDA: 'cuda'>, 'debug': True, 'scheduler': {'mode': <SchedulerType.LOCAL: 'local'>, 'distributed_init_method': <DistributedInitMethod.TCP: 'tcp'>, 'ranks_per_node': 1, 'num_nodes': 1, 'num_array_jobs': 1, 'slurm': {'mem_gb': 80, 'timeout_hr': 168, 'cpus_per_task': 8, 'partition': None, 'qos': None, 'account': None, 'additional_parameters': None}, 'use_ray': False, 'ray_cluster': {'head_gpus': 0}}, 'logger': {'_target_': 'fairchem.core.common.logger.WandBSingletonLogger.init_wandb', '_partial_': True, 'entity': 'example', 'project': 'uma_finetune'}, 'seed': 0, 'deterministic': False, 'runner_state_path': None, 'metadata': {'commit': 'core:None,experimental:NA', 'log_dir': '/tmp/finetune_dir/some_id/logs', 'checkpoint_dir': '/tmp/finetune_dir/some_id/checkpoints', 'results_dir': '/tmp/finetune_dir/some_id/results', 'config_path': '/tmp/finetune_dir/some_id/canonical_config.yaml', 'preemption_checkpoint_dir': '/tmp/finetune_dir/some_id/checkpoints/preemption_state', 'cluster_name': 'github', 'array_job_num': 0, 'slurm_env': {'job_id': None, 'raw_job_id': None, 'array_job_id': None, 'array_task_id': None, 'restart_count': None}}, 'graph_parallel_group_size': None, 'recursive_instantiate_runner': True}, 'runner': {'_target_': 'fairchem.core.components.train.train_runner.TrainEvalRunner', 'train_dataloader': {'_target_': 'fairchem.core.components.common.dataloader_builder.get_dataloader', 'dataset': {'_target_': 'fairchem.core.datasets.mt_concat_dataset.create_concat_dataset', 'dataset_configs': {'omat': {'splits': {'train': {'src': '/tmp/bulk/train'}}, 'format': 'ase_db', 'transforms': {'common_transform': {'dataset_name': 'omat'}, 'stress_reshape_transform': {'dataset_name': 'omat'}}}}, 'combined_dataset_config': {'sampling': {'type': 'temperature', 'temperature': 1.0}}}, 'batch_sampler_fn': {'_target_': 'fairchem.core.common.data_parallel.BalancedBatchSampler', '_partial_': True, 'batch_size': 2, 'shuffle': True, 'seed': 0}, 'num_workers': 0, 'collate_fn': {'_target_': 'fairchem.core.units.mlip_unit.mlip_unit.mt_collater_adapter', 'tasks': [{'_target_': 'fairchem.core.units.mlip_unit.mlip_unit.Task', 'name': 'energy', 'level': 'system', 'property': 'energy', 'loss_fn': {'_target_': 'fairchem.core.modules.loss.DDPMTLoss', 'loss_fn': {'_target_': 'fairchem.core.modules.loss.PerAtomMAELoss'}, 'coefficient': 20}, 'out_spec': {'dim': [1], 'dtype': 'float32'}, 'normalizer': {'_target_': 'fairchem.core.modules.normalization.normalizer.Normalizer', 'mean': 0.0, 'rmsd': 1.0}, 'element_references': {'_target_': 'fairchem.core.modules.normalization.element_references.ElementReferences', 'element_references': {'_target_': 'torch.DoubleTensor', '_args_': [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -2.082805460035017, 0.0, 0.0, 0.0, -3.256404920200298, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]}}, 'datasets': ['omat'], 'metrics': ['mae', 'per_atom_mae']}, {'_target_': 'fairchem.core.units.mlip_unit.mlip_unit.Task', 'name': 'forces', 'level': 'atom', 'property': 'forces', 'out_spec': {'dim': [3], 'dtype': 'float32'}, 'normalizer': {'_target_': 'fairchem.core.modules.normalization.normalizer.Normalizer', 'mean': 0.0, 'rmsd': 1.0}, 'datasets': ['omat'], 'inference_only': True}, {'_target_': 'fairchem.core.units.mlip_unit.mlip_unit.Task', 'name': 'stress', 'level': 'system', 'property': 'stress', 'out_spec': {'dim': [1, 9], 'dtype': 'float32'}, 'normalizer': {'_target_': 'fairchem.core.modules.normalization.normalizer.Normalizer', 'mean': 0.0, 'rmsd': 1.0}, 'datasets': ['omat'], 'inference_only': True}]}}, 'eval_dataloader': {'_target_': 'fairchem.core.components.common.dataloader_builder.get_dataloader', 'dataset': {'_target_': 'fairchem.core.datasets.mt_concat_dataset.create_concat_dataset', 'dataset_configs': {'omat': {'splits': {'val': {'src': '/tmp/bulk/val'}}, 'format': 'ase_db', 'transforms': {'common_transform': {'dataset_name': 'omat'}, 'stress_reshape_transform': {'dataset_name': 'omat'}}}}, 'combined_dataset_config': {'sampling': {'type': 'temperature', 'temperature': 1.0}}}, 'batch_sampler_fn': {'_target_': 'fairchem.core.common.data_parallel.BalancedBatchSampler', '_partial_': True, 'batch_size': 2, 'shuffle': False, 'seed': 0}, 'num_workers': 0, 'collate_fn': {'_target_': 'fairchem.core.units.mlip_unit.mlip_unit.mt_collater_adapter', 'tasks': [{'_target_': 'fairchem.core.units.mlip_unit.mlip_unit.Task', 'name': 'energy', 'level': 'system', 'property': 'energy', 'loss_fn': {'_target_': 'fairchem.core.modules.loss.DDPMTLoss', 'loss_fn': {'_target_': 'fairchem.core.modules.loss.PerAtomMAELoss'}, 'coefficient': 20}, 'out_spec': {'dim': [1], 'dtype': 'float32'}, 'normalizer': {'_target_': 'fairchem.core.modules.normalization.normalizer.Normalizer', 'mean': 0.0, 'rmsd': 1.0}, 'element_references': {'_target_': 'fairchem.core.modules.normalization.element_references.ElementReferences', 'element_references': {'_target_': 'torch.DoubleTensor', '_args_': [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -2.082805460035017, 0.0, 0.0, 0.0, -3.256404920200298, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]}}, 'datasets': ['omat'], 'metrics': ['mae', 'per_atom_mae']}, {'_target_': 'fairchem.core.units.mlip_unit.mlip_unit.Task', 'name': 'forces', 'level': 'atom', 'property': 'forces', 'out_spec': {'dim': [3], 'dtype': 'float32'}, 'normalizer': {'_target_': 'fairchem.core.modules.normalization.normalizer.Normalizer', 'mean': 0.0, 'rmsd': 1.0}, 'datasets': ['omat'], 'inference_only': True}, {'_target_': 'fairchem.core.units.mlip_unit.mlip_unit.Task', 'name': 'stress', 'level': 'system', 'property': 'stress', 'out_spec': {'dim': [1, 9], 'dtype': 'float32'}, 'normalizer': {'_target_': 'fairchem.core.modules.normalization.normalizer.Normalizer', 'mean': 0.0, 'rmsd': 1.0}, 'datasets': ['omat'], 'inference_only': True}]}}, 'train_eval_unit': {'_target_': 'fairchem.core.units.mlip_unit.mlip_unit.MLIPTrainEvalUnit', 'job_config': {'run_name': 'uma_finetune', 'timestamp_id': 'some_id', 'run_dir': '/tmp/finetune_dir', 'device_type': <DeviceType.CUDA: 'cuda'>, 'debug': True, 'scheduler': {'mode': <SchedulerType.LOCAL: 'local'>, 'distributed_init_method': <DistributedInitMethod.TCP: 'tcp'>, 'ranks_per_node': 1, 'num_nodes': 1, 'num_array_jobs': 1, 'slurm': {'mem_gb': 80, 'timeout_hr': 168, 'cpus_per_task': 8, 'partition': None, 'qos': None, 'account': None, 'additional_parameters': None}, 'use_ray': False, 'ray_cluster': {'head_gpus': 0}}, 'logger': {'_target_': 'fairchem.core.common.logger.WandBSingletonLogger.init_wandb', '_partial_': True, 'entity': 'example', 'project': 'uma_finetune'}, 'seed': 0, 'deterministic': False, 'runner_state_path': None, 'metadata': {'commit': 'core:None,experimental:NA', 'log_dir': '/tmp/finetune_dir/some_id/logs', 'checkpoint_dir': '/tmp/finetune_dir/some_id/checkpoints', 'results_dir': '/tmp/finetune_dir/some_id/results', 'config_path': '/tmp/finetune_dir/some_id/canonical_config.yaml', 'preemption_checkpoint_dir': '/tmp/finetune_dir/some_id/checkpoints/preemption_state', 'cluster_name': 'github', 'array_job_num': 0, 'slurm_env': {'job_id': None, 'raw_job_id': None, 'array_job_id': None, 'array_task_id': None, 'restart_count': None}}, 'graph_parallel_group_size': None, 'recursive_instantiate_runner': True}, 'tasks': [{'_target_': 'fairchem.core.units.mlip_unit.mlip_unit.Task', 'name': 'energy', 'level': 'system', 'property': 'energy', 'loss_fn': {'_target_': 'fairchem.core.modules.loss.DDPMTLoss', 'loss_fn': {'_target_': 'fairchem.core.modules.loss.PerAtomMAELoss'}, 'coefficient': 20}, 'out_spec': {'dim': [1], 'dtype': 'float32'}, 'normalizer': {'_target_': 'fairchem.core.modules.normalization.normalizer.Normalizer', 'mean': 0.0, 'rmsd': 1.0}, 'element_references': {'_target_': 'fairchem.core.modules.normalization.element_references.ElementReferences', 'element_references': {'_target_': 'torch.DoubleTensor', '_args_': [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -2.082805460035017, 0.0, 0.0, 0.0, -3.256404920200298, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]}}, 'datasets': ['omat'], 'metrics': ['mae', 'per_atom_mae']}, {'_target_': 'fairchem.core.units.mlip_unit.mlip_unit.Task', 'name': 'forces', 'level': 'atom', 'property': 'forces', 'out_spec': {'dim': [3], 'dtype': 'float32'}, 'normalizer': {'_target_': 'fairchem.core.modules.normalization.normalizer.Normalizer', 'mean': 0.0, 'rmsd': 1.0}, 'datasets': ['omat'], 'inference_only': True}, {'_target_': 'fairchem.core.units.mlip_unit.mlip_unit.Task', 'name': 'stress', 'level': 'system', 'property': 'stress', 'out_spec': {'dim': [1, 9], 'dtype': 'float32'}, 'normalizer': {'_target_': 'fairchem.core.modules.normalization.normalizer.Normalizer', 'mean': 0.0, 'rmsd': 1.0}, 'datasets': ['omat'], 'inference_only': True}], 'model': {'_target_': 'fairchem.core.units.mlip_unit.mlip_unit.initialize_finetuning_model', 'checkpoint_location': {'_target_': 'fairchem.core.calculate.pretrained_mlip.pretrained_checkpoint_path_from_name', 'model_name': 'uma-s-1'}, 'overrides': {'backbone': {'otf_graph': True, 'max_neighbors': 300, 'regress_stress': True, 'always_use_pbc': False}, 'pass_through_head_outputs': True}, 'heads': {'efs': {'module': 'fairchem.core.models.uma.escn_md.MLP_EFS_Head'}}}, 'optimizer_fn': {'_target_': 'torch.optim.AdamW', '_partial_': True, 'lr': 0.0002, 'weight_decay': 0.001}, 'cosine_lr_scheduler_fn': {'_target_': 'fairchem.core.units.mlip_unit.mlip_unit._get_consine_lr_scheduler', '_partial_': True, 'warmup_factor': 0.2, 'warmup_epochs': 0.01, 'lr_min_factor': 0.01, 'epochs': 2, 'steps': None}, 'print_every': 10, 'clip_grad_norm': 100}, 'max_epochs': 2, 'max_steps': None, 'evaluate_every_n_steps': 100, 'callbacks': [{'_target_': 'fairchem.core.components.train.train_runner.TrainCheckpointCallback', 'checkpoint_every_n_steps': 1000, 'max_saved_checkpoints': 5}, {'_target_': 'torchtnt.framework.callbacks.TQDMProgressBar'}]}}
INFO:root:Running in local mode with 1 ranks using device_type:cuda
INFO:root:Running in local mode without elastic launch
INFO:root:Setting env PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
INFO:root:Setting up distributed backend...
INFO:root:Calling runner.run() ...
INFO:root:get_dataloader::Calling batch_sampler_fn=functools.partial(<class 'fairchem.core.common.data_parallel.BalancedBatchSampler'>, batch_size=2, shuffle=True, seed=0)...
WARNING:root:Disabled BalancedBatchSampler because num_replicas=1.
INFO:root:rank: 0: Sampler created...
INFO:root:Created BalancedBatchSampler with sampler=<fairchem.core.common.data_parallel.StatefulDistributedSampler object at 0x7814f6a24680>, batch_size=2, drop_last=False
INFO:root:get_dataloader::Calling Dataloader...
INFO:root:get_dataloader::Done!
INFO:root:get_dataloader::Calling batch_sampler_fn=functools.partial(<class 'fairchem.core.common.data_parallel.BalancedBatchSampler'>, batch_size=2, shuffle=False, seed=0)...
WARNING:root:Disabled BalancedBatchSampler because num_replicas=1.
INFO:root:rank: 0: Sampler created...
INFO:root:Created BalancedBatchSampler with sampler=<fairchem.core.common.data_parallel.StatefulDistributedSampler object at 0x7814f70112b0>, batch_size=2, drop_last=False
INFO:root:get_dataloader::Calling Dataloader...
INFO:root:get_dataloader::Done!
INFO:httpx:HTTP Request: HEAD https://huggingface.co/facebook/UMA/resolve/main/checkpoints/uma-s-1.pt "HTTP/1.1 302 Found"
INFO:httpx:HTTP Request: GET https://huggingface.co/api/models/facebook/UMA/xet-read-token/38529caa2c51a9a8a0d71f0b56b79ac33bc9eceb "HTTP/1.1 200 OK"
checkpoints/uma-s-1.pt: 0%| | 0.00/1.17G [00:00<?, ?B/s]checkpoints/uma-s-1.pt: 0%| | 615k/1.17G [00:00<30:58, 631kB/s]checkpoints/uma-s-1.pt: 6%|▊ | 67.7M/1.17G [00:01<00:14, 78.6MB/s]checkpoints/uma-s-1.pt: 11%|█▊ | 135M/1.17G [00:01<00:09, 107MB/s]checkpoints/uma-s-1.pt: 20%|███▏ | 235M/1.17G [00:01<00:04, 196MB/s]checkpoints/uma-s-1.pt: 31%|█████ | 370M/1.17G [00:01<00:02, 319MB/s]checkpoints/uma-s-1.pt: 37%|█████▉ | 437M/1.17G [00:02<00:01, 370MB/s]checkpoints/uma-s-1.pt: 43%|██████▊ | 504M/1.17G [00:02<00:01, 405MB/s]checkpoints/uma-s-1.pt: 54%|████████▋ | 638M/1.17G [00:02<00:00, 572MB/s]checkpoints/uma-s-1.pt: 66%|██████████▌ | 772M/1.17G [00:02<00:00, 720MB/s]checkpoints/uma-s-1.pt: 77%|████████████▎ | 906M/1.17G [00:02<00:00, 850MB/s]checkpoints/uma-s-1.pt: 89%|█████████████▎ | 1.04G/1.17G [00:02<00:00, 949MB/s]checkpoints/uma-s-1.pt: 100%|██████████████| 1.17G/1.17G [00:02<00:00, 1.05GB/s]checkpoints/uma-s-1.pt: 100%|███████████████| 1.17G/1.17G [00:02<00:00, 430MB/s]
WARNING:root:initialize_finetuning_model starting from checkpoint_location: /home/runner/.cache/fairchem/models--facebook--UMA/snapshots/38529caa2c51a9a8a0d71f0b56b79ac33bc9eceb/checkpoints/uma-s-1.pt
INFO:root:Train Dataloader size 1
INFO:root:Eval Dataloader size 1
INFO:root:No existing checkpoints found, starting from scratch
INFO:torchtnt.framework.fit:Started fit with max_epochs=2 max_steps=None max_train_steps_per_epoch=None max_eval_steps_per_epoch=None evaluate_every_n_steps=100 evaluate_every_n_epochs=1
INFO:torchtnt.framework.train:Started train with max_epochs=2, max_steps=None, max_steps_per_epoch=None
INFO:root:on_train_start: setting sampler state to 0, 0
INFO:root:at beginning of epoch 0, setting sampler start step to 0
INFO:torchtnt.framework.train:Started train epoch
INFO:root:at beginning of epoch 0, setting sampler start step to 0
Train Epoch 0: 0%| | 0/1 [00:00<?, ?it/s]
INFO:root:Saved dcp checkpoint to /tmp/finetune_dir/some_id/checkpoints/step_0
INFO:root:0: Expert variance: 3.60e-07,1.19e-02,7.17e-07,3.88e-07,4.74e-04,7.71e-07,4.15e-07,1.53e-02,8.34e-07,2.89e-05,1.65e-02,7.07e-07,7.77e-07,1.01e-06,2.53e-07,2.40e-07,3.66e-07,5.95e-04,1.57e-05,9.03e-04,1.30e-06,1.34e-04,4.34e-07,5.28e-07,9.35e-07,3.97e-07,9.43e-07,5.13e-05,4.00e-03,6.28e-07,6.77e-06,4.84e-05
INFO:root:0: Expert mean: 6.88e-03,1.02e-01,6.92e-03,6.72e-03,2.17e-02,6.83e-03,6.69e-03,1.28e-01,7.05e-03,1.22e-02,9.99e-02,7.17e-03,6.82e-03,7.00e-03,6.73e-03,6.83e-03,7.01e-03,3.28e-01,1.25e-02,5.07e-02,6.99e-03,1.45e-02,6.66e-03,7.03e-03,6.87e-03,7.00e-03,7.11e-03,1.17e-02,2.01e-01,6.74e-03,8.92e-03,3.75e-02
/home/runner/work/_tool/Python/3.12.12/x64/lib/python3.12/site-packages/torch/optim/lr_scheduler.py:367: UserWarning: To get the last learning rate computed by the scheduler, please use `get_last_lr()`.
_warn_get_lr_called_within_step(self)
INFO:root:{'train/loss': 1.2937601246535735, 'train/lr': 4e-05, 'train/step': 0, 'train/epoch': 0.0, 'train/samples_per_second(approx)': 0.06049049141882921, 'train/atoms_per_second(approx)': 1.4517717940519008, 'train/num_atoms_on_rank': 48, 'train/num_samples_on_rank': 2}
Train Epoch 0: 100%|█████████████████████████████| 1/1 [00:33<00:00, 33.14s/it]
INFO:torchtnt.framework.train:Reached end of train dataloader
Train Epoch 0: 100%|█████████████████████████████| 1/1 [00:33<00:00, 33.14s/it]
INFO:torchtnt.framework.evaluate:Started evaluate with max_steps_per_epoch=None
Eval Epoch 0: 0%| | 0/1 [00:00<?, ?it/s]
Eval Epoch 0: 100%|██████████████████████████████| 1/1 [00:00<00:00, 4.46it/s]
INFO:root:Done eval epoch, aggregating metrics
INFO:root:Finished aggregating metrics:
val/atoms_per_second: 368.0586
val/epoch: 0.0000
val/loss: 61.0236
val/omat.val,energy,mae: 97.6377
val/omat.val,energy,per_atom_mae: 3.0512
Eval Epoch 0: 100%|██████████████████████████████| 1/1 [00:00<00:00, 4.45it/s]
INFO:torchtnt.framework.train:Ended train epoch
INFO:torchtnt.framework.train:After train epoch, train progress: num_epochs_completed = 1, num_steps_completed = 1
INFO:torchtnt.framework.train:Started train epoch
INFO:root:at beginning of epoch 1, setting sampler start step to 0
Train Epoch 1: 0%| | 0/1 [00:00<?, ?it/s]
Train Epoch 1: 100%|█████████████████████████████| 1/1 [00:00<00:00, 2.26it/s]
INFO:torchtnt.framework.train:Reached end of train dataloader
Train Epoch 1: 100%|█████████████████████████████| 1/1 [00:00<00:00, 2.26it/s]
INFO:torchtnt.framework.evaluate:Started evaluate with max_steps_per_epoch=None
Eval Epoch 1: 0%| | 0/1 [00:00<?, ?it/s]
Eval Epoch 1: 100%|██████████████████████████████| 1/1 [00:00<00:00, 4.54it/s]
INFO:root:Done eval epoch, aggregating metrics
INFO:root:Finished aggregating metrics:
val/atoms_per_second: 372.8167
val/epoch: 0.0000
val/loss: 61.0235
val/omat.val,energy,mae: 97.6376
val/omat.val,energy,per_atom_mae: 3.0512
Eval Epoch 1: 100%|██████████████████████████████| 1/1 [00:00<00:00, 4.52it/s]
INFO:torchtnt.framework.train:Ended train epoch
INFO:torchtnt.framework.train:After train epoch, train progress: num_epochs_completed = 2, num_steps_completed = 2
INFO:root:Training Completed 2 steps
INFO:root:Saved dcp checkpoint to /tmp/finetune_dir/some_id/checkpoints/final
INFO:torchtnt.framework.fit:Finished fit
The basic YAML configuration looks like the following:
job:
device_type: CUDA
scheduler:
mode: LOCAL
ranks_per_node: 1
num_nodes: 1
debug: True
run_dir: /tmp/uma_finetune_runs/
run_name: uma_finetune
logger:
_target_: fairchem.core.common.logger.WandBSingletonLogger.init_wandb
_partial_: true
entity: example
project: uma_finetune
base_model_name: uma-s-1p1
max_neighbors: 300
epochs: 1
steps: null
batch_size: 2
lr: 4e-4
train_dataloader ...
eval_dataloader ...
runner ...Configuration Parameters
base_model_name: Refers to a model name that can be retrieved from HuggingFace. If you want to use your custom UMA checkpoint, provide the path directly in the runner:
model: _target_: fairchem.core.units.mlip_unit.mlip_unit.initialize_finetuning_model checkpoint_location: /path/to/your/checkpoint.ptmax_neighbors: The number of neighbors used for the equivariant SO2 convolutions. 300 is the default used in UMA training, but if you don’t have a lot of memory, 100 is usually fine to ensure smoothness of the potential (see the ESEN paper).
epochs, steps: Choose to either run for an integer number of epochs or steps. Only 1 can be specified; the other must be null.
batch_size: In this configuration we use the batch sampler. Start with the largest batch size that can fit on your system without running out of memory. However, don’t use a batch size so large that you complete training in very few steps. The optimal batch size is usually the one that minimizes the final validation loss for a fixed compute budget.
lr, weight_decay: These are standard learning parameters. The recommended values we use are the defaults.
Logging and Artifacts¶
For logging and checkpoints, all artifacts are stored in the location specified in job.run_dir. The visual logger we support is Weights and Biases.
Distributed Training¶
We support multi-GPU distributed training without additional infrastructure and multi-node distributed training on SLURM only.
Multi-GPU locally: Simply set job.scheduler.ranks_per_node=N where N is the number of GPUs you want to train on.
Multi-node on SLURM: Change job.scheduler.mode=SLURM and set both job.scheduler.ranks_per_node and job.scheduler.num_nodes to the desired values.
Resuming Runs¶
To resume from a checkpoint in the middle of a run, find the checkpoint folder at the step you want and use the same fairchem command:
! fairchem -c /tmp/finetune_dir/some_id/checkpoints/final/resume.yamlRunning Inference on the Fine-tuned Model¶
Inference is run in the same way as the UMA models, except you need to load the checkpoint from a local path.
from fairchem.core.units.mlip_unit import load_predict_unit
from fairchem.core import FAIRChemCalculator
predictor = load_predict_unit("/tmp/finetune_dir/some_id/checkpoints/final/inference_ckpt.pt")
calc = FAIRChemCalculator(predictor, task_name="omat")WARNING:root:device was not explicitly set, using device='cuda'.