Created
March 23, 2026 21:32
-
-
Save arhanjain/e95bf05904ec3a316259ee16211b1e3d to your computer and use it in GitHub Desktop.
polaris generate initial conditions
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| Generate random initial conditions for sim-evals environments via rejection sampling. | |
| Launches IsaacLab to read the scene USD. Expects a /randomization prim whose | |
| bounding box defines the XY spawn region. Discovers rigid body objects, computes | |
| collision radii from their bounding boxes, then generates collision-free random poses | |
| with random Z-axis rotation. | |
| Usage: | |
| uv run generate_initial_conditions.py --environment FoodBussing --num 50 -- | |
| uv run generate_initial_conditions.py --environment TapeIntoContainer --num 100 --instruction "Put the tape in the container" -- | |
| """ | |
| import tyro | |
| import argparse | |
| import json | |
| import numpy as np | |
| from pathlib import Path | |
| from itertools import combinations | |
| def main( | |
| environment: str, | |
| num: int = 100, | |
| instruction: str = "", | |
| seed: int = 42, | |
| radius_scale: float = 1.2, | |
| include: tuple[str, ...] = (), | |
| exclude: tuple[str, ...] = (), | |
| overwrite: bool = False, | |
| headless: bool = True, | |
| ): | |
| """Generate initial conditions for an environment. | |
| Args: | |
| environment: Gym environment ID (e.g. FoodBussing, TapeIntoContainer) | |
| num: Number of new pose sets to generate | |
| instruction: Task instruction string (required if creating new file) | |
| seed: Random seed | |
| radius_scale: Multiplier on bounding-box radii for collision checking | |
| include: If specified, only randomize these objects | |
| exclude: Object names to exclude from randomization (e.g. Table243) | |
| overwrite: If True, replace existing file instead of appending | |
| headless: Run headless (no GUI) | |
| """ | |
| # Launch IsaacLab (needed for pxr imports) | |
| from isaaclab.app import AppLauncher | |
| parser = argparse.ArgumentParser() | |
| AppLauncher.add_app_launcher_args(parser) | |
| args_cli, _ = parser.parse_known_args() | |
| args_cli.headless = headless | |
| app_launcher = AppLauncher(args_cli) | |
| from pxr import Usd, UsdGeom, UsdPhysics, Gf | |
| # Import environments to get the registry | |
| import polaris.environments # noqa: F401 | |
| import gymnasium as gym | |
| # Look up the env spec to find the usd_file | |
| spec = gym.spec(environment) | |
| usd_file = spec.kwargs.get("usd_file") | |
| if usd_file is None: | |
| print(f"Error: environment '{environment}' has no usd_file in its registration kwargs.") | |
| app_launcher.app.close() | |
| return | |
| usd_path = Path(usd_file).resolve() | |
| ic_path = usd_path.parent / "initial_conditions.json" | |
| print(f"USD: {usd_path}") | |
| print(f"Initial conditions: {ic_path}") | |
| # Open the USD stage | |
| stage = Usd.Stage.Open(str(usd_path)) | |
| scene_prim = stage.GetPrimAtPath("/World") | |
| bbox_cache = UsdGeom.BBoxCache(Usd.TimeCode.Default(), ["default", "render"]) | |
| # Read the /randomization prim for spawn bounds | |
| rand_prim = stage.GetPrimAtPath("/randomization") | |
| if not rand_prim or not rand_prim.IsValid(): | |
| raise RuntimeError( | |
| "USD is missing a /randomization prim. " | |
| "Add a cube/volume prim at /randomization whose bounding box " | |
| "defines the XY spawn region." | |
| ) | |
| rand_bbox = bbox_cache.ComputeWorldBound(rand_prim) | |
| rand_range = rand_bbox.ComputeAlignedRange() | |
| rand_min = rand_range.GetMin() | |
| rand_max = rand_range.GetMax() | |
| spawn_xmin, spawn_xmax = rand_min[0], rand_max[0] | |
| spawn_ymin, spawn_ymax = rand_min[1], rand_max[1] | |
| print(f"\nRandomization bounds from /randomization:") | |
| print(f" x=[{spawn_xmin:.3f}, {spawn_xmax:.3f}], y=[{spawn_ymin:.3f}, {spawn_ymax:.3f}]") | |
| # Discover rigid body objects + bounding boxes | |
| objects = [] | |
| default_poses = {} # name -> (pos, quat) | |
| bbox_sizes = {} # name -> (sx, sy, sz) | |
| for child in scene_prim.GetChildren(): | |
| name = child.GetName() | |
| def _has_rigid_body_api(prim): | |
| return prim.HasAPI(UsdPhysics.RigidBodyAPI) | |
| is_rigid = _has_rigid_body_api(child) or any( | |
| _has_rigid_body_api(c) for c in child.GetChildren() | |
| ) | |
| if not is_rigid: | |
| continue | |
| # Skip kinematic (fixed) bodies | |
| kinematic_attr = UsdPhysics.RigidBodyAPI(child).GetKinematicEnabledAttr() | |
| if kinematic_attr and kinematic_attr.Get() is True: | |
| print(f" Skipping kinematic body: {name}") | |
| continue | |
| # Filter by include/exclude | |
| if include and name not in include: | |
| print(f" Skipping (not in include): {name}") | |
| continue | |
| if name in exclude: | |
| print(f" Skipping excluded: {name}") | |
| continue | |
| objects.append(name) | |
| # Read default pose | |
| pos = child.GetAttribute("xformOp:translate").Get() | |
| orient_val = child.GetAttribute("xformOp:orient").Get() | |
| if orient_val is not None: | |
| quat = ( | |
| orient_val.GetReal(), | |
| orient_val.GetImaginary()[0], | |
| orient_val.GetImaginary()[1], | |
| orient_val.GetImaginary()[2], | |
| ) | |
| else: | |
| rotate_val = child.GetAttribute("xformOp:rotateXYZ").Get() | |
| if rotate_val is not None: | |
| rx = Gf.Rotation(Gf.Vec3d(1, 0, 0), rotate_val[0]) | |
| ry = Gf.Rotation(Gf.Vec3d(0, 1, 0), rotate_val[1]) | |
| rz = Gf.Rotation(Gf.Vec3d(0, 0, 1), rotate_val[2]) | |
| combined = rz * ry * rx | |
| q = combined.GetQuat() | |
| quat = (q.GetReal(), q.GetImaginary()[0], q.GetImaginary()[1], q.GetImaginary()[2]) | |
| else: | |
| quat = (1.0, 0.0, 0.0, 0.0) | |
| default_poses[name] = (tuple(pos), quat) | |
| # Compute bounding box | |
| bbox = bbox_cache.ComputeWorldBound(child) | |
| rng_bbox = bbox.ComputeAlignedRange() | |
| mn = rng_bbox.GetMin() | |
| mx = rng_bbox.GetMax() | |
| bbox_sizes[name] = (mx[0] - mn[0], mx[1] - mn[1], mx[2] - mn[2]) | |
| print(f"\nDiscovered {len(objects)} rigid objects:") | |
| for name in objects: | |
| pos, quat = default_poses[name] | |
| sx, sy, sz = bbox_sizes[name] | |
| print(f" {name}: pos=({pos[0]:.3f}, {pos[1]:.3f}, {pos[2]:.3f}), " | |
| f"bbox=({sx:.4f}, {sy:.4f}, {sz:.4f})") | |
| if not objects: | |
| print("No rigid bodies found. Nothing to do.") | |
| app_launcher.app.close() | |
| return | |
| # Collision radii from XY bounding box | |
| radii = {} | |
| for name in objects: | |
| sx, sy, _ = bbox_sizes[name] | |
| radii[name] = max(sx, sy) / 2.0 * radius_scale | |
| # Per-object spawn bounds, shrunk inward by half the object's XY extent | |
| # so the full object stays inside the randomization region | |
| xy_bounds = {} | |
| for name in objects: | |
| sx, sy, _ = bbox_sizes[name] | |
| pad_x = sx / 2.0 | |
| pad_y = sy / 2.0 | |
| bx_lo, bx_hi = spawn_xmin + pad_x, spawn_xmax - pad_x | |
| by_lo, by_hi = spawn_ymin + pad_y, spawn_ymax - pad_y | |
| if bx_lo > bx_hi: | |
| bx_lo = bx_hi = (spawn_xmin + spawn_xmax) / 2.0 | |
| if by_lo > by_hi: | |
| by_lo = by_hi = (spawn_ymin + spawn_ymax) / 2.0 | |
| xy_bounds[name] = (bx_lo, bx_hi, by_lo, by_hi) | |
| z_values = {name: default_poses[name][0][2] for name in objects} | |
| default_quats = {name: default_poses[name][1] for name in objects} | |
| print(f"\nCollision radii:") | |
| for name in objects: | |
| print(f" {name}: {radii[name]:.4f}") | |
| # Rejection sampling — keep retrying until we have exactly `num` poses | |
| rng = np.random.default_rng(seed) | |
| new_poses = [] | |
| total_attempts = 0 | |
| max_attempts_per_sample = 1000 | |
| max_total_attempts = num * max_attempts_per_sample | |
| while len(new_poses) < num and total_attempts < max_total_attempts: | |
| pose_set = _sample_pose_set(objects, xy_bounds, z_values, default_quats, radii, rng, max_attempts_per_sample) | |
| total_attempts += 1 | |
| if pose_set is not None: | |
| new_poses.append(pose_set) | |
| print(f"\nGenerated {len(new_poses)}/{num} pose sets ({total_attempts} total attempts)") | |
| if len(new_poses) < num: | |
| print(f"Warning: could only generate {len(new_poses)}/{num}. Try decreasing radius_scale.") | |
| if not new_poses: | |
| app_launcher.app.close() | |
| return | |
| # Load existing file or create new | |
| if ic_path.exists() and not overwrite: | |
| with open(ic_path, "r") as f: | |
| data = json.load(f) | |
| existing_count = len(data["poses"]) | |
| data["poses"].extend(new_poses) | |
| print(f"Appended {len(new_poses)} poses to existing {existing_count} (total: {len(data['poses'])})") | |
| else: | |
| if not instruction: | |
| print("Warning: no --instruction provided for new file, using empty string.") | |
| data = { | |
| "instruction": instruction, | |
| "poses": new_poses, | |
| } | |
| if overwrite and ic_path.exists(): | |
| print(f"Overwriting existing file with {len(new_poses)} poses") | |
| else: | |
| print(f"Created new file with {len(new_poses)} poses") | |
| with open(ic_path, "w") as f: | |
| json.dump(data, f, indent=4) | |
| print(f"Wrote to {ic_path}") | |
| app_launcher.app.close() | |
| def _quat_multiply(q1, q2): | |
| """Multiply two quaternions [w, x, y, z].""" | |
| w1, x1, y1, z1 = q1 | |
| w2, x2, y2, z2 = q2 | |
| return [ | |
| w1*w2 - x1*x2 - y1*y2 - z1*z2, | |
| w1*x2 + x1*w2 + y1*z2 - z1*y2, | |
| w1*y2 - x1*z2 + y1*w2 + z1*x2, | |
| w1*z2 + x1*y2 - y1*x2 + z1*w2, | |
| ] | |
| def _random_yaw_quaternion(rng: np.random.Generator, default_quat: tuple) -> list[float]: | |
| """Compose a random Z-axis rotation on top of the default orientation. | |
| Returns [qw, qx, qy, qz].""" | |
| angle = rng.uniform(0, 2 * np.pi) | |
| yaw_q = [float(np.cos(angle / 2)), 0.0, 0.0, float(np.sin(angle / 2))] | |
| return [float(v) for v in _quat_multiply(yaw_q, list(default_quat))] | |
| def _sample_pose_set( | |
| objects: list[str], | |
| xy_bounds: dict[str, tuple[float, float, float, float]], | |
| z_values: dict[str, float], | |
| default_quats: dict[str, tuple], | |
| radii: dict[str, float], | |
| rng: np.random.Generator, | |
| max_attempts: int = 1000, | |
| ) -> dict[str, list[float]] | None: | |
| """Sample collision-free positions for all objects via rejection sampling.""" | |
| for _ in range(max_attempts): | |
| positions = {} | |
| for obj in objects: | |
| xmin, xmax, ymin, ymax = xy_bounds[obj] | |
| x = rng.uniform(xmin, xmax) | |
| y = rng.uniform(ymin, ymax) | |
| positions[obj] = np.array([x, y]) | |
| # Check all pairwise distances | |
| collision = False | |
| for o1, o2 in combinations(objects, 2): | |
| dist = np.linalg.norm(positions[o1] - positions[o2]) | |
| if dist < radii[o1] + radii[o2]: | |
| collision = True | |
| break | |
| if not collision: | |
| pose_set = {} | |
| for obj in objects: | |
| quat = _random_yaw_quaternion(rng, default_quats[obj]) | |
| pose_set[obj] = [ | |
| float(positions[obj][0]), | |
| float(positions[obj][1]), | |
| z_values[obj], | |
| quat[0], quat[1], quat[2], quat[3], | |
| ] | |
| return pose_set | |
| return None | |
| if __name__ == "__main__": | |
| tyro.cli(main) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment