Skip to content

Instantly share code, notes, and snippets.

@evanthebouncy
Last active March 15, 2026 05:16
Show Gist options
  • Select an option

  • Save evanthebouncy/4f8475dfe372addf8ecaf562522cc57e to your computer and use it in GitHub Desktop.

Select an option

Save evanthebouncy/4f8475dfe372addf8ecaf562522cc57e to your computer and use it in GitHub Desktop.
reinforce_game+AI
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
import numpy as np
from mlagents_envs.environment import UnityEnvironment
from mlagents_envs.envs.unity_gym_env import UnityToGymWrapper
# ── Policy network ────────────────────────────────────────────────────────────
class Policy(nn.Module):
def __init__(self, obs_dim=12, hidden=64):
super().__init__()
self.trunk = nn.Sequential(
nn.Linear(obs_dim, hidden),
nn.ReLU(),
nn.Linear(hidden, hidden),
nn.ReLU(),
)
self.head_x = nn.Linear(hidden, 3) # action head for x component
self.head_y = nn.Linear(hidden, 3) # action head for y component
def forward(self, x):
features = self.trunk(x)
return self.head_x(features), self.head_y(features)
def act(self, obs):
"""Sample (comp_x, comp_y) actions, return (action array, log_prob)."""
logits_x, logits_y = self.forward(obs)
dist_x = Categorical(logits=logits_x)
dist_y = Categorical(logits=logits_y)
ax = dist_x.sample()
ay = dist_y.sample()
log_prob = dist_x.log_prob(ax) + dist_y.log_prob(ay)
return np.array([ax.item(), ay.item()]), log_prob
def compute_returns(rewards, gamma=0.9):
"""Discounted returns G_t = r_t + γ r_{t+1} + γ² r_{t+2} + …"""
G = []
running = 0.0
for r in reversed(rewards):
running = r + gamma * running
G.insert(0, running)
G = torch.tensor(G, dtype=torch.float32)
# Normalise for stable gradients
G = (G - G.mean()) / (G.std() + 1e-8)
return G
# ── Training loop ─────────────────────────────────────────────────────────────
def train(n_episodes=100, lr=1e-3, gamma=0.9, print_every=10):
unity_env = UnityEnvironment("gg_det")
env = UnityToGymWrapper(unity_env)
policy = Policy()
optimizer = optim.Adam(policy.parameters(), lr=lr)
for ep in range(1, n_episodes + 1):
# ═══════════════════════════════════════════════════════════════════
# COLLECT TRAJECTORY
# ═══════════════════════════════════════════════════════════════════
obs = env.reset()
done = False
log_probs = []
rewards = []
while not done:
obs_t = torch.tensor(obs, dtype=torch.float32)
action, log_prob = policy.act(obs_t)
obs, reward, done, _ = env.step(action)
log_probs.append(log_prob)
rewards.append(reward)
# ═══════════════════════════════════════════════════════════════════
# TRAINING
# ═══════════════════════════════════════════════════════════════════
G = compute_returns(rewards, gamma)
loss = torch.stack([-lp * g for lp, g in zip(log_probs, G)]).sum()
optimizer.zero_grad()
loss.backward()
optimizer.step()
if ep % print_every == 0:
total_return = sum(rewards)
print(f"Episode {ep:4d} | Steps: {len(rewards):4d} | "
f"Return: {total_return:8.3f} | Loss: {loss.item():8.4f}")
env.close()
return policy
if __name__ == "__main__":
trained_policy = train(n_episodes=100, lr=1e-3, gamma=0.9, print_every=10)
torch.save(trained_policy.state_dict(), "policy.pt")
print("Saved policy.pt")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment