Fine-tuning with Python#

The recommended way to do training is with the main.py script in ocp. One of the reasons for that is training often takes a long time and is better suited for queue systems like slurm. However, you can submit Python scripts too, and it is possible to run notebooks in Slurm too. Here we work out a proof of concept in training from Python and a Jupyter notebook.

import logging
from fairchem.core.common.utils import SeverityLevelBetween

root = logging.getLogger()


root.setLevel(logging.INFO)

log_formatter = logging.Formatter(
            "%(asctime)s (%(levelname)s): %(message)s",
            datefmt="%Y-%m-%d %H:%M:%S",)

# Send INFO to stdout
handler_out = logging.FileHandler('out.txt', 'w')
handler_out.addFilter(
            SeverityLevelBetween(logging.INFO, logging.WARNING)
        )
handler_out.setFormatter(log_formatter)
root.addHandler(handler_out)

# Send WARNING (and higher) to stderr
handler_err = logging.FileHandler('out.txt', 'w+')
handler_err.setLevel(logging.WARNING)
handler_err.setFormatter(log_formatter)
root.addHandler(handler_err)
! ase db ../../core/fine-tuning/oxides.db
id|age|user  |formula|calculator| energy|natoms| fmax|pbc| volume|charge|   mass
 1|47m|runner|Sn2O4  |unknown   |-41.359|     6|0.045|TTT| 64.258| 0.000|301.416
 2|47m|runner|Sn2O4  |unknown   |-41.853|     6|0.025|TTT| 66.526| 0.000|301.416
 3|47m|runner|Sn2O4  |unknown   |-42.199|     6|0.010|TTT| 68.794| 0.000|301.416
 4|47m|runner|Sn2O4  |unknown   |-42.419|     6|0.006|TTT| 71.062| 0.000|301.416
 5|47m|runner|Sn2O4  |unknown   |-42.534|     6|0.011|TTT| 73.330| 0.000|301.416
 6|47m|runner|Sn2O4  |unknown   |-42.562|     6|0.029|TTT| 75.598| 0.000|301.416
 7|47m|runner|Sn2O4  |unknown   |-42.518|     6|0.033|TTT| 77.866| 0.000|301.416
 8|47m|runner|Sn2O4  |unknown   |-42.415|     6|0.010|TTT| 80.134| 0.000|301.416
 9|47m|runner|Sn2O4  |unknown   |-42.266|     6|0.006|TTT| 82.402| 0.000|301.416
10|47m|runner|Sn2O4  |unknown   |-42.083|     6|0.017|TTT| 84.670| 0.000|301.416
11|47m|runner|Sn4O8  |unknown   |-81.424|    12|0.012|TTT|117.473| 0.000|602.832
12|47m|runner|Sn4O8  |unknown   |-82.437|    12|0.005|TTT|121.620| 0.000|602.832
13|47m|runner|Sn4O8  |unknown   |-83.147|    12|0.015|TTT|125.766| 0.000|602.832
14|47m|runner|Sn4O8  |unknown   |-83.599|    12|0.047|TTT|129.912| 0.000|602.832
15|47m|runner|Sn4O8  |unknown   |-83.831|    12|0.081|TTT|134.058| 0.000|602.832
16|47m|runner|Sn4O8  |unknown   |-83.898|    12|0.001|TTT|138.204| 0.000|602.832
17|47m|runner|Sn4O8  |unknown   |-83.805|    12|0.001|TTT|142.350| 0.000|602.832
18|47m|runner|Sn4O8  |unknown   |-83.586|    12|0.002|TTT|146.496| 0.000|602.832
19|47m|runner|Sn4O8  |unknown   |-83.262|    12|0.002|TTT|150.642| 0.000|602.832
20|47m|runner|Sn4O8  |unknown   |-82.851|    12|0.013|TTT|154.788| 0.000|602.832
Rows: 295 (showing first 20)
from fairchem.core.models.model_registry import model_name_to_local_file

checkpoint_path = model_name_to_local_file('GemNet-OC-S2EFS-OC20+OC22', local_cache='/tmp/fairchem_checkpoints/')
from fairchem.core.common.relaxation.ase_utils import OCPCalculator
calc = OCPCalculator(checkpoint_path=checkpoint_path, trainer='forces', cpu=False)
/home/runner/work/fairchem/fairchem/src/fairchem/core/models/escn/so3.py:23: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  _Jd = torch.load(os.path.join(os.path.dirname(__file__), "Jd.pt"))
/home/runner/work/fairchem/fairchem/src/fairchem/core/models/scn/spherical_harmonics.py:23: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  _Jd = torch.load(os.path.join(os.path.dirname(__file__), "Jd.pt"))
/home/runner/work/fairchem/fairchem/src/fairchem/core/models/equiformer_v2/wigner.py:10: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  _Jd = torch.load(os.path.join(os.path.dirname(__file__), "Jd.pt"))
/home/runner/work/fairchem/fairchem/src/fairchem/core/models/equiformer_v2/layer_norm.py:75: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
  @torch.cuda.amp.autocast(enabled=False)
/home/runner/work/fairchem/fairchem/src/fairchem/core/models/equiformer_v2/layer_norm.py:175: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
  @torch.cuda.amp.autocast(enabled=False)
/home/runner/work/fairchem/fairchem/src/fairchem/core/models/equiformer_v2/layer_norm.py:263: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
  @torch.cuda.amp.autocast(enabled=False)
/home/runner/work/fairchem/fairchem/src/fairchem/core/models/equiformer_v2/layer_norm.py:357: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
  @torch.cuda.amp.autocast(enabled=False)
/home/runner/work/fairchem/fairchem/src/fairchem/core/common/relaxation/ase_utils.py:150: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  checkpoint = torch.load(checkpoint_path, map_location=torch.device("cpu"))

Split the data into train, test, val sets#

! rm -fr train.db test.db val.db

from fairchem.core.common.tutorial_utils import train_test_val_split

train, test, val = train_test_val_split('../../core/fine-tuning/oxides.db')
train, test, val
(PosixPath('/home/runner/work/fairchem/fairchem/docs/tutorials/advanced/train.db'),
 PosixPath('/home/runner/work/fairchem/fairchem/docs/tutorials/advanced/test.db'),
 PosixPath('/home/runner/work/fairchem/fairchem/docs/tutorials/advanced/val.db'))

Setup the training code#

We start by making the config.yml. We build this from the calculator checkpoint.

from fairchem.core.common.tutorial_utils import generate_yml_config

yml = generate_yml_config(checkpoint_path, 'config.yml',
                   delete=['slurm', 'cmd', 'logger', 'task', 'model_attributes','optim.load_balancing',
                           'optim.loss_force', # the checkpoint setting causes an error
                           'dataset', 'test_dataset', 'val_dataset'],
                   update={'gpus': 1,
                           'optim.eval_every': 10,
                           'optim.max_epochs': 1,
                           'optim.batch_size': 4,
                           'logger': 'tensorboard', # don't use wandb unless you already are logged in
                           # Train data
                           'dataset.train.src': 'train.db',
                           'dataset.train.format': 'ase_db',
                           'dataset.train.a2g_args.r_energy': True,
                           'dataset.train.a2g_args.r_forces': True,
                            # Test data - prediction only so no regression
                           'dataset.test.src': 'test.db',
                           'dataset.test.format': 'ase_db',
                           'dataset.test.a2g_args.r_energy': False,
                           'dataset.test.a2g_args.r_forces': False,
                           # val data
                           'dataset.val.src': 'val.db',
                           'dataset.val.format': 'ase_db',
                           'dataset.val.a2g_args.r_energy': True,
                           'dataset.val.a2g_args.r_forces': True,
                          })

yml
/home/runner/work/fairchem/fairchem/src/fairchem/core/common/relaxation/ase_utils.py:150: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  checkpoint = torch.load(checkpoint_path, map_location=torch.device("cpu"))
PosixPath('/home/runner/work/fairchem/fairchem/docs/tutorials/advanced/config.yml')

Setup the training task#

This essentially allows several opportunities to define and override the config. You start with the base config.yml, and then via “command-line” arguments you specify changes you want to make.

The code is build around submitit, which is often used with Slurm, but also works locally.

We have to mimic the main.py setup to get the arguments and config setup. Here is a minimal way to do this.

from fairchem.core.common.flags import flags
parser = flags.get_parser()
args, args_override = parser.parse_known_args(["--mode=train",
                                               "--config-yml=config.yml",
                                               f"--checkpoint={checkpoint_path}",
                                               "--cpu"])
args, args_override
(Namespace(mode='train', config_yml=PosixPath('config.yml'), identifier='', debug=False, run_dir='/home/runner/work/fairchem/fairchem/docs/tutorials/advanced', print_every=10, seed=0, amp=False, checkpoint='/tmp/fairchem_checkpoints/gnoc_oc22_oc20_all_s2ef.pt', timestamp_id=None, sweep_yml=None, submit=False, summit=False, logdir=PosixPath('logs'), slurm_partition=None, slurm_account=None, slurm_qos=None, slurm_mem=80, slurm_timeout=72, num_gpus=1, cpu=True, num_nodes=1, gp_gpus=None),
 [])

Next, we build the first stage in our config. This starts with the file config.yml, then updates it with the args

from fairchem.core.common.utils import build_config, new_trainer_context

config = build_config(args=args, args_override={})
config
{'amp': False,
 'checkpoint': '/tmp/fairchem_checkpoints/gnoc_oc22_oc20_all_s2ef.pt',
 'dataset': {'test': {'a2g_args': {'r_energy': False, 'r_forces': False},
   'format': 'ase_db',
   'src': 'test.db'},
  'train': {'a2g_args': {'r_energy': True, 'r_forces': True},
   'format': 'ase_db',
   'src': 'train.db'},
  'val': {'a2g_args': {'r_energy': True, 'r_forces': True},
   'format': 'ase_db',
   'src': 'val.db'}},
 'evaluation_metrics': {'metrics': {'energy': ['mae'],
   'forces': ['forcesx_mae',
    'forcesy_mae',
    'forcesz_mae',
    'mae',
    'cosine_similarity',
    'magnitude_error'],
   'misc': ['energy_forces_within_threshold']},
  'primary_metric': 'forces_mae'},
 'gpus': 1,
 'logger': 'tensorboard',
 'loss_functions': [{'energy': {'coefficient': 1, 'fn': 'mae'}},
  {'forces': {'coefficient': 1, 'fn': 'l2mae'}}],
 'model': {'activation': 'silu',
  'atom_edge_interaction': True,
  'atom_interaction': True,
  'cbf': {'name': 'spherical_harmonics'},
  'cutoff': 12.0,
  'cutoff_aeaint': 12.0,
  'cutoff_aint': 12.0,
  'cutoff_qint': 12.0,
  'direct_forces': True,
  'edge_atom_interaction': True,
  'emb_size_aint_in': 64,
  'emb_size_aint_out': 64,
  'emb_size_atom': 256,
  'emb_size_cbf': 16,
  'emb_size_edge': 512,
  'emb_size_quad_in': 32,
  'emb_size_quad_out': 32,
  'emb_size_rbf': 16,
  'emb_size_sbf': 32,
  'emb_size_trip_in': 64,
  'emb_size_trip_out': 64,
  'envelope': {'exponent': 5, 'name': 'polynomial'},
  'extensive': True,
  'forces_coupled': False,
  'max_neighbors': 30,
  'max_neighbors_aeaint': 20,
  'max_neighbors_aint': 1000,
  'max_neighbors_qint': 8,
  'name': 'gemnet_oc',
  'num_after_skip': 2,
  'num_atom': 3,
  'num_atom_emb_layers': 2,
  'num_before_skip': 2,
  'num_blocks': 4,
  'num_concat': 1,
  'num_global_out_layers': 2,
  'num_output_afteratom': 3,
  'num_radial': 128,
  'num_spherical': 7,
  'otf_graph': True,
  'output_init': 'HeOrthogonal',
  'qint_tags': [1, 2],
  'quad_interaction': True,
  'rbf': {'name': 'gaussian'},
  'regress_forces': True,
  'sbf': {'name': 'legendre_outer'},
  'symmetric_edge_symmetrization': False},
 'noddp': False,
 'optim': {'batch_size': 4,
  'clip_grad_norm': 10,
  'ema_decay': 0.999,
  'energy_coefficient': 1,
  'eval_batch_size': 16,
  'eval_every': 10,
  'factor': 0.8,
  'force_coefficient': 1,
  'loss_energy': 'mae',
  'lr_initial': 0.0005,
  'max_epochs': 1,
  'mode': 'min',
  'num_workers': 2,
  'optimizer': 'AdamW',
  'optimizer_params': {'amsgrad': True},
  'patience': 3,
  'scheduler': 'ReduceLROnPlateau',
  'weight_decay': 0},
 'outputs': {'energy': {'level': 'system'},
  'forces': {'eval_on_free_atoms': True,
   'level': 'atom',
   'train_on_free_atoms': True}},
 'trainer': 'ocp',
 'mode': 'train',
 'identifier': '',
 'timestamp_id': None,
 'seed': 0,
 'is_debug': False,
 'run_dir': '/home/runner/work/fairchem/fairchem/docs/tutorials/advanced',
 'print_every': 10,
 'cpu': True,
 'submit': False,
 'summit': False,
 'world_size': 1,
 'distributed_backend': 'gloo',
 'gp_gpus': None}

Run the training task#

It is still annoying that if your output is too large the notebook will not be able to be saved. On the other hand, it is annoying to simply capture the output.

We are able to redirect most logging to a file above, but not all of it. The link below will open the file in a browser, and the subsequent cell captures all residual output. We do not need any of that, so it is ultimately discarded.

Alternatively, you can open a Terminal and use tail -f out.txt to see the progress.

from IPython.display import display, FileLink
display(FileLink('out.txt'))
with new_trainer_context(config=config) as ctx:
    config = ctx.config
    task = ctx.task
    trainer = ctx.trainer
    task.setup(trainer)
    task.run()
/home/runner/work/fairchem/fairchem/src/fairchem/core/trainers/base_trainer.py:590: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  checkpoint = torch.load(checkpoint_path, map_location=map_location)
/home/runner/work/fairchem/fairchem/src/fairchem/core/trainers/ocp_trainer.py:155: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
  with torch.cuda.amp.autocast(enabled=self.scaler is not None):
/home/runner/work/fairchem/fairchem/src/fairchem/core/models/gemnet_oc/gemnet_oc.py:1270: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
  with torch.cuda.amp.autocast(False):
device 0:   0%|          | 0/2 [00:00<?, ?it/s]
/home/runner/work/fairchem/fairchem/src/fairchem/core/trainers/base_trainer.py:874: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
  with torch.cuda.amp.autocast(enabled=self.scaler is not None):
device 0:  50%|█████     | 1/2 [00:03<00:03,  3.49s/it]
device 0: 100%|██████████| 2/2 [00:06<00:00,  3.25s/it]
device 0: 100%|██████████| 2/2 [00:06<00:00,  3.33s/it]

device 0:   0%|          | 0/2 [00:00<?, ?it/s]
/home/runner/work/fairchem/fairchem/src/fairchem/core/trainers/ocp_trainer.py:451: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
  with torch.cuda.amp.autocast(enabled=self.scaler is not None):
device 0:  50%|█████     | 1/2 [00:03<00:03,  3.31s/it]
device 0: 100%|██████████| 2/2 [00:05<00:00,  2.63s/it]
device 0: 100%|██████████| 2/2 [00:05<00:00,  2.79s/it]

device 0:   0%|          | 0/2 [00:00<?, ?it/s]
device 0:  50%|█████     | 1/2 [00:03<00:03,  3.59s/it]
device 0: 100%|██████████| 2/2 [00:06<00:00,  3.43s/it]
device 0: 100%|██████████| 2/2 [00:07<00:00,  3.52s/it]

device 0:   0%|          | 0/2 [00:00<?, ?it/s]
device 0:  50%|█████     | 1/2 [00:03<00:03,  3.45s/it]
device 0: 100%|██████████| 2/2 [00:05<00:00,  2.73s/it]
device 0: 100%|██████████| 2/2 [00:05<00:00,  2.91s/it]

device 0:   0%|          | 0/2 [00:00<?, ?it/s]
device 0:  50%|█████     | 1/2 [00:03<00:03,  3.83s/it]
device 0: 100%|██████████| 2/2 [00:06<00:00,  3.35s/it]
device 0: 100%|██████████| 2/2 [00:06<00:00,  3.50s/it]

