core.models.uma.common.rotation_cuda_graph#

Copyright (c) Meta Platforms, Inc. and affiliates.

This source code is licensed under the MIT license found in the LICENSE file in the root directory of this source tree.

Attributes#

Classes#

Functions#

capture_rotmat_and_wigner_with_make_graph_callable(...)

edge_rot_and_wigner_graph_capture_region(...)

create_masks(→ tuple[torch.Tensor, torch.Tensor])

init_edge_rot_mat_cuda_graph(→ torch.Tensor)

euler_from_edge_rot_mat(→ tuple[torch.Tensor, ...)

eulers_to_wigner(→ torch.Tensor)

Module Contents#

core.models.uma.common.rotation_cuda_graph.YTOL = 0.999999#
class core.models.uma.common.rotation_cuda_graph.RotMatWignerCudaGraph#
graph_mod = None#
graph_capture_count = 0#
max_edge_size = None#
_capture_graph(edge_dist_vec: torch.Tensor, jds: list[torch.Tensor])#
get_rotmat_and_wigner(edge_dist_vec: torch.Tensor, jds: list[torch.Tensor]) tuple[torch.Tensor, torch.Tensor, torch.Tensor]#
core.models.uma.common.rotation_cuda_graph.capture_rotmat_and_wigner_with_make_graph_callable(edge_dist_vec: torch.Tensor, jds: list[torch.Tensor])#
core.models.uma.common.rotation_cuda_graph.edge_rot_and_wigner_graph_capture_region(edge_distance_vecs: torch.Tensor, Jd_buffers: list[torch.Tensor], x_hat: torch.Tensor, mask: torch.Tensor, neg_mask: torch.Tensor)#
core.models.uma.common.rotation_cuda_graph.create_masks(edge_distance_vec: torch.Tensor, x_hat: torch.Tensor) tuple[torch.Tensor, torch.Tensor]#
core.models.uma.common.rotation_cuda_graph.init_edge_rot_mat_cuda_graph(edge_distance_vec: torch.Tensor, mask: torch.Tensor, neg_mask: torch.Tensor, x_hat: torch.Tensor) torch.Tensor#
core.models.uma.common.rotation_cuda_graph.euler_from_edge_rot_mat(edge_rot_mat: torch.Tensor, x_hat: torch.Tensor) tuple[torch.Tensor, torch.Tensor, torch.Tensor]#
core.models.uma.common.rotation_cuda_graph.eulers_to_wigner(alpha: torch.Tensor, beta: torch.Tensor, gamma: torch.Tensor, start_lmax: int, end_lmax: int, Jd: list[torch.Tensor]) torch.Tensor#