안정성과 성능을 동시에 잡은 PPO 알고리즘 완전 해부
Proximal Policy Optimization(PPO)은 OpenAI가 제안한 대표적인 정책 기반 강화학습 알고리즘입니다. PPO는 Actor-Critic 구조 위에 "정책 변화의 폭을 제한하는 방식"을 더해 학습 안정성과 성능을 동시에 확보합니다.
왜 PPO인가?
기존의 정책 경사(policy gradient) 방식은 다음과 같은 문제를 안고 있었습니다:
- 학습률을 너무 크게 주면 정책이 급격히 바뀌어 불안정해짐
- 너무 작게 주면 수렴 속도가 느려짐
- TRPO(Trust Region Policy Optimization)는 이를 해결했지만 계산량이 많고 구현이 복잡
PPO는 TRPO의 아이디어를 유지하면서도, 계산은 훨씬 간단한 방식으로 안정적인 학습을 가능하게 합니다.
PPO의 핵심 아이디어
PPO는 정책을 크게 바꾸지 않도록 **정책 비율**(probability ratio)을 도입합니다.
\[ r_t(\theta) = \frac{\pi_\theta(a_t | s_t)}{\pi_{\theta_{\text{old}}}(a_t | s_t)} \]
- pi_theta: 현재 정책
- pi_{theta_old}: 업데이트 전 정책 (행동을 선택할 때 사용된)
- r_t: 정책이 얼마나 바뀌었는지를 나타내는 지표
Clipped Objective
PPO는 정책이 너무 바뀌지 않도록 다음과 같은 목적 함수를 사용합니다:
\[ L^{\text{clip}}(\theta) = \mathbb{E}_t \left[ \min\left( r_t(\theta) A_t, \text{clip}(r_t(\theta), 1 - \epsilon, 1 + \epsilon) A_t \right) \right] \]
- A_t: Advantage — 행동이 얼마나 좋은지를 나타내는 값
- epsilon: 허용된 정책 변화의 폭 (보통 0.1~0.3)
따라서, 정책의 개선이 너무 클 경우, 이를 epsilon으로 제한한다는 것을 의미합니다. 다시말해, 너무 좋은 업데이트도 막는다 인거죠.
PPO 핵심 코드 설명 (PyTorch 기반)
PPO는 정책의 업데이트 폭을 제한하면서도 고속으로 학습되는 알고리즘입니다.
아래 코드는 간략한 예이며, 우선 Actor-Critic 처럼 네트워크 객체를 정의합니다.
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
# 1. Actor-Critic 네트워크 구조
class ActorCritic(nn.Module):
def __init__(self, state_dim, action_dim):
super().__init__()
self.shared = nn.Sequential(
nn.Linear(state_dim, 128),
nn.ReLU()
)
self.actor = nn.Linear(128, action_dim) # 행동 확률
self.critic = nn.Linear(128, 1) # 상태 가치 V(s)
def forward(self, state):
x = self.shared(state)
return self.actor(x), self.critic(x)
def get_action(self, state):
logits, _ = self.forward(state)
dist = Categorical(logits=logits)
action = dist.sample()
return action.item(), dist.log_prob(action), dist.entropy()
def evaluate(self, state, action):
logits, value = self.forward(state)
dist = Categorical(logits=logits)
log_prob = dist.log_prob(action)
entropy = dist.entropy()
return log_prob, entropy, value.squeeze(-1)
Rollout 단계(경험 수집)
에이전트는 일정 step 동안 환경과 상호작용하여 상태, 행동, 보상 등을 수집합니다.
states, actions, rewards, dones, log_probs, values = [], [], [], [], [], []
state = env.reset()
for t in range(rollout_steps):
state_tensor = torch.FloatTensor(state).unsqueeze(0)
action, log_prob, _ = model.get_action(state_tensor)
next_state, reward, done, _ = env.step(action)
states.append(state)
actions.append(action)
rewards.append(reward)
dones.append(done)
log_probs.append(log_prob)
values.append(model(state_tensor)[1])
state = next_state
if done:
state = env.reset()
GAE 기반 Advantage & Return 계산
PPO는 일반적으로 GAE(Generalized Advantage Estimation)를 사용하여 Advantage를 추정합니다.
def compute_gae(rewards, values, dones, gamma=0.99, lam=0.95):
advantages = []
gae = 0
next_value = 0
for t in reversed(range(len(rewards))):
delta = rewards[t] + gamma * next_value * (1 - dones[t]) - values[t]
gae = delta + gamma * lam * (1 - dones[t]) * gae
advantages.insert(0, gae)
next_value = values[t]
returns = [a + v for a, v in zip(advantages, values)]
return torch.tensor(advantages), torch.tensor(returns)
정책 및 가치 함수 학습
본문에서 설명드린 PPO의 클리핑 손실 함수와 MSE 기반의 가치 손실 함수를 계산하여 동시에 학습합니다.
states = torch.FloatTensor(states)
actions = torch.LongTensor(actions)
old_log_probs = torch.stack(log_probs).detach()
advantages, returns = compute_gae(rewards, values, dones)
for _ in range(ppo_epochs):
log_probs, entropy, values = model.evaluate(states, actions)
ratio = torch.exp(log_probs - old_log_probs) # r_t(θ)
surr1 = ratio * advantages
surr2 = torch.clamp(ratio, 1 - eps_clip, 1 + eps_clip) * advantages
policy_loss = -torch.min(surr1, surr2).mean()
value_loss = nn.MSELoss()(values, returns)
loss = policy_loss + 0.5 * value_loss - 0.01 * entropy.mean()
optimizer.zero_grad()
loss.backward()
optimizer.step()
- ratio: 현재 정책이 이전 정책보다 얼마나 확률을 높였는지 비교
- clip: 이 비율이 너무 커지지 않도록 제어 → 정책 안정성
- entropy: 탐험을 유도하기 위한 보너스 항
PPO의 장점!
- 복잡한 제약 조건 없이 안정적인 학습 가능
- 연속 상태 및 행동 공간에서 우수한 성능
- 다양한 환경에서 매우 강력하고 일반적인 성능
단점
- Advantage 추정 성능에 따라 성능이 크게 좌우됨
- 클리핑 범위(epsilon)에 민감할 수 있음
- 여전히 여러 하이퍼파라미터 튜닝이 필요
감사합니다.
'인공지능 > 강화학습' 카테고리의 다른 글
[강화학습] Actor-Critic 이란? (0) | 2025.06.29 |
---|---|
[강화학습] DQN 이란? (2) | 2025.06.28 |
[강화학습] Q-learning 이란? (10) | 2025.06.26 |