core.common.utils#

Copyright (c) Meta, Inc. and its affiliates.

This source code is licensed under the MIT license found in the LICENSE file in the root directory of this source tree.

Attributes#

Classes#

UniqueKeyLoader

Complete

SeverityLevelBetween

Filter instances are used to perform arbitrary filtering of LogRecords.

Functions#

pyg2_data_transform(data)

if we're on the new pyg (2.0 or later) and if the Data stored is in older format

save_checkpoint(→ str)

warmup_lr_lambda(current_step, optim_config)

Returns a learning rate multiplier.

print_cuda_usage(→ None)

conditional_grad(dec)

Decorator to enable/disable grad depending on whether force/energy predictions are being made

plot_histogram(data[, xlabel, ylabel, title])

collate(data_list)

add_edge_distance_to_graph(batch[, device, dmin, ...])

_import_local_file(→ None)

Imports a Python file as a module

setup_experimental_imports(→ None)

Import selected directories of modules from the "experimental" subdirectory.

_get_project_root(→ pathlib.Path)

Gets the root folder of the project (the "ocp" folder)

setup_imports(→ None)

dict_set_recursively(→ None)

parse_value(value)

Parse string as Python literal if possible and fallback to string.

create_dict_from_args(args[, sep])

Create a (nested) dictionary from console arguments.

find_relative_file_in_paths(filename, include_paths)

load_config(path[, files_previously_included, ...])

Load a given config with any defined imports

build_config(args, args_override[, include_paths])

create_grid(base_config, sweep_file)

save_experiment_log(args, jobs, configs)

get_pbc_distances(pos, edge_index, cell, cell_offsets, ...)

radius_graph_pbc(data, radius, max_num_neighbors_threshold)

get_max_neighbors_mask(natoms, index, atom_distance, ...)

Give a mask that filters out edges so that each atom has at most

get_pruned_edge_idx(→ torch.Tensor)

merge_dicts(dict1, dict2)

Recursively merge two dictionaries.

debug_log_entry_exit(func)

setup_logging(→ None)

compute_neighbors(data, edge_index)

check_traj_files(→ bool)

setup_env_vars(→ None)

new_trainer_context(*, config)

_resolve_scale_factor_submodule(model, name)

_report_incompat_keys(→ tuple[list[str], list[str]])

match_state_dict(→ dict)

load_state_dict(→ tuple[list[str], list[str]])

scatter_det(*args, **kwargs)

get_commit_hash()

cg_change_mat(→ torch.tensor)

irreps_sum(→ int)

Returns the sum of the dimensions of the irreps up to the specified angular momentum.

update_config(base_config)

Configs created prior to FAIRChem/OCP 2.0 are organized a little different than they

load_model_and_weights_from_checkpoint(→ torch.nn.Module)

get_timestamp_uid(→ str)

tensor_stats(→ dict)

get_weight_table(→ tuple[list, list])

get_checkpoint_format(→ str)

Module Contents#

core.common.utils.DEFAULT_ENV_VARS#
class core.common.utils.UniqueKeyLoader(stream)#

Bases: yaml.SafeLoader

construct_mapping(node, deep=False)#
core.common.utils.pyg2_data_transform(data: torch_geometric.data.Data)#

if we’re on the new pyg (2.0 or later) and if the Data stored is in older format we need to convert the data to the new format

core.common.utils.save_checkpoint(state, checkpoint_dir: str = 'checkpoints/', checkpoint_file: str = 'checkpoint.pt') str#
core.common.utils.multitask_required_keys#
class core.common.utils.Complete#
__call__(data)#
core.common.utils.warmup_lr_lambda(current_step: int, optim_config)#

Returns a learning rate multiplier. Till warmup_steps, learning rate linearly increases to initial_lr, and then gets multiplied by lr_gamma every time a milestone is crossed.

core.common.utils.print_cuda_usage() None#
core.common.utils.conditional_grad(dec)#

Decorator to enable/disable grad depending on whether force/energy predictions are being made

core.common.utils.plot_histogram(data, xlabel: str = '', ylabel: str = '', title: str = '')#
core.common.utils.collate(data_list)#
core.common.utils.add_edge_distance_to_graph(batch, device='cpu', dmin: float = 0.0, dmax: float = 6.0, num_gaussians: int = 50)#
core.common.utils._import_local_file(path: pathlib.Path, *, project_root: pathlib.Path) None#

Imports a Python file as a module

Parameters:
  • path (Path) – The path to the file to import

  • project_root (Path) – The root directory of the project (i.e., the “ocp” folder)

core.common.utils.setup_experimental_imports(project_root: pathlib.Path) None#

Import selected directories of modules from the “experimental” subdirectory.

If a file named “.include” is present in the “experimental” subdirectory, this will be read as a list of experimental subdirectories whose module (including in any subsubdirectories) should be imported.

Parameters:

project_root – The root directory of the project (i.e., the “ocp” folder)

core.common.utils._get_project_root() pathlib.Path#

Gets the root folder of the project (the “ocp” folder) :return: The absolute path to the project root.

core.common.utils.setup_imports(config: dict | None = None) None#
core.common.utils.dict_set_recursively(dictionary, key_sequence, val) None#
core.common.utils.parse_value(value)#

