core.models.uma.common.rotation_cuda_graph#

Copyright (c) Meta Platforms, Inc. and affiliates.

This source code is licensed under the MIT license found in the LICENSE file in the root directory of this source tree.

Classes#

Functions#

Module Contents#

class core.models.uma.common.rotation_cuda_graph.RotMatWignerCudaGraph#
graph_mod = None#
graph_capture_count = 0#
max_edge_size = None#
_capture_graph(edge_dist_vec: torch.Tensor, jds: list[torch.Tensor])#
get_rotmat_and_wigner(edge_dist_vec: torch.Tensor, jds: list[torch.Tensor]) tuple[torch.Tensor, torch.Tensor, torch.Tensor]#
core.models.uma.common.rotation_cuda_graph.capture_rotmat_and_wigner_with_make_graph_callable(edge_dist_vec: torch.Tensor, jds: list[torch.Tensor])#
core.models.uma.common.rotation_cuda_graph.edge_rot_and_wigner_graph_capture_region(edge_distance_vecs: torch.Tensor, Jd_buffers: list[torch.Tensor])#
core.models.uma.common.rotation_cuda_graph.init_edge_rot_euler_angles_wigner_cuda_graph(edge_distance_vec)#