core.models.escaip.modules.base_block#

Classes#

BaseGraphNeuralNetworkLayer

Base class for Graph Neural Network layers.

Module Contents#

class core.models.escaip.modules.base_block.BaseGraphNeuralNetworkLayer(global_cfg: fairchem.core.models.escaip.configs.GlobalConfigs, molecular_graph_cfg: fairchem.core.models.escaip.configs.MolecularGraphConfigs, gnn_cfg: fairchem.core.models.escaip.configs.GraphNeuralNetworksConfigs, reg_cfg: fairchem.core.models.escaip.configs.RegularizationConfigs)#

Bases: torch.nn.Module

Base class for Graph Neural Network layers. Used in InputLayer and EfficientGraphAttention.

source_atomic_embedding#
target_atomic_embedding#
source_direction_embedding#
target_direction_embedding#
edge_distance_embedding#
get_edge_linear(gnn_cfg: fairchem.core.models.escaip.configs.GraphNeuralNetworksConfigs, global_cfg: fairchem.core.models.escaip.configs.GlobalConfigs, reg_cfg: fairchem.core.models.escaip.configs.RegularizationConfigs)#
get_node_linear(global_cfg: fairchem.core.models.escaip.configs.GlobalConfigs, reg_cfg: fairchem.core.models.escaip.configs.RegularizationConfigs)#
get_edge_features(x: fairchem.core.models.escaip.custom_types.GraphAttentionData) torch.Tensor#
get_node_features(node_features: torch.Tensor, neighbor_list: torch.Tensor) torch.Tensor#
aggregate(edge_features, neighbor_mask)#
abstract forward()#