core.models.equiformer_v2.prediction_heads#

Submodules#

Classes#

Rank2SymmetricTensorHead

A rank 2 symmetric tensor prediction head.

Package Contents#

class core.models.equiformer_v2.prediction_heads.Rank2SymmetricTensorHead(backbone: fairchem.core.models.base.BackboneInterface, output_name: str, decompose: bool = False, edge_level_mlp: bool = False, num_mlp_layers: int = 2, use_source_target_embedding: bool = False, extensive: bool = False, avg_num_nodes: int = 1.0, default_norm_type: str = 'layer_norm_sh')#

Bases: torch.nn.Module, fairchem.core.models.base.HeadInterface

A rank 2 symmetric tensor prediction head.

ouput_name#

name of output prediction property (ie, stress)

sphharm_norm#

layer normalization for spherical harmonic edge weights

xedge_layer_norm#

embedding layer norm

block#

rank 2 equivariant symmetric tensor block

output_name#
decompose#
use_source_target_embedding#
avg_num_nodes#
sphharm_norm#
xedge_layer_norm#
forward(data: dict[str, torch.Tensor] | torch.Tensor, emb: dict[str, torch.Tensor]) dict[str, torch.Tensor]#
Parameters:
  • data – data batch

  • emb – dictionary with embedding object and graph data

Returns: dict of {output property name: predicted value}