Last active
March 15, 2026 05:16
-
-
Save evanthebouncy/4f8475dfe372addf8ecaf562522cc57e to your computer and use it in GitHub Desktop.
reinforce_game+AI
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from torch.distributions import Categorical | |
| import numpy as np | |
| from mlagents_envs.environment import UnityEnvironment | |
| from mlagents_envs.envs.unity_gym_env import UnityToGymWrapper | |
| # ── Policy network ──────────────────────────────────────────────────────────── | |
| class Policy(nn.Module): | |
| def __init__(self, obs_dim=12, hidden=64): | |
| super().__init__() | |
| self.trunk = nn.Sequential( | |
| nn.Linear(obs_dim, hidden), | |
| nn.ReLU(), | |
| nn.Linear(hidden, hidden), | |
| nn.ReLU(), | |
| ) | |
| self.head_x = nn.Linear(hidden, 3) # action head for x component | |
| self.head_y = nn.Linear(hidden, 3) # action head for y component | |
| def forward(self, x): | |
| features = self.trunk(x) | |
| return self.head_x(features), self.head_y(features) | |
| def act(self, obs): | |
| """Sample (comp_x, comp_y) actions, return (action array, log_prob).""" | |
| logits_x, logits_y = self.forward(obs) | |
| dist_x = Categorical(logits=logits_x) | |
| dist_y = Categorical(logits=logits_y) | |
| ax = dist_x.sample() | |
| ay = dist_y.sample() | |
| log_prob = dist_x.log_prob(ax) + dist_y.log_prob(ay) | |
| return np.array([ax.item(), ay.item()]), log_prob | |
| def compute_returns(rewards, gamma=0.9): | |
| """Discounted returns G_t = r_t + γ r_{t+1} + γ² r_{t+2} + …""" | |
| G = [] | |
| running = 0.0 | |
| for r in reversed(rewards): | |
| running = r + gamma * running | |
| G.insert(0, running) | |
| G = torch.tensor(G, dtype=torch.float32) | |
| # Normalise for stable gradients | |
| G = (G - G.mean()) / (G.std() + 1e-8) | |
| return G | |
| # ── Training loop ───────────────────────────────────────────────────────────── | |
| def train(n_episodes=100, lr=1e-3, gamma=0.9, print_every=10): | |
| unity_env = UnityEnvironment("gg_det") | |
| env = UnityToGymWrapper(unity_env) | |
| policy = Policy() | |
| optimizer = optim.Adam(policy.parameters(), lr=lr) | |
| for ep in range(1, n_episodes + 1): | |
| # ═══════════════════════════════════════════════════════════════════ | |
| # COLLECT TRAJECTORY | |
| # ═══════════════════════════════════════════════════════════════════ | |
| obs = env.reset() | |
| done = False | |
| log_probs = [] | |
| rewards = [] | |
| while not done: | |
| obs_t = torch.tensor(obs, dtype=torch.float32) | |
| action, log_prob = policy.act(obs_t) | |
| obs, reward, done, _ = env.step(action) | |
| log_probs.append(log_prob) | |
| rewards.append(reward) | |
| # ═══════════════════════════════════════════════════════════════════ | |
| # TRAINING | |
| # ═══════════════════════════════════════════════════════════════════ | |
| G = compute_returns(rewards, gamma) | |
| loss = torch.stack([-lp * g for lp, g in zip(log_probs, G)]).sum() | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| if ep % print_every == 0: | |
| total_return = sum(rewards) | |
| print(f"Episode {ep:4d} | Steps: {len(rewards):4d} | " | |
| f"Return: {total_return:8.3f} | Loss: {loss.item():8.4f}") | |
| env.close() | |
| return policy | |
| if __name__ == "__main__": | |
| trained_policy = train(n_episodes=100, lr=1e-3, gamma=0.9, print_every=10) | |
| torch.save(trained_policy.state_dict(), "policy.pt") | |
| print("Saved policy.pt") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment