Last active
April 14, 2026 03:31
-
-
Save Ethkuil/bcaad8759a5e1317085903fb2a7107e8 to your computer and use it in GitHub Desktop.
Convenient torch profiler that only profile needed part
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| def on_trace_ready(dir_name, use_gzip=True): | |
| import os | |
| import socket | |
| import time | |
| def handler_fn(prof) -> None: | |
| if not os.path.isdir(dir_name): | |
| try: | |
| os.makedirs(dir_name, exist_ok=True) | |
| except Exception as e: | |
| raise RuntimeError("Can't create directory: " + dir_name) from e | |
| rank = torch.distributed.get_rank() | |
| worker_name = f"rank{rank}.{socket.gethostname()}" | |
| curr_time = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime()) | |
| file_name = f"{curr_time}.{worker_name}.pt.trace.json" | |
| if use_gzip: | |
| file_name = file_name + ".gz" | |
| print(f"Exporting trace file to {os.path.join(dir_name, file_name)}", flush=True) | |
| prof.export_chrome_trace(os.path.join(dir_name, file_name)) | |
| print(f"Trace file exported to {os.path.join(dir_name, file_name)}", flush=True) | |
| return handler_fn | |
| from enum import Enum | |
| class _ProfilerState(Enum): | |
| NOT_STARTED = "not_started" | |
| RUNNING = "running" | |
| STOPPED = "stopped" | |
| def my_profile( | |
| begin_cnt: int, | |
| len_cnt: int | None = None, | |
| rank_list: list[int] | None = None, | |
| my_profiler=None, | |
| enable_profile: bool = True, | |
| trace_dir: str = "./traces", | |
| use_gzip: bool = True, | |
| with_stack: bool = True, | |
| ): | |
| """Decorator that wraps a function with torch profiler recording. | |
| The profiler starts after ``begin_cnt`` calls and records the next | |
| ``len_cnt`` calls, then stops automatically. When ``len_cnt`` is | |
| ``None``, profiling continues until the function returns. | |
| Each decorated function gets its own independent counter and state | |
| (no globals). | |
| If ``my_profiler`` is ``None``, a ``torch.profiler.profile`` is created | |
| internally from ``trace_dir``, ``use_gzip``, and ``with_stack``. | |
| """ | |
| if rank_list is None: | |
| rank_list = [0] | |
| if my_profiler is None and enable_profile: | |
| from torch.profiler import profile, ProfilerActivity | |
| my_profiler = profile( | |
| activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], | |
| with_stack=with_stack, | |
| on_trace_ready=on_trace_ready(trace_dir, use_gzip=use_gzip), | |
| ) | |
| def decorator(func): | |
| if not enable_profile: | |
| return func | |
| from functools import wraps | |
| counter = 0 | |
| state = _ProfilerState.NOT_STARTED | |
| @wraps(func) | |
| def wrapped_func(*args, **kwargs): | |
| nonlocal counter, state | |
| rank = torch.distributed.get_rank() | |
| if rank in rank_list and state != _ProfilerState.STOPPED: | |
| if state == _ProfilerState.NOT_STARTED: | |
| if counter >= begin_cnt: | |
| print( | |
| f"\n=== Profiler started at count {counter} for rank {rank} ===\n" | |
| ) | |
| my_profiler.start() | |
| state = _ProfilerState.RUNNING | |
| elif state == _ProfilerState.RUNNING: | |
| if len_cnt is not None and counter >= begin_cnt + len_cnt: | |
| my_profiler.stop() | |
| state = _ProfilerState.STOPPED | |
| print( | |
| f"\n=== Profiler stopped at count {counter} for rank {rank} ===\n" | |
| ) | |
| counter += 1 | |
| return func(*args, **kwargs) | |
| def stop_profiler(): | |
| nonlocal state | |
| if state == _ProfilerState.RUNNING: | |
| my_profiler.stop() | |
| state = _ProfilerState.STOPPED | |
| wrapped_func.stop_profiler = stop_profiler | |
| return wrapped_func | |
| return decorator |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment