core.models.uma.common.rotation_cuda_graph#
Copyright (c) Meta Platforms, Inc. and affiliates.
This source code is licensed under the MIT license found in the LICENSE file in the root directory of this source tree.
Classes#
Functions#
Module Contents#
- class core.models.uma.common.rotation_cuda_graph.RotMatWignerCudaGraph#
- graph_mod = None#
- graph_capture_count = 0#
- max_edge_size = None#
- _capture_graph(edge_dist_vec: torch.Tensor, jds: list[torch.Tensor])#
- get_rotmat_and_wigner(edge_dist_vec: torch.Tensor, jds: list[torch.Tensor]) tuple[torch.Tensor, torch.Tensor, torch.Tensor] #
- core.models.uma.common.rotation_cuda_graph.capture_rotmat_and_wigner_with_make_graph_callable(edge_dist_vec: torch.Tensor, jds: list[torch.Tensor])#
- core.models.uma.common.rotation_cuda_graph.edge_rot_and_wigner_graph_capture_region(edge_distance_vecs: torch.Tensor, Jd_buffers: list[torch.Tensor])#
- core.models.uma.common.rotation_cuda_graph.init_edge_rot_euler_angles_wigner_cuda_graph(edge_distance_vec)#