Fast batched inference#

The ASE calculator is not necessarily the most efficient way to run a lot of computations. It is better to do a “mass inference” using a command line utility. We illustrate how to do that here.

In this paper we computed about 10K different gold structures:

Boes, J. R., Groenenboom, M. C., Keith, J. A., & Kitchin, J. R. (2016). Neural network and Reaxff comparison for Au properties. Int. J. Quantum Chem., 116(13), 979–987. http://dx.doi.org/10.1002/qua.25115

You can retrieve the dataset below. In this notebook we learn how to do “mass inference” without an ASE calculator. You do this by creating a config.yml file, and running the main.py command line utility.

from fairchem.core.scripts.download_large_files import download_file_group
download_file_group('docs')
docs/tutorials/NRR/NRR_example_bulks.pkl already exists
docs/core/fine-tuning/supporting-information.json already exists
docs/core/data.db already exists

Inference on this file will be fast if we have a gpu, but if we don’t this could take a while. To keep things fast for the automated builds, we’ll just select the first 10 structures so it’s still approachable with just a CPU. Comment or skip this block to use the whole dataset!

! mv data.db full_data.db

import ase.db
import numpy as np

with ase.db.connect('full_data.db') as full_db:
  with ase.db.connect('data.db',append=False) as subset_db:

    # Select 50 random points for the subset, ASE DB ids start at 1
    for i in np.random.choice(list(range(1,len(full_db)+1)),size=50,replace=False):
      atoms = full_db.get_atoms(f'id={i}', add_additional_information=True)

      if 'tag' in atoms.info['key_value_pairs']:
        atoms.info['key_value_pairs']['tag'] = int(atoms.info['key_value_pairs']['tag'])

      for key in atoms.info["key_value_pairs"]:
        if atoms.info["key_value_pairs"][key] == "True":
            atoms.info["key_value_pairs"][key] = True
        elif atoms.info["key_value_pairs"][key] == "False":
            atoms.info["key_value_pairs"][key] = False

      subset_db.write(atoms, **atoms.info['key_value_pairs'])
! ase db data.db
id|age|user  |formula|calculator|  energy|natoms| fmax|pbc|  volume|charge|     mass
 1| 0s|runner|Au38   |unknown   |-100.459|    38|0.349|TTT|5846.048| 0.000| 7484.730
 2| 0s|runner|Au39   |unknown   | -97.139|    39|0.829|TTT|5893.874| 0.000| 7681.696
 3| 0s|runner|Au15   |unknown   | -47.869|    15|0.028|TTT| 280.010| 0.000| 2954.499
 4| 0s|runner|Au16   |unknown   | -37.401|    16|0.156|TTT|3923.762| 0.000| 3151.465
 5| 0s|runner|Au35   |unknown   | -90.962|    35|0.126|TTT|5504.869| 0.000| 6893.830
 6| 0s|runner|Au35   |unknown   | -90.824|    35|0.169|TTT|5504.869| 0.000| 6893.830
 7| 0s|runner|Au15   |unknown   | -47.869|    15|0.024|TTT| 279.885| 0.000| 2954.499
 8| 0s|runner|Au24   |unknown   | -59.349|    24|0.042|TTT|4672.311| 0.000| 4727.198
 9| 0s|runner|Au7    |unknown   | -13.834|     7|0.259|TTT|2498.476| 0.000| 1378.766
10| 0s|runner|Au24   |unknown   | -59.318|    24|0.106|TTT|4672.311| 0.000| 4727.198
11| 0s|runner|Au33   |unknown   | -84.824|    33|0.149|TTT|5480.808| 0.000| 6499.897
12| 0s|runner|Au20   |unknown   | -48.159|    20|0.139|TTT|4413.396| 0.000| 3939.331
13| 0s|runner|Au25   |unknown   | -61.171|    25|0.199|TTT|4771.343| 0.000| 4924.164
14| 0s|runner|Au107  |vasp      |-325.310|   107|0.209|TTT|9178.098| 0.000|21075.423
15| 0s|runner|Au2    |unknown   |  -1.073|     2|0.000|TTT|  23.169| 0.000|  393.933
16| 0s|runner|Au7    |unknown   | -13.705|     7|1.983|TTT|3511.808| 0.000| 1378.766
17| 0s|runner|Au25   |unknown   | -60.797|    25|0.667|TTT|4771.343| 0.000| 4924.164
18| 0s|runner|Au8    |unknown   | -19.571|     8|0.000|TTT| 209.823| 0.000| 1575.733
19| 0s|runner|Au4    |unknown   | -12.005|     4|0.883|TTT| 143.096| 0.000|  787.866
20| 0s|runner|Au36   |unknown   | -93.633|    36|0.132|TTT|5663.627| 0.000| 7090.796
Rows: 50 (showing first 20)
Keys: NEB, bulk, calc_time, cluster, concentration, config, converged, diffusion, ediff, ... (31 more)

You have to choose a checkpoint to start with. The newer checkpoints may require too much memory for this environment.

from fairchem.core.models.model_registry import available_pretrained_models
print(available_pretrained_models)
('CGCNN-S2EF-OC20-200k', 'CGCNN-S2EF-OC20-2M', 'CGCNN-S2EF-OC20-20M', 'CGCNN-S2EF-OC20-All', 'DimeNet-S2EF-OC20-200k', 'DimeNet-S2EF-OC20-2M', 'SchNet-S2EF-OC20-200k', 'SchNet-S2EF-OC20-2M', 'SchNet-S2EF-OC20-20M', 'SchNet-S2EF-OC20-All', 'DimeNet++-S2EF-OC20-200k', 'DimeNet++-S2EF-OC20-2M', 'DimeNet++-S2EF-OC20-20M', 'DimeNet++-S2EF-OC20-All', 'SpinConv-S2EF-OC20-2M', 'SpinConv-S2EF-OC20-All', 'GemNet-dT-S2EF-OC20-2M', 'GemNet-dT-S2EF-OC20-All', 'PaiNN-S2EF-OC20-All', 'GemNet-OC-S2EF-OC20-2M', 'GemNet-OC-S2EF-OC20-All', 'GemNet-OC-S2EF-OC20-All+MD', 'GemNet-OC-Large-S2EF-OC20-All+MD', 'SCN-S2EF-OC20-2M', 'SCN-t4-b2-S2EF-OC20-2M', 'SCN-S2EF-OC20-All+MD', 'eSCN-L4-M2-Lay12-S2EF-OC20-2M', 'eSCN-L6-M2-Lay12-S2EF-OC20-2M', 'eSCN-L6-M2-Lay12-S2EF-OC20-All+MD', 'eSCN-L6-M3-Lay20-S2EF-OC20-All+MD', 'EquiformerV2-83M-S2EF-OC20-2M', 'EquiformerV2-31M-S2EF-OC20-All+MD', 'EquiformerV2-153M-S2EF-OC20-All+MD', 'SchNet-S2EF-force-only-OC20-All', 'DimeNet++-force-only-OC20-All', 'DimeNet++-Large-S2EF-force-only-OC20-All', 'DimeNet++-S2EF-force-only-OC20-20M+Rattled', 'DimeNet++-S2EF-force-only-OC20-20M+MD', 'CGCNN-IS2RE-OC20-10k', 'CGCNN-IS2RE-OC20-100k', 'CGCNN-IS2RE-OC20-All', 'DimeNet-IS2RE-OC20-10k', 'DimeNet-IS2RE-OC20-100k', 'DimeNet-IS2RE-OC20-all', 'SchNet-IS2RE-OC20-10k', 'SchNet-IS2RE-OC20-100k', 'SchNet-IS2RE-OC20-All', 'DimeNet++-IS2RE-OC20-10k', 'DimeNet++-IS2RE-OC20-100k', 'DimeNet++-IS2RE-OC20-All', 'PaiNN-IS2RE-OC20-All', 'GemNet-dT-S2EFS-OC22', 'GemNet-OC-S2EFS-OC22', 'GemNet-OC-S2EFS-OC20+OC22', 'GemNet-OC-S2EFS-nsn-OC20+OC22', 'GemNet-OC-S2EFS-OC20->OC22', 'EquiformerV2-lE4-lF100-S2EFS-OC22', 'SchNet-S2EF-ODAC', 'DimeNet++-S2EF-ODAC', 'PaiNN-S2EF-ODAC', 'GemNet-OC-S2EF-ODAC', 'eSCN-S2EF-ODAC', 'EquiformerV2-S2EF-ODAC', 'EquiformerV2-Large-S2EF-ODAC', 'Gemnet-OC-IS2RE-ODAC', 'eSCN-IS2RE-ODAC', 'EquiformerV2-IS2RE-ODAC')
from fairchem.core.models.model_registry import model_name_to_local_file

checkpoint_path = model_name_to_local_file('GemNet-dT-S2EFS-OC22', local_cache='/tmp/fairchem_checkpoints/')
checkpoint_path
'/tmp/fairchem_checkpoints/gndt_oc22_all_s2ef.pt'

We have to update our configuration yml file with the test dataset.

from fairchem.core.common.tutorial_utils import generate_yml_config
yml = generate_yml_config(checkpoint_path, 'config.yml',
                   delete=['cmd', 'logger', 'task', 'model_attributes',
                           'dataset', 'slurm'],
                   update={'amp': True,
                           'gpus': 1,
                           'task.prediction_dtype': 'float32',
                           'logger':'tensorboard', # don't use wandb!
                            # Test data - prediction only so no regression
                           'dataset.test.src': 'data.db',
                           'dataset.test.format': 'ase_db',
                           'dataset.test.a2g_args.r_energy': False,
                           'dataset.test.a2g_args.r_forces': False,
                           'dataset.test.select_args.selection': 'natoms>5,xc=PBE',
                          })

yml
/home/runner/work/fairchem/fairchem/src/fairchem/core/models/escn/so3.py:23: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  _Jd = torch.load(os.path.join(os.path.dirname(__file__), "Jd.pt"))
/home/runner/work/fairchem/fairchem/src/fairchem/core/models/scn/spherical_harmonics.py:23: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  _Jd = torch.load(os.path.join(os.path.dirname(__file__), "Jd.pt"))
/home/runner/work/fairchem/fairchem/src/fairchem/core/models/equiformer_v2/wigner.py:10: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  _Jd = torch.load(os.path.join(os.path.dirname(__file__), "Jd.pt"))
/home/runner/work/fairchem/fairchem/src/fairchem/core/models/equiformer_v2/layer_norm.py:75: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
  @torch.cuda.amp.autocast(enabled=False)
/home/runner/work/fairchem/fairchem/src/fairchem/core/models/equiformer_v2/layer_norm.py:175: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
  @torch.cuda.amp.autocast(enabled=False)
/home/runner/work/fairchem/fairchem/src/fairchem/core/models/equiformer_v2/layer_norm.py:263: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
  @torch.cuda.amp.autocast(enabled=False)
/home/runner/work/fairchem/fairchem/src/fairchem/core/models/equiformer_v2/layer_norm.py:357: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
  @torch.cuda.amp.autocast(enabled=False)
/home/runner/work/fairchem/fairchem/src/fairchem/core/common/relaxation/ase_utils.py:150: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  checkpoint = torch.load(checkpoint_path, map_location=torch.device("cpu"))
WARNING:root:Detected old config, converting to new format. Consider updating to avoid potential incompatibilities.
WARNING:root:Scale factor comment not found in model
WARNING:root:No seed has been set in modelcheckpoint or OCPCalculator! Results may not be reproducible on re-run
PosixPath('/home/runner/work/fairchem/fairchem/docs/core/config.yml')

