AdsorbML tutorial#

from fairchem.core.common.relaxation.ase_utils import OCPCalculator
import ase.io
from ase.optimize import BFGS

from fairchem.data.oc.core import Adsorbate, AdsorbateSlabConfig, Bulk, Slab
import os
from glob import glob
import pandas as pd
from fairchem.data.oc.utils import DetectTrajAnomaly
from fairchem.data.oc.utils.vasp import write_vasp_input_files

# Optional - see below
import numpy as np
from dscribe.descriptors import SOAP
from scipy.spatial.distance import pdist, squareform
from x3dase.visualize import view_x3d_n

Enumerate the adsorbate-slab configurations to run relaxations on#

AdsorbML incorporates random placement, which is especially useful for more complicated adsorbates which may have many degrees of freedom. I have opted sample a few random placements and a few heuristic. Here I am using *CO on copper (1,1,1) as an example.

bulk_src_id = "mp-30"
adsorbate_smiles = "*CO"

bulk = Bulk(bulk_src_id_from_db = bulk_src_id)
adsorbate = Adsorbate(adsorbate_smiles_from_db=adsorbate_smiles)
slabs = Slab.from_bulk_get_specific_millers(bulk = bulk, specific_millers=(1,1,1))

# There may be multiple slabs with this miller index.
# For demonstrative purposes we will take the first entry.
slab = slabs[0]
Downloading src/fairchem/data/oc/databases/pkls/bulks.pkl...
# Perform heuristic placements
heuristic_adslabs = AdsorbateSlabConfig(slabs[0], adsorbate, mode="heuristic")

# Perform random placements
# (for AdsorbML we use `num_sites = 100` but we will use 4 for brevity here)
random_adslabs = AdsorbateSlabConfig(slabs[0], adsorbate, mode="random_site_heuristic_placement", num_sites = 4)

Run ML relaxations:#

There are 2 options for how to do this.

  1. Using OCPCalculator as the calculator within the ASE framework

  2. By writing objects to lmdb and relaxing them using main.py in the ocp repo

(1) is really only adequate for small stuff and it is what I will show here, but if you plan to run many relaxations, you should definitely use (2). More details about writing lmdbs has been provided here - follow the IS2RS/IS2RE instructions. And more information about running relaxations once the lmdb has been written is here.

You need to provide the calculator with a path to a model checkpoint file. That can be downloaded here

from fairchem.core.common.relaxation.ase_utils import OCPCalculator
from fairchem.core.models.model_registry import model_name_to_local_file
import os

checkpoint_path = model_name_to_local_file('EquiformerV2-31M-S2EF-OC20-All+MD', local_cache='/tmp/fairchem_checkpoints/')

os.makedirs(f"data/{bulk}_{adsorbate}", exist_ok=True)

# Define the calculator
calc = OCPCalculator(checkpoint_path=checkpoint_path) # if you have a gpu, add `cpu=False` to speed up calculations

adslabs = [*heuristic_adslabs.atoms_list, *random_adslabs.atoms_list]
# Set up the calculator
for idx, adslab in enumerate(adslabs):
    adslab.calc = calc
    opt = BFGS(adslab, trajectory=f"data/{bulk}_{adsorbate}/{idx}.traj")
    opt.run(fmax=0.05, steps=100) # For the AdsorbML results we used fmax = 0.02 and steps = 300, but we will use less strict values for brevity.
/home/runner/work/fairchem/fairchem/src/fairchem/core/models/escn/so3.py:23: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  _Jd = torch.load(os.path.join(os.path.dirname(__file__), "Jd.pt"))
/home/runner/work/fairchem/fairchem/src/fairchem/core/models/scn/spherical_harmonics.py:23: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  _Jd = torch.load(os.path.join(os.path.dirname(__file__), "Jd.pt"))
/home/runner/work/fairchem/fairchem/src/fairchem/core/models/equiformer_v2/wigner.py:10: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  _Jd = torch.load(os.path.join(os.path.dirname(__file__), "Jd.pt"))
/home/runner/work/fairchem/fairchem/src/fairchem/core/models/equiformer_v2/layer_norm.py:75: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
  @torch.cuda.amp.autocast(enabled=False)
