Skip to content

Instantly share code, notes, and snippets.

@ssghost
Forked from skye/tpu_topology_env_vars.py
Created November 7, 2022 11:15
Show Gist options
  • Select an option

  • Save ssghost/ccd7ed6bbf8b546cc1deb1fc13c5ed6c to your computer and use it in GitHub Desktop.

Select an option

Save ssghost/ccd7ed6bbf8b546cc1deb1fc13c5ed6c to your computer and use it in GitHub Desktop.
You can use these environment variables to run a Python process on a subset of the TPU cores on a Cloud TPU VM. This allows running multiple TPU processes at the same time, since only one process can access a given TPU chip at a time. Note that in JAX, 1 TPU core = 1 TpuDevice as reported by `jax.devices()`, and each TPU chip has 2 cores.
# 4x 1 chip (2 cores) per process:
os.environ["TPU_CHIPS_PER_HOST_BOUNDS"] = "1,1,1"
os.environ["TPU_HOST_BOUNDS"] = "1,1,1"
# Different per process:
os.environ["TPU_VISIBLE_DEVICES"] = "0" # "1", "2", "3"
# Pick a unique port per process
os.environ["TPU_MESH_CONTROLLER_ADDRESS"] = "localhost:8476"
os.environ["TPU_MESH_CONTROLLER_PORT"] = "8476"
# 1-liner for bash: TPU_CHIPS_PER_HOST_BOUNDS=1,1,1 TPU_HOST_BOUNDS=1,1,1 TPU_VISIBLE_DEVICES=0 TPU_MESH_CONTROLLER_ADDRESS=localhost:8476 TPU_MESH_CONTROLLER_PORT=8476
# 2x 2 chips (4 cores) per process:
os.environ["TPU_CHIPS_PER_HOST_BOUNDS"] = "1,2,1"
os.environ["TPU_HOST_BOUNDS"] = "1,1,1"
# Different per process:
os.environ["TPU_VISIBLE_DEVICES"] = "0,1" # "2,3"
# Pick a unique port per process
os.environ["TPU_MESH_CONTROLLER_ADDRESS"] = "localhost:8476"
os.environ["TPU_MESH_CONTROLLER_PORT"] = "8476"
# 1-liner for bash: TPU_CHIPS_PER_HOST_BOUNDS=1,2,1 TPU_HOST_BOUNDS=1,1,1 TPU_VISIBLE_DEVICES=0,1 TPU_MESH_CONTROLLER_ADDRESS=localhost:8476 TPU_MESH_CONTROLLER_PORT=8476
# 1x 4 chips (8 cores) for one process per host (default on v2-8, v3-8):
os.environ["TPU_CHIPS_PER_HOST_BOUNDS"] = "2,2,1"
os.environ["TPU_HOST_BOUNDS"] = "1,1,1"
os.environ["TPU_VISIBLE_DEVICES"] = "0,1,2,3"
# 1-liner for bash: TPU_CHIPS_PER_HOST_BOUNDS=2,2,1 TPU_HOST_BOUNDS=1,1,1 TPU_VISIBLE_DEVICES=0,1,2,3
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment