core.models.uma.common.rotation_cuda_graph#
Copyright (c) Meta Platforms, Inc. and affiliates.
This source code is licensed under the MIT license found in the LICENSE file in the root directory of this source tree.
Attributes#
Classes#
Functions#
|
|
|
|
|
|
|
Module Contents#
- core.models.uma.common.rotation_cuda_graph.YTOL = 0.999999#
- class core.models.uma.common.rotation_cuda_graph.RotMatWignerCudaGraph#
- graph_mod = None#
- graph_capture_count = 0#
- max_edge_size = None#
- _capture_graph(edge_dist_vec: torch.Tensor, jds: list[torch.Tensor])#
- get_rotmat_and_wigner(edge_dist_vec: torch.Tensor, jds: list[torch.Tensor]) tuple[torch.Tensor, torch.Tensor, torch.Tensor] #
- core.models.uma.common.rotation_cuda_graph.capture_rotmat_and_wigner_with_make_graph_callable(edge_dist_vec: torch.Tensor, jds: list[torch.Tensor])#
- core.models.uma.common.rotation_cuda_graph.edge_rot_and_wigner_graph_capture_region(edge_distance_vecs: torch.Tensor, Jd_buffers: list[torch.Tensor], x_hat: torch.Tensor, mask: torch.Tensor, neg_mask: torch.Tensor)#
- core.models.uma.common.rotation_cuda_graph.create_masks(edge_distance_vec: torch.Tensor, x_hat: torch.Tensor) tuple[torch.Tensor, torch.Tensor] #
- core.models.uma.common.rotation_cuda_graph.init_edge_rot_mat_cuda_graph(edge_distance_vec: torch.Tensor, mask: torch.Tensor, neg_mask: torch.Tensor, x_hat: torch.Tensor) torch.Tensor #
- core.models.uma.common.rotation_cuda_graph.euler_from_edge_rot_mat(edge_rot_mat: torch.Tensor, x_hat: torch.Tensor) tuple[torch.Tensor, torch.Tensor, torch.Tensor] #
- core.models.uma.common.rotation_cuda_graph.eulers_to_wigner(alpha: torch.Tensor, beta: torch.Tensor, gamma: torch.Tensor, start_lmax: int, end_lmax: int, Jd: list[torch.Tensor]) torch.Tensor #