/home/runner/work/fairchem/fairchem/src/fairchem/core/models/equiformer_v2/layer_norm.py:175: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
  @torch.cuda.amp.autocast(enabled=False)
/home/runner/work/fairchem/fairchem/src/fairchem/core/models/equiformer_v2/layer_norm.py:263: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
  @torch.cuda.amp.autocast(enabled=False)
/home/runner/work/fairchem/fairchem/src/fairchem/core/models/equiformer_v2/layer_norm.py:357: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
  @torch.cuda.amp.autocast(enabled=False)
/home/runner/work/fairchem/fairchem/src/fairchem/core/common/relaxation/ase_utils.py:191: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  checkpoint = torch.load(checkpoint_path, map_location=torch.device("cpu"))
WARNING:root:Detected old config, converting to new format. Consider updating to avoid potential incompatibilities.
INFO:root:amp: true
cmd:
  checkpoint_dir: /home/runner/work/fairchem/fairchem/docs/tutorials/checkpoints/2024-12-19-04-35-12
  commit: 83e1a53
  identifier: ''
  logs_dir: /home/runner/work/fairchem/fairchem/docs/tutorials/logs/wandb/2024-12-19-04-35-12
  print_every: 100
  results_dir: /home/runner/work/fairchem/fairchem/docs/tutorials/results/2024-12-19-04-35-12
  seed: null
  timestamp_id: 2024-12-19-04-35-12
  version: 1.4.0
dataset:
  format: trajectory_lmdb_v2
  grad_target_mean: 0.0
  grad_target_std: 2.887317180633545
  key_mapping:
    force: forces
    y: energy
  normalize_labels: true
  target_mean: -0.7554450631141663
  target_std: 2.887317180633545
  transforms:
    normalizer:
      energy:
        mean: -0.7554450631141663
        stdev: 2.887317180633545
      forces:
        mean: 0.0
        stdev: 2.887317180633545
evaluation_metrics:
  metrics:
    energy:
    - mae
    forces:
    - forcesx_mae
    - forcesy_mae
    - forcesz_mae
    - mae
    - cosine_similarity
    - magnitude_error
    misc:
    - energy_forces_within_threshold
  primary_metric: forces_mae
gp_gpus: null
gpus: 0
logger: wandb
loss_functions:
- energy:
    coefficient: 4
    fn: mae
- forces:
    coefficient: 100
    fn: l2mae
model:
  alpha_drop: 0.1
  attn_activation: silu
  attn_alpha_channels: 64
  attn_hidden_channels: 64
  attn_value_channels: 16
  distance_function: gaussian
  drop_path_rate: 0.1
  edge_channels: 128
  ffn_activation: silu
  ffn_hidden_channels: 128
  grid_resolution: 18
  lmax_list:
  - 4
  max_neighbors: 20
  max_num_elements: 90
  max_radius: 12.0
  mmax_list:
  - 2
  name: equiformer_v2
  norm_type: layer_norm_sh
  num_distance_basis: 512
  num_heads: 8
  num_layers: 8
  num_sphere_samples: 128
  otf_graph: true
  proj_drop: 0.0
  regress_forces: true
  sphere_channels: 128
  use_atom_edge_embedding: true
  use_gate_act: false
  use_grid_mlp: true
  use_pbc: true
  use_s2_act_attn: false
  weight_init: uniform
optim:
  batch_size: 8
  clip_grad_norm: 100
  ema_decay: 0.999
  energy_coefficient: 4
  eval_batch_size: 8
  eval_every: 10000
  force_coefficient: 100
  grad_accumulation_steps: 1
  load_balancing: atoms
  loss_energy: mae
  loss_force: l2mae
  lr_initial: 0.0004
  max_epochs: 3
  num_workers: 8
  optimizer: AdamW
  optimizer_params:
    weight_decay: 0.001
  scheduler: LambdaLR
  scheduler_params:
    epochs: 1009275
    lambda_type: cosine
    lr: 0.0004
    lr_min_factor: 0.01
    warmup_epochs: 3364.25
    warmup_factor: 0.2
outputs:
  energy:
    level: system
  forces:
    eval_on_free_atoms: true
    level: atom
    train_on_free_atoms: true
relax_dataset: {}
slurm:
  additional_parameters:
    constraint: volta32gb
  cpus_per_task: 9
  folder: /checkpoint/abhshkdz/open-catalyst-project/logs/equiformer_v2/8307793
  gpus_per_node: 8
  job_id: '8307793'
  job_name: eq2s_051701_allmd
  mem: 480GB
  nodes: 8
  ntasks_per_node: 8
  partition: learnaccel
  time: 4320
task:
  dataset: trajectory_lmdb_v2
  eval_on_free_atoms: true
  grad_input: atomic forces
  labels:
  - potential energy
  primary_metric: forces_mae
  train_on_free_atoms: true
test_dataset: {}
trainer: ocp
val_dataset: {}
INFO:root:Loading model: equiformer_v2
WARNING:root:equiformer_v2 (EquiformerV2) class is deprecated in favor of equiformer_v2_backbone_and_heads  (EquiformerV2BackboneAndHeads)
INFO:root:Loaded EquiformerV2 with 31058690 parameters.
INFO:root:Loading checkpoint in inference-only mode, not loading keys associated with trainer state!
/home/runner/work/fairchem/fairchem/src/fairchem/core/modules/normalization/normalizer.py:69: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  "mean": torch.tensor(state_dict["mean"]),
WARNING:root:No seed has been set in modelcheckpoint or OCPCalculator! Results may not be reproducible on re-run
/home/runner/work/fairchem/fairchem/src/fairchem/core/trainers/ocp_trainer.py:472: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
  with torch.cuda.amp.autocast(enabled=self.scaler is not None):
      Step     Time          Energy          fmax
BFGS:    0 04:34:51       -0.459410        1.868565
BFGS:    1 04:34:54       -0.456782        2.288704
BFGS:    2 04:34:56       -0.487793        0.683244
BFGS:    3 04:34:59       -0.504864        0.615533
BFGS:    4 04:35:02       -0.537666        1.127299
BFGS:    5 04:35:04       -0.556772        0.664653
BFGS:    6 04:35:07       -0.560836        0.180722
BFGS:    7 04:35:10       -0.561142        0.093567
BFGS:    8 04:35:12       -0.560791        0.177767
BFGS:    9 04:35:15       -0.561797        0.179635
BFGS:   10 04:35:18       -0.563631        0.108874
BFGS:   11 04:35:20       -0.565386        0.042180
      Step     Time          Energy          fmax
BFGS:    0 04:35:23       -0.438439        1.855576
BFGS:    1 04:35:26       -0.437413        2.273213
BFGS:    2 04:35:28       -0.469418        0.720866
BFGS:    3 04:35:31       -0.488605        0.678357
BFGS:    4 04:35:34       -0.526483        1.205885
BFGS:    5 04:35:36       -0.551586        0.633520
BFGS:    6 04:35:39       -0.555082        0.192150
BFGS:    7 04:35:42       -0.556277        0.137095
BFGS:    8 04:35:44       -0.556391        0.230259
BFGS:    9 04:35:47       -0.557781        0.251012
BFGS:   10 04:35:50       -0.561437        0.175503
BFGS:   11 04:35:52       -0.563956        0.074275
BFGS:   12 04:35:55       -0.564337        0.072650
BFGS:   13 04:35:58       -0.564774        0.201003
BFGS:   14 04:36:00       -0.565338        0.217049
BFGS:   15 04:36:03       -0.566763        0.125956
BFGS:   16 04:36:06       -0.567156        0.059064
BFGS:   17 04:36:08       -0.567241        0.020441
      Step     Time          Energy          fmax
BFGS:    0 04:36:11       -0.223862        1.068471
BFGS:    1 04:36:14       -0.255421        1.501444
BFGS:    2 04:36:16       -0.349241        1.781516
BFGS:    3 04:36:19       -0.432403        1.496218
BFGS:    4 04:36:22       -0.467393        0.711297
BFGS:    5 04:36:25       -0.486816        0.332839
BFGS:    6 04:36:27       -0.494269        0.225754
BFGS:    7 04:36:30       -0.496147        0.213818
BFGS:    8 04:36:33       -0.512914        0.370679
BFGS:    9 04:36:35       -0.515949        0.203652
BFGS:   10 04:36:38       -0.518369        0.110932
BFGS:   11 04:36:41       -0.519801        0.197030
BFGS:   12 04:36:44       -0.522370        0.284527
BFGS:   13 04:36:46       -0.525197        0.330019
BFGS:   14 04:36:49       -0.527369        0.250453
BFGS:   15 04:36:52       -0.528519        0.197420
BFGS:   16 04:36:54       -0.528999        0.061521
BFGS:   17 04:36:57       -0.529947        0.108310
BFGS:   18 04:37:00       -0.531005        0.131405
BFGS:   19 04:37:02       -0.532644        0.096263
BFGS:   20 04:37:05       -0.534341        0.018336
      Step     Time          Energy          fmax
BFGS:    0 04:37:08       -0.346452        1.331664
BFGS:    1 04:37:10       -0.355163        1.677248
BFGS:    2 04:37:13       -0.387176        1.073220
BFGS:    3 04:37:16       -0.495847        0.876390
BFGS:    4 04:37:18       -0.499013        0.253774
BFGS:    5 04:37:21       -0.500018        0.266305
BFGS:    6 04:37:23       -0.501739        0.469810
BFGS:    7 04:37:26       -0.504657        0.511919
BFGS:    8 04:37:29       -0.511111        0.323575
BFGS:    9 04:37:31       -0.514555        0.111588
BFGS:   10 04:37:34       -0.516308        0.186337
BFGS:   11 04:37:37       -0.517669        0.348970
BFGS:   12 04:37:39       -0.519247        0.332590
BFGS:   13 04:37:42       -0.519927        0.156889
BFGS:   14 04:37:45       -0.520484        0.073965
BFGS:   15 04:37:47       -0.522283        0.072732
BFGS:   16 04:37:50       -0.522703        0.103519
BFGS:   17 04:37:53       -0.523155        0.105630
BFGS:   18 04:37:55       -0.523336        0.070875
BFGS:   19 04:37:58       -0.523522        0.021420
      Step     Time          Energy          fmax
BFGS:    0 04:38:01       -0.204919        1.019968
BFGS:    1 04:38:03       -0.235486        1.044380
BFGS:    2 04:38:06       -0.447195        1.715526
BFGS:    3 04:38:09       -0.448526        0.919786
BFGS:    4 04:38:11       -0.463058        0.239707
BFGS:    5 04:38:14       -0.465922        0.207464
BFGS:    6 04:38:17       -0.480176        0.310411
BFGS:    7 04:38:20       -0.484393        0.260539
BFGS:    8 04:38:22       -0.490654        0.129870
BFGS:    9 04:38:25       -0.491494        0.100420
BFGS:   10 04:38:28       -0.492974        0.205766
BFGS:   11 04:38:30       -0.495508        0.274228
BFGS:   12 04:38:33       -0.498676        0.249517
BFGS:   13 04:38:36       -0.500498        0.179994
BFGS:   14 04:38:39       -0.501941        0.082222
BFGS:   15 04:38:41       -0.503042        0.128629
BFGS:   16 04:38:44       -0.505908        0.197258
BFGS:   17 04:38:47       -0.507952        0.227101
BFGS:   18 04:38:49       -0.509744        0.095838
BFGS:   19 04:38:52       -0.510938        0.045532
      Step     Time          Energy          fmax
BFGS:    0 04:38:55       -0.363422        1.582203
BFGS:    1 04:38:58       -0.367915        1.967607
BFGS:    2 04:39:00       -0.397860        0.969030
BFGS:    3 04:39:03       -0.458684        1.324727
BFGS:    4 04:39:06       -0.507919        1.341870
BFGS:    5 04:39:09       -0.521198        0.430499
BFGS:    6 04:39:11       -0.521596        0.237984
BFGS:    7 04:39:14       -0.522814        0.446505
BFGS:    8 04:39:17       -0.526364        0.610117
BFGS:    9 04:39:20       -0.536132        0.557420
BFGS:   10 04:39:22       -0.541791        0.280650
BFGS:   11 04:39:25       -0.546443        0.202002
BFGS:   12 04:39:28       -0.549416        0.435544
BFGS:   13 04:39:31       -0.552359        0.598911
BFGS:   14 04:39:33       -0.555305        0.495776
BFGS:   15 04:39:36       -0.557143        0.223133
BFGS:   16 04:39:39       -0.558239        0.090110
BFGS:   17 04:39:41       -0.557968        0.154109
BFGS:   18 04:39:44       -0.559640        0.227958
BFGS:   19 04:39:47       -0.560797        0.199341
BFGS:   20 04:39:50       -0.562607        0.095084
BFGS:   21 04:39:52       -0.562299        0.056508
BFGS:   22 04:39:55       -0.561852        0.105514
BFGS:   23 04:39:58       -0.561460        0.160962
BFGS:   24 04:40:01       -0.561511        0.200121
BFGS:   25 04:40:03       -0.561369        0.173253
BFGS:   26 04:40:06       -0.562499        0.091996
BFGS:   27 04:40:09       -0.563333        0.045304
      Step     Time          Energy          fmax
BFGS:    0 04:40:12       -0.260118        0.997321
BFGS:    1 04:40:14       -0.277445        1.208738
BFGS:    2 04:40:17       -0.342610        1.628580
BFGS:    3 04:40:20       -0.303270        2.802871
BFGS:    4 04:40:23       -0.421471        0.463981
BFGS:    5 04:40:25       -0.437133        0.260185
BFGS:    6 04:40:28       -0.446305        0.325454
BFGS:    7 04:40:31       -0.448396        0.315088
BFGS:    8 04:40:33       -0.472492        0.180168
BFGS:    9 04:40:36       -0.475543        0.187912
BFGS:   10 04:40:39       -0.485399        0.251360
BFGS:   11 04:40:41       -0.494492        0.289222
BFGS:   12 04:40:44       -0.501602        0.281359
BFGS:   13 04:40:47       -0.505959        0.256589
BFGS:   14 04:40:49       -0.510611        0.239812
BFGS:   15 04:40:52       -0.515758        0.332673
BFGS:   16 04:40:55       -0.521807        0.410106
BFGS:   17 04:40:57       -0.529884        0.161644
BFGS:   18 04:41:00       -0.531381        0.115839
BFGS:   19 04:41:03       -0.532575        0.180569
BFGS:   20 04:41:05       -0.535910        0.130149
BFGS:   21 04:41:08       -0.538829        0.167357
BFGS:   22 04:41:11       -0.540487        0.102388
BFGS:   23 04:41:13       -0.540858        0.120017
BFGS:   24 04:41:16       -0.540348        0.143728
BFGS:   25 04:41:19       -0.543803        0.113517
BFGS:   26 04:41:22       -0.548340        0.119420
BFGS:   27 04:41:24       -0.553366        0.190281
BFGS:   28 04:41:27       -0.555767        0.166831
BFGS:   29 04:41:30       -0.557731        0.076734
BFGS:   30 04:41:32       -0.559192        0.105692
BFGS:   31 04:41:35       -0.560685        0.214869
BFGS:   32 04:41:38       -0.561590        0.226059
BFGS:   33 04:41:40       -0.562678        0.136892
BFGS:   34 04:41:43       -0.562922        0.060237
BFGS:   35 04:41:46       -0.563709        0.072827
BFGS:   36 04:41:48       -0.565307        0.059519
BFGS:   37 04:41:51       -0.565615        0.050173
BFGS:   38 04:41:54       -0.567249        0.057912
BFGS:   39 04:41:56       -0.565712        0.043019
      Step     Time          Energy          fmax
BFGS:    0 04:41:59       -0.153929        0.774063
BFGS:    1 04:42:02       -0.175161        0.760828
BFGS:    2 04:42:04       -0.333972        3.025137
BFGS:    3 04:42:07       -0.361913        1.141759
BFGS:    4 04:42:10       -0.396643        0.482422
BFGS:    5 04:42:12       -0.408846        0.722350
BFGS:    6 04:42:15       -0.419923        0.691916
BFGS:    7 04:42:17       -0.458407        1.609237
BFGS:    8 04:42:20       -0.476559        0.608795
BFGS:    9 04:42:23       -0.479775        0.182417
BFGS:   10 04:42:25       -0.481104        0.348030
BFGS:   11 04:42:28       -0.483741        0.276739
BFGS:   12 04:42:31       -0.490235        0.124686
BFGS:   13 04:42:33       -0.493777        0.167654
BFGS:   14 04:42:36       -0.495376        0.173300
BFGS:   15 04:42:39       -0.497525        0.128310
BFGS:   16 04:42:41       -0.499120        0.049188

Parse the trajectories and post-process#

As a post-processing step we check to see if:

  1. the adsorbate desorbed

  2. the adsorbate disassociated

  3. the adsorbate intercalated

  4. the surface has changed

We check these because they effect our referencing scheme and may result in erroneous energies. For (4), the relaxed surface should really be supplied as well. It will be necessary when correcting the SP / RX energies later. Since we don’t have it here, we will ommit supplying it, and the detector will instead compare the initial and final slab from the adsorbate-slab relaxation trajectory. If a relaxed slab is provided, the detector will compare it and the slab after the adsorbate-slab relaxation. The latter is more correct! Note: for the results in the AdsorbML paper, we did not check if the adsorbate was intercalated (is_adsorbate_intercalated()) because it is a new addition.

# Iterate over trajs to extract results
results = []
for file in glob(f"data/{bulk}_{adsorbate}/*.traj"):
    rx_id = file.split("/")[-1].split(".")[0]
    traj = ase.io.read(file, ":")
    
    # Check to see if the trajectory is anomolous
    initial_atoms = traj[0]
    final_atoms = traj[-1]
    atom_tags = initial_atoms.get_tags()
    detector = DetectTrajAnomaly(initial_atoms, final_atoms, atom_tags)
    anom = (
        detector.is_adsorbate_dissociated()
        or detector.is_adsorbate_desorbed()
        or detector.has_surface_changed()
        or detector.is_adsorbate_intercalated()
    )
    rx_energy = traj[-1].get_potential_energy()
    results.append({"relaxation_idx": rx_id, "relaxed_atoms": traj[-1],
                    "relaxed_energy_ml": rx_energy, "anomolous": anom})
df = pd.DataFrame(results)
df
relaxation_idx relaxed_atoms relaxed_energy_ml anomolous
0 1 (Atom('Cu', [-1.3000465215529715, 2.2517466275... -0.567241 False
1 7 (Atom('Cu', [-1.3000465215529715, 2.2517466275... -0.499120 False
2 0 (Atom('Cu', [-1.3000465215529715, 2.2517466275... -0.565386 False
3 5 (Atom('Cu', [-1.3000465215529715, 2.2517466275... -0.563333 False
4 2 (Atom('Cu', [-1.3000465215529715, 2.2517466275... -0.534341 False
5 3 (Atom('Cu', [-1.3000465215529715, 2.2517466275... -0.523522 False
6 4 (Atom('Cu', [-1.3000465215529715, 2.2517466275... -0.510938 False
7 6 (Atom('Cu', [-1.3000465215529715, 2.2517466275... -0.565712 False
#scrap anomalies
df = df[~df.anomolous].copy().reset_index()

(Optional) Deduplicate structures#

We may have enumerated very similar structures or structures may have relaxed to the same configuration. For this reason, it is advantageous to cull systems if they are very similar. This results in marginal improvements in the recall metrics we calculated for AdsorbML, so it wasnt implemented there. It is, however, a good way to prevent wasteful VASP calculations. You can also imagine that if we would have enumerated 1000 configs per slab adsorbate combo rather than 100 for AdsorbML, it is more likely that having redundant systems would reduce performance, so its a good thing to keep in mind. This may be done by eye for a small number of systems, but with many systems it is easier to use an automated approach. Here is an example of one such approach, which uses a SOAP descriptor to find similar systems.

# Extract the configs and their energies
def deduplicate(configs_for_deduplication: list,
                adsorbate_binding_index: int,
                cosine_similarity = 1e-3,
               ):
    """
    A function that may be used to deduplicate similar structures.
    Among duplicate entries, the one with the lowest energy will be kept.
    
    Args:
        configs_for_deduplication: a list of ML relaxed adsorbate-
            surface configurations.
        cosine_similarity: The cosine simularity value above which,
            configurations are considered duplicate.
            
    Returns:
        (list): the indices of configs which should be kept as non-duplicate
    """
    
    energies_for_deduplication = np.array([atoms.get_potential_energy() for atoms in configs_for_deduplication])
    # Instantiate the soap descriptor
    soap = SOAP(
        species=np.unique(configs_for_deduplication[0].get_chemical_symbols()),
        r_cut = 2.0,
        n_max=6,
        l_max=3,
        periodic=True,
    )
    #Figure out which index cooresponds to 
    ads_len = list(configs_for_deduplication[0].get_tags()).count(2)
    position_idx = -1*(ads_len-adsorbate_binding_index)
    # Iterate over the systems to get the SOAP vectors
    soap_desc = []
    for config in configs_for_deduplication:
        soap_ex = soap.create(config, centers=[position_idx])
        soap_desc.extend(soap_ex)

    soap_descs = np.vstack(soap_desc)

    #Use euclidean distance to assess similarity
    distance = squareform(pdist(soap_descs, metric="cosine"))

    bool_matrix = np.where(distance <= cosine_similarity, 1, 0)
    # For configs that are found to be similar, just keep the lowest energy one
    idxs_to_keep = []
    pass_idxs = []
    for idx, row in enumerate(bool_matrix):
        if idx in pass_idxs:
            continue
            
        elif sum(row) == 1:
            idxs_to_keep.append(idx)
        else:
            same_idxs = [row_idx for row_idx, val in enumerate(row) if val == 1]
            pass_idxs.extend(same_idxs)
            # Pick the one with the lowest energy by ML
            min_e = min(energies_for_deduplication[same_idxs])
            idxs_to_keep.append(list(energies_for_deduplication).index(min_e))
    return idxs_to_keep
configs_for_deduplication =  df.relaxed_atoms.tolist()
idxs_to_keep = deduplicate(configs_for_deduplication, adsorbate.binding_indices[0])
# Flip through your configurations to check them out (and make sure deduplication looks good)
print(idxs_to_keep)
view_x3d_n(configs_for_deduplication[2].repeat((2,2,1)))
df = df.iloc[idxs_to_keep]
low_e_values = np.round(df.sort_values(by = "relaxed_energy_ml").relaxed_energy_ml.tolist()[0:5],3)
print(f"The lowest 5 energies are: {low_e_values}")
df
The lowest 5 energies are: [-0.567 -0.534 -0.524 -0.511]
index relaxation_idx relaxed_atoms relaxed_energy_ml anomolous
0 0 1 (Atom('Cu', [-1.3000465215529715, 2.2517466275... -0.567241 False
6 6 4 (Atom('Cu', [-1.3000465215529715, 2.2517466275... -0.510938 False
4 4 2 (Atom('Cu', [-1.3000465215529715, 2.2517466275... -0.534341 False
5 5 3 (Atom('Cu', [-1.3000465215529715, 2.2517466275... -0.523522 False

Write VASP input files#

This assumes you have access to VASP pseudopotentials and the right environment variables configured for ASE. The default VASP flags (which are equivalent to those used to make OC20) are located in ocdata.utils.vasp. Alternatively, you may pass your own vasp flags to the write_vasp_input_files function as vasp_flags.

# Grab the 5 systems with the lowest energy
configs_for_dft = df.sort_values(by = "relaxed_energy_ml").relaxed_atoms.tolist()[0:5]
config_idxs = df.sort_values(by = "relaxed_energy_ml").relaxation_idx.tolist()[0:5]

# Write the inputs
for idx, config in enumerate(configs_for_dft):
    os.mkdir(f"data/{config_idxs[idx]}")
    write_vasp_input_files(config, outdir = f"data/{config_idxs[idx]}/")