Building a Robust Soft Actor-Critic (SAC) Agent for LunarLanderContinuous-v2

Leveraging Prioritized Experience Replay, Layer Normalization, Gradient Clipping, and More

Introduction

Deep Reinforcement Learning (DRL) has revolutionized how agents learn complex control tasks. In this article, we explore a comprehensive implementation of the Soft Actor-Critic (SAC) algorithm applied to OpenAI Gym’s LunarLanderContinuous-v2 environment. Our implementation goes beyond a simple example by incorporating advanced improvements such as:

  • Prioritized Experience Replay: Samples important transitions more frequently to improve learning efficiency.
  • Layer Normalization: Stabilizes training by normalizing intermediate activations in both the actor and critic networks.
  • Gradient Clipping: Prevents exploding gradients, ensuring smoother updates.
  • Learning Rate Scheduling: Dynamically decays learning rates during training.
  • TensorBoard Logging & Checkpointing: Enables real-time monitoring of training progress and model saving.
  • Video Recording: Uses Gym’s RecordVideo wrapper to visualize the agent’s performance post-training.

This blog post presents the code divided into three modular files: sac_model.py (shared definitions), train.py (training the agent), and visualize.py (recording and visualizing the agent’s behavior).

Algorithm Overview

Soft Actor-Critic (SAC) is an off-policy actor-critic algorithm that optimizes a stochastic policy in an entropy-regularized framework. This means that the agent not only aims to maximize cumulative rewards but also strives to maximize the policy’s entropy. The benefits of such an approach include improved exploration and more stable convergence.

Key components of our implementation include:

  • Actor (Policy) Network: Outputs a Gaussian distribution (mean and log standard deviation) over actions. We employ the reparameterization trick along with tanh squashing to ensure actions fall within valid ranges.
  • Critic (Q) Networks: Two critic networks mitigate overestimation bias by taking the minimum Q-value estimate during updates.
  • Target Networks & Soft Updates: Target networks for the critics are updated slowly using a soft update factor (TAU) to improve stability.
  • Automatic Entropy Tuning: The entropy temperature (alpha) is automatically adjusted to balance exploration and exploitation.
  • Prioritized Replay Buffer: Samples transitions based on their TD-error, ensuring that more “informative” experiences are replayed more often.

Code Structure

Our code is modularized into three files:

  1. sac_model.py:
    Contains shared definitions including the replay buffer (both prioritized and uniform versions), the GaussianPolicy network, QNetwork, and the SACAgent class. This file encapsulates the model architecture and core components.
  2. train.py:
    Implements the training loop. This file initializes the environment, sets up the SAC agent and prioritized replay buffer, configures optimizers with gradient clipping and learning rate schedulers, and logs training metrics via TensorBoard. It also periodically saves checkpoints of the trained model.
  3. visualize.py:
    Loads the saved SAC agent and records evaluation episodes using Gym’s RecordVideo wrapper. It registers the necessary classes for pickle compatibility, then runs the agent in the environment to produce video files showing the agent’s performance.

Detailed Code Walkthrough

1. sac_model.py

This file defines our core model components with detailed docstrings:

