import numpy as np
import matplotlib.pyplot as plt
import random
import time
from tqdm import tqdm
print("=" * 70)
print("SARSA: APRENDIZADO ON-POLICY POR DIFERENÇA TEMPORAL")
print("=" * 70)
# ============================================
# AMBIENTE: CLIFF WALKING (PENHASCO)
# ============================================
class CliffWalking:
"""
Ambiente 4x12 onde o agente deve evitar o penhasco.
Recompensas: -1 por passo, -100 ao cair no penhasco, 0 ao chegar ao objetivo.
"""
def __init__(self):
self.n_linhas = 4
self.n_colunas = 12
self.n_estados = self.n_linhas * self.n_colunas
self.n_acoes = 4 # 0=cima, 1=baixo, 2=esq, 3=dir
# Posições especiais
self.inicio = 3 * 12 + 0 # linha 3, coluna 0 (canto inferior esquerdo)
self.objetivo = 3 * 12 + 11 # linha 3, coluna 11 (canto inferior direito)
# Penhasco: linha 3, colunas 1 a 10
self.penhasco = [3 * 12 + c for c in range(1, 11)]
def reset(self):
self.estado = self.inicio
return self.estado
def step(self, acao):
"""Executa ação e retorna (próximo_estado, recompensa, terminou)"""
linha = self.estado // self.n_colunas
coluna = self.estado % self.n_colunas
# Calcula movimento
if acao == 0: # cima
linha = max(0, linha - 1)
elif acao == 1: # baixo
linha = min(self.n_linhas - 1, linha + 1)
elif acao == 2: # esquerda
coluna = max(0, coluna - 1)
else: # direita
coluna = min(self.n_colunas - 1, coluna + 1)
novo_estado = linha * self.n_colunas + coluna
self.estado = novo_estado
# Verifica penhasco
if novo_estado in self.penhasco:
return novo_estado, -100, True
# Verifica objetivo
if novo_estado == self.objetivo:
return novo_estado, 0, True
# Movimento normal
return novo_estado, -1, False
def render(self, estado=None):
"""Mostra o estado atual do ambiente"""
if estado is None:
estado = self.estado
print("\n" + "-" * 30)
for i in range(self.n_linhas):
linha = ""
for j in range(self.n_colunas):
s = i * self.n_colunas + j
if s == estado:
linha += "🤠 "
elif s == self.objetivo:
linha += "🏆 "
elif s in self.penhasco:
linha += "💀 "
else:
linha += "⬜ "
print(linha)
print("-" * 30)
# ============================================
# AGENTE SARSA (ON-POLICY)
# ============================================
class SarsaAgente:
"""Agente que aprende com SARSA on-policy"""
def __init__(self, n_estados, n_acoes, alpha=0.1, gamma=0.95, epsilon=0.1):
self.Q = np.zeros((n_estados, n_acoes))
self.alpha = alpha
self.gamma = gamma
self.epsilon = epsilon
self.n_estados = n_estados
self.n_acoes = n_acoes
def escolher_acao(self, estado):
"""Política ε-greedy on-policy"""
if random.random() < self.epsilon:
return random.randint(0, self.n_acoes - 1)
return np.argmax(self.Q[estado])
def aprender(self, estado, acao, recompensa, prox_estado, prox_acao, terminou):
"""Atualização SARSA: usa a próxima ação real"""
if terminou:
alvo = recompensa
else:
alvo = recompensa + self.gamma * self.Q[prox_estado, prox_acao]
erro_td = alvo - self.Q[estado, acao]
self.Q[estado, acao] += self.alpha * erro_td
return erro_td
# ============================================
# AGENTE Q-LEARNING (OFF-POLICY PARA COMPARAÇÃO)
# ============================================
class QLearningAgente:
"""Agente Q-learning off-policy para comparação"""
def __init__(self, n_estados, n_acoes, alpha=0.1, gamma=0.95, epsilon=0.1):
self.Q = np.zeros((n_estados, n_acoes))
self.alpha = alpha
self.gamma = gamma
self.epsilon = epsilon
self.n_estados = n_estados
self.n_acoes = n_acoes
def escolher_acao(self, estado):
if random.random() < self.epsilon:
return random.randint(0, self.n_acoes - 1)
return np.argmax(self.Q[estado])
def aprender(self, estado, acao, recompensa, prox_estado, terminou):
"""Atualização Q-learning: usa o máximo sobre ações futuras"""
if terminou:
alvo = recompensa
else:
alvo = recompensa + self.gamma * np.max(self.Q[prox_estado])
erro_td = alvo - self.Q[estado, acao]
self.Q[estado, acao] += self.alpha * erro_td
return erro_td
# ============================================
# EXPERIMENTO: SARSA vs Q-LEARNING
# ============================================
print("\n" + "=" * 70)
print("COMPARAÇÃO: SARSA (ON-POLICY) vs Q-LEARNING (OFF-POLICY)")
print("=" * 70)
num_episodios = 1000
env = CliffWalking()
# Inicializa agentes
sarsa_agente = SarsaAgente(n_estados=48, n_acoes=4, alpha=0.1, gamma=0.95, epsilon=0.1)
qlearning_agente = QLearningAgente(n_estados=48, n_acoes=4, alpha=0.1, gamma=0.95, epsilon=0.1)
# Armazena histórico
recompensas_sarsa = []
recompensas_qlearning = []
passos_sarsa = []
passos_qlearning = []
print("\n🚀 Treinando SARSA...\n")
# Treino SARSA
with tqdm(total=num_episodios, desc="SARSA", unit="ep",
ncols=80, mininterval=0.5, bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]') as pbar:
for ep in range(num_episodios):
estado = env.reset()
acao = sarsa_agente.escolher_acao(estado)
recompensa_total = 0
passos = 0
terminou = False
while not terminou and passos < 200:
prox_estado, recompensa, terminou = env.step(acao)
prox_acao = sarsa_agente.escolher_acao(prox_estado) if not terminou else None
sarsa_agente.aprender(estado, acao, recompensa, prox_estado, prox_acao, terminou)
recompensa_total += recompensa
estado = prox_estado
acao = prox_acao if not terminou else acao
passos += 1
recompensas_sarsa.append(recompensa_total)
passos_sarsa.append(passos)
if (ep + 1) % 100 == 0:
media_recomp = np.mean(recompensas_sarsa[-100:])
pbar.set_postfix({'Recomp': f'{media_recomp:.1f}'})
pbar.update(100)
elif ep == 0:
pbar.update(1)
print("\n🚀 Treinando Q-Learning...\n")
# Treino Q-Learning
with tqdm(total=num_episodios, desc="Q-Learning", unit="ep",
ncols=80, mininterval=0.5, bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]') as pbar:
for ep in range(num_episodios):
estado = env.reset()
recompensa_total = 0
passos = 0
terminou = False
while not terminou and passos < 200:
acao = qlearning_agente.escolher_acao(estado)
prox_estado, recompensa, terminou = env.step(acao)
qlearning_agente.aprender(estado, acao, recompensa, prox_estado, terminou)
recompensa_total += recompensa
estado = prox_estado
passos += 1
recompensas_qlearning.append(recompensa_total)
passos_qlearning.append(passos)
if (ep + 1) % 100 == 0:
media_recomp = np.mean(recompensas_qlearning[-100:])
pbar.set_postfix({'Recomp': f'{media_recomp:.1f}'})
pbar.update(100)
elif ep == 0:
pbar.update(1)
print("\n✅ Treinamento concluído!")
# ============================================
# AVALIAÇÃO DOS AGENTES (SEM EXPLORAÇÃO)
# ============================================
print("\n" + "=" * 70)
print("AVALIAÇÃO DOS AGENTES (ε = 0)")
print("=" * 70)
def avaliar_agente(agente, nome, n_testes=100):
"""Avalia agente sem exploração"""
# Salva epsilon original
eps_original = agente.epsilon
agente.epsilon = 0
recompensas_teste = []
passos_teste = []
sucessos = 0
with tqdm(total=n_testes, desc=f"Avaliando {nome}", unit="teste",
ncols=80, mininterval=0.5, bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]') as pbar:
for _ in range(n_testes):
estado = env.reset()
terminou = False
recompensa_total = 0
passos = 0
while not terminou and passos < 200:
acao = agente.escolher_acao(estado)
estado, recompensa, terminou = env.step(acao)
recompensa_total += recompensa
passos += 1
if recompensa_total > -50: # Não caiu no penhasco
sucessos += 1
recompensas_teste.append(recompensa_total)
passos_teste.append(passos)
pbar.update(1)
# Restaura epsilon
agente.epsilon = eps_original
return np.mean(recompensas_teste), np.mean(passos_teste), sucessos / n_testes * 100
recomp_sarsa, passos_sarsa_avg, taxa_sarsa = avaliar_agente(sarsa_agente, "SARSA")
recomp_ql, passos_ql_avg, taxa_ql = avaliar_agente(qlearning_agente, "Q-Learning")
print(f"\n🏆 RESULTADOS DA AVALIAÇÃO:")
print(f" SARSA:")
print(f" - Recompensa média: {recomp_sarsa:.1f}")
print(f" - Passos médios: {passos_sarsa_avg:.1f}")
print(f" - Taxa de sucesso: {taxa_sarsa:.1f}%")
print(f"\n Q-Learning:")
print(f" - Recompensa média: {recomp_ql:.1f}")
print(f" - Passos médios: {passos_ql_avg:.1f}")
print(f" - Taxa de sucesso: {taxa_ql:.1f}%")
# ============================================
# VISUALIZAÇÃO DOS RESULTADOS
# ============================================
print("\n📊 Gerando gráficos...")
plt.figure(figsize=(14, 6))
# Gráfico 1: Recompensas por episódio
plt.subplot(2, 2, 1)
media_sarsa = np.convolve(recompensas_sarsa, np.ones(50)/50, mode='valid')
media_ql = np.convolve(recompensas_qlearning, np.ones(50)/50, mode='valid')
plt.plot(media_sarsa, 'g-', linewidth=1.5, label='SARSA', alpha=0.8)
plt.plot(media_ql, 'b-', linewidth=1.5, label='Q-Learning', alpha=0.8)
plt.xlabel('Episódio')
plt.ylabel('Recompensa média (janela 50)')
plt.title('Convergência: SARSA vs Q-Learning')
plt.legend()
plt.grid(True, alpha=0.3)
# Gráfico 2: Passos por episódio (quanto menor, melhor)
plt.subplot(2, 2, 2)
passos_sarsa_smooth = np.convolve(passos_sarsa, np.ones(50)/50, mode='valid')
passos_ql_smooth = np.convolve(passos_qlearning, np.ones(50)/50, mode='valid')
plt.plot(passos_sarsa_smooth, 'g-', linewidth=1.5, label='SARSA', alpha=0.8)
plt.plot(passos_ql_smooth, 'b-', linewidth=1.5, label='Q-Learning', alpha=0.8)
plt.xlabel('Episódio')
plt.ylabel('Passos por episódio')
plt.title('Eficiência (menos passos = melhor)')
plt.legend()
plt.grid(True, alpha=0.3)
# Gráfico 3: Política do SARSA (mapa de calor das ações)
plt.subplot(2, 2, 3)
politica_sarsa = np.argmax(sarsa_agente.Q, axis=1).reshape(4, 12)
setas = ['↑', '↓', '←', '→']
mapa_sarsa = np.empty((4, 12), dtype='<U3')
for i in range(4):
for j in range(12):
s = i * 12 + j
if s in env.penhasco:
mapa_sarsa[i, j] = '💀'
elif s == env.objetivo:
mapa_sarsa[i, j] = '🏆'
else:
mapa_sarsa[i, j] = setas[politica_sarsa[i, j]]
plt.imshow(np.zeros((4, 12)), cmap='gray', alpha=0.1)
for i in range(4):
for j in range(12):
plt.text(j, i, mapa_sarsa[i, j], ha='center', va='center', fontsize=10)
plt.title('Política SARSA (evita o penhasco 💀)')
plt.xlim(-0.5, 11.5)
plt.ylim(3.5, -0.5)
plt.axis('off')
# Gráfico 4: Política do Q-Learning
plt.subplot(2, 2, 4)
politica_ql = np.argmax(qlearning_agente.Q, axis=1).reshape(4, 12)
mapa_ql = np.empty((4, 12), dtype='<U3')
for i in range(4):
for j in range(12):
s = i * 12 + j
if s in env.penhasco:
mapa_ql[i, j] = '💀'
elif s == env.objetivo:
mapa_ql[i, j] = '🏆'
else:
mapa_ql[i, j] = setas[politica_ql[i, j]]
plt.imshow(np.zeros((4, 12)), cmap='gray', alpha=0.1)
for i in range(4):
for j in range(12):
plt.text(j, i, mapa_ql[i, j], ha='center', va='center', fontsize=10)
plt.title('Política Q-Learning (pode se aproximar do penhasco)')
plt.xlim(-0.5, 11.5)
plt.ylim(3.5, -0.5)
plt.axis('off')
plt.tight_layout()
plt.show()
# ============================================
# VISUALIZAÇÃO DA FUNÇÃO VALOR
# ============================================
print("\n📊 Visualizando função valor aprendida...")
plt.figure(figsize=(14, 4))
# Função valor do SARSA
plt.subplot(1, 2, 1)
V_sarsa = np.max(sarsa_agente.Q, axis=1).reshape(4, 12)
V_sarsa_masked = np.ma.masked_where(np.array([[s in env.penhasco for s in range(i*12, (i+1)*12)] for i in range(4)]), V_sarsa)
im1 = plt.imshow(V_sarsa_masked, cmap='RdYlGn', interpolation='nearest')
plt.colorbar(im1, label='Valor V(s)')
for i in range(4):
for j in range(12):
s = i * 12 + j
if s in env.penhasco:
plt.text(j, i, '💀', ha='center', va='center', fontsize=12)
elif s == env.objetivo:
plt.text(j, i, '🏆', ha='center', va='center', fontsize=12)
else:
plt.text(j, i, f'{V_sarsa[i, j]:.1f}', ha='center', va='center', fontsize=7)
plt.title('Função Valor - SARSA (segura)')
plt.xlabel('Coluna')
plt.ylabel('Linha')
# Função valor do Q-Learning
plt.subplot(1, 2, 2)
V_ql = np.max(qlearning_agente.Q, axis=1).reshape(4, 12)
V_ql_masked = np.ma.masked_where(np.array([[s in env.penhasco for s in range(i*12, (i+1)*12)] for i in range(4)]), V_ql)
im2 = plt.imshow(V_ql_masked, cmap='RdYlGn', interpolation='nearest')
plt.colorbar(im2, label='Valor V(s)')
for i in range(4):
for j in range(12):
s = i * 12 + j
if s in env.penhasco:
plt.text(j, i, '💀', ha='center', va='center', fontsize=12)
elif s == env.objetivo:
plt.text(j, i, '🏆', ha='center', va='center', fontsize=12)
else:
plt.text(j, i, f'{V_ql[i, j]:.1f}', ha='center', va='center', fontsize=7)
plt.title('Função Valor - Q-Learning (otimista)')
plt.xlabel('Coluna')
plt.ylabel('Linha')
plt.tight_layout()
plt.show()
# ============================================
# EXPLICAÇÃO MATEMÁTICA
# ============================================
print("\n" + "=" * 70)
print("FUNDAMENTOS DO SARSA")
print("=" * 70)
print("""
✅ SARSA: STATE-ACTION-REWARD-STATE-ACTION
SARSA é um algoritmo on-policy de diferença temporal.
✅ FÓRMULA DE ATUALIZAÇÃO:
[latex] Q(s,a) \leftarrow Q(s,a) + \\alpha [r + \\gamma Q(s',a') - Q(s,a)] [/latex]
Onde:
- (s,a): estado e ação atuais
- r: recompensa recebida
- s': próximo estado
- a': próxima ação (escolhida pela política atual)
✅ DIFERENÇA CRÍTICA PARA Q-LEARNING:
• Q-LEARNING: usa max_a Q(s',a) (off-policy)
• SARSA: usa Q(s',a') onde a' é a ação real (on-policy)
✅ IMPLICAÇÕES PRÁTICAS:
SARSA é mais conservador porque considera a exploração futura.
Ele evita caminhos perigosos que o Q-learning poderia tentar.
✅ HIPERPARÂMETROS:
• α (alpha): Taxa de aprendizado (0.1 típico)
• γ (gamma): Fator de desconto (0.95 típico)
• ε (epsilon): Exploração (0.1 típico)
✅ QUANDO USAR SARSA:
✓ Ambientes com riscos reais (robótica, finanças)
✓ Quando segurança é prioridade
✓ Quando a política ótima precisa ser robusta à exploração
✅ VANTAGENS DO SARSA:
• Mais estável em ambientes estocásticos
• Aprende políticas mais seguras
• Evita comportamentos de "beira de penhasco"
✅ DESVANTAGENS:
• Pode ser excessivamente conservador
• Converge mais lentamente que Q-learning
• Não encontra a política ótima se ε > 0
""")
print("\n" + "=" * 70)
print("CONCLUSÃO")
print("=" * 70)
print("""
✅ SARSA é um algoritmo on-policy poderoso e seguro.
✅ Ele aprende considerando a política atual, não a ótima.
✅ No problema do penhasco, SARSA evita o perigo.
✅ Q-Learning pode cair no penhasco durante o treino.
✅ A escolha entre SARSA e Q-learning depende do problema.
RESUMO DO EXPERIMENTO:
• SARSA aprendeu um caminho seguro pelo topo.
• Q-Learning aprendeu o caminho ótimo (mas arriscado).
• SARSA teve maior taxa de sucesso na avaliação.
• Q-Learning obteve recompensas ligeiramente melhores.
""")
print("\n✅ PROGRAMA CONCLUÍDO COM SUCESSO!")