Fine tuning a model#

In this section we look at how to improve a model. We start by using a pre-trained model to predict the energies of several oxide polymorphs. The data here comes from this paper:

Identifying Potential BO2 Oxide Polymorphs for Epitaxial Growth Candidates Mehta, P., Salvador, P. A., & Kitchin, J. R. (2015). Identifying potential BO2 oxide polymorphs for epitaxial growth candidates. ACS Appl. Mater. Interfaces, 6(5), 3630–3639. http://dx.doi.org/10.1021/am4059149.

This data set shows equations of state for several oxide/polymorph combinations. I use this dataset to compare with predictions from OCP

First we get the checkpoint that we want. According to the MODELS the GemNet-OC OC20+OC22 combination has an energy MAE of 0.483 which seems like a good place to start. This model was trained on oxides.

We get this checkpoint here.

from fairchem.core.models.model_registry import model_name_to_local_file

checkpoint_path = model_name_to_local_file('GemNet-OC-S2EFS-OC20+OC22', local_cache='/tmp/fairchem_checkpoints/')
Matplotlib is building the font cache; this may take a moment.

The data we need is provided in supporting-information.json. That file is embedded in the supporting information for the article, and is provided here in the tutorial. We load this data and explore it a little. The json file provides a dictionary with the structure:

[oxide][polymorph][xc][EOS][configurations]

The first key is a composition, the second is a string for the polymorph structure, the third indicates which XC functional was used (we focus on PBE here), the fourth key is for the Equation of State calculations, and the last key is a list of results for the EOS.

import json
import numpy as np
import matplotlib.pyplot as plt
from ase import Atoms

from fairchem.core.scripts import download_large_files

download_large_files.download_file_group("docs")

with open('supporting-information.json', 'rb') as f:
     d = json.loads(f.read())
        

oxides = list(d.keys())
polymorphs = list(d['TiO2'].keys())
oxides, polymorphs
Downloading docs/tutorials/NRR/NRR_example_bulks.pkl...
Downloading docs/core/fine-tuning/supporting-information.json...
Downloading docs/core/data.db...
(['SnO2', 'IrO2', 'RuO2', 'TiO2', 'VO2'],
 ['rutile', 'pyrite', 'columbite', 'brookite', 'fluorite', 'anatase'])

A specific calculation has a lot of details in it. You can use this to recreate the calculations. Shortly we will use these to compare the DFT results to OCP.

d['TiO2']['rutile']['PBE']['EOS']['calculations'][0]
{'incar': {'doc': 'INCAR parameters',
  'prec': 'Normal',
  'isif': 4,
  'nbands': 20,
  'ibrion': 2,
  'gga': 'PE',
  'encut': 520.0,
  'ismear': 0,
  'sigma': 0.001,
  'nsw': 50},
 'doc': 'JSON representation of a VASP calculation.\n\nenergy is in eV\nforces are in eV/\\AA\nstress is in GPa (sxx, syy, szz,  syz, sxz, sxy)\nmagnetic moments are in Bohr-magneton\nThe density of states is reported with E_f at 0 eV.\nVolume is reported in \\AA^3\nCoordinates and cell parameters are reported in \\AA\n\nIf atom-projected dos are included they are in the form:\n{ados:{energy:data, {atom index: {orbital : dos}}}\n',
 'potcar': [['O',
   'potpaw/O/POTCAR',
   '0cf2ce56049ca395c567026b700ed66c94a85161'],
  ['Ti', 'potpaw/Ti/POTCAR', '51f7f05982d6b4052becc160375a8b8b670177a7']],
 'input': {'kpts': [6, 6, 6],
  'reciprocal': False,
  'xc': 'LDA',
  'kpts_nintersections': None,
  'setups': None,
  'txt': '-',
  'gamma': False},
 'atoms': {'cell': [[4.3789762519649225, 0.0, 0.0],
   [0.0, 4.3789762519649225, 0.0],
   [0.0, 0.0, 2.864091775985314]],
  'symbols': ['Ti', 'Ti', 'O', 'O', 'O', 'O'],
  'tags': [0, 0, 0, 0, 0, 0],
  'pbc': [True, True, True],
  'positions': [[0.0, 0.0, 0.0],
   [2.1894881259824612, 2.1894881259824612, 1.432045887992657],
   [1.3181554154438013, 1.3181554154438013, 0.0],
   [3.0608208365211214, 3.0608208365211214, 0.0],
   [3.5076435414262623, 0.87133271053866, 1.432045887992657],
   [0.87133271053866, 3.5076435414262623, 1.432045887992657]]},
 'data': {'stress': [496.18519999, 496.18519999, 502.82679392, 0.0, 0.0, 0.0],
  'doc': 'Data from the output of the calculation',
  'volume': 54.92019999999996,
  'total_energy': -56.230672,
  'forces': [[0.0, 0.0, 0.0],
   [0.0, 0.0, 0.0],
   [-0.001264, -0.001264, 0.0],
   [0.001264, 0.001264, 0.0],
   [-0.001264, 0.001264, 0.0],
   [0.001264, -0.001264, 0.0]],
  'fermi_level': 3.153}}

For each result we can retrieve the atomic geometry, energy and forces from the json file. We use these to recreate an Atoms object. Here is an example.

c = d['TiO2']['rutile']['PBE']['EOS']['calculations'][0]
atoms = Atoms(symbols=c['atoms']['symbols'],
              positions=c['atoms']['positions'],
              cell=c['atoms']['cell'],
              pbc=c['atoms']['pbc'])
atoms.set_tags(np.ones(len(atoms)))
atoms, c['data']['total_energy'], c['data']['forces']
(Atoms(symbols='Ti2O4', pbc=True, cell=[4.3789762519649225, 4.3789762519649225, 2.864091775985314], tags=...),
 -56.230672,
 [[0.0, 0.0, 0.0],
  [0.0, 0.0, 0.0],
  [-0.001264, -0.001264, 0.0],
  [0.001264, 0.001264, 0.0],
  [-0.001264, 0.001264, 0.0],
  [0.001264, -0.001264, 0.0]])

Next, we will create an OCP calculator that we can use to get predictions from.

from fairchem.core.common.relaxation.ase_utils import OCPCalculator
calc = OCPCalculator(checkpoint_path=checkpoint_path, trainer='forces', cpu=False)
/home/runner/work/fairchem/fairchem/src/fairchem/core/models/escn/so3.py:23: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  _Jd = torch.load(os.path.join(os.path.dirname(__file__), "Jd.pt"))
/home/runner/work/fairchem/fairchem/src/fairchem/core/models/scn/spherical_harmonics.py:23: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  _Jd = torch.load(os.path.join(os.path.dirname(__file__), "Jd.pt"))
/home/runner/work/fairchem/fairchem/src/fairchem/core/models/equiformer_v2/wigner.py:10: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  _Jd = torch.load(os.path.join(os.path.dirname(__file__), "Jd.pt"))
/home/runner/work/fairchem/fairchem/src/fairchem/core/models/equiformer_v2/layer_norm.py:75: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
  @torch.cuda.amp.autocast(enabled=False)
/home/runner/work/fairchem/fairchem/src/fairchem/core/models/equiformer_v2/layer_norm.py:175: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
  @torch.cuda.amp.autocast(enabled=False)
/home/runner/work/fairchem/fairchem/src/fairchem/core/models/equiformer_v2/layer_norm.py:263: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
  @torch.cuda.amp.autocast(enabled=False)
/home/runner/work/fairchem/fairchem/src/fairchem/core/models/equiformer_v2/layer_norm.py:357: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
  @torch.cuda.amp.autocast(enabled=False)
/home/runner/work/fairchem/fairchem/src/fairchem/core/common/relaxation/ase_utils.py:150: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  checkpoint = torch.load(checkpoint_path, map_location=torch.device("cpu"))
WARNING:root:Detected old config, converting to new format. Consider updating to avoid potential incompatibilities.
WARNING:root:Unrecognized arguments: ['symmetric_edge_symmetrization']
WARNING:root:No seed has been set in modelcheckpoint or OCPCalculator! Results may not be reproducible on re-run

Now, we loop through each structure and accumulate the OCP predictions. Then, we plot the parity results.

import time
t0 = time.time()

eos_data = {}
for oxide in oxides:
    eos_data[oxide] = {}
    for polymorph in polymorphs:
        dft = []
        ocp = []
        vols = []
        calculations = d[oxide][polymorph]['PBE']['EOS']['calculations']
        for c in calculations:
                        
            atoms = Atoms(symbols=c['atoms']['symbols'],
                          positions=c['atoms']['positions'],
                          cell=c['atoms']['cell'],
                          pbc=c['atoms']['pbc'])
            atoms.set_tags(np.ones(len(atoms)))
            
            atoms.calc = calc
            ocp += [atoms.get_potential_energy() / len(atoms)]
            dft += [c['data']['total_energy'] / len(atoms)]
            vols += [atoms.get_volume()]
        plt.plot(dft, ocp, marker='s' if oxide == 'VO2' else '.',
                 alpha=0.5, label=f'{oxide}-{polymorph}')
        eos_data[oxide][polymorph] = (vols, dft, ocp)
        
plt.xlabel('DFT (eV/atom)')
plt.ylabel('OCP (eV/atom)')
plt.legend(loc='center left', bbox_to_anchor=(1, 0.5), ncol=3);
print(f'Elapsed time {time.time() - t0:1.1f} seconds.')
/home/runner/work/fairchem/fairchem/src/fairchem/core/trainers/ocp_trainer.py:451: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
  with torch.cuda.amp.autocast(enabled=self.scaler is not None):
/home/runner/work/fairchem/fairchem/src/fairchem/core/models/gemnet_oc/gemnet_oc.py:1270: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
  with torch.cuda.amp.autocast(False):
Elapsed time 84.4 seconds.
../../_images/341531b0218eea408d04046d981d55fe218d43494f1ecf51297aeefb6668761e.png
mae = np.mean(np.abs(np.array(dft) - np.array(ocp)))
print(f'MAE = {mae:1.3f} eV/atom')
MAE = 1.669 eV/atom

The MAE is somewhat high compared to the reported value of 0.458 eV. That is not too surprising; although OC22 was trained on oxides, it was not trained on all of these structures. It is also evident on inspection that the main issues are all the VO2 structures, and these skew the MAE. Next we look at the EOS for each material. Here you can see offsets, and qualitative issues in the shapes. The offset is largely a PBE/RPBE difference.

Some notable issues are with fluorite structures. We show one here for VO2.

oxide, polymorph = 'VO2', 'fluorite'
V, D, O = eos_data[oxide][polymorph]
        
plt.plot(V, D, label='dft')
plt.plot(V, O, label='ocp')
plt.title(f'pretrained {oxide} - {polymorph}')
plt.legend();
../../_images/06bfbae8ad0fd185eb61fd56188dac1591ce89e9e71d245404f81779a21d57fd.png

We can improve this by fine-tuning the model.

Fine tuning the checkpoint#

There are severeal steps in fine-tuning:

  1. Create train, test, val datasets

  2. Create a configuration yml file

  3. Run the training

  4. Check the results

First we create the training databases. OCP supports ase databases, so we use those here. The first step is to make a database containing the data. We need to write an atoms object with an attached SinglePointCalculator for each structure.

from ase.db import connect
from ase.calculators.singlepoint import SinglePointCalculator

! rm -fr oxides.db  # start clean when you run this cell
db = connect('oxides.db')
for oxide in oxides:
    for polymorph in polymorphs:
        for c in d[oxide][polymorph]['PBE']['EOS']['calculations']:
            atoms = Atoms(symbols=c['atoms']['symbols'],
                          positions=c['atoms']['positions'],
                          cell=c['atoms']['cell'],
                          pbc=c['atoms']['pbc'])
            atoms.set_tags(np.ones(len(atoms)))
            calc = SinglePointCalculator(atoms, 
                                         energy=c['data']['total_energy'],
                                         forces=c['data']['forces'])
            atoms.set_calculator(calc)
            db.write(atoms)
/tmp/ipykernel_2346/1448814737.py:12: DeprecationWarning: Please use atoms.calc = calc
  atoms.set_calculator(calc)

Let’s see what we made.

! ase db oxides.db
id|age|user  |formula|calculator| energy|natoms| fmax|pbc| volume|charge|   mass
 1| 1s|runner|Sn2O4  |unknown   |-41.359|     6|0.045|TTT| 64.258| 0.000|301.416
 2| 1s|runner|Sn2O4  |unknown   |-41.853|     6|0.025|TTT| 66.526| 0.000|301.416
 3| 1s|runner|Sn2O4  |unknown   |-42.199|     6|0.010|TTT| 68.794| 0.000|301.416
 4| 1s|runner|Sn2O4  |unknown   |-42.419|     6|0.006|TTT| 71.062| 0.000|301.416
 5| 1s|runner|Sn2O4  |unknown   |-42.534|     6|0.011|TTT| 73.330| 0.000|301.416
 6| 1s|runner|Sn2O4  |unknown   |-42.562|     6|0.029|TTT| 75.598| 0.000|301.416
 7| 1s|runner|Sn2O4  |unknown   |-42.518|     6|0.033|TTT| 77.866| 0.000|301.416
 8| 1s|runner|Sn2O4  |unknown   |-42.415|     6|0.010|TTT| 80.134| 0.000|301.416
 9| 1s|runner|Sn2O4  |unknown   |-42.266|     6|0.006|TTT| 82.402| 0.000|301.416
10| 1s|runner|Sn2O4  |unknown   |-42.083|     6|0.017|TTT| 84.670| 0.000|301.416
11| 1s|runner|Sn4O8  |unknown   |-81.424|    12|0.012|TTT|117.473| 0.000|602.832
12| 1s|runner|Sn4O8  |unknown   |-82.437|    12|0.005|TTT|121.620| 0.000|602.832
13| 1s|runner|Sn4O8  |unknown   |-83.147|    12|0.015|TTT|125.766| 0.000|602.832
14| 1s|runner|Sn4O8  |unknown   |-83.599|    12|0.047|TTT|129.912| 0.000|602.832
15| 1s|runner|Sn4O8  |unknown   |-83.831|    12|0.081|TTT|134.058| 0.000|602.832
16| 1s|runner|Sn4O8  |unknown   |-83.898|    12|0.001|TTT|138.204| 0.000|602.832
17| 1s|runner|Sn4O8  |unknown   |-83.805|    12|0.001|TTT|142.350| 0.000|602.832
18| 1s|runner|Sn4O8  |unknown   |-83.586|    12|0.002|TTT|146.496| 0.000|602.832
19| 1s|runner|Sn4O8  |unknown   |-83.262|    12|0.002|TTT|150.642| 0.000|602.832
20| 1s|runner|Sn4O8  |unknown   |-82.851|    12|0.013|TTT|154.788| 0.000|602.832
Rows: 295 (showing first 20)

Make the train, test, val splits#

We need to split the ase-db into three separate databases, one for training (80%), one for testing (10%) and one for validation. We generate a list of ids and then shuffle them. Then we write the first 80% into train.db, the next 10% into test.db, and the remaining into val.db.

The train set is used for training. The test and val sets are used to check for overfitting.

You choose the splits you want, 80:10:10 is common. We take a simple approach to split the database here. We make an array of integers that correspond to the ids, randomly shuffle them, and then get each row in the randomized order and write them to a new db.

We provide some helper functions in fairchem.core.common.tutorial_utils to streamline this process.

from fairchem.core.common.tutorial_utils import train_test_val_split
! rm -fr train.db test.db val.db

train, test, val = train_test_val_split('oxides.db')
train, test, val
(PosixPath('/home/runner/work/fairchem/fairchem/docs/core/fine-tuning/train.db'),
 PosixPath('/home/runner/work/fairchem/fairchem/docs/core/fine-tuning/test.db'),
 PosixPath('/home/runner/work/fairchem/fairchem/docs/core/fine-tuning/val.db'))

Setting up the configuration yaml file#

We have to create a yaml configuration file for the model we are using. The pre-trained checkpoints contain their config data, so we use this to get the base configuration, and then remove pieces we don’t need, and update pieces we do need.

from fairchem.core.common.tutorial_utils import generate_yml_config 

yml = generate_yml_config(checkpoint_path, 'config.yml',
                   delete=['slurm', 'cmd', 'logger', 'task', 'model_attributes',
                           'optim.loss_force', # the checkpoint setting causes an error
			   'optim.load_balancing',
                           'dataset', 'test_dataset', 'val_dataset'],
                   update={'gpus': 1,
                           'optim.eval_every': 10,
                           'optim.max_epochs': 1,
                           'optim.batch_size': 4,
                            'logger':'tensorboard', # don't use wandb!
                           # Train data
                           'dataset.train.src': 'train.db',
                           'dataset.train.format': 'ase_db',
                           'dataset.train.a2g_args.r_energy': True,
                           'dataset.train.a2g_args.r_forces': True,
                            # Test data - prediction only so no regression
                           'dataset.test.src': 'test.db',
                           'dataset.test.format': 'ase_db',
                           'dataset.test.a2g_args.r_energy': False,
                           'dataset.test.a2g_args.r_forces': False,
                           # val data
                           'dataset.val.src': 'val.db',
                           'dataset.val.format': 'ase_db',
                           'dataset.val.a2g_args.r_energy': True,
                           'dataset.val.a2g_args.r_forces': True,
                          })

yml
/home/runner/work/fairchem/fairchem/src/fairchem/core/common/relaxation/ase_utils.py:150: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  checkpoint = torch.load(checkpoint_path, map_location=torch.device("cpu"))
WARNING:root:Detected old config, converting to new format. Consider updating to avoid potential incompatibilities.
WARNING:root:Unrecognized arguments: ['symmetric_edge_symmetrization']
WARNING:root:No seed has been set in modelcheckpoint or OCPCalculator! Results may not be reproducible on re-run
PosixPath('/home/runner/work/fairchem/fairchem/docs/core/fine-tuning/config.yml')
! cat config.yml
amp: true
checkpoint: /tmp/fairchem_checkpoints/gnoc_oc22_oc20_all_s2ef.pt
dataset:
  test:
    a2g_args:
      r_energy: false
      r_forces: false
    format: ase_db
    src: test.db
  train:
    a2g_args:
      r_energy: true
      r_forces: true
    format: ase_db
    src: train.db
  val:
    a2g_args:
      r_energy: true
      r_forces: true
    format: ase_db
    src: val.db
evaluation_metrics:
  metrics:
    energy:
    - mae
    forces:
    - forcesx_mae
    - forcesy_mae
    - forcesz_mae
    - mae
    - cosine_similarity
    - magnitude_error
    misc:
    - energy_forces_within_threshold
  primary_metric: forces_mae
gpus: 1
logger: tensorboard
loss_functions:
- energy:
    coefficient: 1
    fn: mae
- forces:
    coefficient: 1
    fn: l2mae
model:
  activation: silu
  atom_edge_interaction: true
  atom_interaction: true
  cbf:
    name: spherical_harmonics
  cutoff: 12.0
  cutoff_aeaint: 12.0
  cutoff_aint: 12.0
  cutoff_qint: 12.0
  direct_forces: true
  edge_atom_interaction: true
  emb_size_aint_in: 64
  emb_size_aint_out: 64
  emb_size_atom: 256
  emb_size_cbf: 16
  emb_size_edge: 512
  emb_size_quad_in: 32
  emb_size_quad_out: 32
  emb_size_rbf: 16
  emb_size_sbf: 32
  emb_size_trip_in: 64
  emb_size_trip_out: 64
  envelope:
    exponent: 5
    name: polynomial
  extensive: true
  forces_coupled: false
  max_neighbors: 30
  max_neighbors_aeaint: 20
  max_neighbors_aint: 1000
  max_neighbors_qint: 8
  name: gemnet_oc
  num_after_skip: 2
  num_atom: 3
  num_atom_emb_layers: 2
  num_before_skip: 2
  num_blocks: 4
  num_concat: 1
  num_global_out_layers: 2
  num_output_afteratom: 3
  num_radial: 128
  num_spherical: 7
  otf_graph: true
  output_init: HeOrthogonal
  qint_tags:
  - 1
  - 2
  quad_interaction: true
  rbf:
    name: gaussian
  regress_forces: true
  sbf:
    name: legendre_outer
  symmetric_edge_symmetrization: false
noddp: false
optim:
  batch_size: 4
  clip_grad_norm: 10
  ema_decay: 0.999
  energy_coefficient: 1
  eval_batch_size: 16
  eval_every: 10
  factor: 0.8
  force_coefficient: 1
  loss_energy: mae
  lr_initial: 0.0005
  max_epochs: 1
  mode: min
  num_workers: 2
  optimizer: AdamW
  optimizer_params:
    amsgrad: true
  patience: 3
  scheduler: ReduceLROnPlateau
  weight_decay: 0
outputs:
  energy:
    level: system
  forces:
    eval_on_free_atoms: true
    level: atom
    train_on_free_atoms: true
trainer: ocp

Running the training job#

fairchem provides a main.py file that is used for training. Here we construct the Python command you need to run, and run it. main.py is not executable, so we have to run it with python, and you need the absolute path to it, which we get from the fairchem_main() that is defined in the fairchem.core.common.tutorial_utils.

you must set a mode and provide a config-yml. We provide a checkpoint for a starting point, if you don’t do this, it will start from scratch.

By default the files that are created go in a directories called checkpoints, logs and results. You can change that with a --run-dir preferred-place option.

You can also add an identifier tag to the end of the timestamps with --identifier tag. This can make it easier to find them later.

The cell below uses some IPython magic to put Python variables in the shell command.

This command takes some time to run even on a GPU, e.g. about 30 minutes.

It is advisable to redirect the outputs to files. The reason is that if the outputs are very large, the notebook may not be able to be saved. This also makes your notebooks more reproducible. The checkpoints are stored in time-stamped directories that change everytime you run them. Below we show how to reproducibly retrieve this directory name in a way that allows you to run the notebook again later, while automatically updating the directory name.

You can follow how the training is going by opening a terminal and running

tail -f train.txt

You can also visit it in a browser at train.txt. You have to periodically refresh the view to see updates though.

This can take up to 30 minutes for 80 epochs, so we only do a few here to see what happens. If you have a gpu or multiple gpus, you should use the flag –num-gpus= and remove the –cpu flag.

import time
from fairchem.core.common.tutorial_utils import fairchem_main

t0 = time.time()
! python {fairchem_main()} --mode train --config-yml {yml} --checkpoint {checkpoint_path} --run-dir fine-tuning --identifier ft-oxides --cpu > train.txt 2>&1
print(f'Elapsed time = {time.time() - t0:1.1f} seconds')
Hide code cell output
Elapsed time = 214.4 seconds

Now since we have a file, we can find the training results in it. See train.txt. At the top, the config is printed, so we can get the checkpoint directory. I use shell commands and Python to get the line, split and strip it here.

cpline = !grep "checkpoint_dir:" train.txt
cpdir = cpline[0].split(':')[-1].strip()
cpdir
'fine-tuning/checkpoints/2024-09-18-21-13-36-ft-oxides'

There will be two files in there: checkpoint.pt and best_checkpoint.pt.

The best_checkpoint.pt is the one that performs best on the validation dataset. The checkpoint.pt is the most recently saved one. Probably it has the lowest loss in training, but this could be an indication of overfitting. You have to use some judgement in determining which one to use, and if it is sufficiently accurate for your needs. It may need additional training to further improve it.

newckpt = cpdir + '/checkpoint.pt'
newcalc = OCPCalculator(checkpoint_path=newckpt, cpu=False)
/home/runner/work/fairchem/fairchem/src/fairchem/core/common/relaxation/ase_utils.py:150: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  checkpoint = torch.load(checkpoint_path, map_location=torch.device("cpu"))
WARNING:root:Unrecognized arguments: ['symmetric_edge_symmetrization']
WARNING:root:No seed has been set in modelcheckpoint or OCPCalculator! Results may not be reproducible on re-run
eos_data = {}
for oxide in oxides:
    eos_data[oxide] = {}
    for polymorph in polymorphs:
        dft = []
        ocp = []
        vols = []
        calculations = d[oxide][polymorph]['PBE']['EOS']['calculations']
        for c in calculations:
                        
            atoms = Atoms(symbols=c['atoms']['symbols'],
                          positions=c['atoms']['positions'],
                          cell=c['atoms']['cell'],
                          pbc=c['atoms']['pbc'])
            atoms.set_tags(np.ones(len(atoms)))
            
            atoms.calc = newcalc
            ocp += [atoms.get_potential_energy() / len(atoms)]
            dft += [c['data']['total_energy'] / len(atoms)]
            vols += [atoms.get_volume()]
        plt.plot(dft, ocp, marker='s' if oxide == 'VO2' else '.',
                 alpha=0.5, label=f'{oxide}-{polymorph}')
        eos_data[oxide][polymorph] = (vols, dft, ocp)
        
plt.xlabel('DFT (eV/atom)')
plt.ylabel('OCP (eV/atom)')
plt.legend(loc='center left', bbox_to_anchor=(1, 0.5), ncol=3);
/home/runner/work/fairchem/fairchem/src/fairchem/core/trainers/ocp_trainer.py:451: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
  with torch.cuda.amp.autocast(enabled=self.scaler is not None):
/home/runner/work/fairchem/fairchem/src/fairchem/core/models/gemnet_oc/gemnet_oc.py:1270: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
  with torch.cuda.amp.autocast(False):
../../_images/b49962cb4ca420417f689dc5bb50ab31f204a19081298e684a9d77f4ecee72e4.png
mae = np.mean(np.abs(np.array(dft) - np.array(ocp)))
print(f'New MAE = {mae:1.3f} eV/atom')
New MAE = 0.029 eV/atom

This is a substantial improvement! We can see what it means in terms of the Equations of State. There is still not perfect agreement, but the curves are closer together. Additional fine tuning, or a better model could probably still improve this.

oxide, polymorph = 'VO2', 'fluorite'
V, D, O = eos_data[oxide][polymorph]
        
plt.plot(V, D, label='dft')
plt.plot(V, O, label='ocp')
plt.title(f'fine-tuned {oxide} - {polymorph}')
plt.legend();
../../_images/ea483ecbf95b4d4a342b7bc55d0e66e88ac01fa89d2a9800f15b425db48a690b.png

It is possible to continue refining the fit. The simple things to do are to use more epochs of training. Eventually the MAE will stabilize, and then it may be necessary to adjust other optimization parameters like the learning rate (usually you decrease it).

Depending on what is important to you, you may consider changing the relative importances of energy and forces; you can often trade off accuracy of one for the other. It may be necessary to add additional data to cover the composition and configuration space more thoroughly.

There are also other models you could consider. Newer models tend to be more accurate, but they may also be more expensive to run. These are all compromises you can consider.