Fine tuning a model#
In this section we look at how to improve a model. We start by using a pre-trained model to predict the energies of several oxide polymorphs. The data here comes from this paper:
Identifying Potential BO2 Oxide Polymorphs for Epitaxial Growth Candidates Mehta, P., Salvador, P. A., & Kitchin, J. R. (2015). Identifying potential BO2 oxide polymorphs for epitaxial growth candidates. ACS Appl. Mater. Interfaces, 6(5), 3630–3639. http://dx.doi.org/10.1021/am4059149.
This data set shows equations of state for several oxide/polymorph combinations. I use this dataset to compare with predictions from OCP
First we get the checkpoint that we want. According to the MODELS the GemNet-OC OC20+OC22 combination has an energy MAE of 0.483 which seems like a good place to start. This model was trained on oxides.
We get this checkpoint here.
from fairchem.core.models.model_registry import model_name_to_local_file
checkpoint_path = model_name_to_local_file('GemNet-OC-S2EFS-OC20+OC22', local_cache='/tmp/fairchem_checkpoints/')
Matplotlib is building the font cache; this may take a moment.
The data we need is provided in supporting-information.json
. That file is embedded in the supporting information for the article, and is provided here in the tutorial. We load this data and explore it a little. The json file provides a dictionary with the structure:
[oxide][polymorph][xc][EOS][configurations]
The first key is a composition, the second is a string for the polymorph structure, the third indicates which XC functional was used (we focus on PBE here), the fourth key is for the Equation of State calculations, and the last key is a list of results for the EOS.
import json
import numpy as np
import matplotlib.pyplot as plt
from ase import Atoms
from fairchem.core.scripts import download_large_files
download_large_files.download_file_group("docs")
with open('supporting-information.json', 'rb') as f:
d = json.loads(f.read())
oxides = list(d.keys())
polymorphs = list(d['TiO2'].keys())
oxides, polymorphs
Downloading docs/tutorials/NRR/NRR_example_bulks.pkl...
Downloading docs/core/fine-tuning/supporting-information.json...
Downloading docs/core/data.db...
(['SnO2', 'IrO2', 'RuO2', 'TiO2', 'VO2'],
['rutile', 'pyrite', 'columbite', 'brookite', 'fluorite', 'anatase'])
A specific calculation has a lot of details in it. You can use this to recreate the calculations. Shortly we will use these to compare the DFT results to OCP.
d['TiO2']['rutile']['PBE']['EOS']['calculations'][0]
{'incar': {'doc': 'INCAR parameters',
'prec': 'Normal',
'isif': 4,
'nbands': 20,
'ibrion': 2,
'gga': 'PE',
'encut': 520.0,
'ismear': 0,
'sigma': 0.001,
'nsw': 50},
'doc': 'JSON representation of a VASP calculation.\n\nenergy is in eV\nforces are in eV/\\AA\nstress is in GPa (sxx, syy, szz, syz, sxz, sxy)\nmagnetic moments are in Bohr-magneton\nThe density of states is reported with E_f at 0 eV.\nVolume is reported in \\AA^3\nCoordinates and cell parameters are reported in \\AA\n\nIf atom-projected dos are included they are in the form:\n{ados:{energy:data, {atom index: {orbital : dos}}}\n',
'potcar': [['O',
'potpaw/O/POTCAR',
'0cf2ce56049ca395c567026b700ed66c94a85161'],
['Ti', 'potpaw/Ti/POTCAR', '51f7f05982d6b4052becc160375a8b8b670177a7']],
'input': {'kpts': [6, 6, 6],
'reciprocal': False,
'xc': 'LDA',
'kpts_nintersections': None,
'setups': None,
'txt': '-',
'gamma': False},
'atoms': {'cell': [[4.3789762519649225, 0.0, 0.0],
[0.0, 4.3789762519649225, 0.0],
[0.0, 0.0, 2.864091775985314]],
'symbols': ['Ti', 'Ti', 'O', 'O', 'O', 'O'],
'tags': [0, 0, 0, 0, 0, 0],
'pbc': [True, True, True],
'positions': [[0.0, 0.0, 0.0],
[2.1894881259824612, 2.1894881259824612, 1.432045887992657],
[1.3181554154438013, 1.3181554154438013, 0.0],
[3.0608208365211214, 3.0608208365211214, 0.0],
[3.5076435414262623, 0.87133271053866, 1.432045887992657],
[0.87133271053866, 3.5076435414262623, 1.432045887992657]]},
'data': {'stress': [496.18519999, 496.18519999, 502.82679392, 0.0, 0.0, 0.0],
'doc': 'Data from the output of the calculation',
'volume': 54.92019999999996,
'total_energy': -56.230672,
'forces': [[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[-0.001264, -0.001264, 0.0],
[0.001264, 0.001264, 0.0],
[-0.001264, 0.001264, 0.0],
[0.001264, -0.001264, 0.0]],
'fermi_level': 3.153}}
For each result we can retrieve the atomic geometry, energy and forces from the json file. We use these to recreate an Atoms object. Here is an example.
c = d['TiO2']['rutile']['PBE']['EOS']['calculations'][0]
atoms = Atoms(symbols=c['atoms']['symbols'],
positions=c['atoms']['positions'],
cell=c['atoms']['cell'],
pbc=c['atoms']['pbc'])
atoms.set_tags(np.ones(len(atoms)))
atoms, c['data']['total_energy'], c['data']['forces']
(Atoms(symbols='Ti2O4', pbc=True, cell=[4.3789762519649225, 4.3789762519649225, 2.864091775985314], tags=...),
-56.230672,
[[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[-0.001264, -0.001264, 0.0],
[0.001264, 0.001264, 0.0],
[-0.001264, 0.001264, 0.0],
[0.001264, -0.001264, 0.0]])
Next, we will create an OCP calculator that we can use to get predictions from.
from fairchem.core.common.relaxation.ase_utils import OCPCalculator
calc = OCPCalculator(checkpoint_path=checkpoint_path, trainer='forces', cpu=False)
/home/runner/work/fairchem/fairchem/src/fairchem/core/models/escn/so3.py:23: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
_Jd = torch.load(os.path.join(os.path.dirname(__file__), "Jd.pt"))
/home/runner/work/fairchem/fairchem/src/fairchem/core/models/scn/spherical_harmonics.py:23: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
_Jd = torch.load(os.path.join(os.path.dirname(__file__), "Jd.pt"))
/home/runner/work/fairchem/fairchem/src/fairchem/core/models/equiformer_v2/wigner.py:10: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
_Jd = torch.load(os.path.join(os.path.dirname(__file__), "Jd.pt"))
/home/runner/work/fairchem/fairchem/src/fairchem/core/models/equiformer_v2/layer_norm.py:75: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
@torch.cuda.amp.autocast(enabled=False)
/home/runner/work/fairchem/fairchem/src/fairchem/core/models/equiformer_v2/layer_norm.py:175: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
@torch.cuda.amp.autocast(enabled=False)
/home/runner/work/fairchem/fairchem/src/fairchem/core/models/equiformer_v2/layer_norm.py:263: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
@torch.cuda.amp.autocast(enabled=False)
/home/runner/work/fairchem/fairchem/src/fairchem/core/models/equiformer_v2/layer_norm.py:357: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
@torch.cuda.amp.autocast(enabled=False)
/home/runner/work/fairchem/fairchem/src/fairchem/core/common/relaxation/ase_utils.py:150: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
checkpoint = torch.load(checkpoint_path, map_location=torch.device("cpu"))
WARNING:root:Detected old config, converting to new format. Consider updating to avoid potential incompatibilities.
WARNING:root:Unrecognized arguments: ['symmetric_edge_symmetrization']
WARNING:root:No seed has been set in modelcheckpoint or OCPCalculator! Results may not be reproducible on re-run
Now, we loop through each structure and accumulate the OCP predictions. Then, we plot the parity results.
import time
t0 = time.time()
eos_data = {}
for oxide in oxides:
eos_data[oxide] = {}
for polymorph in polymorphs:
dft = []
ocp = []
vols = []
calculations = d[oxide][polymorph]['PBE']['EOS']['calculations']
for c in calculations:
atoms = Atoms(symbols=c['atoms']['symbols'],
positions=c['atoms']['positions'],
cell=c['atoms']['cell'],
pbc=c['atoms']['pbc'])
atoms.set_tags(np.ones(len(atoms)))
atoms.calc = calc
ocp += [atoms.get_potential_energy() / len(atoms)]
dft += [c['data']['total_energy'] / len(atoms)]
vols += [atoms.get_volume()]
plt.plot(dft, ocp, marker='s' if oxide == 'VO2' else '.',
alpha=0.5, label=f'{oxide}-{polymorph}')
eos_data[oxide][polymorph] = (vols, dft, ocp)
plt.xlabel('DFT (eV/atom)')
plt.ylabel('OCP (eV/atom)')
plt.legend(loc='center left', bbox_to_anchor=(1, 0.5), ncol=3);
print(f'Elapsed time {time.time() - t0:1.1f} seconds.')
/home/runner/work/fairchem/fairchem/src/fairchem/core/trainers/ocp_trainer.py:451: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
with torch.cuda.amp.autocast(enabled=self.scaler is not None):
/home/runner/work/fairchem/fairchem/src/fairchem/core/models/gemnet_oc/gemnet_oc.py:1270: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
with torch.cuda.amp.autocast(False):
Elapsed time 84.4 seconds.
mae = np.mean(np.abs(np.array(dft) - np.array(ocp)))
print(f'MAE = {mae:1.3f} eV/atom')
MAE = 1.669 eV/atom
The MAE is somewhat high compared to the reported value of 0.458 eV. That is not too surprising; although OC22 was trained on oxides, it was not trained on all of these structures. It is also evident on inspection that the main issues are all the VO2 structures, and these skew the MAE. Next we look at the EOS for each material. Here you can see offsets, and qualitative issues in the shapes. The offset is largely a PBE/RPBE difference.
Some notable issues are with fluorite structures. We show one here for VO2.
oxide, polymorph = 'VO2', 'fluorite'
V, D, O = eos_data[oxide][polymorph]
plt.plot(V, D, label='dft')
plt.plot(V, O, label='ocp')
plt.title(f'pretrained {oxide} - {polymorph}')
plt.legend();
We can improve this by fine-tuning the model.
Fine tuning the checkpoint#
There are severeal steps in fine-tuning:
Create train, test, val datasets
Create a configuration yml file
Run the training
Check the results
First we create the training databases. OCP supports ase databases, so we use those here. The first step is to make a database containing the data. We need to write an atoms object with an attached SinglePointCalculator for each structure.
from ase.db import connect
from ase.calculators.singlepoint import SinglePointCalculator
! rm -fr oxides.db # start clean when you run this cell
db = connect('oxides.db')
for oxide in oxides:
for polymorph in polymorphs:
for c in d[oxide][polymorph]['PBE']['EOS']['calculations']:
atoms = Atoms(symbols=c['atoms']['symbols'],
positions=c['atoms']['positions'],
cell=c['atoms']['cell'],
pbc=c['atoms']['pbc'])
atoms.set_tags(np.ones(len(atoms)))
calc = SinglePointCalculator(atoms,
energy=c['data']['total_energy'],
forces=c['data']['forces'])
atoms.set_calculator(calc)
db.write(atoms)
/tmp/ipykernel_2346/1448814737.py:12: DeprecationWarning: Please use atoms.calc = calc
atoms.set_calculator(calc)
Let’s see what we made.
! ase db oxides.db
id|age|user |formula|calculator| energy|natoms| fmax|pbc| volume|charge| mass
1| 1s|runner|Sn2O4 |unknown |-41.359| 6|0.045|TTT| 64.258| 0.000|301.416
2| 1s|runner|Sn2O4 |unknown |-41.853| 6|0.025|TTT| 66.526| 0.000|301.416
3| 1s|runner|Sn2O4 |unknown |-42.199| 6|0.010|TTT| 68.794| 0.000|301.416
4| 1s|runner|Sn2O4 |unknown |-42.419| 6|0.006|TTT| 71.062| 0.000|301.416
5| 1s|runner|Sn2O4 |unknown |-42.534| 6|0.011|TTT| 73.330| 0.000|301.416
6| 1s|runner|Sn2O4 |unknown |-42.562| 6|0.029|TTT| 75.598| 0.000|301.416
7| 1s|runner|Sn2O4 |unknown |-42.518| 6|0.033|TTT| 77.866| 0.000|301.416
8| 1s|runner|Sn2O4 |unknown |-42.415| 6|0.010|TTT| 80.134| 0.000|301.416
9| 1s|runner|Sn2O4 |unknown |-42.266| 6|0.006|TTT| 82.402| 0.000|301.416
10| 1s|runner|Sn2O4 |unknown |-42.083| 6|0.017|TTT| 84.670| 0.000|301.416
11| 1s|runner|Sn4O8 |unknown |-81.424| 12|0.012|TTT|117.473| 0.000|602.832
12| 1s|runner|Sn4O8 |unknown |-82.437| 12|0.005|TTT|121.620| 0.000|602.832
13| 1s|runner|Sn4O8 |unknown |-83.147| 12|0.015|TTT|125.766| 0.000|602.832
14| 1s|runner|Sn4O8 |unknown |-83.599| 12|0.047|TTT|129.912| 0.000|602.832
15| 1s|runner|Sn4O8 |unknown |-83.831| 12|0.081|TTT|134.058| 0.000|602.832
16| 1s|runner|Sn4O8 |unknown |-83.898| 12|0.001|TTT|138.204| 0.000|602.832
17| 1s|runner|Sn4O8 |unknown |-83.805| 12|0.001|TTT|142.350| 0.000|602.832
18| 1s|runner|Sn4O8 |unknown |-83.586| 12|0.002|TTT|146.496| 0.000|602.832
19| 1s|runner|Sn4O8 |unknown |-83.262| 12|0.002|TTT|150.642| 0.000|602.832
20| 1s|runner|Sn4O8 |unknown |-82.851| 12|0.013|TTT|154.788| 0.000|602.832
Rows: 295 (showing first 20)
Make the train, test, val splits#
We need to split the ase-db into three separate databases, one for training (80%), one for testing (10%) and one for validation. We generate a list of ids and then shuffle them. Then we write the first 80% into train.db
, the next 10% into test.db
, and the remaining into val.db
.
The train set is used for training. The test and val sets are used to check for overfitting.
You choose the splits you want, 80:10:10 is common. We take a simple approach to split the database here. We make an array of integers that correspond to the ids, randomly shuffle them, and then get each row in the randomized order and write them to a new db.
We provide some helper functions in fairchem.core.common.tutorial_utils
to streamline this process.
from fairchem.core.common.tutorial_utils import train_test_val_split
! rm -fr train.db test.db val.db
train, test, val = train_test_val_split('oxides.db')
train, test, val
(PosixPath('/home/runner/work/fairchem/fairchem/docs/core/fine-tuning/train.db'),
PosixPath('/home/runner/work/fairchem/fairchem/docs/core/fine-tuning/test.db'),
PosixPath('/home/runner/work/fairchem/fairchem/docs/core/fine-tuning/val.db'))
Setting up the configuration yaml file#
We have to create a yaml configuration file for the model we are using. The pre-trained checkpoints contain their config data, so we use this to get the base configuration, and then remove pieces we don’t need, and update pieces we do need.
from fairchem.core.common.tutorial_utils import generate_yml_config
yml = generate_yml_config(checkpoint_path, 'config.yml',
delete=['slurm', 'cmd', 'logger', 'task', 'model_attributes',
'optim.loss_force', # the checkpoint setting causes an error
'optim.load_balancing',
'dataset', 'test_dataset', 'val_dataset'],
update={'gpus': 1,
'optim.eval_every': 10,
'optim.max_epochs': 1,
'optim.batch_size': 4,
'logger':'tensorboard', # don't use wandb!
# Train data
'dataset.train.src': 'train.db',
'dataset.train.format': 'ase_db',
'dataset.train.a2g_args.r_energy': True,
'dataset.train.a2g_args.r_forces': True,
# Test data - prediction only so no regression
'dataset.test.src': 'test.db',
'dataset.test.format': 'ase_db',
'dataset.test.a2g_args.r_energy': False,
'dataset.test.a2g_args.r_forces': False,
# val data
'dataset.val.src': 'val.db',
'dataset.val.format': 'ase_db',
'dataset.val.a2g_args.r_energy': True,
'dataset.val.a2g_args.r_forces': True,
})
yml
/home/runner/work/fairchem/fairchem/src/fairchem/core/common/relaxation/ase_utils.py:150: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
checkpoint = torch.load(checkpoint_path, map_location=torch.device("cpu"))
WARNING:root:Detected old config, converting to new format. Consider updating to avoid potential incompatibilities.
WARNING:root:Unrecognized arguments: ['symmetric_edge_symmetrization']
WARNING:root:No seed has been set in modelcheckpoint or OCPCalculator! Results may not be reproducible on re-run
PosixPath('/home/runner/work/fairchem/fairchem/docs/core/fine-tuning/config.yml')
! cat config.yml
amp: true
checkpoint: /tmp/fairchem_checkpoints/gnoc_oc22_oc20_all_s2ef.pt
dataset:
test:
a2g_args:
r_energy: false
r_forces: false
format: ase_db
src: test.db
train:
a2g_args:
r_energy: true
r_forces: true
format: ase_db
src: train.db
val:
a2g_args:
r_energy: true
r_forces: true
format: ase_db
src: val.db
evaluation_metrics:
metrics:
energy:
- mae
forces:
- forcesx_mae
- forcesy_mae
- forcesz_mae
- mae
- cosine_similarity
- magnitude_error
misc:
- energy_forces_within_threshold
primary_metric: forces_mae
gpus: 1
logger: tensorboard
loss_functions:
- energy:
coefficient: 1
fn: mae
- forces:
coefficient: 1
fn: l2mae
model:
activation: silu
atom_edge_interaction: true
atom_interaction: true
cbf:
name: spherical_harmonics
cutoff: 12.0
cutoff_aeaint: 12.0
cutoff_aint: 12.0
cutoff_qint: 12.0
direct_forces: true
edge_atom_interaction: true
emb_size_aint_in: 64
emb_size_aint_out: 64
emb_size_atom: 256
emb_size_cbf: 16
emb_size_edge: 512
emb_size_quad_in: 32
emb_size_quad_out: 32
emb_size_rbf: 16
emb_size_sbf: 32
emb_size_trip_in: 64
emb_size_trip_out: 64
envelope:
exponent: 5
name: polynomial
extensive: true
forces_coupled: false
max_neighbors: 30
max_neighbors_aeaint: 20
max_neighbors_aint: 1000
max_neighbors_qint: 8
name: gemnet_oc
num_after_skip: 2
num_atom: 3
num_atom_emb_layers: 2
num_before_skip: 2
num_blocks: 4
num_concat: 1
num_global_out_layers: 2
num_output_afteratom: 3
num_radial: 128
num_spherical: 7
otf_graph: true
output_init: HeOrthogonal
qint_tags:
- 1
- 2
quad_interaction: true
rbf:
name: gaussian
regress_forces: true
sbf:
name: legendre_outer
symmetric_edge_symmetrization: false
noddp: false
optim:
batch_size: 4
clip_grad_norm: 10
ema_decay: 0.999
energy_coefficient: 1
eval_batch_size: 16
eval_every: 10
factor: 0.8
force_coefficient: 1
loss_energy: mae
lr_initial: 0.0005
max_epochs: 1
mode: min
num_workers: 2
optimizer: AdamW
optimizer_params:
amsgrad: true
patience: 3
scheduler: ReduceLROnPlateau
weight_decay: 0
outputs:
energy:
level: system
forces:
eval_on_free_atoms: true
level: atom
train_on_free_atoms: true
trainer: ocp
Running the training job#
fairchem
provides a main.py
file that is used for training. Here we construct the Python command you need to run, and run it. main.py
is not executable, so we have to run it with python, and you need the absolute path to it, which we get from the fairchem_main()
that is defined in the fairchem.core.common.tutorial_utils.
you must set a mode
and provide a config-yml
. We provide a checkpoint for a starting point, if you don’t do this, it will start from scratch.
By default the files that are created go in a directories called checkpoints
, logs
and results
. You can change that with a --run-dir preferred-place
option.
You can also add an identifier tag to the end of the timestamps with --identifier tag
. This can make it easier to find them later.
The cell below uses some IPython magic to put Python variables in the shell command.
This command takes some time to run even on a GPU, e.g. about 30 minutes.
It is advisable to redirect the outputs to files. The reason is that if the outputs are very large, the notebook may not be able to be saved. This also makes your notebooks more reproducible. The checkpoints are stored in time-stamped directories that change everytime you run them. Below we show how to reproducibly retrieve this directory name in a way that allows you to run the notebook again later, while automatically updating the directory name.
You can follow how the training is going by opening a terminal and running
tail -f train.txt
You can also visit it in a browser at train.txt. You have to periodically refresh the view to see updates though.
This can take up to 30 minutes for 80 epochs, so we only do a few here to see what happens. If you have a gpu or multiple gpus, you should use the flag –num-gpus=
import time
from fairchem.core.common.tutorial_utils import fairchem_main
t0 = time.time()
! python {fairchem_main()} --mode train --config-yml {yml} --checkpoint {checkpoint_path} --run-dir fine-tuning --identifier ft-oxides --cpu > train.txt 2>&1
print(f'Elapsed time = {time.time() - t0:1.1f} seconds')
Show code cell output
Elapsed time = 214.4 seconds
Now since we have a file, we can find the training results in it. See train.txt. At the top, the config is printed, so we can get the checkpoint directory. I use shell commands and Python to get the line, split and strip it here.
cpline = !grep "checkpoint_dir:" train.txt
cpdir = cpline[0].split(':')[-1].strip()
cpdir
'fine-tuning/checkpoints/2024-09-18-21-13-36-ft-oxides'
There will be two files in there: checkpoint.pt
and best_checkpoint.pt
.
The best_checkpoint.pt
is the one that performs best on the validation dataset. The checkpoint.pt
is the most recently saved one. Probably it has the lowest loss in training, but this could be an indication of overfitting. You have to use some judgement in determining which one to use, and if it is sufficiently accurate for your needs. It may need additional training to further improve it.
newckpt = cpdir + '/checkpoint.pt'
newcalc = OCPCalculator(checkpoint_path=newckpt, cpu=False)
/home/runner/work/fairchem/fairchem/src/fairchem/core/common/relaxation/ase_utils.py:150: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
checkpoint = torch.load(checkpoint_path, map_location=torch.device("cpu"))
WARNING:root:Unrecognized arguments: ['symmetric_edge_symmetrization']
WARNING:root:No seed has been set in modelcheckpoint or OCPCalculator! Results may not be reproducible on re-run
eos_data = {}
for oxide in oxides:
eos_data[oxide] = {}
for polymorph in polymorphs:
dft = []
ocp = []
vols = []
calculations = d[oxide][polymorph]['PBE']['EOS']['calculations']
for c in calculations:
atoms = Atoms(symbols=c['atoms']['symbols'],
positions=c['atoms']['positions'],
cell=c['atoms']['cell'],
pbc=c['atoms']['pbc'])
atoms.set_tags(np.ones(len(atoms)))
atoms.calc = newcalc
ocp += [atoms.get_potential_energy() / len(atoms)]
dft += [c['data']['total_energy'] / len(atoms)]
vols += [atoms.get_volume()]
plt.plot(dft, ocp, marker='s' if oxide == 'VO2' else '.',
alpha=0.5, label=f'{oxide}-{polymorph}')
eos_data[oxide][polymorph] = (vols, dft, ocp)
plt.xlabel('DFT (eV/atom)')
plt.ylabel('OCP (eV/atom)')
plt.legend(loc='center left', bbox_to_anchor=(1, 0.5), ncol=3);
/home/runner/work/fairchem/fairchem/src/fairchem/core/trainers/ocp_trainer.py:451: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
with torch.cuda.amp.autocast(enabled=self.scaler is not None):
/home/runner/work/fairchem/fairchem/src/fairchem/core/models/gemnet_oc/gemnet_oc.py:1270: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
with torch.cuda.amp.autocast(False):
mae = np.mean(np.abs(np.array(dft) - np.array(ocp)))
print(f'New MAE = {mae:1.3f} eV/atom')
New MAE = 0.029 eV/atom
This is a substantial improvement! We can see what it means in terms of the Equations of State. There is still not perfect agreement, but the curves are closer together. Additional fine tuning, or a better model could probably still improve this.
oxide, polymorph = 'VO2', 'fluorite'
V, D, O = eos_data[oxide][polymorph]
plt.plot(V, D, label='dft')
plt.plot(V, O, label='ocp')
plt.title(f'fine-tuned {oxide} - {polymorph}')
plt.legend();
It is possible to continue refining the fit. The simple things to do are to use more epochs of training. Eventually the MAE will stabilize, and then it may be necessary to adjust other optimization parameters like the learning rate (usually you decrease it).
Depending on what is important to you, you may consider changing the relative importances of energy and forces; you can often trade off accuracy of one for the other. It may be necessary to add additional data to cover the composition and configuration space more thoroughly.
There are also other models you could consider. Newer models tend to be more accurate, but they may also be more expensive to run. These are all compromises you can consider.