Contents
Model-Based RL - Learning Environment Models for Planning
Welcome to our comprehensive guide on Model-Based Reinforcement Learning! In this post, we’ll explore how Model-Based RL learns environment dynamics to enable planning and more sample-efficient learning. We’ll cover the fundamental concepts, popular algorithms, and implement working toy examples.
Introduction: What is Model-Based RL?
Model-Based Reinforcement Learning is an approach where the agent learns a model of the environment’s dynamics and uses this model for planning and decision-making. Unlike model-free methods that learn directly from experience, model-based methods learn how the environment works.
Model-Free vs Model-Based RL
Model-Free RL (DQN, PPO, SAC):
- Learns policy or value functions directly
- No explicit model of environment
- Requires many samples
- Harder to plan ahead
- Examples: DQN, PPO, SAC, A3C
Model-Based RL:
- Learns environment dynamics model
- Uses model for planning
- More sample efficient
- Can plan multiple steps ahead
- Examples: PETS, MBPO, Dreamer, AlphaZero
Why Model-Based RL?
Sample Efficiency:
- Learn from fewer interactions
- Simulate experience using learned model
- Reuse model for multiple tasks
- Faster learning in data-scarce scenarios
Planning:
- Look ahead multiple steps
- Optimize actions using model
- Handle long-term dependencies
- Better strategic decision-making
Generalization:
- Transfer learned models to new tasks
- Adapt to changing environments
- Combine with model-free methods
- Robust to distribution shift
The Model-Based RL Framework
Core Components
1. Dynamics Model
Learns to predict next state given current state and action: \[s_{t+1} = f(s_t, a_t) + \epsilon\]
Where:
- $s_t$: Current state
- $a_t$: Current action
- $f$: Learned dynamics model
- $\epsilon$: Prediction noise
2. Planning Algorithm
Uses learned model to find optimal actions:
- Random Shooting: Sample random action sequences, pick best
- Cross-Entropy Method (CEM): Iteratively refine action sequences
- Model Predictive Control (MPC): Optimize actions over horizon
- Tree Search: Explore action space systematically
3. Data Collection
Gathers experience from real environment:
- Exploration strategies
- Balance exploration and exploitation
- Update model with new data
- Maintain diverse dataset
Toy Example: Simple Grid World
Let’s start with a simple 2D grid world to illustrate model-based concepts:
import numpy as np
import matplotlib.pyplot as plt
from collections import deque
class GridWorld:
"""
Simple 2D grid world environment
Agent moves in 4 directions, collects rewards
"""
def __init__(self, size=5, goal_pos=(4, 4)):
self.size = size
self.goal_pos = goal_pos
self.agent_pos = (0, 0)
# Action space: 0=up, 1=down, 2=left, 3=right
self.action_space = 4
self.state_space = (size, size)
# Obstacles
self.obstacles = [(2, 2), (2, 3), (3, 2)]
def reset(self):
self.agent_pos = (0, 0)
return self.get_state()
def get_state(self):
return np.array(self.agent_pos)
def step(self, action):
# Get current position
x, y = self.agent_pos
# Apply action
if action == 0: # up
new_pos = (x, min(y + 1, self.size - 1))
elif action == 1: # down
new_pos = (x, max(y - 1, 0))
elif action == 2: # left
new_pos = (max(x - 1, 0), y)
elif action == 3: # right
new_pos = (min(x + 1, self.size - 1), y)
# Check obstacles
if new_pos not in self.obstacles:
self.agent_pos = new_pos
# Compute reward
reward = -1 # Step penalty
if self.agent_pos == self.goal_pos:
reward = 100 # Goal reward
# Check done
done = self.agent_pos == self.goal_pos
return self.get_state(), reward, done, {}
def render(self):
grid = np.zeros((self.size, self.size))
# Mark goal
grid[self.goal_pos[1], self.goal_pos[0]] = 0.5
# Mark obstacles
for obs in self.obstacles:
grid[obs[1], obs[0]] = -1
# Mark agent
grid[self.agent_pos[1], self.agent_pos[0]] = 1
plt.figure(figsize=(6, 6))
plt.imshow(grid, cmap='RdYlGn', vmin=-1, vmax=1)
plt.colorbar()
plt.title(f"Agent at {self.agent_pos}, Goal at {self.goal_pos}")
plt.show()
Learning the Dynamics Model
import torch
import torch.nn as nn
import torch.optim as optim
class GridWorldDynamicsModel(nn.Module):
"""
Neural network to predict next state in grid world
"""
def __init__(self, state_dim=2, action_dim=4, hidden_dim=64):
super().__init__()
# Encode state and action
self.encoder = nn.Sequential(
nn.Linear(state_dim + action_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU()
)
# Predict next state
self.decoder = nn.Linear(hidden_dim, state_dim)
def forward(self, state, action):
# One-hot encode action
action_onehot = torch.zeros(action.shape[0], 4)
action_onehot.scatter_(1, action.long(), 1)
# Concatenate state and action
x = torch.cat([state, action_onehot], dim=1)
# Encode and decode
features = self.encoder(x)
next_state = self.decoder(features)
return next_state
def predict(self, state, action):
"""
Predict next state (handles numpy inputs)
"""
with torch.no_grad():
if isinstance(state, np.ndarray):
state = torch.FloatTensor(state)
if isinstance(action, (int, np.integer)):
action = torch.LongTensor([[action]])
elif isinstance(action, np.ndarray):
action = torch.LongTensor(action)
if len(state.shape) == 1:
state = state.unsqueeze(0)
if len(action.shape) == 0:
action = action.unsqueeze(0)
pred = self.forward(state, action)
return pred.squeeze(0).numpy()
def collect_data(env, n_episodes=100):
"""
Collect experience data from environment
"""
data = []
for episode in range(n_episodes):
state = env.reset()
for step in range(50):
# Random action
action = np.random.randint(0, 4)
# Step environment
next_state, reward, done, _ = env.step(action)
# Store transition
data.append({
'state': state.copy(),
'action': action,
'next_state': next_state.copy(),
'reward': reward
})
state = next_state
if done:
break
return data
def train_dynamics_model(model, data, epochs=100, batch_size=32):
"""
Train dynamics model on collected data
"""
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.MSELoss()
for epoch in range(epochs):
# Shuffle data
np.random.shuffle(data)
total_loss = 0
for i in range(0, len(data), batch_size):
batch = data[i:i+batch_size]
# Prepare batch
states = torch.FloatTensor([d['state'] for d in batch])
actions = torch.LongTensor([d['action'] for d in batch])
next_states = torch.FloatTensor([d['next_state'] for d in batch])
# Forward pass
pred_next_states = model(states, actions)
# Compute loss
loss = criterion(pred_next_states, next_states)
# Backward pass
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
if (epoch + 1) % 20 == 0:
print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(data):.6f}")
return model
# Train dynamics model
env = GridWorld(size=5)
data = collect_data(env, n_episodes=200)
model = GridWorldDynamicsModel()
model = train_dynamics_model(model, data, epochs=100)
# Test model
test_state = np.array([2, 2])
test_action = 3 # right
predicted_next = model.predict(test_state, test_action)
print(f"State: {test_state}, Action: {test_action}")
print(f"Predicted next state: {predicted_next}")
Planning with Learned Model
class RandomShootingPlanner:
"""
Plan actions by sampling random action sequences
"""
def __init__(self, model, action_space, horizon=10, n_samples=100):
self.model = model
self.action_space = action_space
self.horizon = horizon
self.n_samples = n_samples
def plan(self, state):
"""
Find best action sequence using random shooting
"""
best_action = None
best_reward = float('-inf')
# Sample random action sequences
for _ in range(self.n_samples):
actions = np.random.randint(0, self.action_space, self.horizon)
# Simulate trajectory
current_state = state.copy()
total_reward = 0
for action in actions:
# Predict next state
next_state = self.model.predict(current_state, action)
# Compute reward (distance to goal)
distance = np.linalg.norm(next_state - np.array([4, 4]))
reward = -distance
total_reward += reward
current_state = next_state
# Update best
if total_reward > best_reward:
best_reward = total_reward
best_action = actions[0]
return best_action
class CrossEntropyMethodPlanner:
"""
Plan actions using Cross-Entropy Method (CEM)
"""
def __init__(self, model, action_space, horizon=10, n_samples=100, n_elite=10):
self.model = model
self.action_space = action_space
self.horizon = horizon
self.n_samples = n_samples
self.n_elite = n_elite
def plan(self, state, n_iterations=5):
"""
Find best action sequence using CEM
"""
# Initialize action distribution
action_mean = np.zeros(self.horizon)
action_std = np.ones(self.horizon) * 2.0
for iteration in range(n_iterations):
# Sample action sequences
samples = []
rewards = []
for _ in range(self.n_samples):
# Sample action sequence
actions = np.random.normal(action_mean, action_std)
actions = np.clip(actions, 0, self.action_space - 1).astype(int)
# Simulate trajectory
current_state = state.copy()
total_reward = 0
for action in actions:
next_state = self.model.predict(current_state, action)
distance = np.linalg.norm(next_state - np.array([4, 4]))
reward = -distance
total_reward += reward
current_state = next_state
samples.append(actions)
rewards.append(total_reward)
# Select elite samples
elite_indices = np.argsort(rewards)[-self.n_elite:]
elite_samples = [samples[i] for i in elite_indices]
# Update distribution
action_mean = np.mean(elite_samples, axis=0)
action_std = np.std(elite_samples, axis=0) + 1e-6
return int(action_mean[0])
# Test planners
state = env.reset()
# Random shooting planner
rs_planner = RandomShootingPlanner(model, action_space=4)
action = rs_planner.plan(state)
print(f"Random Shooting Planner action: {action}")
# CEM planner
cem_planner = CrossEntropyMethodPlanner(model, action_space=4)
action = cem_planner.plan(state)
print(f"CEM Planner action: {action}")
Toy Example: Continuous Control (CartPole)
Let’s implement model-based RL on a continuous control task:
import gym
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
class CartPoleDynamicsModel(nn.Module):
"""
Learn dynamics model for CartPole environment
"""
def __init__(self, state_dim=4, action_dim=1, hidden_dim=128):
super().__init__()
# Encoder
self.encoder = nn.Sequential(
nn.Linear(state_dim + action_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU()
)
# Predict state delta
self.delta_head = nn.Linear(hidden_dim, state_dim)
# Predict reward
self.reward_head = nn.Linear(hidden_dim, 1)
# Predict done
self.done_head = nn.Linear(hidden_dim, 1)
def forward(self, state, action):
# Concatenate state and action
x = torch.cat([state, action], dim=-1)
# Encode
features = self.encoder(x)
# Predictions
delta = self.delta_head(features)
reward = self.reward_head(features)
done = torch.sigmoid(self.done_head(features))
return delta, reward, done
def predict(self, state, action):
"""
Predict next state, reward, done
"""
with torch.no_grad():
if isinstance(state, np.ndarray):
state = torch.FloatTensor(state)
if isinstance(action, np.ndarray):
action = torch.FloatTensor(action)
if len(state.shape) == 1:
state = state.unsqueeze(0)
if len(action.shape) == 1:
action = action.unsqueeze(0)
delta, reward, done = self.forward(state, action)
next_state = state + delta
return (next_state.squeeze(0).numpy(),
reward.item(),
done.item())
class EnsembleDynamicsModel:
"""
Ensemble of dynamics models for uncertainty estimation
"""
def __init__(self, n_models=5, state_dim=4, action_dim=1, hidden_dim=128):
self.models = [
CartPoleDynamicsModel(state_dim, action_dim, hidden_dim)
for _ in range(n_models)
]
def predict(self, state, action):
"""
Predict with ensemble, return mean and std
"""
predictions = []
for model in self.models:
pred = model.predict(state, action)
predictions.append(pred)
# Compute mean and std
predictions = np.array(predictions)
mean_pred = np.mean(predictions, axis=0)
std_pred = np.std(predictions, axis=0)
return mean_pred, std_pred
def collect_cartpole_data(env, n_episodes=100):
"""
Collect data from CartPole environment
"""
data = []
for episode in range(n_episodes):
state = env.reset()
for step in range(200):
# Random action
action = env.action_space.sample()
# Step environment
next_state, reward, done, _ = env.step(action)
# Store transition
data.append({
'state': state.copy(),
'action': np.array([action]),
'next_state': next_state.copy(),
'reward': reward,
'done': done
})
state = next_state
if done:
break
return data
def train_cartpole_model(model, data, epochs=50, batch_size=64):
"""
Train CartPole dynamics model
"""
optimizer = optim.Adam(model.parameters(), lr=1e-3)
mse_loss = nn.MSELoss()
bce_loss = nn.BCELoss()
for epoch in range(epochs):
np.random.shuffle(data)
total_loss = 0
for i in range(0, len(data), batch_size):
batch = data[i:i+batch_size]
# Prepare batch
states = torch.FloatTensor([d['state'] for d in batch])
actions = torch.FloatTensor([d['action'] for d in batch])
next_states = torch.FloatTensor([d['next_state'] for d in batch])
rewards = torch.FloatTensor([d['reward'] for d in batch]).unsqueeze(1)
dones = torch.FloatTensor([d['done'] for d in batch]).unsqueeze(1)
# Forward pass
pred_delta, pred_reward, pred_done = model(states, actions)
# Compute losses
delta_loss = mse_loss(pred_delta, next_states - states)
reward_loss = mse_loss(pred_reward, rewards)
done_loss = bce_loss(pred_done, dones)
loss = delta_loss + reward_loss + done_loss
# Backward pass
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
if (epoch + 1) % 10 == 0:
print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(data):.6f}")
return model
# Train CartPole model
env = gym.make('CartPole-v1')
data = collect_cartpole_data(env, n_episodes=100)
model = CartPoleDynamicsModel()
model = train_cartpole_model(model, data, epochs=50)
# Test model
test_state = env.reset()
test_action = np.array([1])
next_state, reward, done = model.predict(test_state, test_action)
print(f"State: {test_state}")
print(f"Action: {test_action}")
print(f"Predicted next state: {next_state}")
print(f"Predicted reward: {reward}, done: {done}")
MPC with Learned Model
class MPCPlanner:
"""
Model Predictive Control planner
"""
def __init__(self, model, action_space, horizon=20, n_samples=100):
self.model = model
self.action_space = action_space
self.horizon = horizon
self.n_samples = n_samples
def plan(self, state):
"""
Find best action sequence using MPC
"""
best_action_sequence = None
best_reward = float('-inf')
for _ in range(self.n_samples):
# Sample action sequence
actions = np.random.randint(0, self.action_space, self.horizon)
# Simulate trajectory
current_state = state.copy()
total_reward = 0
done = False
for action in actions:
if done:
break
# Predict next state
next_state, reward, done_pred = self.model.predict(
current_state,
np.array([action])
)
total_reward += reward
current_state = next_state
if done_pred > 0.5:
done = True
# Update best
if total_reward > best_reward:
best_reward = total_reward
best_action_sequence = actions
return best_action_sequence[0] if best_action_sequence is not None else 0
def run_mpc_agent(env, model, n_episodes=10):
"""
Run MPC agent in environment
"""
planner = MPCPlanner(model, action_space=env.action_space.n)
for episode in range(n_episodes):
state = env.reset()
total_reward = 0
done = False
while not done:
# Plan action
action = planner.plan(state)
# Execute action
next_state, reward, done, _ = env.step(action)
total_reward += reward
state = next_state
print(f"Episode {episode+1}, Reward: {total_reward}")
# Run MPC agent
run_mpc_agent(env, model, n_episodes=10)
Popular Model-Based RL Algorithms
1. PETS (Probabilistic Ensembles for Trajectory Shooting)
Uses ensemble of probabilistic models for uncertainty-aware planning:
class ProbabilisticDynamicsModel(nn.Module):
"""
Probabilistic dynamics model with variance prediction
"""
def __init__(self, state_dim, action_dim, hidden_dim=128):
super().__init__()
self.encoder = nn.Sequential(
nn.Linear(state_dim + action_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU()
)
# Predict mean and log variance
self.mean_head = nn.Linear(hidden_dim, state_dim)
self.logvar_head = nn.Linear(hidden_dim, state_dim)
def forward(self, state, action):
x = torch.cat([state, action], dim=-1)
features = self.encoder(x)
mean = self.mean_head(features)
logvar = self.logvar_head(features)
return mean, logvar
def sample(self, state, action):
"""
Sample next state from learned distribution
"""
mean, logvar = self.forward(state, action)
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mean + eps * std
class PETSAgent:
"""
PETS: Probabilistic Ensembles for Trajectory Shooting
"""
def __init__(self, state_dim, action_dim, n_models=5, horizon=10):
self.models = [
ProbabilisticDynamicsModel(state_dim, action_dim)
for _ in range(n_models)
]
self.horizon = horizon
self.action_space = action_dim
def plan(self, state, n_samples=100):
"""
Plan using trajectory shooting with probabilistic models
"""
best_action = None
best_reward = float('-inf')
for _ in range(n_samples):
# Sample action sequence
actions = np.random.randint(0, self.action_space, self.horizon)
# Sample model for each step
total_reward = 0
current_state = state.copy()
for action in actions:
# Randomly select model
model_idx = np.random.randint(len(self.models))
model = self.models[model_idx]
# Sample next state
if isinstance(current_state, np.ndarray):
state_tensor = torch.FloatTensor(current_state).unsqueeze(0)
else:
state_tensor = current_state
action_tensor = torch.LongTensor([[action]])
next_state = model.sample(state_tensor, action_tensor)
current_state = next_state.squeeze(0).detach().numpy()
# Simple reward (distance to goal)
total_reward -= np.linalg.norm(current_state[:2])
if total_reward > best_reward:
best_reward = total_reward
best_action = actions[0]
return best_action
2. MBPO (Model-Based Policy Optimization)
Combines model-based planning with policy optimization:
class MBPOAgent:
"""
MBPO: Model-Based Policy Optimization
"""
def __init__(self, state_dim, action_dim, hidden_dim=128):
# Dynamics model
self.dynamics_model = CartPoleDynamicsModel(state_dim, action_dim, hidden_dim)
# Policy network
self.policy = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, action_dim),
nn.Tanh()
)
# Value network
self.value = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1)
)
def get_action(self, state):
"""
Get action from policy
"""
with torch.no_grad():
if isinstance(state, np.ndarray):
state = torch.FloatTensor(state)
if len(state.shape) == 1:
state = state.unsqueeze(0)
action = self.policy(state)
return action.squeeze(0).numpy()
def generate_imagined_data(self, real_data, n_imagined=1000):
"""
Generate imagined data using learned model
"""
imagined_data = []
for _ in range(n_imagined):
# Sample random state from real data
sample = real_data[np.random.randint(len(real_data))]
state = sample['state']
# Get action from policy
action = self.get_action(state)
# Predict next state using model
next_state, reward, done = self.dynamics_model.predict(state, action)
imagined_data.append({
'state': state.copy(),
'action': action,
'next_state': next_state,
'reward': reward,
'done': done
})
return imagined_data
def update_policy(self, data, epochs=10):
"""
Update policy using policy gradient
"""
optimizer = optim.Adam(self.policy.parameters(), lr=1e-4)
for epoch in range(epochs):
np.random.shuffle(data)
for i in range(0, len(data), 32):
batch = data[i:i+32]
states = torch.FloatTensor([d['state'] for d in batch])
actions = torch.FloatTensor([d['action'] for d in batch])
rewards = torch.FloatTensor([d['reward'] for d in batch])
# Get action from policy
pred_actions = self.policy(states)
# Policy gradient loss (simplified)
advantage = rewards - torch.mean(rewards)
loss = -(pred_actions * advantage.unsqueeze(1)).mean()
optimizer.zero_grad()
loss.backward()
optimizer.step()
3. Dreamer
Model-based RL with learned world model in latent space:
class DreamerWorldModel(nn.Module):
"""
Dreamer-style world model with latent dynamics
"""
def __init__(self, obs_dim, action_dim, latent_dim=32, hidden_dim=128):
super().__init__()
# Encoder: observation to latent
self.encoder = nn.Sequential(
nn.Linear(obs_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, latent_dim * 2) # Mean and logvar
)
# Dynamics: latent transition
self.dynamics = nn.Sequential(
nn.Linear(latent_dim + action_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, latent_dim * 2)
)
# Decoder: latent to observation
self.decoder = nn.Sequential(
nn.Linear(latent_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, obs_dim)
)
# Reward predictor
self.reward_head = nn.Linear(latent_dim, 1)
def encode(self, obs):
"""
Encode observation to latent distribution
"""
x = self.encoder(obs)
mean, logvar = x.chunk(2, dim=-1)
return mean, logvar
def decode(self, latent):
"""
Decode latent to observation
"""
return self.decoder(latent)
def transition(self, latent, action):
"""
Predict next latent state
"""
x = torch.cat([latent, action], dim=-1)
x = self.dynamics(x)
mean, logvar = x.chunk(2, dim=-1)
return mean, logvar
def predict_reward(self, latent):
"""
Predict reward from latent
"""
return self.reward_head(latent)
def sample_latent(self, obs):
"""
Sample latent from observation
"""
mean, logvar = self.encode(obs)
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mean + eps * std
def imagine_trajectory(self, initial_obs, actions):
"""
Imagine trajectory from initial observation
"""
latent = self.sample_latent(initial_obs)
rewards = []
for action in actions:
# Transition latent
mean, logvar = self.transition(latent, action)
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
latent = mean + eps * std
# Predict reward
reward = self.predict_reward(latent)
rewards.append(reward)
return rewards
Advantages and Disadvantages
Advantages of Model-Based RL
Sample Efficiency:
- Learn from fewer environment interactions
- Generate synthetic experience using model
- Reuse model across tasks
- Faster convergence
Planning:
- Look ahead multiple steps
- Optimize actions over horizon
- Handle long-term dependencies
- Better strategic decisions
Interpretability:
- Learned model is interpretable
- Understand environment dynamics
- Debug and analyze behavior
- Transfer knowledge
Disadvantages of Model-Based RL
Model Bias:
- Model errors compound over time
- Inaccuracies lead to poor decisions
- Hard to learn complex dynamics
- Model misspecification
Computational Cost:
- Planning is computationally expensive
- Need to simulate many trajectories
- Real-time constraints
- Memory requirements
Training Complexity:
- Need to train both model and policy
- Balance model accuracy and policy performance
- More hyperparameters to tune
- Complex implementation
Practical Tips
1. Start Simple
Begin with simple environments and models:
# Start with linear model
class LinearDynamicsModel(nn.Module):
def __init__(self, state_dim, action_dim):
super().__init__()
self.linear = nn.Linear(state_dim + action_dim, state_dim)
def forward(self, state, action):
x = torch.cat([state, action], dim=-1)
return self.linear(x)
# Progress to neural network
class NNDynamicsModel(nn.Module):
def __init__(self, state_dim, action_dim, hidden_dim=64):
super().__init__()
self.net = nn.Sequential(
nn.Linear(state_dim + action_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, state_dim)
)
def forward(self, state, action):
x = torch.cat([state, action], dim=-1)
return self.net(x)
2. Use Ensembles
Ensemble models reduce uncertainty:
def train_ensemble(n_models=5, state_dim=4, action_dim=1):
"""
Train ensemble of dynamics models
"""
models = []
for i in range(n_models):
# Initialize model
model = CartPoleDynamicsModel(state_dim, action_dim)
# Train on different data subsets
subset = data[i::n_models]
model = train_cartpole_model(model, subset, epochs=50)
models.append(model)
return models
def predict_with_ensemble(models, state, action):
"""
Predict with ensemble, return mean and uncertainty
"""
predictions = []
for model in models:
pred = model.predict(state, action)
predictions.append(pred)
predictions = np.array(predictions)
mean = np.mean(predictions, axis=0)
std = np.std(predictions, axis=0)
return mean, std
3. Balance Real and Imagined Data
Combine real and simulated experience:
def train_mbpo(real_data, imagined_data, ratio=0.5):
"""
Train policy with mix of real and imagined data
"""
# Mix data
n_real = int(len(real_data) * ratio)
n_imagined = len(imagined_data) - n_real
mixed_data = real_data[:n_real] + imagined_data[:n_imagined]
np.random.shuffle(mixed_data)
# Train on mixed data
for epoch in range(10):
for batch in get_batches(mixed_data, batch_size=32):
update_policy(batch)
return policy
4. Monitor Model Accuracy
Track model prediction error:
def evaluate_model(model, test_data):
"""
Evaluate model prediction accuracy
"""
errors = []
for sample in test_data:
state = sample['state']
action = sample['action']
true_next = sample['next_state']
# Predict
pred_next, _, _ = model.predict(state, action)
# Compute error
error = np.linalg.norm(pred_next - true_next)
errors.append(error)
mean_error = np.mean(errors)
std_error = np.std(errors)
print(f"Model error: {mean_error:.4f} ± {std_error:.4f}")
return mean_error, std_error
Testing the Code
Here’s a comprehensive test script to verify all code examples work correctly:
#!/usr/bin/env python3
"""
Test script to verify code in Model-Based RL blog post
"""
import sys
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
print("Testing code from Model-Based RL blog post...\n")
# Test 1: GridWorld Environment
print("Test 1: GridWorld Environment")
try:
class GridWorld:
def __init__(self, size=5, goal_pos=(4, 4)):
self.size = size
self.goal_pos = goal_pos
self.agent_pos = (0, 0)
self.action_space = 4
self.state_space = (size, size)
self.obstacles = [(2, 2), (2, 3), (3, 2)]
def reset(self):
self.agent_pos = (0, 0)
return self.get_state()
def get_state(self):
return np.array(self.agent_pos)
def step(self, action):
x, y = self.agent_pos
if action == 0:
new_pos = (x, min(y + 1, self.size - 1))
elif action == 1:
new_pos = (x, max(y - 1, 0))
elif action == 2:
new_pos = (max(x - 1, 0), y)
elif action == 3:
new_pos = (min(x + 1, self.size - 1), y)
if new_pos not in self.obstacles:
self.agent_pos = new_pos
reward = -1
if self.agent_pos == self.goal_pos:
reward = 100
done = self.agent_pos == self.goal_pos
return self.get_state(), reward, done, {}
env = GridWorld()
state = env.reset()
next_state, reward, done, _ = env.step(3)
assert state.shape == (2,), f"Expected shape (2,), got {state.shape}"
assert isinstance(reward, (int, float)), f"Expected number, got {type(reward)}"
assert isinstance(done, bool), f"Expected bool, got {type(done)}"
print("✓ GridWorld works correctly\n")
except Exception as e:
print(f"✗ GridWorld failed: {e}\n")
sys.exit(1)
# Test 2: GridWorldDynamicsModel
print("Test 2: GridWorldDynamicsModel")
try:
class GridWorldDynamicsModel(nn.Module):
def __init__(self, state_dim=2, action_dim=4, hidden_dim=64):
super().__init__()
self.encoder = nn.Sequential(
nn.Linear(state_dim + action_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU()
)
self.decoder = nn.Linear(hidden_dim, state_dim)
def forward(self, state, action):
action_onehot = torch.zeros(action.shape[0], 4)
action_onehot.scatter_(1, action.long(), 1)
x = torch.cat([state, action_onehot], dim=1)
features = self.encoder(x)
next_state = self.decoder(features)
return next_state
def predict(self, state, action):
with torch.no_grad():
if isinstance(state, np.ndarray):
state = torch.FloatTensor(state)
if isinstance(action, (int, np.integer)):
action = torch.LongTensor([[action]])
elif isinstance(action, np.ndarray):
action = torch.LongTensor(action)
if len(state.shape) == 1:
state = state.unsqueeze(0)
if len(action.shape) == 0:
action = action.unsqueeze(0)
pred = self.forward(state, action)
return pred.squeeze(0).numpy()
model = GridWorldDynamicsModel()
state = torch.randn(1, 2)
action = torch.LongTensor([[1]])
output = model(state, action)
assert output.shape == (1, 2), f"Expected shape (1, 2), got {output.shape}"
test_state = np.array([2, 2])
test_action = 3
pred = model.predict(test_state, test_action)
assert pred.shape == (2,), f"Expected shape (2,), got {pred.shape}"
print("✓ GridWorldDynamicsModel works correctly\n")
except Exception as e:
print(f"✗ GridWorldDynamicsModel failed: {e}\n")
sys.exit(1)
# Test 3: RandomShootingPlanner
print("Test 3: RandomShootingPlanner")
try:
class RandomShootingPlanner:
def __init__(self, model, action_space, horizon=10, n_samples=100):
self.model = model
self.action_space = action_space
self.horizon = horizon
self.n_samples = n_samples
def plan(self, state):
best_action = None
best_reward = float('-inf')
for _ in range(self.n_samples):
actions = np.random.randint(0, self.action_space, self.horizon)
current_state = state.copy()
total_reward = 0
for action in actions:
next_state = self.model.predict(current_state, action)
distance = np.linalg.norm(next_state - np.array([4, 4]))
reward = -distance
total_reward += reward
current_state = next_state
if total_reward > best_reward:
best_reward = total_reward
best_action = actions[0]
return best_action
planner = RandomShootingPlanner(model, action_space=4, n_samples=10)
state = np.array([0, 0])
action = planner.plan(state)
assert isinstance(action, (int, np.integer)), f"Expected int, got {type(action)}"
assert 0 <= action < 4, f"Expected action in [0,3], got {action}"
print("✓ RandomShootingPlanner works correctly\n")
except Exception as e:
print(f"✗ RandomShootingPlanner failed: {e}\n")
sys.exit(1)
# Test 4: CartPoleDynamicsModel
print("Test 4: CartPoleDynamicsModel")
try:
class CartPoleDynamicsModel(nn.Module):
def __init__(self, state_dim=4, action_dim=1, hidden_dim=128):
super().__init__()
self.encoder = nn.Sequential(
nn.Linear(state_dim + action_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU()
)
self.delta_head = nn.Linear(hidden_dim, state_dim)
self.reward_head = nn.Linear(hidden_dim, 1)
self.done_head = nn.Linear(hidden_dim, 1)
def forward(self, state, action):
x = torch.cat([state, action], dim=-1)
features = self.encoder(x)
delta = self.delta_head(features)
reward = self.reward_head(features)
done = torch.sigmoid(self.done_head(features))
return delta, reward, done
def predict(self, state, action):
with torch.no_grad():
if isinstance(state, np.ndarray):
state = torch.FloatTensor(state)
if isinstance(action, np.ndarray):
action = torch.FloatTensor(action)
if len(state.shape) == 1:
state = state.unsqueeze(0)
if len(action.shape) == 1:
action = action.unsqueeze(0)
delta, reward, done = self.forward(state, action)
next_state = state + delta
return (next_state.squeeze(0).numpy(), reward.item(), done.item())
model = CartPoleDynamicsModel()
state = torch.randn(1, 4)
action = torch.randn(1, 1)
delta, reward, done = model(state, action)
assert delta.shape == (1, 4), f"Expected shape (1, 4), got {delta.shape}"
assert reward.shape == (1, 1), f"Expected shape (1, 1), got {reward.shape}"
assert done.shape == (1, 1), f"Expected shape (1, 1), got {done.shape}"
print("✓ CartPoleDynamicsModel works correctly\n")
except Exception as e:
print(f"✗ CartPoleDynamicsModel failed: {e}\n")
sys.exit(1)
# Test 5: EnsembleDynamicsModel
print("Test 5: EnsembleDynamicsModel")
try:
class EnsembleDynamicsModel:
def __init__(self, n_models=5, state_dim=4, action_dim=1, hidden_dim=128):
self.models = [
CartPoleDynamicsModel(state_dim, action_dim, hidden_dim)
for _ in range(n_models)
]
def predict(self, state, action):
predictions = []
for model in self.models:
pred = model.predict(state, action)
predictions.append(pred[0] if isinstance(pred, tuple) else pred)
predictions_array = np.array(predictions)
mean_pred = np.mean(predictions_array, axis=0)
std_pred = np.std(predictions_array, axis=0)
return mean_pred, std_pred
ensemble = EnsembleDynamicsModel(n_models=3)
state = np.random.randn(4)
action = np.array([1])
mean, std = ensemble.predict(state, action)
assert mean.shape == (4,), f"Expected shape (4,), got {mean.shape}"
assert std.shape == (4,), f"Expected shape (4,), got {std.shape}"
print("✓ EnsembleDynamicsModel works correctly\n")
except Exception as e:
print(f"✗ EnsembleDynamicsModel failed: {e}\n")
sys.exit(1)
# Test 6: ProbabilisticDynamicsModel
print("Test 6: ProbabilisticDynamicsModel")
try:
class ProbabilisticDynamicsModel(nn.Module):
def __init__(self, state_dim, action_dim, hidden_dim=128):
super().__init__()
self.encoder = nn.Sequential(
nn.Linear(state_dim + action_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU()
)
self.mean_head = nn.Linear(hidden_dim, state_dim)
self.logvar_head = nn.Linear(hidden_dim, state_dim)
def forward(self, state, action):
x = torch.cat([state, action], dim=-1)
features = self.encoder(x)
mean = self.mean_head(features)
logvar = self.logvar_head(features)
return mean, logvar
def sample(self, state, action):
mean, logvar = self.forward(state, action)
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mean + eps * std
model = ProbabilisticDynamicsModel(state_dim=4, action_dim=1)
state = torch.randn(1, 4)
action = torch.randn(1, 1)
mean, logvar = model(state, action)
sample = model.sample(state, action)
assert mean.shape == (1, 4), f"Expected shape (1, 4), got {mean.shape}"
assert logvar.shape == (1, 4), f"Expected shape (1, 4), got {logvar.shape}"
assert sample.shape == (1, 4), f"Expected shape (1, 4), got {sample.shape}"
print("✓ ProbabilisticDynamicsModel works correctly\n")
except Exception as e:
print(f"✗ ProbabilisticDynamicsModel failed: {e}\n")
sys.exit(1)
# Test 7: DreamerWorldModel
print("Test 7: DreamerWorldModel")
try:
class DreamerWorldModel(nn.Module):
def __init__(self, obs_dim, action_dim, latent_dim=32, hidden_dim=128):
super().__init__()
self.encoder = nn.Sequential(
nn.Linear(obs_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, latent_dim * 2)
)
self.dynamics = nn.Sequential(
nn.Linear(latent_dim + action_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, latent_dim * 2)
)
self.decoder = nn.Sequential(
nn.Linear(latent_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, obs_dim)
)
self.reward_head = nn.Linear(latent_dim, 1)
def encode(self, obs):
x = self.encoder(obs)
mean, logvar = x.chunk(2, dim=-1)
return mean, logvar
def decode(self, latent):
return self.decoder(latent)
def transition(self, latent, action):
x = torch.cat([latent, action], dim=-1)
x = self.dynamics(x)
mean, logvar = x.chunk(2, dim=-1)
return mean, logvar
def predict_reward(self, latent):
return self.reward_head(latent)
def sample_latent(self, obs):
mean, logvar = self.encode(obs)
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mean + eps * std
model = DreamerWorldModel(obs_dim=10, action_dim=2, latent_dim=16)
obs = torch.randn(1, 10)
action = torch.randn(1, 2)
mean, logvar = model.encode(obs)
assert mean.shape == (1, 16), f"Expected shape (1, 16), got {mean.shape}"
latent = model.sample_latent(obs)
assert latent.shape == (1, 16), f"Expected shape (1, 16), got {latent.shape}"
next_mean, next_logvar = model.transition(latent, action)
assert next_mean.shape == (1, 16), f"Expected shape (1, 16), got {next_mean.shape}"
reward = model.predict_reward(latent)
assert reward.shape == (1, 1), f"Expected shape (1, 1), got {reward.shape}"
decoded = model.decode(latent)
assert decoded.shape == (1, 10), f"Expected shape (1, 10), got {decoded.shape}"
print("✓ DreamerWorldModel works correctly\n")
except Exception as e:
print(f"✗ DreamerWorldModel failed: {e}\n")
sys.exit(1)
# Test 8: LinearDynamicsModel
print("Test 8: LinearDynamicsModel")
try:
class LinearDynamicsModel(nn.Module):
def __init__(self, state_dim, action_dim):
super().__init__()
self.linear = nn.Linear(state_dim + action_dim, state_dim)
def forward(self, state, action):
x = torch.cat([state, action], dim=-1)
return self.linear(x)
model = LinearDynamicsModel(state_dim=4, action_dim=1)
state = torch.randn(1, 4)
action = torch.randn(1, 1)
output = model(state, action)
assert output.shape == (1, 4), f"Expected shape (1, 4), got {output.shape}"
print("✓ LinearDynamicsModel works correctly\n")
except Exception as e:
print(f"✗ LinearDynamicsModel failed: {e}\n")
sys.exit(1)
print("=" * 50)
print("All tests passed! ✓")
print("=" * 50)
print("\nThe code in blog post is syntactically correct")
print("and should work as expected.")
To run this test script, save it as test_modelbased_code.py and execute:
python test_modelbased_code.py
Conclusion
Model-Based Reinforcement Learning offers a powerful approach to learning efficient policies by understanding environment dynamics. By learning models of the world and using them for planning, we can achieve sample-efficient learning and better long-term decision-making.
Key Takeaways
- Learn Dynamics Models: Train neural networks to predict state transitions
- Use Planning: Employ MPC, CEM, or other planning algorithms
- Handle Uncertainty: Use ensembles and probabilistic models
- Balance Real/Simulated: Mix real and imagined experience
- Start Simple: Begin with linear models, progress to neural networks
Future Directions
- Better World Models: More accurate and generalizable models
- Sample Efficient Methods: Learn from even fewer interactions
- Uncertainty Quantification: Better handling of model uncertainty
- Hierarchical Planning: Multi-level planning for complex tasks
- Meta-Learning: Learn to learn models quickly
Resources
- Libraries: mbrl-lib, garage, rlpyt
- Simulators: MuJoCo, PyBullet, Isaac Gym
Happy model-based learning!