It is a good idea to redirect the output to a file. If the output gets too large here, the notebook may fail to save. Normally I would use a redirect like 2&>1, but this does not work with the main.py method. An alternative here is to open a terminal and run it there. If you have a gpu or multiple gpus, you should use the flag –num-gpus= and remove the –cpu flag.

%%capture inference
import time
from fairchem.core.common.tutorial_utils import fairchem_main

t0 = time.time()
! python {fairchem_main()} --mode predict --config-yml {yml} --checkpoint {checkpoint_path} --cpu
print(f'Elapsed time = {time.time() - t0:1.1f} seconds')
with open('mass-inference.txt', 'wb') as f:
    f.write(inference.stdout.encode('utf-8'))
! grep "Total time taken:" 'mass-inference.txt'
2024-09-18 21:19:13 (INFO): Total time taken: 7.306535959243774

The mass inference approach takes 1-2 minutes to run. See the output here.

results = ! grep "  results_dir:" mass-inference.txt
d = results[0].split(':')[-1].strip()
import numpy as np
results = np.load(f'{d}/ocp_predictions.npz', allow_pickle=True)
results.files
['energy', 'forces', 'chunk_idx', 'ids']

It is not obvious, but the data from mass inference is not in the same order. We have to get an id from the mass inference, and then “resort” the results so they are in the same order.

inds = np.array([int(r.split('_')[0]) for r in results['ids']])
sind = np.argsort(inds)
inds[sind]
array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
       17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33,
       34, 35, 36, 37, 38, 39, 40, 41, 42])

To compare this with the results, we need to get the energy data from the ase db.

from ase.db import connect
db = connect('data.db')

energies = np.array([row.energy for row in db.select('natoms>5,xc=PBE')])
natoms = np.array([row.natoms for row in db.select('natoms>5,xc=PBE')])

Now, we can see the predictions. They are only ok here; that is not surprising, the data set has lots of Au configurations that have never been seen by this model. Fine-tuning would certainly help improve this.

import matplotlib.pyplot as plt

plt.plot(energies / natoms, results['energy'][sind] / natoms, 'b.')
plt.xlabel('DFT')
plt.ylabel('OCP');
../_images/4dcb128b635cd71a2eef3c94324a4f8a5c6f157a4b0bc98d9c8be1b386bd6a53.png

The ASE calculator way#

We include this here just to show that:

  1. We get the same results

  2. That this is much slower.

from fairchem.core.common.relaxation.ase_utils import OCPCalculator
calc = OCPCalculator(checkpoint_path=checkpoint_path, cpu=False)
/home/runner/work/fairchem/fairchem/src/fairchem/core/common/relaxation/ase_utils.py:150: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  checkpoint = torch.load(checkpoint_path, map_location=torch.device("cpu"))
WARNING:root:Detected old config, converting to new format. Consider updating to avoid potential incompatibilities.
WARNING:root:Scale factor comment not found in model
WARNING:root:No seed has been set in modelcheckpoint or OCPCalculator! Results may not be reproducible on re-run
import time
from tqdm import tqdm
t0 = time.time()
OCP, DFT = [], []
for row in tqdm(db.select('natoms>5,xc=PBE')):
    atoms = row.toatoms()
    atoms.set_calculator(calc)
    DFT += [row.energy / len(atoms)]
    OCP += [atoms.get_potential_energy() / len(atoms)]
print(f'Elapsed time {time.time() - t0:1.1} seconds')
0it [00:00, ?it/s]
/tmp/ipykernel_2750/402181013.py:7: DeprecationWarning: Please use atoms.calc = calc
  atoms.set_calculator(calc)
/home/runner/work/fairchem/fairchem/src/fairchem/core/trainers/ocp_trainer.py:451: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
  with torch.cuda.amp.autocast(enabled=self.scaler is not None):
1it [00:00,  3.23it/s]
2it [00:00,  3.40it/s]
3it [00:00,  3.48it/s]
4it [00:01,  4.29it/s]
5it [00:01,  3.76it/s]
6it [00:01,  3.52it/s]
7it [00:01,  3.49it/s]
8it [00:02,  3.87it/s]
10it [00:02,  4.92it/s]
11it [00:02,  4.66it/s]
12it [00:02,  4.94it/s]
13it [00:03,  4.95it/s]
15it [00:03,  5.44it/s]
16it [00:03,  5.59it/s]
17it [00:03,  5.10it/s]
18it [00:03,  5.64it/s]
19it [00:04,  4.19it/s]
20it [00:04,  4.25it/s]
21it [00:04,  4.33it/s]
22it [00:04,  4.56it/s]
23it [00:05,  4.01it/s]
24it [00:05,  3.88it/s]
25it [00:05,  4.08it/s]
26it [00:05,  4.64it/s]
27it [00:06,  4.10it/s]
28it [00:06,  3.27it/s]
30it [00:07,  3.60it/s]
31it [00:07,  3.47it/s]
32it [00:07,  3.76it/s]
33it [00:07,  4.09it/s]
34it [00:08,  4.07it/s]
35it [00:08,  4.03it/s]
36it [00:08,  3.69it/s]
37it [00:08,  4.00it/s]
38it [00:09,  4.60it/s]
39it [00:09,  5.06it/s]
40it [00:09,  4.16it/s]
41it [00:09,  4.79it/s]
42it [00:09,  4.30it/s]
43it [00:10,  4.73it/s]
43it [00:10,  4.25it/s]
Elapsed time 1e+01 seconds

This takes at least twice as long as the mass-inference approach above. It is conceptually simpler though, and does not require resorting.

plt.plot(DFT, OCP, 'b.')
plt.xlabel('DFT (eV/atom)')
plt.ylabel('OCP (eV/atom)');
../_images/2d33d2519b86aa6a1c181115d2d4536b85d09e8a37a78075c3cec7263c717409.png

Comparing ASE calculator and main.py#

The results should be the same.

It is worth noting the default precision of predictions is float16 with main.py, but with the ASE calculator the default precision is float32. Supposedly you can specify --task.prediction_dtype=float32 at the command line to or specify it in the config.yml like we do above, but as of the tutorial this does not resolve the issue.

As noted above (see also Issue 542), the ASE calculator and main.py use different precisions by default, which can lead to small differences.

np.mean(np.abs(results['energy'][sind] - OCP * natoms))  # MAE
39.12066565699807
np.min(results['energy'][sind] - OCP * natoms), np.max(results['energy'][sind] - OCP * natoms)
(-134.28153038024902, 134.28153038024902)
plt.hist(results['energy'][sind] - OCP * natoms, bins=20);
../_images/73d4397bc47adf7dd1bfaf0ad95149745534bc70b888193565051cdebb08b017.png

Here we see many of the differences are very small. 0.0078125 = 1 / 128, and these errors strongly suggest some kind of mixed precision is responsible for these differences. It is an open issue to remove them and identify where the cause is.

(results['energy'][sind] - OCP * natoms)[0:400]
array([[-1.42108547e-14, -1.32502747e+00, -6.39400826e+01, ...,
        -7.13904953e+01, -1.99435120e+01, -7.34947090e+01],
       [ 1.32502747e+00,  0.00000000e+00, -6.26150551e+01, ...,
        -7.00654678e+01, -1.86184845e+01, -7.21696815e+01],
       [ 6.39400826e+01,  6.26150551e+01,  0.00000000e+00, ...,
        -7.45041275e+00,  4.39965706e+01, -9.55462646e+00],
       ...,
       [ 7.13904953e+01,  7.00654678e+01,  7.45041275e+00, ...,
         0.00000000e+00,  5.14469833e+01, -2.10421371e+00],
       [ 1.99435120e+01,  1.86184845e+01, -4.39965706e+01, ...,
        -5.14469833e+01,  0.00000000e+00, -5.35511971e+01],
       [ 7.34947090e+01,  7.21696815e+01,  9.55462646e+00, ...,
         2.10421371e+00,  5.35511971e+01,  0.00000000e+00]])