Parse string as Python literal if possible and fallback to string.

core.common.utils.create_dict_from_args(args: list, sep: str = '.')#

Create a (nested) dictionary from console arguments. Keys in different dictionary levels are separated by sep.

core.common.utils.find_relative_file_in_paths(filename, include_paths)#
core.common.utils.load_config(path: str, files_previously_included: list | None = None, include_paths: list | None = None)#

Load a given config with any defined imports

When imports are present this is a recursive function called on imports. To prevent any cyclic imports we keep track of already imported yml files using files_previously_included

core.common.utils.build_config(args, args_override, include_paths=None)#
core.common.utils.create_grid(base_config, sweep_file: str)#
core.common.utils.save_experiment_log(args, jobs, configs)#
core.common.utils.get_pbc_distances(pos, edge_index, cell, cell_offsets, neighbors, return_offsets: bool = False, return_distance_vec: bool = False)#
core.common.utils.radius_graph_pbc(data, radius, max_num_neighbors_threshold, enforce_max_neighbors_strictly: bool = False, pbc=None)#
core.common.utils.get_max_neighbors_mask(natoms, index, atom_distance, max_num_neighbors_threshold, degeneracy_tolerance: float = 0.01, enforce_max_strictly: bool = False)#

Give a mask that filters out edges so that each atom has at most max_num_neighbors_threshold neighbors. Assumes that index is sorted.

Enforcing the max strictly can force the arbitrary choice between degenerate edges. This can lead to undesired behaviors; for example, bulk formation energies which are not invariant to unit cell choice.

A degeneracy tolerance can help prevent sudden changes in edge existence from small changes in atom position, for example, rounding errors, slab relaxation, temperature, etc.

core.common.utils.get_pruned_edge_idx(edge_index, num_atoms: int, max_neigh: float = 1000000000.0) torch.Tensor#
core.common.utils.merge_dicts(dict1: dict, dict2: dict)#

Recursively merge two dictionaries. Values in dict2 override values in dict1. If dict1 and dict2 contain a dictionary as a value, this will call itself recursively to merge these dictionaries. This does not modify the input dictionaries (creates an internal copy). Additionally returns a list of detected duplicates. Adapted from TUM-DAML/seml

Parameters:
  • dict1 (dict) – First dict.

  • dict2 (dict) – Second dict. Values in dict2 will override values from dict1 in case they share the same key.

Returns:

return_dict – Merged dictionaries.

Return type:

dict

class core.common.utils.SeverityLevelBetween(min_level: int, max_level: int)#

Bases: logging.Filter

Filter instances are used to perform arbitrary filtering of LogRecords.

Loggers and Handlers can optionally use Filter instances to filter records as desired. The base filter class only allows events which are below a certain point in the logger hierarchy. For example, a filter initialized with “A.B” will allow events logged by loggers “A.B”, “A.B.C”, “A.B.C.D”, “A.B.D” etc. but not “A.BB”, “B.A.B” etc. If initialized with the empty string, all events are passed.

min_level#
max_level#
filter(record) bool#

Determine if the specified record is to be logged.

Returns True if the record should be logged, or False otherwise. If deemed appropriate, the record may be modified in-place.

core.common.utils.debug_log_entry_exit(func)#
core.common.utils.setup_logging() None#
core.common.utils.compute_neighbors(data, edge_index)#
core.common.utils.check_traj_files(batch, traj_dir) bool#
core.common.utils.setup_env_vars() None#
core.common.utils.new_trainer_context(*, config: dict[str, Any])#
core.common.utils._resolve_scale_factor_submodule(model: torch.nn.Module, name: str)#
core.common.utils._report_incompat_keys(model: torch.nn.Module, keys: torch.nn.modules.module._IncompatibleKeys, strict: bool = False) tuple[list[str], list[str]]#
core.common.utils.match_state_dict(model_state_dict: collections.abc.Mapping[str, torch.Tensor], checkpoint_state_dict: collections.abc.Mapping[str, torch.Tensor]) dict#
core.common.utils.load_state_dict(module: torch.nn.Module, state_dict: collections.abc.Mapping[str, torch.Tensor], strict: bool = True) tuple[list[str], list[str]]#
core.common.utils.scatter_det(*args, **kwargs)#
core.common.utils.get_commit_hash()#
core.common.utils.cg_change_mat(ang_mom: int, device: str = 'cpu') torch.tensor#
core.common.utils.irreps_sum(ang_mom: int) int#

Returns the sum of the dimensions of the irreps up to the specified angular momentum.

Parameters:

ang_mom – max angular momenttum to sum up dimensions of irreps

core.common.utils.update_config(base_config)#

Configs created prior to FAIRChem/OCP 2.0 are organized a little different than they are now. Update old configs to fit the new expected structure.

core.common.utils.load_model_and_weights_from_checkpoint(checkpoint_path: str) torch.nn.Module#
core.common.utils.get_timestamp_uid() str#
core.common.utils.tensor_stats(name: str, x: torch.Tensor) dict#
core.common.utils.get_weight_table(model: torch.nn.Module) tuple[list, list]#
core.common.utils.get_checkpoint_format(config: dict) str#