Python
"""
sac_model.py

This module implements the core components of the Soft Actor-Critic (SAC)
algorithm for continuous control tasks. It includes:

- PrioritizedReplayBuffer: A replay buffer with prioritized sampling.
- ReplayBuffer: A uniform replay buffer.
- GaussianPolicy: The actor network with layer normalization.
- QNetwork: The critic network with layer normalization.
- SACAgent: The SAC agent that aggregates the networks and automatic entropy tuning.

All components run on the available device (CPU or GPU).
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal
import numpy as np
import random

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class PrioritizedReplayBuffer:
    """
    A prioritized replay buffer that samples transitions with probabilities 
    proportional to their TD error (raised to a power alpha).
    """
    def __init__(self, capacity, alpha=0.6):
        self.capacity = capacity
        self.buffer = []
        self.pos = 0
        self.priorities = np.zeros((capacity,), dtype=np.float32)
        self.alpha = alpha

    def push(self, state, action, reward, next_state, done):
        max_prio = self.priorities.max() if self.buffer else 1.0
        if len(self.buffer) < self.capacity:
            self.buffer.append((state, action, reward, next_state, done))
        else:
            self.buffer[self.pos] = (state, action, reward, next_state, done)
        self.priorities[self.pos] = max_prio
        self.pos = (self.pos + 1) % self.capacity

    def sample(self, batch_size, beta=0.4):
        if len(self.buffer) == self.capacity:
            prios = self.priorities
        else:
            prios = self.priorities[:len(self.buffer)]
        probs = prios ** self.alpha
        probs /= probs.sum()
        indices = np.random.choice(len(self.buffer), batch_size, p=probs)
        samples = [self.buffer[idx] for idx in indices]
        total = len(self.buffer)
        weights = (total * probs[indices]) ** (-beta)
        weights /= weights.max()
        weights = np.array(weights, dtype=np.float32)
        state, action, reward, next_state, done = map(np.stack, zip(*samples))
        return (torch.FloatTensor(state).to(device),
                torch.FloatTensor(action).to(device),
                torch.FloatTensor(reward).unsqueeze(1).to(device),
                torch.FloatTensor(next_state).to(device),
                torch.FloatTensor(np.float32(done)).unsqueeze(1).to(device),
                torch.FloatTensor(weights).unsqueeze(1).to(device),
                indices)

    def update_priorities(self, indices, priorities):
        for idx, prio in zip(indices, priorities):
            self.priorities[idx] = prio

    def __len__(self):
        return len(self.buffer)

class ReplayBuffer:
    """
    A uniform replay buffer for storing transitions.
    """
    def __init__(self, capacity):
        self.capacity = capacity
        self.buffer = []
        self.position = 0

    def push(self, state, action, reward, next_state, done):
        if len(self.buffer) < self.capacity:
            self.buffer.append(None)
        self.buffer[self.position] = (state, action, reward, next_state, done)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        state, action, reward, next_state, done = map(np.stack, zip(*batch))
        return (torch.FloatTensor(state).to(device),
                torch.FloatTensor(action).to(device),
                torch.FloatTensor(reward).unsqueeze(1).to(device),
                torch.FloatTensor(next_state).to(device),
                torch.FloatTensor(np.float32(done)).unsqueeze(1).to(device))

    def __len__(self):
        return len(self.buffer)

class GaussianPolicy(nn.Module):
    """
    A Gaussian policy network that outputs a Gaussian distribution over actions.
    Uses two hidden layers with layer normalization for stability.
    """
    def __init__(self, num_inputs, num_actions, hidden_size, log_std_min=-20, log_std_max=2):
        super(GaussianPolicy, self).__init__()
        self.log_std_min = log_std_min
        self.log_std_max = log_std_max

        self.fc1 = nn.Linear(num_inputs, hidden_size)
        self.ln1 = nn.LayerNorm(hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.ln2 = nn.LayerNorm(hidden_size)
        self.mean_layer = nn.Linear(hidden_size, num_actions)
        self.log_std_layer = nn.Linear(hidden_size, num_actions)
    
    def forward(self, state):
        x = F.relu(self.ln1(self.fc1(state)))
        x = F.relu(self.ln2(self.fc2(x)))
        mean = self.mean_layer(x)
        log_std = self.log_std_layer(x)
        log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
        return mean, log_std

    def sample(self, state):
        mean, log_std = self.forward(state)
        std = log_std.exp()
        normal = Normal(mean, std)
        x_t = normal.rsample()  # reparameterization trick
        y_t = torch.tanh(x_t)
        action = y_t
        log_prob = normal.log_prob(x_t)
        log_prob = log_prob.sum(dim=-1, keepdim=True)
        log_prob -= torch.log(1 - y_t.pow(2) + 1e-6).sum(dim=-1, keepdim=True)
        mean = torch.tanh(mean)
        return action, log_prob, mean

class QNetwork(nn.Module):
    """
    A critic network (Q-network) that estimates the Q-value for a given state-action pair.
    Uses two hidden layers with layer normalization.
    """
    def __init__(self, num_inputs, num_actions, hidden_size):
        super(QNetwork, self).__init__()
        self.fc1 = nn.Linear(num_inputs + num_actions, hidden_size)
        self.ln1 = nn.LayerNorm(hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.ln2 = nn.LayerNorm(hidden_size)
        self.fc3 = nn.Linear(hidden_size, 1)
    
    def forward(self, state, action):
        x = torch.cat([state, action], dim=1)
        x = F.relu(self.ln1(self.fc1(x)))
        x = F.relu(self.ln2(self.fc2(x)))
        return self.fc3(x)

class SACAgent:
    """
    Soft Actor-Critic (SAC) Agent.
    
    This agent aggregates the actor, critics, target networks, and automatic entropy tuning.
    Optimizers are assigned externally during training.
    """
    def __init__(self, state_dim, action_dim, hidden_size, init_alpha=0.2):
        self.state_dim = state_dim
        self.action_dim = action_dim

        self.policy = GaussianPolicy(state_dim, action_dim, hidden_size).to(device)
        self.policy_optimizer = None  # Set externally during training

        self.q1 = QNetwork(state_dim, action_dim, hidden_size).to(device)
        self.q2 = QNetwork(state_dim, action_dim, hidden_size).to(device)
        self.q1_optimizer = None
        self.q2_optimizer = None

        self.q1_target = QNetwork(state_dim, action_dim, hidden_size).to(device)
        self.q2_target = QNetwork(state_dim, action_dim, hidden_size).to(device)
        self.q1_target.load_state_dict(self.q1.state_dict())
        self.q2_target.load_state_dict(self.q2.state_dict())

        self.target_entropy = -action_dim
        self.log_alpha = torch.tensor(np.log(init_alpha), requires_grad=True, device=device)
        self.alpha_optimizer = None  # Set externally during training

    @property
    def alpha(self):
        """
        Return the current temperature (alpha) value computed as exp(log_alpha).
        """
        return self.log_alpha.exp()

    def select_action(self, state, evaluate=False):
        """
        Select an action given the state.
        
        Args:
            state (np.array): Current state.
            evaluate (bool): Whether to use evaluation mode (deterministic).
            
        Returns:
            np.array: Selected action.
        """
        state = torch.FloatTensor(state).unsqueeze(0).to(device)
        with torch.no_grad():
            action, _, _ = self.policy.sample(state)
        return action.cpu().numpy()[0]

2. train.py

This file contains the training loop with extensive improvements and logging.

Python
"""
train.py

