Skip to content

Instantly share code, notes, and snippets.

@arhanjain
Created March 23, 2026 21:32
Show Gist options
  • Select an option

  • Save arhanjain/e95bf05904ec3a316259ee16211b1e3d to your computer and use it in GitHub Desktop.

Select an option

Save arhanjain/e95bf05904ec3a316259ee16211b1e3d to your computer and use it in GitHub Desktop.
polaris generate initial conditions
"""
Generate random initial conditions for sim-evals environments via rejection sampling.
Launches IsaacLab to read the scene USD. Expects a /randomization prim whose
bounding box defines the XY spawn region. Discovers rigid body objects, computes
collision radii from their bounding boxes, then generates collision-free random poses
with random Z-axis rotation.
Usage:
uv run generate_initial_conditions.py --environment FoodBussing --num 50 --
uv run generate_initial_conditions.py --environment TapeIntoContainer --num 100 --instruction "Put the tape in the container" --
"""
import tyro
import argparse
import json
import numpy as np
from pathlib import Path
from itertools import combinations
def main(
environment: str,
num: int = 100,
instruction: str = "",
seed: int = 42,
radius_scale: float = 1.2,
include: tuple[str, ...] = (),
exclude: tuple[str, ...] = (),
overwrite: bool = False,
headless: bool = True,
):
"""Generate initial conditions for an environment.
Args:
environment: Gym environment ID (e.g. FoodBussing, TapeIntoContainer)
num: Number of new pose sets to generate
instruction: Task instruction string (required if creating new file)
seed: Random seed
radius_scale: Multiplier on bounding-box radii for collision checking
include: If specified, only randomize these objects
exclude: Object names to exclude from randomization (e.g. Table243)
overwrite: If True, replace existing file instead of appending
headless: Run headless (no GUI)
"""
# Launch IsaacLab (needed for pxr imports)
from isaaclab.app import AppLauncher
parser = argparse.ArgumentParser()
AppLauncher.add_app_launcher_args(parser)
args_cli, _ = parser.parse_known_args()
args_cli.headless = headless
app_launcher = AppLauncher(args_cli)
from pxr import Usd, UsdGeom, UsdPhysics, Gf
# Import environments to get the registry
import polaris.environments # noqa: F401
import gymnasium as gym
# Look up the env spec to find the usd_file
spec = gym.spec(environment)
usd_file = spec.kwargs.get("usd_file")
if usd_file is None:
print(f"Error: environment '{environment}' has no usd_file in its registration kwargs.")
app_launcher.app.close()
return
usd_path = Path(usd_file).resolve()
ic_path = usd_path.parent / "initial_conditions.json"
print(f"USD: {usd_path}")
print(f"Initial conditions: {ic_path}")
# Open the USD stage
stage = Usd.Stage.Open(str(usd_path))
scene_prim = stage.GetPrimAtPath("/World")
bbox_cache = UsdGeom.BBoxCache(Usd.TimeCode.Default(), ["default", "render"])
# Read the /randomization prim for spawn bounds
rand_prim = stage.GetPrimAtPath("/randomization")
if not rand_prim or not rand_prim.IsValid():
raise RuntimeError(
"USD is missing a /randomization prim. "
"Add a cube/volume prim at /randomization whose bounding box "
"defines the XY spawn region."
)
rand_bbox = bbox_cache.ComputeWorldBound(rand_prim)
rand_range = rand_bbox.ComputeAlignedRange()
rand_min = rand_range.GetMin()
rand_max = rand_range.GetMax()
spawn_xmin, spawn_xmax = rand_min[0], rand_max[0]
spawn_ymin, spawn_ymax = rand_min[1], rand_max[1]
print(f"\nRandomization bounds from /randomization:")
print(f" x=[{spawn_xmin:.3f}, {spawn_xmax:.3f}], y=[{spawn_ymin:.3f}, {spawn_ymax:.3f}]")
# Discover rigid body objects + bounding boxes
objects = []
default_poses = {} # name -> (pos, quat)
bbox_sizes = {} # name -> (sx, sy, sz)
for child in scene_prim.GetChildren():
name = child.GetName()
def _has_rigid_body_api(prim):
return prim.HasAPI(UsdPhysics.RigidBodyAPI)
is_rigid = _has_rigid_body_api(child) or any(
_has_rigid_body_api(c) for c in child.GetChildren()
)
if not is_rigid:
continue
# Skip kinematic (fixed) bodies
kinematic_attr = UsdPhysics.RigidBodyAPI(child).GetKinematicEnabledAttr()
if kinematic_attr and kinematic_attr.Get() is True:
print(f" Skipping kinematic body: {name}")
continue
# Filter by include/exclude
if include and name not in include:
print(f" Skipping (not in include): {name}")
continue
if name in exclude:
print(f" Skipping excluded: {name}")
continue
objects.append(name)
# Read default pose
pos = child.GetAttribute("xformOp:translate").Get()
orient_val = child.GetAttribute("xformOp:orient").Get()
if orient_val is not None:
quat = (
orient_val.GetReal(),
orient_val.GetImaginary()[0],
orient_val.GetImaginary()[1],
orient_val.GetImaginary()[2],
)
else:
rotate_val = child.GetAttribute("xformOp:rotateXYZ").Get()
if rotate_val is not None:
rx = Gf.Rotation(Gf.Vec3d(1, 0, 0), rotate_val[0])
ry = Gf.Rotation(Gf.Vec3d(0, 1, 0), rotate_val[1])
rz = Gf.Rotation(Gf.Vec3d(0, 0, 1), rotate_val[2])
combined = rz * ry * rx
q = combined.GetQuat()
quat = (q.GetReal(), q.GetImaginary()[0], q.GetImaginary()[1], q.GetImaginary()[2])
else:
quat = (1.0, 0.0, 0.0, 0.0)
default_poses[name] = (tuple(pos), quat)
# Compute bounding box
bbox = bbox_cache.ComputeWorldBound(child)
rng_bbox = bbox.ComputeAlignedRange()
mn = rng_bbox.GetMin()
mx = rng_bbox.GetMax()
bbox_sizes[name] = (mx[0] - mn[0], mx[1] - mn[1], mx[2] - mn[2])
print(f"\nDiscovered {len(objects)} rigid objects:")
for name in objects:
pos, quat = default_poses[name]
sx, sy, sz = bbox_sizes[name]
print(f" {name}: pos=({pos[0]:.3f}, {pos[1]:.3f}, {pos[2]:.3f}), "
f"bbox=({sx:.4f}, {sy:.4f}, {sz:.4f})")
if not objects:
print("No rigid bodies found. Nothing to do.")
app_launcher.app.close()
return
# Collision radii from XY bounding box
radii = {}
for name in objects:
sx, sy, _ = bbox_sizes[name]
radii[name] = max(sx, sy) / 2.0 * radius_scale
# Per-object spawn bounds, shrunk inward by half the object's XY extent
# so the full object stays inside the randomization region
xy_bounds = {}
for name in objects:
sx, sy, _ = bbox_sizes[name]
pad_x = sx / 2.0
pad_y = sy / 2.0
bx_lo, bx_hi = spawn_xmin + pad_x, spawn_xmax - pad_x
by_lo, by_hi = spawn_ymin + pad_y, spawn_ymax - pad_y
if bx_lo > bx_hi:
bx_lo = bx_hi = (spawn_xmin + spawn_xmax) / 2.0
if by_lo > by_hi:
by_lo = by_hi = (spawn_ymin + spawn_ymax) / 2.0
xy_bounds[name] = (bx_lo, bx_hi, by_lo, by_hi)
z_values = {name: default_poses[name][0][2] for name in objects}
default_quats = {name: default_poses[name][1] for name in objects}
print(f"\nCollision radii:")
for name in objects:
print(f" {name}: {radii[name]:.4f}")
# Rejection sampling — keep retrying until we have exactly `num` poses
rng = np.random.default_rng(seed)
new_poses = []
total_attempts = 0
max_attempts_per_sample = 1000
max_total_attempts = num * max_attempts_per_sample
while len(new_poses) < num and total_attempts < max_total_attempts:
pose_set = _sample_pose_set(objects, xy_bounds, z_values, default_quats, radii, rng, max_attempts_per_sample)
total_attempts += 1
if pose_set is not None:
new_poses.append(pose_set)
print(f"\nGenerated {len(new_poses)}/{num} pose sets ({total_attempts} total attempts)")
if len(new_poses) < num:
print(f"Warning: could only generate {len(new_poses)}/{num}. Try decreasing radius_scale.")
if not new_poses:
app_launcher.app.close()
return
# Load existing file or create new
if ic_path.exists() and not overwrite:
with open(ic_path, "r") as f:
data = json.load(f)
existing_count = len(data["poses"])
data["poses"].extend(new_poses)
print(f"Appended {len(new_poses)} poses to existing {existing_count} (total: {len(data['poses'])})")
else:
if not instruction:
print("Warning: no --instruction provided for new file, using empty string.")
data = {
"instruction": instruction,
"poses": new_poses,
}
if overwrite and ic_path.exists():
print(f"Overwriting existing file with {len(new_poses)} poses")
else:
print(f"Created new file with {len(new_poses)} poses")
with open(ic_path, "w") as f:
json.dump(data, f, indent=4)
print(f"Wrote to {ic_path}")
app_launcher.app.close()
def _quat_multiply(q1, q2):
"""Multiply two quaternions [w, x, y, z]."""
w1, x1, y1, z1 = q1
w2, x2, y2, z2 = q2
return [
w1*w2 - x1*x2 - y1*y2 - z1*z2,
w1*x2 + x1*w2 + y1*z2 - z1*y2,
w1*y2 - x1*z2 + y1*w2 + z1*x2,
w1*z2 + x1*y2 - y1*x2 + z1*w2,
]
def _random_yaw_quaternion(rng: np.random.Generator, default_quat: tuple) -> list[float]:
"""Compose a random Z-axis rotation on top of the default orientation.
Returns [qw, qx, qy, qz]."""
angle = rng.uniform(0, 2 * np.pi)
yaw_q = [float(np.cos(angle / 2)), 0.0, 0.0, float(np.sin(angle / 2))]
return [float(v) for v in _quat_multiply(yaw_q, list(default_quat))]
def _sample_pose_set(
objects: list[str],
xy_bounds: dict[str, tuple[float, float, float, float]],
z_values: dict[str, float],
default_quats: dict[str, tuple],
radii: dict[str, float],
rng: np.random.Generator,
max_attempts: int = 1000,
) -> dict[str, list[float]] | None:
"""Sample collision-free positions for all objects via rejection sampling."""
for _ in range(max_attempts):
positions = {}
for obj in objects:
xmin, xmax, ymin, ymax = xy_bounds[obj]
x = rng.uniform(xmin, xmax)
y = rng.uniform(ymin, ymax)
positions[obj] = np.array([x, y])
# Check all pairwise distances
collision = False
for o1, o2 in combinations(objects, 2):
dist = np.linalg.norm(positions[o1] - positions[o2])
if dist < radii[o1] + radii[o2]:
collision = True
break
if not collision:
pose_set = {}
for obj in objects:
quat = _random_yaw_quaternion(rng, default_quats[obj])
pose_set[obj] = [
float(positions[obj][0]),
float(positions[obj][1]),
z_values[obj],
quat[0], quat[1], quat[2], quat[3],
]
return pose_set
return None
if __name__ == "__main__":
tyro.cli(main)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment