Skip to content

Instantly share code, notes, and snippets.

@Ethkuil
Last active April 14, 2026 03:31
Show Gist options
  • Select an option

  • Save Ethkuil/bcaad8759a5e1317085903fb2a7107e8 to your computer and use it in GitHub Desktop.

Select an option

Save Ethkuil/bcaad8759a5e1317085903fb2a7107e8 to your computer and use it in GitHub Desktop.
Convenient torch profiler that only profile needed part
import torch
def on_trace_ready(dir_name, use_gzip=True):
import os
import socket
import time
def handler_fn(prof) -> None:
if not os.path.isdir(dir_name):
try:
os.makedirs(dir_name, exist_ok=True)
except Exception as e:
raise RuntimeError("Can't create directory: " + dir_name) from e
rank = torch.distributed.get_rank()
worker_name = f"rank{rank}.{socket.gethostname()}"
curr_time = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime())
file_name = f"{curr_time}.{worker_name}.pt.trace.json"
if use_gzip:
file_name = file_name + ".gz"
print(f"Exporting trace file to {os.path.join(dir_name, file_name)}", flush=True)
prof.export_chrome_trace(os.path.join(dir_name, file_name))
print(f"Trace file exported to {os.path.join(dir_name, file_name)}", flush=True)
return handler_fn
from enum import Enum
class _ProfilerState(Enum):
NOT_STARTED = "not_started"
RUNNING = "running"
STOPPED = "stopped"
def my_profile(
begin_cnt: int,
len_cnt: int | None = None,
rank_list: list[int] | None = None,
my_profiler=None,
enable_profile: bool = True,
trace_dir: str = "./traces",
use_gzip: bool = True,
with_stack: bool = True,
):
"""Decorator that wraps a function with torch profiler recording.
The profiler starts after ``begin_cnt`` calls and records the next
``len_cnt`` calls, then stops automatically. When ``len_cnt`` is
``None``, profiling continues until the function returns.
Each decorated function gets its own independent counter and state
(no globals).
If ``my_profiler`` is ``None``, a ``torch.profiler.profile`` is created
internally from ``trace_dir``, ``use_gzip``, and ``with_stack``.
"""
if rank_list is None:
rank_list = [0]
if my_profiler is None and enable_profile:
from torch.profiler import profile, ProfilerActivity
my_profiler = profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
with_stack=with_stack,
on_trace_ready=on_trace_ready(trace_dir, use_gzip=use_gzip),
)
def decorator(func):
if not enable_profile:
return func
from functools import wraps
counter = 0
state = _ProfilerState.NOT_STARTED
@wraps(func)
def wrapped_func(*args, **kwargs):
nonlocal counter, state
rank = torch.distributed.get_rank()
if rank in rank_list and state != _ProfilerState.STOPPED:
if state == _ProfilerState.NOT_STARTED:
if counter >= begin_cnt:
print(
f"\n=== Profiler started at count {counter} for rank {rank} ===\n"
)
my_profiler.start()
state = _ProfilerState.RUNNING
elif state == _ProfilerState.RUNNING:
if len_cnt is not None and counter >= begin_cnt + len_cnt:
my_profiler.stop()
state = _ProfilerState.STOPPED
print(
f"\n=== Profiler stopped at count {counter} for rank {rank} ===\n"
)
counter += 1
return func(*args, **kwargs)
def stop_profiler():
nonlocal state
if state == _ProfilerState.RUNNING:
my_profiler.stop()
state = _ProfilerState.STOPPED
wrapped_func.stop_profiler = stop_profiler
return wrapped_func
return decorator
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment