AdsorbML tutorial#

from fairchem.core.common.relaxation.ase_utils import OCPCalculator
import ase.io
from ase.optimize import BFGS

from fairchem.data.oc.core import Adsorbate, AdsorbateSlabConfig, Bulk, Slab
import os
from glob import glob
import pandas as pd
from fairchem.data.oc.utils import DetectTrajAnomaly
from fairchem.data.oc.utils.vasp import write_vasp_input_files

# Optional - see below
import numpy as np
from dscribe.descriptors import SOAP
from scipy.spatial.distance import pdist, squareform
from x3dase.visualize import view_x3d_n

Enumerate the adsorbate-slab configurations to run relaxations on#

AdsorbML incorporates random placement, which is especially useful for more complicated adsorbates which may have many degrees of freedom. I have opted sample a few random placements and a few heuristic. Here I am using *CO on copper (1,1,1) as an example.

bulk_src_id = "mp-30"
adsorbate_smiles = "*CO"

bulk = Bulk(bulk_src_id_from_db = bulk_src_id)
adsorbate = Adsorbate(adsorbate_smiles_from_db=adsorbate_smiles)
slabs = Slab.from_bulk_get_specific_millers(bulk = bulk, specific_millers=(1,1,1))

# There may be multiple slabs with this miller index.
# For demonstrative purposes we will take the first entry.
slab = slabs[0]
Downloading /home/runner/work/fairchem/fairchem/src/fairchem/data/oc/databases/pkls/bulks.pkl...
# Perform heuristic placements
heuristic_adslabs = AdsorbateSlabConfig(slabs[0], adsorbate, mode="heuristic")

# Perform random placements
# (for AdsorbML we use `num_sites = 100` but we will use 4 for brevity here)
random_adslabs = AdsorbateSlabConfig(slabs[0], adsorbate, mode="random_site_heuristic_placement", num_sites = 4)

Run ML relaxations:#

There are 2 options for how to do this.

  1. Using OCPCalculator as the calculator within the ASE framework

  2. By writing objects to lmdb and relaxing them using main.py in the ocp repo

(1) is really only adequate for small stuff and it is what I will show here, but if you plan to run many relaxations, you should definitely use (2). More details about writing lmdbs has been provided here - follow the IS2RS/IS2RE instructions. And more information about running relaxations once the lmdb has been written is here.

You need to provide the calculator with a path to a model checkpoint file. That can be downloaded here

from fairchem.core.common.relaxation.ase_utils import OCPCalculator
from fairchem.core.models.model_registry import model_name_to_local_file
import os

checkpoint_path = model_name_to_local_file('EquiformerV2-31M-S2EF-OC20-All+MD', local_cache='/tmp/fairchem_checkpoints/')

os.makedirs(f"data/{bulk}_{adsorbate}", exist_ok=True)

# Define the calculator
calc = OCPCalculator(checkpoint_path=checkpoint_path) # if you have a gpu, add `cpu=False` to speed up calculations

adslabs = [*heuristic_adslabs.atoms_list, *random_adslabs.atoms_list]
# Set up the calculator
for idx, adslab in enumerate(adslabs):
    adslab.calc = calc
    opt = BFGS(adslab, trajectory=f"data/{bulk}_{adsorbate}/{idx}.traj")
    opt.run(fmax=0.05, steps=100) # For the AdsorbML results we used fmax = 0.02 and steps = 300, but we will use less strict values for brevity.
/home/runner/work/fairchem/fairchem/src/fairchem/core/models/scn/spherical_harmonics.py:23: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  _Jd = torch.load(os.path.join(os.path.dirname(__file__), "Jd.pt"))
/home/runner/work/fairchem/fairchem/src/fairchem/core/models/equiformer_v2/wigner.py:10: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  _Jd = torch.load(os.path.join(os.path.dirname(__file__), "Jd.pt"))
/home/runner/work/fairchem/fairchem/src/fairchem/core/models/escn/so3.py:23: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  _Jd = torch.load(os.path.join(os.path.dirname(__file__), "Jd.pt"))
/home/runner/work/fairchem/fairchem/src/fairchem/core/common/relaxation/ase_utils.py:200: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  checkpoint = torch.load(checkpoint_path, map_location=torch.device("cpu"))
WARNING:root:Detected old config, converting to new format. Consider updating to avoid potential incompatibilities.
INFO:root:amp: true
cmd:
  checkpoint_dir: /home/runner/work/fairchem/fairchem/docs/tutorials/checkpoints/2025-03-29-05-13-36
  commit: core:b88b66e,experimental:NA
  identifier: ''
  logs_dir: /home/runner/work/fairchem/fairchem/docs/tutorials/logs/wandb/2025-03-29-05-13-36
  print_every: 100
  results_dir: /home/runner/work/fairchem/fairchem/docs/tutorials/results/2025-03-29-05-13-36
  seed: null
  timestamp_id: 2025-03-29-05-13-36
  version: 0.1.dev1+gb88b66e
dataset:
  format: trajectory_lmdb_v2
  grad_target_mean: 0.0
  grad_target_std: 2.887317180633545
  key_mapping:
    force: forces
    y: energy
  normalize_labels: true
  target_mean: -0.7554450631141663
  target_std: 2.887317180633545
  transforms:
    normalizer:
      energy:
        mean: -0.7554450631141663
        stdev: 2.887317180633545
      forces:
        mean: 0.0
        stdev: 2.887317180633545
evaluation_metrics:
  metrics:
    energy:
    - mae
    forces:
    - forcesx_mae
    - forcesy_mae
    - forcesz_mae
    - mae
    - cosine_similarity
    - magnitude_error
    misc:
    - energy_forces_within_threshold
  primary_metric: forces_mae
gp_gpus: null
gpus: 0
logger: wandb
loss_functions:
- energy:
    coefficient: 4
    fn: mae
- forces:
    coefficient: 100
    fn: l2mae
model:
  alpha_drop: 0.1
  attn_activation: silu
  attn_alpha_channels: 64
  attn_hidden_channels: 64
  attn_value_channels: 16
  distance_function: gaussian
  drop_path_rate: 0.1
  edge_channels: 128
  ffn_activation: silu
  ffn_hidden_channels: 128
  grid_resolution: 18
  lmax_list:
  - 4
  max_neighbors: 20
  max_num_elements: 90
  max_radius: 12.0
  mmax_list:
  - 2
  name: equiformer_v2
  norm_type: layer_norm_sh
  num_distance_basis: 512
  num_heads: 8
  num_layers: 8
  num_sphere_samples: 128
  otf_graph: true
  proj_drop: 0.0
  regress_forces: true
  sphere_channels: 128
  use_atom_edge_embedding: true
  use_gate_act: false
  use_grid_mlp: true
  use_pbc: true
  use_s2_act_attn: false
  weight_init: uniform
optim:
  batch_size: 8
  clip_grad_norm: 100
  ema_decay: 0.999
  energy_coefficient: 4
  eval_batch_size: 8
  eval_every: 10000
  force_coefficient: 100
  grad_accumulation_steps: 1
  load_balancing: atoms
  loss_energy: mae
  loss_force: l2mae
  lr_initial: 0.0004
  max_epochs: 3
  num_workers: 8
  optimizer: AdamW
  optimizer_params:
    weight_decay: 0.001
  scheduler: LambdaLR
  scheduler_params:
    epochs: 1009275
    lambda_type: cosine
    lr: 0.0004
    lr_min_factor: 0.01
    warmup_epochs: 3364.25
    warmup_factor: 0.2
outputs:
  energy:
    level: system
  forces:
    eval_on_free_atoms: true
    level: atom
    train_on_free_atoms: true
relax_dataset: {}
slurm:
  additional_parameters:
    constraint: volta32gb
  cpus_per_task: 9
  folder: /checkpoint/abhshkdz/open-catalyst-project/logs/equiformer_v2/8307793
  gpus_per_node: 8
  job_id: '8307793'
  job_name: eq2s_051701_allmd
  mem: 480GB
  nodes: 8
  ntasks_per_node: 8
  partition: learnaccel
  time: 4320
task:
  dataset: trajectory_lmdb_v2
  eval_on_free_atoms: true
  grad_input: atomic forces
  labels:
  - potential energy
  primary_metric: forces_mae
  train_on_free_atoms: true
test_dataset: {}
trainer: ocp
val_dataset: {}
INFO:root:Loading model: equiformer_v2
WARNING:root:equiformer_v2 (EquiformerV2) class is deprecated in favor of equiformer_v2_backbone_and_heads  (EquiformerV2BackboneAndHeads)
INFO:root:Loaded EquiformerV2 with 31058690 parameters.
INFO:root:Loading checkpoint in inference-only mode, not loading keys associated with trainer state!
/home/runner/work/fairchem/fairchem/src/fairchem/core/modules/normalization/normalizer.py:69: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  "mean": torch.tensor(state_dict["mean"]),
WARNING:root:No seed has been set in modelcheckpoint or OCPCalculator! Results may not be reproducible on re-run
      Step     Time          Energy          fmax
BFGS:    0 05:14:08       -0.237320        1.108962
BFGS:    1 05:14:10       -0.268118        1.496098
BFGS:    2 05:14:13       -0.365250        1.830217
BFGS:    3 05:14:16       -0.440508        1.458719
BFGS:    4 05:14:18       -0.477428        0.633667
BFGS:    5 05:14:21       -0.495394        0.289875
BFGS:    6 05:14:24       -0.501293        0.166063
BFGS:    7 05:14:26       -0.502321        0.153143
BFGS:    8 05:14:29       -0.516564        0.188849
BFGS:    9 05:14:31       -0.520255        0.186137
BFGS:   10 05:14:34       -0.521796        0.083719
BFGS:   11 05:14:37       -0.524095        0.202445
BFGS:   12 05:14:39       -0.526818        0.303606
BFGS:   13 05:14:42       -0.528846        0.263245
BFGS:   14 05:14:45       -0.529310        0.177213
BFGS:   15 05:14:47       -0.527592        0.076353
BFGS:   16 05:14:50       -0.528426        0.071222
BFGS:   17 05:14:52       -0.532503        0.158865
BFGS:   18 05:14:55       -0.533924        0.117473
BFGS:   19 05:14:58       -0.536292        0.012965
      Step     Time          Energy          fmax
BFGS:    0 05:15:00       -0.298822        1.325029
BFGS:    1 05:15:03       -0.310345        1.677442
BFGS:    2 05:15:06       -0.344664        1.117966
BFGS:    3 05:15:08       -0.477595        0.447185
BFGS:    4 05:15:11       -0.480132        0.253913
BFGS:    5 05:15:13       -0.483576        0.559245
BFGS:    6 05:15:16       -0.487928        0.725295
BFGS:    7 05:15:19       -0.495423        0.636133
BFGS:    8 05:15:21       -0.502394        0.313067
BFGS:    9 05:15:24       -0.506148        0.165851
BFGS:   10 05:15:27       -0.508851        0.230729
BFGS:   11 05:15:29       -0.511312        0.269364
BFGS:   12 05:15:32       -0.511888        0.212668
BFGS:   13 05:15:34       -0.513724        0.244636
BFGS:   14 05:15:37       -0.515430        0.153254
BFGS:   15 05:15:40       -0.517271        0.067139
BFGS:   16 05:15:42       -0.518542        0.106979
BFGS:   17 05:15:45       -0.519912        0.143927
BFGS:   18 05:15:47       -0.520899        0.129973
BFGS:   19 05:15:50       -0.521654        0.061610
BFGS:   20 05:15:53       -0.521514        0.051400
BFGS:   21 05:15:55       -0.519026        0.098905
BFGS:   22 05:15:58       -0.518120        0.103895
BFGS:   23 05:16:01       -0.517997        0.085759
BFGS:   24 05:16:03       -0.518382        0.059654
BFGS:   25 05:16:06       -0.519510        0.059368
BFGS:   26 05:16:08       -0.519461        0.066780
BFGS:   27 05:16:11       -0.520392        0.043860
      Step     Time          Energy          fmax
BFGS:    0 05:16:14       -0.371668        1.894466
BFGS:    1 05:16:16       -0.374110        2.303658
BFGS:    2 05:16:19       -0.409443        0.839549
BFGS:    3 05:16:21       -0.439106        0.835563
BFGS:    4 05:16:24       -0.496803        1.263908
BFGS:    5 05:16:26       -0.537505        0.227330
BFGS:    6 05:16:29       -0.537940        0.188922
BFGS:    7 05:16:32       -0.539342        0.244935
BFGS:    8 05:16:34       -0.543446        0.341750
BFGS:    9 05:16:37       -0.550959        0.379196
BFGS:   10 05:16:39       -0.554913        0.266122
BFGS:   11 05:16:42       -0.557846        0.115706
BFGS:   12 05:16:45       -0.558646        0.153109
BFGS:   13 05:16:47       -0.559208        0.207120
BFGS:   14 05:16:50       -0.560608        0.221718
BFGS:   15 05:16:52       -0.561000        0.187337
BFGS:   16 05:16:55       -0.562644        0.084098
BFGS:   17 05:16:58       -0.563746        0.062689
BFGS:   18 05:17:00       -0.563813        0.106593
BFGS:   19 05:17:03       -0.563797        0.118281
BFGS:   20 05:17:05       -0.564982        0.101520
BFGS:   21 05:17:08       -0.565487        0.052436
BFGS:   22 05:17:11       -0.565161        0.049262
      Step     Time          Energy          fmax
BFGS:    0 05:17:13       -0.393698        1.893087
BFGS:    1 05:17:16       -0.395853        2.303321
BFGS:    2 05:17:19       -0.430650        0.807664
BFGS:    3 05:17:21       -0.457029        0.788132
BFGS:    4 05:17:24       -0.509419        1.216115
BFGS:    5 05:17:27       -0.545435        0.284832
BFGS:    6 05:17:29       -0.546629        0.189050
BFGS:    7 05:17:32       -0.547702        0.340235
BFGS:    8 05:17:34       -0.549684        0.351311
BFGS:    9 05:17:37       -0.555498        0.250267
BFGS:   10 05:17:40       -0.558048        0.110556
BFGS:   11 05:17:42       -0.559785        0.106885
BFGS:   12 05:17:45       -0.560437        0.180129
BFGS:   13 05:17:47       -0.561112        0.212880
BFGS:   14 05:17:50       -0.562818        0.198912
BFGS:   15 05:17:53       -0.563637        0.149351
BFGS:   16 05:17:55       -0.565444        0.068107
BFGS:   17 05:17:58       -0.565625        0.042632
      Step     Time          Energy          fmax
BFGS:    0 05:18:00       -0.339144        1.450435
BFGS:    1 05:18:03       -0.346280        1.809480
BFGS:    2 05:18:06       -0.375137        1.068677
BFGS:    3 05:18:08       -0.473164        1.639453
BFGS:    4 05:18:11       -0.497438        0.960804
BFGS:    5 05:18:13       -0.501646        0.372217
BFGS:    6 05:18:16       -0.504997        0.435051
BFGS:    7 05:18:19       -0.516770        0.834697
BFGS:    8 05:18:21       -0.531439        0.952980
BFGS:    9 05:18:24       -0.542019        0.631848
BFGS:   10 05:18:26       -0.548329        0.199602
BFGS:   11 05:18:29       -0.549886        0.391562
BFGS:   12 05:18:32       -0.554345        0.617441
BFGS:   13 05:18:34       -0.559074        0.543721
BFGS:   14 05:18:37       -0.561572        0.215104
BFGS:   15 05:18:39       -0.562589        0.093295
BFGS:   16 05:18:42       -0.562970        0.192550
BFGS:   17 05:18:45       -0.564565        0.238265
BFGS:   18 05:18:47       -0.564704        0.183571
BFGS:   19 05:18:50       -0.565238        0.078946
BFGS:   20 05:18:52       -0.564843        0.042639
      Step     Time          Energy          fmax
BFGS:    0 05:18:55       -0.216587        0.935367
BFGS:    1 05:18:58       -0.242902        0.903389
BFGS:    2 05:19:00       -0.405015        2.168278
BFGS:    3 05:19:03       -0.430394        1.514976
BFGS:    4 05:19:05       -0.445190        0.286703
BFGS:    5 05:19:08       -0.452484        0.301474
BFGS:    6 05:19:11       -0.455610        0.294461
BFGS:    7 05:19:13       -0.469697        0.731202
BFGS:    8 05:19:16       -0.476438        0.462668
BFGS:    9 05:19:18       -0.480633        0.098986
BFGS:   10 05:19:21       -0.481393        0.137786
BFGS:   11 05:19:24       -0.482715        0.212172
BFGS:   12 05:19:26       -0.484034        0.199055
BFGS:   13 05:19:29       -0.485597        0.122450
BFGS:   14 05:19:32       -0.486556        0.076524
BFGS:   15 05:19:34       -0.487420        0.093154
BFGS:   16 05:19:37       -0.488743        0.153082
BFGS:   17 05:19:39       -0.489861        0.191178
BFGS:   18 05:19:42       -0.490602        0.150069
BFGS:   19 05:19:45       -0.491825        0.063978
BFGS:   20 05:19:47       -0.493306        0.071337
BFGS:   21 05:19:50       -0.494668        0.116884
BFGS:   22 05:19:52       -0.496598        0.141868
BFGS:   23 05:19:55       -0.498546        0.108177
BFGS:   24 05:19:58       -0.499616        0.046654
      Step     Time          Energy          fmax
BFGS:    0 05:20:00       -0.288106        1.538582
BFGS:    1 05:20:03       -0.295145        1.893453
BFGS:    2 05:20:06       -0.325364        1.035357
BFGS:    3 05:20:08       -0.416481        1.425735
BFGS:    4 05:20:11       -0.495297        0.588352
BFGS:    5 05:20:13       -0.498872        0.300176
BFGS:    6 05:20:16       -0.502449        0.499205
BFGS:    7 05:20:19       -0.505996        0.663448
BFGS:    8 05:20:21       -0.514951        0.612536
BFGS:    9 05:20:24       -0.522226        0.315855
BFGS:   10 05:20:26       -0.527108        0.189013
BFGS:   11 05:20:29       -0.530116        0.391530
BFGS:   12 05:20:32       -0.532272        0.506257
BFGS:   13 05:20:34       -0.533588        0.500013
BFGS:   14 05:20:37       -0.536025        0.358184
BFGS:   15 05:20:39       -0.538607        0.152759
BFGS:   16 05:20:42       -0.540549        0.178378
BFGS:   17 05:20:45       -0.541748        0.295552
BFGS:   18 05:20:47       -0.543561        0.360881
BFGS:   19 05:20:50       -0.549489        0.257151
BFGS:   20 05:20:52       -0.552111        0.100925
BFGS:   21 05:20:55       -0.551678        0.118759
BFGS:   22 05:20:58       -0.552006        0.242973
BFGS:   23 05:21:00       -0.552935        0.277257
BFGS:   24 05:21:03       -0.556683        0.216423
BFGS:   25 05:21:05       -0.560582        0.091024
BFGS:   26 05:21:08       -0.560497        0.060937
BFGS:   27 05:21:11       -0.561251        0.095772
BFGS:   28 05:21:13       -0.561554        0.133474
BFGS:   29 05:21:16       -0.562724        0.116287
BFGS:   30 05:21:18       -0.563512        0.049782
      Step     Time          Energy          fmax
BFGS:    0 05:21:21       -0.245962        1.293238
BFGS:    1 05:21:23       -0.258293        1.513580
BFGS:    2 05:21:26       -0.297281        1.279717
BFGS:    3 05:21:29       -0.357780        3.182815
BFGS:    4 05:21:31       -0.436799        0.764074
BFGS:    5 05:21:34       -0.457999        0.329738
BFGS:    6 05:21:37       -0.470130        0.459920
BFGS:    7 05:21:39       -0.477425        0.451763
BFGS:    8 05:21:42       -0.506299        0.473308
BFGS:    9 05:21:45       -0.513471        0.146628
BFGS:   10 05:21:47       -0.515903        0.164277
BFGS:   11 05:21:50       -0.518632        0.179617
BFGS:   12 05:21:52       -0.522371        0.144118
BFGS:   13 05:21:55       -0.524994        0.085558
BFGS:   14 05:21:58       -0.525516        0.087178
BFGS:   15 05:22:00       -0.525043        0.072836
BFGS:   16 05:22:03       -0.525620        0.052326
BFGS:   17 05:22:05       -0.526401        0.042001

Parse the trajectories and post-process#

As a post-processing step we check to see if:

  1. the adsorbate desorbed

  2. the adsorbate disassociated

  3. the adsorbate intercalated

  4. the surface has changed

We check these because they effect our referencing scheme and may result in erroneous energies. For (4), the relaxed surface should really be supplied as well. It will be necessary when correcting the SP / RX energies later. Since we don’t have it here, we will ommit supplying it, and the detector will instead compare the initial and final slab from the adsorbate-slab relaxation trajectory. If a relaxed slab is provided, the detector will compare it and the slab after the adsorbate-slab relaxation. The latter is more correct! Note: for the results in the AdsorbML paper, we did not check if the adsorbate was intercalated (is_adsorbate_intercalated()) because it is a new addition.

# Iterate over trajs to extract results
results = []
for file in glob(f"data/{bulk}_{adsorbate}/*.traj"):
    rx_id = file.split("/")[-1].split(".")[0]
    traj = ase.io.read(file, ":")
    
    # Check to see if the trajectory is anomolous
    initial_atoms = traj[0]
    final_atoms = traj[-1]
    atom_tags = initial_atoms.get_tags()
    detector = DetectTrajAnomaly(initial_atoms, final_atoms, atom_tags)
    anom = (
        detector.is_adsorbate_dissociated()
        or detector.is_adsorbate_desorbed()
        or detector.has_surface_changed()
        or detector.is_adsorbate_intercalated()
    )
    rx_energy = traj[-1].get_potential_energy()
    results.append({"relaxation_idx": rx_id, "relaxed_atoms": traj[-1],
                    "relaxed_energy_ml": rx_energy, "anomolous": anom})
df = pd.DataFrame(results)
df
relaxation_idx relaxed_atoms relaxed_energy_ml anomolous
0 1 (Atom('Cu', [-1.3000465215529715, 2.2517466275... -0.520392 False
1 3 (Atom('Cu', [-1.3000465215529715, 2.2517466275... -0.565625 False
2 4 (Atom('Cu', [-1.3000465215529715, 2.2517466275... -0.564843 False
3 7 (Atom('Cu', [-1.3000465215529715, 2.2517466275... -0.526401 False
4 5 (Atom('Cu', [-1.3000465215529715, 2.2517466275... -0.499616 False
5 0 (Atom('Cu', [-1.3000465215529715, 2.2517466275... -0.536292 False
6 6 (Atom('Cu', [-1.3000465215529715, 2.2517466275... -0.563512 False
7 2 (Atom('Cu', [-1.3000465215529715, 2.2517466275... -0.565161 False
#scrap anomalies
df = df[~df.anomolous].copy().reset_index()

(Optional) Deduplicate structures#

We may have enumerated very similar structures or structures may have relaxed to the same configuration. For this reason, it is advantageous to cull systems if they are very similar. This results in marginal improvements in the recall metrics we calculated for AdsorbML, so it wasnt implemented there. It is, however, a good way to prevent wasteful VASP calculations. You can also imagine that if we would have enumerated 1000 configs per slab adsorbate combo rather than 100 for AdsorbML, it is more likely that having redundant systems would reduce performance, so its a good thing to keep in mind. This may be done by eye for a small number of systems, but with many systems it is easier to use an automated approach. Here is an example of one such approach, which uses a SOAP descriptor to find similar systems.

# Extract the configs and their energies
def deduplicate(configs_for_deduplication: list,
                adsorbate_binding_index: int,
                cosine_similarity = 1e-3,
               ):
    """
    A function that may be used to deduplicate similar structures.
    Among duplicate entries, the one with the lowest energy will be kept.
    
    Args:
        configs_for_deduplication: a list of ML relaxed adsorbate-
            surface configurations.
        cosine_similarity: The cosine simularity value above which,
            configurations are considered duplicate.
            
    Returns:
        (list): the indices of configs which should be kept as non-duplicate
    """
    
    energies_for_deduplication = np.array([atoms.get_potential_energy() for atoms in configs_for_deduplication])
    # Instantiate the soap descriptor
    soap = SOAP(
        species=np.unique(configs_for_deduplication[0].get_chemical_symbols()),
        r_cut = 2.0,
        n_max=6,
        l_max=3,
        periodic=True,
    )
    #Figure out which index cooresponds to 
    ads_len = list(configs_for_deduplication[0].get_tags()).count(2)
    position_idx = -1*(ads_len-adsorbate_binding_index)
    # Iterate over the systems to get the SOAP vectors
    soap_desc = []
    for config in configs_for_deduplication:
        soap_ex = soap.create(config, centers=[position_idx])
        soap_desc.extend(soap_ex)

    soap_descs = np.vstack(soap_desc)

    #Use euclidean distance to assess similarity
    distance = squareform(pdist(soap_descs, metric="cosine"))

    bool_matrix = np.where(distance <= cosine_similarity, 1, 0)
    # For configs that are found to be similar, just keep the lowest energy one
    idxs_to_keep = []
    pass_idxs = []
    for idx, row in enumerate(bool_matrix):
        if idx in pass_idxs:
            continue
            
        elif sum(row) == 1:
            idxs_to_keep.append(idx)
        else:
            same_idxs = [row_idx for row_idx, val in enumerate(row) if val == 1]
            pass_idxs.extend(same_idxs)
            # Pick the one with the lowest energy by ML
            min_e = min(energies_for_deduplication[same_idxs])
            idxs_to_keep.append(list(energies_for_deduplication).index(min_e))
    return idxs_to_keep
configs_for_deduplication =  df.relaxed_atoms.tolist()
idxs_to_keep = deduplicate(configs_for_deduplication, adsorbate.binding_indices[0])
# Flip through your configurations to check them out (and make sure deduplication looks good)
print(idxs_to_keep)
view_x3d_n(configs_for_deduplication[2].repeat((2,2,1)))
df = df.iloc[idxs_to_keep]
low_e_values = np.round(df.sort_values(by = "relaxed_energy_ml").relaxed_energy_ml.tolist()[0:5],3)
print(f"The lowest 5 energies are: {low_e_values}")
df
The lowest 5 energies are: [-0.566 -0.536 -0.526 -0.5  ]
index relaxation_idx relaxed_atoms relaxed_energy_ml anomolous
3 3 7 (Atom('Cu', [-1.3000465215529715, 2.2517466275... -0.526401 False
1 1 3 (Atom('Cu', [-1.3000465215529715, 2.2517466275... -0.565625 False
4 4 5 (Atom('Cu', [-1.3000465215529715, 2.2517466275... -0.499616 False
5 5 0 (Atom('Cu', [-1.3000465215529715, 2.2517466275... -0.536292 False

Write VASP input files#

This assumes you have access to VASP pseudopotentials and the right environment variables configured for ASE. The default VASP flags (which are equivalent to those used to make OC20) are located in ocdata.utils.vasp. Alternatively, you may pass your own vasp flags to the write_vasp_input_files function as vasp_flags.

# Grab the 5 systems with the lowest energy
configs_for_dft = df.sort_values(by = "relaxed_energy_ml").relaxed_atoms.tolist()[0:5]
config_idxs = df.sort_values(by = "relaxed_energy_ml").relaxation_idx.tolist()[0:5]

# Write the inputs
for idx, config in enumerate(configs_for_dft):
    os.mkdir(f"data/{config_idxs[idx]}")
    write_vasp_input_files(config, outdir = f"data/{config_idxs[idx]}/")