device 0:   0%|          | 0/2 [00:00<?, ?it/s]
device 0:  50%|█████     | 1/2 [00:03<00:03,  3.59s/it]
device 0: 100%|██████████| 2/2 [00:05<00:00,  2.74s/it]
device 0: 100%|██████████| 2/2 [00:05<00:00,  2.95s/it]

device 0:   0%|          | 0/2 [00:00<?, ?it/s]
device 0:  50%|█████     | 1/2 [00:03<00:03,  3.70s/it]
device 0: 100%|██████████| 2/2 [00:06<00:00,  3.41s/it]
device 0: 100%|██████████| 2/2 [00:07<00:00,  3.54s/it]

device 0:   0%|          | 0/2 [00:00<?, ?it/s]
device 0:  50%|█████     | 1/2 [00:03<00:03,  3.54s/it]
device 0: 100%|██████████| 2/2 [00:05<00:00,  2.73s/it]
device 0: 100%|██████████| 2/2 [00:05<00:00,  2.96s/it]

device 0:   0%|          | 0/2 [00:00<?, ?it/s]
device 0:  50%|█████     | 1/2 [00:03<00:03,  3.89s/it]
device 0: 100%|██████████| 2/2 [00:07<00:00,  3.44s/it]
device 0: 100%|██████████| 2/2 [00:07<00:00,  3.60s/it]

! head out.txt
! tail out.txt
2024-09-18 21:59:57 (WARNING): Detected old config, converting to new format. Consider updating to avoid potential incompatibilities.
2024-09-18 21:59:57 (WARNING): Unrecognized arguments: ['symmetric_edge_symmetrization']
2024-09-18 21:59:59 (WARNING): No seed has been set in modelcheckpoint or OCPCalculator! Results may not be reproducible on re-run
2024-09-18 22:00:00 (WARNING): Detected old config, converting to new format. Consider updating to avoid potential incompatibilities.
2024-09-18 22:00:00 (WARNING): Unrecognized arguments: ['symmetric_edge_symmetrization']
2024-09-18 22:00:02 (WARNING): No seed has been set in modelcheckpoint or OCPCalculator! Results may not be reproducible on re-run
2024-09-18 22:00:03 (WARNING): Unrecognized arguments: ['symmetric_edge_symmetrization']
2024-09-18 22:00:05 (WARNING): log_summary for Tensorboard not supported
2024-09-18 22:00:05 (WARNING): Could not find dataset metadata.npz files in '[PosixPath('train.db')]'
2024-09-18 22:00:05 (WARNING): Disabled BalancedBatchSampler because num_replicas=1.
2024-09-18 22:00:05 (WARNING): Could not find dataset metadata.npz files in '[PosixPath('train.db')]'
2024-09-18 22:00:05 (WARNING): Disabled BalancedBatchSampler because num_replicas=1.
2024-09-18 22:00:05 (WARNING): Failed to get data sizes, falling back to uniform partitioning. BalancedBatchSampler requires a dataset that has a metadata attributed with number of atoms.
2024-09-18 22:00:05 (WARNING): Could not find dataset metadata.npz files in '[PosixPath('val.db')]'
2024-09-18 22:00:05 (WARNING): Disabled BalancedBatchSampler because num_replicas=1.
2024-09-18 22:00:05 (WARNING): Failed to get data sizes, falling back to uniform partitioning. BalancedBatchSampler requires a dataset that has a metadata attributed with number of atoms.
2024-09-18 22:00:05 (WARNING): Could not find dataset metadata.npz files in '[PosixPath('test.db')]'
2024-09-18 22:00:05 (WARNING): Disabled BalancedBatchSampler because num_replicas=1.
2024-09-18 22:00:05 (WARNING): Failed to get data sizes, falling back to uniform partitioning. BalancedBatchSampler requires a dataset that has a metadata attributed with number of atoms.
2024-09-18 22:00:05 (WARNING): Using `weight_decay` from `optim` instead of `optim.optimizer_params`.Please update your config to use `optim.optimizer_params.weight_decay`.`optim.weight_decay` will soon be deprecated.

Now, you are all set to carry on with what ever subsequent analysis you want to do.