Training script for the Soft Actor-Critic (SAC) agent on LunarLanderContinuous-v2.

This script:
- Initializes the environment and the SAC agent.
- Uses a prioritized replay buffer.
- Implements improvements including layer normalization, gradient clipping,
  learning rate scheduling, TensorBoard logging, and checkpointing.
- Trains the agent and saves the final model.
"""

import gym
import numpy as np
import random
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from sac_model import SACAgent, PrioritizedReplayBuffer
import os

# Hyperparameters
ENV_NAME = "LunarLanderContinuous-v2"
HIDDEN_SIZE = 256
REPLAY_BUFFER_SIZE = int(1e6)
BATCH_SIZE = 256
GAMMA = 0.99
TAU = 0.005         # Soft update factor
LR_ACTOR = 3e-4
LR_CRITIC = 3e-4
INIT_ALPHA = 0.2

MAX_EPISODES = 1000
MAX_STEPS = 1000    # Maximum steps per episode
START_STEPS = 10000 # Use random actions until this many steps are collected
UPDATE_AFTER = 1000 # Begin network updates after these many steps
UPDATE_EVERY = 50   # Perform updates every n steps
GRAD_CLIP = 1.0

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def main():
    """Main training loop for the SAC agent."""
    env = gym.make(ENV_NAME)
    state, _ = env.reset(seed=0)
    env.action_space.seed(0)
    env.observation_space.seed(0)
    random.seed(0)
    np.random.seed(0)
    torch.manual_seed(0)

    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]

    # Initialize the prioritized replay buffer
    replay_buffer = PrioritizedReplayBuffer(REPLAY_BUFFER_SIZE, alpha=0.6)

    # Initialize SAC agent
    agent = SACAgent(state_dim, action_dim, HIDDEN_SIZE, init_alpha=INIT_ALPHA)
    # Set up optimizers
    agent.policy_optimizer = optim.Adam(agent.policy.parameters(), lr=LR_ACTOR)
    agent.q1_optimizer = optim.Adam(agent.q1.parameters(), lr=LR_CRITIC)
    agent.q2_optimizer = optim.Adam(agent.q2.parameters(), lr=LR_CRITIC)
    agent.alpha_optimizer = optim.Adam([agent.log_alpha], lr=LR_ACTOR)

    # Set up learning rate schedulers (decay every 100 episodes)
    policy_scheduler = optim.lr_scheduler.StepLR(agent.policy_optimizer, step_size=100, gamma=0.95)
    q1_scheduler = optim.lr_scheduler.StepLR(agent.q1_optimizer, step_size=100, gamma=0.95)
    q2_scheduler = optim.lr_scheduler.StepLR(agent.q2_optimizer, step_size=100, gamma=0.95)
    alpha_scheduler = optim.lr_scheduler.StepLR(agent.alpha_optimizer, step_size=100, gamma=0.95)

    writer = SummaryWriter("runs/sac_training")
    os.makedirs("checkpoints", exist_ok=True)

    total_steps = 0
    beta_start = 0.4
    beta_frames = MAX_EPISODES * MAX_STEPS
    for episode in range(1, MAX_EPISODES + 1):
        state, _ = env.reset()
        episode_reward = 0
        for step in range(MAX_STEPS):
            if total_steps < START_STEPS:
                action = env.action_space.sample()
            else:
                action = agent.select_action(state)
            next_state, reward, terminated, truncated, info = env.step(action)
            done = terminated or truncated
            replay_buffer.push(state, action, reward, next_state, done)
            state = next_state
            episode_reward += reward
            total_steps += 1

            if total_steps >= UPDATE_AFTER and total_steps % UPDATE_EVERY == 0:
                # Increase beta linearly from beta_start to 1 over beta_frames
                beta = min(1.0, beta_start + total_steps * (1.0 - beta_start) / beta_frames)
                for _ in range(UPDATE_EVERY):
                    (state_batch, action_batch, reward_batch, next_state_batch,
                     done_batch, weights, indices) = replay_buffer.sample(BATCH_SIZE, beta=beta)
                    with torch.no_grad():
                        next_action, next_log_prob, _ = agent.policy.sample(next_state_batch)
                        q1_next = agent.q1_target(next_state_batch, next_action)
                        q2_next = agent.q2_target(next_state_batch, next_action)
                        min_q_next = torch.min(q1_next, q2_next) - agent.alpha * next_log_prob
                        q_target = reward_batch + (1 - done_batch) * GAMMA * min_q_next

                    q1_current = agent.q1(state_batch, action_batch)
                    q2_current = agent.q2(state_batch, action_batch)
                    # Compute TD errors for updating priorities
                    td_error1 = (q1_current - q_target).abs().detach().cpu().numpy().flatten() + 1e-6
                    td_error2 = (q2_current - q_target).abs().detach().cpu().numpy().flatten() + 1e-6
                    td_error = (td_error1 + td_error2) / 2.0
                    replay_buffer.update_priorities(indices, td_error)

                    q1_loss = (weights * F.mse_loss(q1_current, q_target, reduction='none')).mean()
                    q2_loss = (weights * F.mse_loss(q2_current, q_target, reduction='none')).mean()

                    agent.q1_optimizer.zero_grad()
                    q1_loss.backward()
                    torch.nn.utils.clip_grad_norm_(agent.q1.parameters(), GRAD_CLIP)
                    agent.q1_optimizer.step()

                    agent.q2_optimizer.zero_grad()
                    q2_loss.backward()
                    torch.nn.utils.clip_grad_norm_(agent.q2.parameters(), GRAD_CLIP)
                    agent.q2_optimizer.step()

                    new_action, log_prob, _ = agent.policy.sample(state_batch)
                    q1_new = agent.q1(state_batch, new_action)
                    q2_new = agent.q2(state_batch, new_action)
                    q_new = torch.min(q1_new, q2_new)
                    policy_loss = (agent.alpha * log_prob - q_new).mean()

                    agent.policy_optimizer.zero_grad()
                    policy_loss.backward()
                    torch.nn.utils.clip_grad_norm_(agent.policy.parameters(), GRAD_CLIP)
                    agent.policy_optimizer.step()

                    alpha_loss = -(agent.log_alpha * (log_prob + agent.target_entropy).detach()).mean()
                    agent.alpha_optimizer.zero_grad()
                    alpha_loss.backward()
                    torch.nn.utils.clip_grad_norm_([agent.log_alpha], GRAD_CLIP)
                    agent.alpha_optimizer.step()

                    # Soft update target networks
                    for target_param, param in zip(agent.q1_target.parameters(), agent.q1.parameters()):
                        target_param.data.copy_(target_param.data * (1.0 - TAU) + param.data * TAU)
                    for target_param, param in zip(agent.q2_target.parameters(), agent.q2.parameters()):
                        target_param.data.copy_(target_param.data * (1.0 - TAU) + param.data * TAU)

                # Log update losses (logged once per update block)
                writer.add_scalar("Loss/q1", q1_loss.item(), total_steps)
                writer.add_scalar("Loss/q2", q2_loss.item(), total_steps)
                writer.add_scalar("Loss/policy", policy_loss.item(), total_steps)
                writer.add_scalar("Loss/alpha", alpha_loss.item(), total_steps)

            if done:
                break
        writer.add_scalar("Reward/Episode", episode_reward, episode)
        print(f"Episode: {episode}, Reward: {episode_reward:.2f}, Total Steps: {total_steps}")

        # Step learning rate schedulers if updates have begun
        if total_steps >= UPDATE_AFTER:
            policy_scheduler.step()
            q1_scheduler.step()
            q2_scheduler.step()
            alpha_scheduler.step()

        # Checkpointing: Save model every 50 episodes
        if episode % 50 == 0:
            checkpoint_path = f"checkpoints/sac_agent_episode_{episode}.pth"
            torch.save(agent, checkpoint_path)
            print(f"Checkpoint saved at episode {episode} to {checkpoint_path}")

    env.close()
    torch.save(agent, "sac_agent.pth")
    print("Training complete. Final agent saved as sac_agent.pth.")
    writer.close()

if __name__ == "__main__":
    main()

3. visualize.py

This script loads the trained agent and records evaluation episodes as videos using Gym’s RecordVideo wrapper.

Python
"""
visualize.py

