MKK Blog project
Pong (Reinforcement Learning) with WebSocket Backend
Personal

Pong (Reinforcement Learning) with WebSocket Backend

Aug 2025
Table of Contents

Reinforcement Learning

For this project, I implemented Proximal Policy Optimization (PPO) to train agents to play Pong. Unlike standard Supervised Learning where we have a ground truth (e.g., “This image is a cat”), Reinforcement Learning (RL) involves an agent learning through trial and error by interacting with an environment.

The unique challenge here was Self-Play (Multi-Agent). The agent doesn’t play against a pre-programmed definition of a “hard” computer opponent; it plays against a copy of itself. As the agent gets better, its opponent gets better, creating a natural curriculum of increasing difficulty.

Here is the high-level architecture of the training loop:

  1. Environment: pong_v3 (PettingZoo) - Multi-agent Atari environment.
  2. Algorithm: PPO (Proximal Policy Optimization).
  3. Model: A 3-Layer Convolutional Neural Network (CNN) processing stacked frames.

The Model Architecture

The agent “sees” the game not as a single static image, but as a stack of 4 grayscale consecutive frames (84×8484 \times 84 pixels). Additionally, because this is a multi-agent environment, 2 extra channels are added to indicate which agent (paddle) the model is currently controlling (Agent Indicator).

The input tensor shape is (Batch, 6, 84, 84) (4 Frame Stack + 2 Agent Indicators).

self.network = nn.Sequential(
    layer_init(nn.Conv2d(6, 32, 8, stride=4)),  # Layer 1
    nn.ReLU(),
    layer_init(nn.Conv2d(32, 64, 4, stride=2)), # Layer 2
    nn.ReLU(),
    layer_init(nn.Conv2d(64, 64, 3, stride=1)), # Layer 3
    nn.ReLU(),
    nn.Flatten(),
    layer_init(nn.Linear(64 * 7 * 7, 512)),     # Dense Representation
    nn.ReLU(),
)

This backbone feeds into two separate “heads”:

  1. Actor (Policy): Outputs logits for the 6 possible actions (NOOP, FIRE, UP, DOWN, UP+FIRE, DOWN+FIRE).
  2. Critic (Value): Estimates the Value V(s)V(s) of the current state (how likely am I to win from here?).

Proximal Policy Optimization (PPO)

PPO is a policy gradient method that optimizes the agent’s decision-making policy. It iteratively improves the policy by taking small, safe update steps, preventing the unstable learning that can occur if the policy changes too drastically in a single update.

1. Policy Gradient & The “Clip”

The core idea of Policy Gradient is simple:

However, standard policy gradient methods can be unstable. If we take too large a step based on a single batch of data, the policy might change drastically and never recover (the “cliff” problem).

PPO solves this with a Clipping Mechanism. It prevents the new policy πθ\pi_\theta from deviating too far from the old policy πold\pi_{\text{old}} in a single update.

The Objective Function:

LCLIP(θ)=E^t[min(rt(θ)A^t,clip(rt(θ),1ϵ,1+ϵ)A^t)]L^{CLIP}(\theta) = \hat{\mathbb{E}}_t \left[ \min(r_t(\theta)\hat{A}_t, \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon)\hat{A}_t) \right]

Where:

Here is a simple calculation demonstrating how this works. Assume an advantage A=1.0A = 1.0 and ϵ=0.2\epsilon = 0.2.

Case 1: Safe Update The policy changes slightly (Ratio r=1.1r = 1.1).

Objective=min(1.1×1.0,clip(1.1,0.8,1.2)×1.0)=min(1.1,1.1)=1.1\text{Objective} = \min(1.1 \times 1.0, \text{clip}(1.1, 0.8, 1.2) \times 1.0) = \min(1.1, 1.1) = 1.1

(Since r=1.1r=1.1 is within the bounds [0.8,1.2][0.8, 1.2], the clip function returns it unchanged)

Result: The update is accepted as is.

Case 2: Dangerous Update The policy changes drastically (Ratio r=2.0r = 2.0, doubling the probability).

Objective=min(2.0×1.0,clip(2.0,0.8,1.2)×1.0)=min(2.0,1.2)=1.2\text{Objective} = \min(2.0 \times 1.0, \text{clip}(2.0, 0.8, 1.2) \times 1.0) = \min(2.0, 1.2) = 1.2

(Since r=2.0r=2.0 is greater than the upper bound 1+ϵ=1.21+\epsilon=1.2, the clip function limits it to 1.21.2)

Result: The update is clipped. The model is only rewarded as if it took the maximum safe step (1.2), removing the incentive to make such a large jump.

2. Generalized Advantage Estimation (GAE)

How do we know if an action was “good”? We use the Advantage Function.

Simply using the immediate reward is shortsighted (hitting the ball now might lead to losing later). PPO uses Generalized Advantage Estimation (GAE) to balance variance and bias. It calculates a weighted sum of Bellman errors (δ\delta) to propagate rewards backward in time.

A^t=δt+(γλ)δt+1+(γλ)2δt+2+...\hat{A}_t = \delta_t + (\gamma \lambda)\delta_{t+1} + (\gamma \lambda)^2\delta_{t+2} + ...

To calculate this in practice, we use two components:

  1. TD (Temporal Difference) Error (δt\delta_t): The difference between “What actually happened + What we expect next” vs “What we expected originally”. δt=rt+γV(st+1)TargetV(st)Baseline\delta*t = \underbrace{r_t + \gamma V(s*{t+1})}_{\text{Target}} - \underbrace{V(s_t)}_{\text{Baseline}}
  2. Advantage (AtA_t): The recursive sum of these errors (GAE). At=δt+(γλ)At+1A*t = \delta_t + (\gamma \lambda) A*{t+1}

Where:

To see this propagation in action, let’s look at a key moment in the game: The agent hits the ball past the opponent.

Here are the raw values the agent observes:

