core.common.profiler_utils#

Classes#

ProfilerCallback

A Callback is an optional extension that can be used to supplement your loop with additional functionality. Good candidates

Functions#

get_default_profiler_handler(run_id, output_dir, logger)

Get a standard callback handle for the pytorch profiler

get_profile_schedule([wait, warmup, active])

Get a profile schedule and total number of steps to run

Module Contents#

core.common.profiler_utils.get_default_profiler_handler(run_id: str, output_dir: str, logger: fairchem.core.common.logger.Logger, all_ranks: bool = False)#

Get a standard callback handle for the pytorch profiler

core.common.profiler_utils.get_profile_schedule(wait: int = 5, warmup: int = 5, active: int = 2)#

Get a profile schedule and total number of steps to run check pytorch docs on the meaning of these paramters: https://pytorch.org/docs/stable/profiler.html#torch.profiler.schedule Example usage: ``` trace_handler = get_default_profiler_handler(run_id = self.config[“cmd”][“timestamp_id”],

output_dir = self.config[“cmd”][“results_dir”], logger = self.logger)

profile_schedule, total_profile_steps = get_profile_schedule()

with profile(

activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], schedule=profile_schedule, on_trace_ready=trace_handler

) as p:
for i in steps:

<code block to profile> if i < total_profile_steps:

p.step()

class core.common.profiler_utils.ProfilerCallback(job_config: omegaconf.DictConfig, wait_steps: int = 5, warmup_steps: int = 5, active_steps: int = 2, all_ranks: bool = False, activities: tuple = (ProfilerActivity.CPU, ProfilerActivity.CUDA))#

Bases: torchtnt.framework.callback.Callback

A Callback is an optional extension that can be used to supplement your loop with additional functionality. Good candidates for such logic are ones that can be re-used across units. Callbacks are generally not intended for modeling code; this should go in your Unit. To write your own callback, subclass the Callback class and add your own code into the hooks.

Below is an example of a basic callback which prints a message at various points during execution.

from torchtnt.framework.callback import Callback
from torchtnt.framework.state import State
from torchtnt.framework.unit import TEvalUnit, TPredictUnit, TTrainUnit

class PrintingCallback(Callback):
    def on_train_start(self, state: State, unit: TTrainUnit) -> None:
        print("Starting training")

    def on_train_end(self, state: State, unit: TTrainUnit) -> None:
        print("Ending training")

    def on_eval_start(self, state: State, unit: TEvalUnit) -> None:
        print("Starting evaluation")

    def on_eval_end(self, state: State, unit: TEvalUnit) -> None:
        print("Ending evaluation")

    def on_predict_start(self, state: State, unit: TPredictUnit) -> None:
        print("Starting prediction")

    def on_predict_end(self, state: State, unit: TPredictUnit) -> None:
        print("Ending prediction")

To use a callback, instantiate the class and pass it in the callbacks parameter to the train(), evaluate(), predict(), or fit() entry point.

printing_callback = PrintingCallback()
train(train_unit, train_dataloader, callbacks=[printing_callback])
profiler#
on_train_start(state: torchtnt.framework.State, unit: torchtnt.framework.TTrainUnit) None#

Hook called before training starts.

on_train_step_end(state: torchtnt.framework.State, unit: torchtnt.framework.TTrainUnit) None#

Hook called after a train step ends.