Visualization script for the trained Soft Actor-Critic (SAC) agent on LunarLanderContinuous-v2.

This script:
- Loads the trained agent from a saved checkpoint.
- Uses Gym’s RecordVideo wrapper to record evaluation episodes.
- Saves the recorded videos in the specified folder.
"""

import gym
import torch
from gym.wrappers import RecordVideo
import sac_model
import __main__
# Register classes in __main__ for pickle compatibility
__main__.GaussianPolicy = sac_model.GaussianPolicy
__main__.QNetwork = sac_model.QNetwork
__main__.SACAgent = sac_model.SACAgent

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ENV_NAME = "LunarLanderContinuous-v2"

def visualize_policy_video(agent, num_episodes=3, video_folder="videos"):
    """
    Record and save evaluation episodes as videos.
    
    Args:
        agent (SACAgent): The trained SAC agent.
        num_episodes (int): Number of episodes to record.
        video_folder (str): Folder where the videos will be saved.
    """
    env = gym.make(ENV_NAME, render_mode="rgb_array")
    env = RecordVideo(env, video_folder=video_folder, episode_trigger=lambda episode: True)
    for episode in range(num_episodes):
        state, _ = env.reset()
        done = False
        episode_reward = 0
        while not done:
            action = agent.select_action(state)
            next_state, reward, terminated, truncated, info = env.step(action)
            done = terminated or truncated
            state = next_state
            episode_reward += reward
        print(f"Recorded Evaluation Episode {episode+1}: Reward {episode_reward:.2f}")
    env.close()
    print(f"Videos saved to folder '{video_folder}'.")

def main():
    """
    Load the trained SAC agent and record evaluation videos.
    """
    agent = torch.load("sac_agent.pth", map_location=device)
    agent.policy.to(device)
    visualize_policy_video(agent, num_episodes=3, video_folder="videos")

if __name__ == "__main__":
    main()

Example Videos of Agent Performance

After training, our SAC agent demonstrates robust landing behaviors on LunarLanderContinuous-v2. For instance, the saved videos in the videos folder show the agent’s evolution—from early unstable attempts (with low or negative rewards) to later, more refined landings achieving high cumulative rewards.

Conclusion

This project demonstrates a comprehensive, professional implementation of the Soft Actor-Critic algorithm applied to the LunarLanderContinuous-v2 environment. By incorporating advanced techniques—such as prioritized replay, layer normalization, gradient clipping, learning rate scheduling, TensorBoard logging, and checkpointing—we built an agent that learns robust control policies. Additionally, using Gym’s RecordVideo wrapper allows us to visualize the agent’s performance, making it easier to analyze and improve upon the model.

This modular approach not only simplifies debugging and future modifications but also serves as a solid foundation for further research and application in more complex continuous control tasks.

You can download and inspect the implementation in this post in this GitHub repository .

Happy coding and good luck with your reinforcement learning projects!