Step  | Reward | Value (This represents the Critic's predicted probability of winning)
3     | 1.0    | 0.80
2     | 0.0    | 0.50
1     | 0.0    | 0.60

Now, let’s calculate the Advantage (credit) for each step working backwards, assuming γ=0.99\gamma=0.99 and λ=0.95\lambda=0.95:

Step 3 (Winning Hit): The agent hits the ball, and it passes the opponent. Reward = +1.

Step 2 (The Setup): The agent stays in position as the ball approaches. Reward = 0.

Step 1 (The Approach): The agent sees the ball coming and starts moving up. Reward = 0.

Even though Step 1 and 2 had 0 immediate reward, they received a massive positive advantage because they led to the point scoring at Step 3. The agent learns that “moving to the right spot” is just as valuable as the final hit.

3. The Total Loss Function

Finally, we combine everything into a single number that we want to minimize. The total loss function consists of three parts:

  1. Policy Loss (LCLIPL^{CLIP}): “Do more of what worked.” We want to maximize this.
LCLIP=min(r×At,clip(r,1ϵ,1+ϵ)×At)L^{CLIP} = -\min(r \times A_t, \text{clip}(r, 1-\epsilon, 1+\epsilon) \times A_t)
  1. Value Loss (LVFL^{VF}): “Predict the score better.” We minimize the error (MSE) between predictions and actual returns. LVF=(V(st)Vttarget)2L^{VF} = (V(s_t) - V_t^{\text{target}})^2
  2. Entropy Bonus (SS): “Don’t be too sure of yourself.” We want to maximize entropy to encourage exploration.

To optimize all three simultaneously using standard gradient descent (which minimizes loss), we flip the signs of the terms we want to maximize:

Loss=LCLIPMaximize Policy+c1LVFMinimize Prediction Errorc2SMaximize EntropyLoss = \underbrace{- L^{CLIP}}_{\text{Maximize Policy}} + \underbrace{c_1 L^{VF}}_{\text{Minimize Prediction Error}} - \underbrace{c_2 S}_{\text{Maximize Entropy}}

Where,

Let’s calculate the loss for Step 2 of the example above.

1. Inputs

2. Hyperparameters

3. Loss Calculation

LVF=(V2Vtarget)2=(0.500.98)2=(0.48)2=0.2304\begin{aligned} L^{VF} &= (V_2 - V_{target})^2 \\ &= (0.50 - 0.98)^2 \\ &= (-0.48)^2 \\ &= \mathbf{0.2304} \end{aligned}

Final Loss:

Loss=LCLIP+c1LVFc2S=0.576Policy+(0.5×0.2304Value)(0.01×0.6Entropy)=0.576+0.11520.006=0.4668\begin{aligned} Loss &= -L^{CLIP} + c_1 L^{VF} - c_2 S \\ &= -\underbrace{0.576}_{\text{Policy}} + (0.5 \times \underbrace{0.2304}_{\text{Value}}) - (0.01 \times \underbrace{0.6}_{\text{Entropy}}) \\ &= -0.576 + 0.1152 - 0.006 \\ &= \mathbf{-0.4668} \end{aligned}

The loss is negative, which is good. The optimizer will try to make it even more negative by:

  1. Increasing the probability of this good action (increasing Policy term).
  2. Making the Critic’s prediction closer to 0.98 (decreasing Value term).

Training Loop

1. Experience Collection (Rollout)

First, the agent interacts with the environment for num_steps (typically 128) to collect a batch of data. We disable gradient calculation here (torch.no_grad()) because we are only collecting data, not training yet.

for step in range(0, args.num_steps):
    # 1. Forward Pass
    with torch.no_grad():
        action, logprob, _, value = agent.get_action_and_value(next_obs)
        values[step] = value.flatten()

    # 2. Storage
    actions[step] = action
    logprobs[step] = logprob

    # 3. Environment Step
    next_obs, reward, done, info = envs.step(action.cpu().numpy())

    # 4. Update State
    rewards[step] = torch.tensor(reward).to(device).view(-1)
    next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(done).to(device)
  1. Forward Pass: We feed the next_obs (from the previous step) into the agent to get action, logprob, and value. We use no_grad() because we don’t want to backpropagate yet; we are just collecting data.
  2. Storage: We save these into our buffers (values, actions, logprobs) to use later for calculating advantages and losses.
  3. Environment Step: We execute the action in the game environment. Note the .cpu().numpy() conversion because the environment expects standard Python arrays, not GPU tensors.
  4. Update State: We convert the new next_obs and reward back to PyTorch Tensors and move them to the GPU (to(device)), ready for the next iteration of the loop.

2. Generalized Advantage Estimation (GAE)

Once we have a full batch of experience, we calculate the advantages backwards from the last step to the first. This is where we apply the GAE formula to balance bias and variance.

with torch.no_grad():
    # 1. Bootstrap Value
    next_value = agent.get_value(next_obs).reshape(1, -1)
    advantages = torch.zeros_like(rewards).to(device)
    lastgaelam = 0

    # 2. Reverse Loop
    for t in reversed(range(args.num_steps)):
        if t == args.num_steps - 1:
            nextnonterminal = 1.0 - next_done
            nextvalues = next_value
        else:
            nextnonterminal = 1.0 - dones[t + 1]
            nextvalues = values[t + 1]

        # 3. Delta Calculation (TD Error)
        delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t]

        # 4. Recursive Advantage
        advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam

    # 5. Returns
    returns = advantages + values
  1. Bootstrap Value: We need the value of the very last state (V(st+1)V(s_{t+1})) to kickstart the backward calculation.
  2. Reverse Loop: We iterate backwards from the end of the episode to the beginning to allow rewards to “flow” back in time.
  3. Delta Calculation: We compute the TD error δt=rt+γV(st+1)V(st)\delta_t = r_t + \gamma V(s_{t+1}) - V(s_t). The nextnonterminal term ensures we don’t look past the end of an episode (if the game ended, next value is 0).
  4. Recursive Advantage: We calculate the GAE At=δt+(γλ)At+1A_t = \delta_t + (\gamma\lambda)A_{t+1} using lastgaelam to store the advantage from the previous step (which is t+1t+1 since we are going backwards).
  5. Returns: Finally, we compute the target returns Rt=At+VtR_t = A_t + V_t which we will use to train the Value network.

3. Optimization (PPO Update)

Finally, we use the collected experience to update the neural network. We loop through the data multiple times (update_epochs) and calculate the Total Loss (Policy + Value - Entropy) to update the weights.

# Optimizing the policy and value network
for epoch in range(args.update_epochs):
    np.random.shuffle(b_inds)
    for start in range(0, args.batch_size, args.minibatch_size):
        end = start + args.minibatch_size
        mb_inds = b_inds[start:end]

        # 1. Re-Evaluation
        _, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_actions.long()[mb_inds])
        logratio = newlogprob - b_logprobs[mb_inds]
        ratio = logratio.exp()

        # 2. Diagnostics (KL & Clipping)
        with torch.no_grad():
            old_approx_kl = (-logratio).mean()
            approx_kl = ((ratio - 1) - logratio).mean()
            clipfracs += [((ratio - 1.0).abs() > args.clip_coef).float().mean().item()]

        mb_advantages = b_advantages[mb_inds]
        if args.norm_adv:
            mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8)

        # 3. Policy Loss Calculation
        pg_loss1 = -mb_advantages * ratio
        pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef)
        pg_loss = torch.max(pg_loss1, pg_loss2).mean()

        # 4. Value Loss Calculation
        newvalue = newvalue.view(-1)
        if args.clip_vloss:
            v_loss_unclipped = (newvalue - b_returns[mb_inds]) ** 2
            v_clipped = b_values[mb_inds] + torch.clamp(
                newvalue - b_values[mb_inds],
                -args.clip_coef,
                args.clip_coef,
            )
            v_loss_clipped = (v_clipped - b_returns[mb_inds]) ** 2
            v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
            v_loss = 0.5 * v_loss_max.mean()
        else:
            v_loss = 0.5 * ((newvalue - b_returns[mb_inds]) ** 2).mean()

        # 5. Final Optimization Step
        entropy_loss = entropy.mean()
        loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef

        optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm)
        optimizer.step()

    if args.target_kl is not None:
        if approx_kl > args.target_kl:
            break
  1. Re-Evaluation: We pass the batch of observations back through the network to get new probabilities and values. This is crucial because the policy changes slightly after every mini-batch update.
  2. Diagnostics (KL & Clipping): We calculate the Approximate KL Divergence to measure how much the policy has changed. The code calculates this using the estimator (r1)log(r)(r-1) - \log(r): DKLEt[(rt1)log(rt)]D_{KL} \approx \mathbb{E}_t [ (r_t - 1) - \log(r_t) ] If this change is too drastic (larger than target_kl), we stop the update early to preserve training stability. We also track clipfracs to see how often the PPO clipping limits are triggering.
  3. Policy Loss: We implement the Clipped Surrogate Objective here. torch.clamp handles the clipping range [1ϵ,1+ϵ][1-\epsilon, 1+\epsilon], and torch.max (since we are minimizing negative objective) picks the pessimistic bound.
  4. Value Loss: We calculate the MSE between predicted values and actual returns. Note that we also clip the value updates (v_clipped) to prevent the critic from changing too drastically in one go.
  5. Final Optimization: We combine the three terms: pg_loss (Policy), v_loss (Value), and entropy_loss. We then perform standard backpropagation (loss.backward()) and an optimizer step (optimizer.step()).

Backend

The backend is built with FastAPI and serves as the bridge between the browser and the Reinforcement Learning model. Because the agent was trained on raw pixels from the Atari emulator, the backend must run the exact same environment to ensure valid inference.

Tech Stack

TechnologyRoleDescription
FastAPIAPI FrameworkHigh-performance Python framework for handling WebSocket connections.
PettingZooEnvironmentThe Multi-Agent Reinforcement Learning (MARL) environment wrapper for Atari Pong.
PyTorchInferenceRuns the trained PPO agent to predict actions from game states.
OpenCVRenderingProcesses raw pixel frames from the emulator into JPEG images for the frontend.
SuperSuitPreprocessingApplies the same frame skipping, resizing, and stacking used during training.
uvPackage ManagerFast, secure, and user-friendly package manager for Python.
WebSocketCommunicationEnables real-time bidirectional communication between the browser and backend.
DockerContainerizationEnsures consistent environments across development, testing, and production.

Architecture: Server-Side Rendering

Unlike typical web games where the game logic runs in JavaScript on the client, this project uses Server-Side Rendering (SSR) for the game state.

The RL agent’s “brain” is a Convolutional Neural Network (CNN) trained on specific pixel patterns (84x84 grayscale, stacked 4 frames deep) produced by the ALE (Arcade Learning Environment). Replicating the exact physics and rendering quirks of the Atari 2600 in a browser-based JavaScript emulator is incredibly difficult and prone to “distribution shift”, where slight visual differences confuse the agent.

Instead, this application runs the actual PettingZoo environment on the server and streams the visual output to the client.

  1. Server: Runs the game loop, queries the AI for actions, renders the frame.
  2. Network: Sends the frame (Base64 JPEG) via WebSocket.
  3. Client: Displays the image and sends back user keystrokes.

Real-Time Streaming via WebSockets

To achieve the low-latency required for a playable Pong game, the application relies on WebSockets rather than standard HTTP requests.

A traditional HTTP model (Request \rightarrow Response) would add too much overhead for streaming 15-30 frames per second. WebSockets facilitate Real-Time Distribution by:

Implementation Details

1. Environment Setup

The application uses the exact same SuperSuit wrappers as the training phase to ensure the agent sees what it expects.

def create_env():
    """Create Pong environment - using rgb_array for manual rendering"""
    env = pong_v3.parallel_env(render_mode="rgb_array")

    # Critical: Must match training preprocessing exactly
    env = ss.max_observation_v0(env, 2)           # Maximize over 2 frames (flicker fix)
    env = ss.frame_skip_v0(env, 4)                # Skip 4 frames (standard Atari)
    env = ss.clip_reward_v0(env, lower_bound=-1, upper_bound=1)
    env = ss.color_reduction_v0(env, mode="B")    # B&W
    env = ss.resize_v1(env, x_size=84, y_size=84) # Downsample to 84x84
    env = ss.frame_stack_v1(env, 4)               # Stack 4 frames
    return env

2. The Game Loop

The core of the backend is an asyncio loop handling the WebSocket connection. It manages the game state, synchronizes the Human and AI actions, and maintains a stable frame rate.

async def game_loop():
    env = create_env()
    observations, infos = env.reset()

    while state["running"]:
        if is_paused:
            await asyncio.sleep(0.1)
            continue

        # 1. AI Inference (Right Paddle)
        if "first_0" in env.agents:
            obs_tensor = torch.from_numpy(observations["first_0"]).float().unsqueeze(0).to(DEVICE)
            with torch.no_grad():
                action = agent.get_action(obs_tensor)
            actions["first_0"] = action.item()

        # 2. Human Input (Left Paddle)
        if "second_0" in env.agents:
            actions["second_0"] = state["human_action"]

        # 3. Environment Step
        observations, rewards, terminations, truncations, infos = env.step(actions)

        # 4. Render & Send
        frame = env.render()
        frame_base64 = encode_frame(frame) # Helper to convert numpy -> jpg base64

        await websocket.send_json({
            "frame": f"data:image/jpeg;base64,{frame_base64}",
            "reward": float(rewards.get("second_0", 0)),
            "scores": episode_rewards
        })

        # Cap at ~15 FPS to match standard Atari play speed
        await asyncio.sleep(1/15)

Related Projects

    Mike 3.0

    Send a message to start the chat!

    You can ask the bot anything about me and it will help to find the relevant information!

    Try asking: