core.datasets.atomic_data#
Copyright (c) Meta Platforms, Inc. and affiliates.
This source code is licensed under the MIT license found in the LICENSE file in the root directory of this source tree.
modified from troch_geometric Data class
Attributes#
Classes#
Functions#
|
|
|
Preforms nearest neighbor search and returns edge index, distances, |
|
Stack center and neighbor index and reshapes distances, |
|
all data points must be single graphs and have the same set of keys. |
|
Module Contents#
- core.datasets.atomic_data.pmg_installed = True#
- core.datasets.atomic_data.IndexType#
- core.datasets.atomic_data._REQUIRED_KEYS = ['pos', 'atomic_numbers', 'cell', 'pbc', 'natoms', 'edge_index', 'cell_offsets', 'nedges',...#
- core.datasets.atomic_data._OPTIONAL_KEYS = ['energy', 'forces', 'stress', 'dataset']#
- core.datasets.atomic_data.size_repr(key: str, item: torch.Tensor, indent=0) str#
- core.datasets.atomic_data.get_neighbors_pymatgen(atoms: ase.Atoms, cutoff, max_neigh)#
Preforms nearest neighbor search and returns edge index, distances, and cell offsets
- core.datasets.atomic_data.reshape_features(c_index: numpy.ndarray, n_index: numpy.ndarray, n_distance: numpy.ndarray, offsets: numpy.ndarray)#
Stack center and neighbor index and reshapes distances, takes in np.arrays and returns torch tensors
- class core.datasets.atomic_data.AtomicData(pos: torch.Tensor, atomic_numbers: torch.Tensor, cell: torch.Tensor, pbc: torch.Tensor, natoms: torch.Tensor, edge_index: torch.Tensor, cell_offsets: torch.Tensor, nedges: torch.Tensor, charge: torch.Tensor, spin: torch.Tensor, fixed: torch.Tensor, tags: torch.Tensor, energy: torch.Tensor | None = None, forces: torch.Tensor | None = None, stress: torch.Tensor | None = None, batch: torch.Tensor | None = None, sid: list[str] | None = None, dataset: list[str] | str | None = None)#
- __keys__#
- pos#
- atomic_numbers#
- cell#
- pbc#
- natoms#
- edge_index#
- cell_offsets#
- nedges#
- charge#
- spin#
- fixed#
- tags#
- sid#
- __slices__ = None#
- __cumsum__ = None#
- __cat_dims__ = None#
- __natoms_list__ = None#
- property task_name#
- assign_batch_stats(slices, cumsum, cat_dims, natoms_list)#
- get_batch_stats()#
- validate()#
- classmethod from_ase(input_atoms: ase.Atoms, r_edges: bool = False, radius: float = 6.0, max_neigh: int | None = None, sid: str | None = None, molecule_cell_size: float | None = None, r_energy: bool = True, r_forces: bool = True, r_stress: bool = True, r_data_keys: list[str] | None = None, task_name: str | None = None, target_dtype: torch.dtype = torch.float32) AtomicData#
- to_ase_single() ase.Atoms#
- to_ase() list[ase.Atoms]#
- classmethod from_dict(dictionary)#
Creates a data object from a python dictionary.
- to_dict()#
- values()#
- property num_nodes: int#
Returns or sets the number of nodes in the graph.
- property num_edges: int#
Returns the number of edges in the graph.
- property num_graphs: int#
Returns the number of graphs in the batch.
- __len__()#
- get(key, default)#
- __getitem__(idx)#
- __setitem__(key: str, value: torch.Tensor)#
Sets the attribute
keytovalue.
- __setattr__(key: str, value: torch.Tensor)#
- __delitem__(key: str)#
Deletes the attribute
key.
- keys()#
- __contains__(key)#
Returns
True, if the attributekeyis present in the data.
- __iter__()#
Iterates over all present attributes in the data, yielding their attribute names and content.
- __call__(*keys)#
Iterates over all attributes
*keysin the data, yielding their attribute names and content. If*keysis not given this method will iterative over all present attributes.
- __cat_dim__(key, value) int#
Returns the dimension for which
valueof attributekeywill get concatenated when creating batches.Note
This method is for internal use only, and should only be overridden if the batch concatenation process is corrupted for a specific data attribute.
- __inc__(key, value) int#
Returns the incremental count to cumulatively increase the value of the next attribute of
keywhen creating batches.Note
This method is for internal use only, and should only be overridden if the batch concatenation process is corrupted for a specific data attribute.
- __apply__(item, func)#
- apply(func)#
Applies the function
functo all tensor attributes
- contiguous()#
Ensures a contiguous memory layout for all tensor attributes
- to(device, **kwargs)#
Performs tensor dtype and/or device conversion for all tensor attributes
- cpu()#
Copies all tensor attributes to CPU memory.
- cuda(device=None, non_blocking=False)#
Copies all tensor attributes to GPU memory.
- clone()#
Performs a deep-copy of the data object.
- __repr__()#
- get_example(idx: int) AtomicData#
Reconstructs the
AtomicDataobject at indexidxfrom a batched AtomicData object.
- index_select(idx: IndexType) list[AtomicData]#
- batch_to_atomicdata_list() list[AtomicData]#
Reconstructs the list of
torch_geometric.data.Dataobjects from the batch object. The batch object must have been created viafrom_data_list()in order to be able to reconstruct the initial objects.
- update_batch_edges(edge_index: torch.Tensor, cell_offsets: torch.Tensor, nedges: torch.Tensor) AtomicData#
Update the connectivity of each batched AtomicData sample.
- Parameters:
edge_index (torch.Tensor) – New batch edge_index (shape [2, total_edges]).
cell_offsets (torch.Tensor) – Cell offsets per edge (shape [total_edges, 3]).
nedges (torch.Tensor) – Number of edges per system (shape [num_systems]).
- Returns:
The updated batch object.
- Return type:
- core.datasets.atomic_data.atomicdata_list_to_batch(data_list: list[AtomicData], exclude_keys: list | None = None) AtomicData#
all data points must be single graphs and have the same set of keys. TODO: exclude keys?
- core.datasets.atomic_data.tensor_or_int_to_tensor(x, dtype=torch.int)#