Fast batched inference#
The ASE calculator is not necessarily the most efficient way to run a lot of computations. It is better to do a “mass inference” using a command line utility. We illustrate how to do that here.
In this paper we computed about 10K different gold structures:
Boes, J. R., Groenenboom, M. C., Keith, J. A., & Kitchin, J. R. (2016). Neural network and Reaxff comparison for Au properties. Int. J. Quantum Chem., 116(13), 979–987. http://dx.doi.org/10.1002/qua.25115
You can retrieve the dataset below. In this notebook we learn how to do “mass inference” without an ASE calculator. You do this by creating a config.yml file, and running the main.py
command line utility.
from fairchem.core.scripts.download_large_files import download_file_group
download_file_group('docs')
docs/tutorials/NRR/NRR_example_bulks.pkl already exists
docs/core/fine-tuning/supporting-information.json already exists
docs/core/data.db already exists
Inference on this file will be fast if we have a gpu, but if we don’t this could take a while. To keep things fast for the automated builds, we’ll just select the first 10 structures so it’s still approachable with just a CPU. Comment or skip this block to use the whole dataset!
! mv data.db full_data.db
import ase.db
import numpy as np
with ase.db.connect('full_data.db') as full_db:
with ase.db.connect('data.db',append=False) as subset_db:
# Select 50 random points for the subset, ASE DB ids start at 1
for i in np.random.choice(list(range(1,len(full_db)+1)),size=50,replace=False):
atoms = full_db.get_atoms(f'id={i}', add_additional_information=True)
if 'tag' in atoms.info['key_value_pairs']:
atoms.info['key_value_pairs']['tag'] = int(atoms.info['key_value_pairs']['tag'])
for key in atoms.info["key_value_pairs"]:
if atoms.info["key_value_pairs"][key] == "True":
atoms.info["key_value_pairs"][key] = True
elif atoms.info["key_value_pairs"][key] == "False":
atoms.info["key_value_pairs"][key] = False
subset_db.write(atoms, **atoms.info['key_value_pairs'])
! ase db data.db
id|age|user |formula|calculator| energy|natoms| fmax|pbc| volume|charge| mass
1| 0s|runner|Au33 |unknown | -85.751| 33|0.123|TTT|5480.808| 0.000| 6499.897
2| 0s|runner|Au10 |unknown | -21.713| 10|0.701|TTT|4112.512| 0.000| 1969.666
3| 0s|runner|Au4 |unknown | -12.213| 4|0.034|TTT| 143.096| 0.000| 787.866
4| 0s|runner|Au6 |unknown | -12.705| 6|0.040|TTT|3511.808| 0.000| 1181.799
5| 0s|runner|Au26 |unknown | -82.842| 26|0.122|TTT| 475.188| 0.000| 5121.131
6| 0s|runner|Au4 |unknown | -8.292| 4|6.615|TTT| 156.940| 0.000| 787.866
7| 0s|runner|Au40 |unknown |-104.276| 40|0.244|TTT|5868.580| 0.000| 7878.663
8| 0s|runner|Au37 |unknown | -96.209| 37|0.086|TTT|5771.976| 0.000| 7287.763
9| 0s|runner|Au10 |unknown | -21.668| 10|0.034|TTT|5639.752| 0.000| 1969.666
10| 0s|runner|Au31 |unknown | -79.160| 31|0.096|TTT|5304.824| 0.000| 6105.964
11| 0s|runner|Au52 |unknown |-137.922| 52|0.187|TTT|6744.510| 0.000|10242.262
12| 0s|runner|Au41 |unknown |-105.733| 41|0.184|TTT|6012.425| 0.000| 8075.629
13| 0s|runner|Au43 |unknown |-112.016| 43|0.179|TTT|6014.312| 0.000| 8469.562
14| 0s|runner|Au15 |unknown | -47.400| 15|0.326|TTT| 286.452| 0.000| 2954.499
15| 0s|runner|Au18 |unknown | -40.868| 18|0.674|TTT|4250.202| 0.000| 3545.398
16| 0s|runner|Au37 |unknown | -96.184| 37|0.186|TTT|5771.976| 0.000| 7287.763
17| 0s|runner|Au26 |unknown | -78.477| 26|0.178|TTT| 539.101| 0.000| 5121.131
18| 0s|runner|Au24 |unknown | -59.263| 24|0.096|TTT|4672.311| 0.000| 4727.198
19| 0s|runner|Au |unknown | -0.700| 1|0.000|TTT| 60.301| 0.000| 196.967
20| 0s|runner|Au48 |unknown |-125.852| 48|0.176|TTT|6440.631| 0.000| 9454.395
Rows: 50 (showing first 20)
Keys: NEB, bulk, calc_time, cluster, concentration, config, diffusion, ediff, ediffg, encut, factor, fermi, gga, group, image, kpt1, kpt2, kpt3, lattice, miller, neural_energy, reax_energy, relaxed, strain, structure, surf, train_set, traj, type, vacuum, volume, xc
You have to choose a checkpoint to start with. The newer checkpoints may require too much memory for this environment.
from fairchem.core.models.model_registry import available_pretrained_models
print(available_pretrained_models)
('CGCNN-S2EF-OC20-200k', 'CGCNN-S2EF-OC20-2M', 'CGCNN-S2EF-OC20-20M', 'CGCNN-S2EF-OC20-All', 'DimeNet-S2EF-OC20-200k', 'DimeNet-S2EF-OC20-2M', 'SchNet-S2EF-OC20-200k', 'SchNet-S2EF-OC20-2M', 'SchNet-S2EF-OC20-20M', 'SchNet-S2EF-OC20-All', 'DimeNet++-S2EF-OC20-200k', 'DimeNet++-S2EF-OC20-2M', 'DimeNet++-S2EF-OC20-20M', 'DimeNet++-S2EF-OC20-All', 'SpinConv-S2EF-OC20-2M', 'SpinConv-S2EF-OC20-All', 'GemNet-dT-S2EF-OC20-2M', 'GemNet-dT-S2EF-OC20-All', 'PaiNN-S2EF-OC20-All', 'GemNet-OC-S2EF-OC20-2M', 'GemNet-OC-S2EF-OC20-All', 'GemNet-OC-S2EF-OC20-All+MD', 'GemNet-OC-Large-S2EF-OC20-All+MD', 'SCN-S2EF-OC20-2M', 'SCN-t4-b2-S2EF-OC20-2M', 'SCN-S2EF-OC20-All+MD', 'eSCN-L4-M2-Lay12-S2EF-OC20-2M', 'eSCN-L6-M2-Lay12-S2EF-OC20-2M', 'eSCN-L6-M2-Lay12-S2EF-OC20-All+MD', 'eSCN-L6-M3-Lay20-S2EF-OC20-All+MD', 'EquiformerV2-83M-S2EF-OC20-2M', 'EquiformerV2-31M-S2EF-OC20-All+MD', 'EquiformerV2-153M-S2EF-OC20-All+MD', 'SchNet-S2EF-force-only-OC20-All', 'DimeNet++-force-only-OC20-All', 'DimeNet++-Large-S2EF-force-only-OC20-All', 'DimeNet++-S2EF-force-only-OC20-20M+Rattled', 'DimeNet++-S2EF-force-only-OC20-20M+MD', 'CGCNN-IS2RE-OC20-10k', 'CGCNN-IS2RE-OC20-100k', 'CGCNN-IS2RE-OC20-All', 'DimeNet-IS2RE-OC20-10k', 'DimeNet-IS2RE-OC20-100k', 'DimeNet-IS2RE-OC20-all', 'SchNet-IS2RE-OC20-10k', 'SchNet-IS2RE-OC20-100k', 'SchNet-IS2RE-OC20-All', 'DimeNet++-IS2RE-OC20-10k', 'DimeNet++-IS2RE-OC20-100k', 'DimeNet++-IS2RE-OC20-All', 'PaiNN-IS2RE-OC20-All', 'GemNet-dT-S2EFS-OC22', 'GemNet-OC-S2EFS-OC22', 'GemNet-OC-S2EFS-OC20+OC22', 'GemNet-OC-S2EFS-nsn-OC20+OC22', 'GemNet-OC-S2EFS-OC20->OC22', 'EquiformerV2-lE4-lF100-S2EFS-OC22', 'SchNet-S2EF-ODAC', 'DimeNet++-S2EF-ODAC', 'PaiNN-S2EF-ODAC', 'GemNet-OC-S2EF-ODAC', 'eSCN-S2EF-ODAC', 'EquiformerV2-S2EF-ODAC', 'EquiformerV2-Large-S2EF-ODAC', 'Gemnet-OC-IS2RE-ODAC', 'eSCN-IS2RE-ODAC', 'EquiformerV2-IS2RE-ODAC')
from fairchem.core.models.model_registry import model_name_to_local_file
checkpoint_path = model_name_to_local_file('GemNet-dT-S2EFS-OC22', local_cache='/tmp/fairchem_checkpoints/')
checkpoint_path
'/tmp/fairchem_checkpoints/gndt_oc22_all_s2ef.pt'
We have to update our configuration yml file with the test dataset.
from fairchem.core.common.tutorial_utils import generate_yml_config
yml = generate_yml_config(checkpoint_path, 'config.yml',
delete=['cmd', 'logger', 'task', 'model_attributes',
'dataset', 'slurm'],
update={'amp': True,
'gpus': 1,
'task.prediction_dtype': 'float32',
'logger':'tensorboard', # don't use wandb!
# Test data - prediction only so no regression
'dataset.test.src': 'data.db',
'dataset.test.format': 'ase_db',
'dataset.test.a2g_args.r_energy': False,
'dataset.test.a2g_args.r_forces': False,
'dataset.test.select_args.selection': 'natoms>5,xc=PBE',
})
yml
/home/runner/work/fairchem/fairchem/src/fairchem/core/models/escn/so3.py:23: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
_Jd = torch.load(os.path.join(os.path.dirname(__file__), "Jd.pt"))
/home/runner/work/fairchem/fairchem/src/fairchem/core/models/scn/spherical_harmonics.py:23: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
_Jd = torch.load(os.path.join(os.path.dirname(__file__), "Jd.pt"))
/home/runner/work/fairchem/fairchem/src/fairchem/core/models/equiformer_v2/wigner.py:10: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
_Jd = torch.load(os.path.join(os.path.dirname(__file__), "Jd.pt"))
/home/runner/work/fairchem/fairchem/src/fairchem/core/models/equiformer_v2/layer_norm.py:75: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
@torch.cuda.amp.autocast(enabled=False)
/home/runner/work/fairchem/fairchem/src/fairchem/core/models/equiformer_v2/layer_norm.py:175: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
@torch.cuda.amp.autocast(enabled=False)
/home/runner/work/fairchem/fairchem/src/fairchem/core/models/equiformer_v2/layer_norm.py:263: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
@torch.cuda.amp.autocast(enabled=False)
/home/runner/work/fairchem/fairchem/src/fairchem/core/models/equiformer_v2/layer_norm.py:357: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
@torch.cuda.amp.autocast(enabled=False)
/home/runner/work/fairchem/fairchem/src/fairchem/core/common/relaxation/ase_utils.py:150: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
checkpoint = torch.load(checkpoint_path, map_location=torch.device("cpu"))
WARNING:root:Detected old config, converting to new format. Consider updating to avoid potential incompatibilities.
INFO:root:amp: true
cmd:
checkpoint_dir: /home/runner/work/fairchem/fairchem/docs/core/checkpoints/2024-11-19-06-21-52
commit: aa298ac
identifier: ''
logs_dir: /home/runner/work/fairchem/fairchem/docs/core/logs/wandb/2024-11-19-06-21-52
print_every: 100
results_dir: /home/runner/work/fairchem/fairchem/docs/core/results/2024-11-19-06-21-52
seed: null
timestamp_id: 2024-11-19-06-21-52
version: 0.1.dev1+gaa298ac
dataset:
format: oc22_lmdb
key_mapping:
force: forces
y: energy
normalize_labels: false
oc20_ref: /checkpoint/janlan/ocp/other_data/final_ref_energies_02_07_2021.pkl
raw_energy_target: true
evaluation_metrics:
metrics:
energy:
- mae
forces:
- forcesx_mae
- forcesy_mae
- forcesz_mae
- mae
- cosine_similarity
- magnitude_error
misc:
- energy_forces_within_threshold
primary_metric: forces_mae
gp_gpus: null
gpus: 0
logger: wandb
loss_functions:
- energy:
coefficient: 1
fn: mae
- forces:
coefficient: 1
fn: l2mae
model:
activation: silu
cbf:
name: spherical_harmonics
cutoff: 6.0
direct_forces: true
emb_size_atom: 512
emb_size_bil_trip: 64
emb_size_cbf: 16
emb_size_edge: 512
emb_size_rbf: 16
emb_size_trip: 64
envelope:
exponent: 5
name: polynomial
extensive: true
max_neighbors: 50
name: gemnet_t
num_after_skip: 2
num_atom: 3
num_before_skip: 1
num_blocks: 3
num_concat: 1
num_radial: 128
num_spherical: 7
otf_graph: true
output_init: HeOrthogonal
rbf:
name: gaussian
regress_forces: true
optim:
batch_size: 16
clip_grad_norm: 10
ema_decay: 0.999
energy_coefficient: 1
eval_batch_size: 16
eval_every: 5000
force_coefficient: 1
loss_energy: mae
loss_force: atomwisel2
lr_gamma: 0.8
lr_initial: 0.0005
lr_milestones:
- 64000
- 96000
- 128000
- 160000
- 192000
max_epochs: 80
num_workers: 2
optimizer: AdamW
optimizer_params:
amsgrad: true
warmup_steps: -1
outputs:
energy:
level: system
forces:
eval_on_free_atoms: true
level: atom
train_on_free_atoms: true
relax_dataset: {}
slurm:
additional_parameters:
constraint: volta32gb
cpus_per_task: 3
folder: /checkpoint/abhshkdz/ocp_oct1_logs/57864354
gpus_per_node: 8
job_id: '57864354'
job_name: gndt_oc22_s2ef
mem: 480GB
nodes: 2
ntasks_per_node: 8
partition: ocp
time: 4320
task:
dataset: oc22_lmdb
eval_on_free_atoms: true
primary_metric: forces_mae
train_on_free_atoms: true
test_dataset: {}
trainer: ocp
val_dataset: {}
INFO:root:Loading model: gemnet_t
INFO:root:Loaded GemNetT with 31671825 parameters.
INFO:root:Loading checkpoint in inference-only mode, not loading keys associated with trainer state!
INFO:root:Overwriting scaling factors with those loaded from checkpoint. If you're generating predictions with a pretrained checkpoint, this is the correct behavior. To disable this, delete `scale_dict` from the checkpoint.
WARNING:root:Scale factor comment not found in model
WARNING:root:No seed has been set in modelcheckpoint or OCPCalculator! Results may not be reproducible on re-run
PosixPath('/home/runner/work/fairchem/fairchem/docs/core/config.yml')
It is a good idea to redirect the output to a file. If the output gets too large here, the notebook may fail to save. Normally I would use a redirect like 2&>1
, but this does not work with the main.py method. An alternative here is to open a terminal and run it there. If you have a gpu or multiple gpus, you should use the flag –num-gpus=
%%capture inference
import time
from fairchem.core.common.tutorial_utils import fairchem_main
t0 = time.time()
! python {fairchem_main()} --mode predict --config-yml {yml} --checkpoint {checkpoint_path} --cpu
print(f'Elapsed time = {time.time() - t0:1.1f} seconds')
with open('mass-inference.txt', 'wb') as f:
f.write(inference.stdout.encode('utf-8'))
! grep "Total time taken:" 'mass-inference.txt'
2024-11-19 06:22:22 (INFO): Total time taken: 8.455804347991943
The mass inference approach takes 1-2 minutes to run. See the output here.
results = ! grep " results_dir:" mass-inference.txt
d = results[0].split(':')[-1].strip()
import numpy as np
results = np.load(f'{d}/ocp_predictions.npz', allow_pickle=True)
results.files
['energy', 'forces', 'chunk_idx', 'ids']
It is not obvious, but the data from mass inference is not in the same order. We have to get an id from the mass inference, and then “resort” the results so they are in the same order.
inds = np.array([int(r.split('_')[0]) for r in results['ids']])
sind = np.argsort(inds)
inds[sind]
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33,
34, 35, 36, 37, 38])
To compare this with the results, we need to get the energy data from the ase db.
from ase.db import connect
db = connect('data.db')
energies = np.array([row.energy for row in db.select('natoms>5,xc=PBE')])
natoms = np.array([row.natoms for row in db.select('natoms>5,xc=PBE')])
Now, we can see the predictions. They are only ok here; that is not surprising, the data set has lots of Au configurations that have never been seen by this model. Fine-tuning would certainly help improve this.
import matplotlib.pyplot as plt
plt.plot(energies / natoms, results['energy'][sind] / natoms, 'b.')
plt.xlabel('DFT')
plt.ylabel('OCP');
The ASE calculator way#
We include this here just to show that:
We get the same results
That this is much slower.
from fairchem.core.common.relaxation.ase_utils import OCPCalculator
calc = OCPCalculator(checkpoint_path=checkpoint_path, cpu=False)
/home/runner/work/fairchem/fairchem/src/fairchem/core/common/relaxation/ase_utils.py:150: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
checkpoint = torch.load(checkpoint_path, map_location=torch.device("cpu"))
WARNING:root:Detected old config, converting to new format. Consider updating to avoid potential incompatibilities.
INFO:root:amp: true
cmd:
checkpoint_dir: /home/runner/work/fairchem/fairchem/docs/core/checkpoints/2024-11-19-06-21-52
commit: aa298ac
identifier: ''
logs_dir: /home/runner/work/fairchem/fairchem/docs/core/logs/wandb/2024-11-19-06-21-52
print_every: 100
results_dir: /home/runner/work/fairchem/fairchem/docs/core/results/2024-11-19-06-21-52
seed: null
timestamp_id: 2024-11-19-06-21-52
version: 0.1.dev1+gaa298ac
dataset:
format: oc22_lmdb
key_mapping:
force: forces
y: energy
normalize_labels: false
oc20_ref: /checkpoint/janlan/ocp/other_data/final_ref_energies_02_07_2021.pkl
raw_energy_target: true
evaluation_metrics:
metrics:
energy:
- mae
forces:
- forcesx_mae
- forcesy_mae
- forcesz_mae
- mae
- cosine_similarity
- magnitude_error
misc:
- energy_forces_within_threshold
primary_metric: forces_mae
gp_gpus: null
gpus: 0
logger: wandb
loss_functions:
- energy:
coefficient: 1
fn: mae
- forces:
coefficient: 1
fn: l2mae
model:
activation: silu
cbf:
name: spherical_harmonics
cutoff: 6.0
direct_forces: true
emb_size_atom: 512
emb_size_bil_trip: 64
emb_size_cbf: 16
emb_size_edge: 512
emb_size_rbf: 16
emb_size_trip: 64
envelope:
exponent: 5
name: polynomial
extensive: true
max_neighbors: 50
name: gemnet_t
num_after_skip: 2
num_atom: 3
num_before_skip: 1
num_blocks: 3
num_concat: 1
num_radial: 128
num_spherical: 7
otf_graph: true
output_init: HeOrthogonal
rbf:
name: gaussian
regress_forces: true
optim:
batch_size: 16
clip_grad_norm: 10
ema_decay: 0.999
energy_coefficient: 1
eval_batch_size: 16
eval_every: 5000
force_coefficient: 1
loss_energy: mae
loss_force: atomwisel2
lr_gamma: 0.8
lr_initial: 0.0005
lr_milestones:
- 64000
- 96000
- 128000
- 160000
- 192000
max_epochs: 80
num_workers: 2
optimizer: AdamW
optimizer_params:
amsgrad: true
warmup_steps: -1
outputs:
energy:
level: system
forces:
eval_on_free_atoms: true
level: atom
train_on_free_atoms: true
relax_dataset: {}
slurm:
additional_parameters:
constraint: volta32gb
cpus_per_task: 3
folder: /checkpoint/abhshkdz/ocp_oct1_logs/57864354
gpus_per_node: 8
job_id: '57864354'
job_name: gndt_oc22_s2ef
mem: 480GB
nodes: 2
ntasks_per_node: 8
partition: ocp
time: 4320
task:
dataset: oc22_lmdb
eval_on_free_atoms: true
primary_metric: forces_mae
train_on_free_atoms: true
test_dataset: {}
trainer: ocp
val_dataset: {}
INFO:root:Loading model: gemnet_t
INFO:root:Loaded GemNetT with 31671825 parameters.
INFO:root:Loading checkpoint in inference-only mode, not loading keys associated with trainer state!
INFO:root:Overwriting scaling factors with those loaded from checkpoint. If you're generating predictions with a pretrained checkpoint, this is the correct behavior. To disable this, delete `scale_dict` from the checkpoint.
WARNING:root:Scale factor comment not found in model
WARNING:root:No seed has been set in modelcheckpoint or OCPCalculator! Results may not be reproducible on re-run
import time
from tqdm import tqdm
t0 = time.time()
OCP, DFT = [], []
for row in tqdm(db.select('natoms>5,xc=PBE')):
atoms = row.toatoms()
atoms.set_calculator(calc)
DFT += [row.energy / len(atoms)]
OCP += [atoms.get_potential_energy() / len(atoms)]
print(f'Elapsed time {time.time() - t0:1.1} seconds')
0it [00:00, ?it/s]
/tmp/ipykernel_2740/402181013.py:7: DeprecationWarning: Please use atoms.calc = calc
atoms.set_calculator(calc)
/home/runner/work/fairchem/fairchem/src/fairchem/core/trainers/ocp_trainer.py:461: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
with torch.cuda.amp.autocast(enabled=self.scaler is not None):
1it [00:00, 3.47it/s]
2it [00:00, 4.87it/s]
3it [00:00, 5.75it/s]
4it [00:01, 3.41it/s]
5it [00:01, 3.27it/s]
6it [00:01, 3.26it/s]
7it [00:01, 3.88it/s]
8it [00:02, 4.06it/s]
9it [00:02, 3.52it/s]
10it [00:02, 3.29it/s]
11it [00:03, 3.05it/s]
12it [00:03, 3.03it/s]
13it [00:03, 3.44it/s]
14it [00:03, 3.50it/s]
15it [00:04, 3.44it/s]
16it [00:04, 3.85it/s]
17it [00:04, 3.55it/s]
18it [00:05, 2.15it/s]
19it [00:06, 2.30it/s]
20it [00:06, 2.90it/s]
21it [00:06, 3.30it/s]
22it [00:06, 3.76it/s]
23it [00:06, 4.01it/s]
24it [00:07, 3.58it/s]
25it [00:07, 4.16it/s]
26it [00:07, 4.83it/s]
27it [00:07, 5.54it/s]
28it [00:07, 4.66it/s]
29it [00:07, 5.41it/s]
30it [00:08, 4.74it/s]
31it [00:08, 3.66it/s]
32it [00:09, 2.24it/s]
33it [00:09, 2.66it/s]
34it [00:09, 2.90it/s]
35it [00:10, 3.17it/s]
36it [00:10, 3.69it/s]
37it [00:10, 3.92it/s]
38it [00:10, 4.28it/s]
39it [00:10, 4.96it/s]
39it [00:10, 3.58it/s]
Elapsed time 1e+01 seconds
This takes at least twice as long as the mass-inference approach above. It is conceptually simpler though, and does not require resorting.
plt.plot(DFT, OCP, 'b.')
plt.xlabel('DFT (eV/atom)')
plt.ylabel('OCP (eV/atom)');
Comparing ASE calculator and main.py#
The results should be the same.
It is worth noting the default precision of predictions is float16 with main.py, but with the ASE calculator the default precision is float32. Supposedly you can specify --task.prediction_dtype=float32
at the command line to or specify it in the config.yml like we do above, but as of the tutorial this does not resolve the issue.
As noted above (see also Issue 542), the ASE calculator and main.py use different precisions by default, which can lead to small differences.
np.mean(np.abs(results['energy'][sind] - OCP * natoms)) # MAE
54.2420292063656
np.min(results['energy'][sind] - OCP * natoms), np.max(results['energy'][sind] - OCP * natoms)
(-281.943151473999, 281.943151473999)
plt.hist(results['energy'][sind] - OCP * natoms, bins=20);
Here we see many of the differences are very small. 0.0078125 = 1 / 128, and these errors strongly suggest some kind of mixed precision is responsible for these differences. It is an open issue to remove them and identify where the cause is.
(results['energy'][sind] - OCP * natoms)[0:400]
array([[ 0. , -67.61706161, -78.87683678, ..., -15.60034943,
-44.52700806, -58.80632401],
[ 67.61706161, 0. , -11.25977516, ..., 52.01671219,
23.09005356, 8.81073761],
[ 78.87683678, 11.25977516, 0. , ..., 63.27648735,
34.34982872, 20.07051277],
...,
[ 15.60034943, -52.01671219, -63.27648735, ..., 0. ,
-28.92665863, -43.20597458],
[ 44.52700806, -23.09005356, -34.34982872, ..., 28.92665863,
0. , -14.27931595],
[ 58.80632401, -8.81073761, -20.07051277, ..., 43.20597458,
14.27931595, 0. ]])