core.models.escaip.modules.graph_attention_block

Contents

core.models.escaip.modules.graph_attention_block#

Classes#

EfficientGraphAttentionBlock

Efficient Graph Attention Block module.

EfficientGraphAttention

Efficient Graph Attention module.

FeedForwardNetwork

Feed Forward Network module.

Module Contents#

class core.models.escaip.modules.graph_attention_block.EfficientGraphAttentionBlock(global_cfg: fairchem.core.models.escaip.configs.GlobalConfigs, molecular_graph_cfg: fairchem.core.models.escaip.configs.MolecularGraphConfigs, gnn_cfg: fairchem.core.models.escaip.configs.GraphNeuralNetworksConfigs, reg_cfg: fairchem.core.models.escaip.configs.RegularizationConfigs, is_last: bool = False)#

Bases: torch.nn.Module

Efficient Graph Attention Block module. Ref: swin transformer

backbone_dtype#
graph_attention#
feedforward#
norm_attn_node#
norm_attn_edge#
norm_ffn_node#
stochastic_depth_attn#
stochastic_depth_ffn#
forward(data: fairchem.core.models.escaip.custom_types.GraphAttentionData, node_features: torch.Tensor, edge_features: torch.Tensor)#
class core.models.escaip.modules.graph_attention_block.EfficientGraphAttention(global_cfg: fairchem.core.models.escaip.configs.GlobalConfigs, molecular_graph_cfg: fairchem.core.models.escaip.configs.MolecularGraphConfigs, gnn_cfg: fairchem.core.models.escaip.configs.GraphNeuralNetworksConfigs, reg_cfg: fairchem.core.models.escaip.configs.RegularizationConfigs)#

Bases: fairchem.core.models.escaip.modules.base_block.BaseGraphNeuralNetworkLayer

Efficient Graph Attention module.

backbone_dtype#
repeating_dimensions_list#
rep_dim_len#
use_frequency_embedding#
edge_attr_linear#
edge_attr_norm#
node_hidden_linear#
edge_hidden_linear#
message_norm#
use_message_gate#
attn_in_proj_q#
attn_in_proj_k#
attn_in_proj_v#
attn_out_proj#
attn_num_heads#
attn_dropout#
use_angle_embedding#
use_graph_attention#
forward(data: fairchem.core.models.escaip.custom_types.GraphAttentionData, node_features: torch.Tensor, edge_features: torch.Tensor)#
multi_head_self_attention(input, attn_mask, frequency_vectors=None)#
get_attn_bias(angle_embedding, edge_distance_expansion)#
graph_attention_aggregate(edge_output, neighbor_mask)#
class core.models.escaip.modules.graph_attention_block.FeedForwardNetwork(global_cfg: fairchem.core.models.escaip.configs.GlobalConfigs, gnn_cfg: fairchem.core.models.escaip.configs.GraphNeuralNetworksConfigs, reg_cfg: fairchem.core.models.escaip.configs.RegularizationConfigs, is_last: bool = False)#

Bases: torch.nn.Module

Feed Forward Network module.

backbone_dtype#
mlp_node#
forward(node_features: torch.Tensor, edge_features: torch.Tensor)#