Files
majiang-rl/majiang_rl/rl/trainer.py
2026-01-14 10:49:00 +08:00

62 lines
1.7 KiB
Python

from __future__ import annotations
from dataclasses import dataclass
from typing import List
from .agents import Transition
@dataclass
class EpisodeResult:
total_reward: float
steps: int
terminated: bool
truncated: bool
class EpisodeRunner:
def __init__(self, env, agent, max_steps: int = 1000):
self.env = env
self.agent = agent
self.max_steps = max_steps
def run_episode(self) -> EpisodeResult:
obs, info = self.env.reset()
self.agent.reset()
total_reward = 0.0
terminated = False
truncated = False
steps = 0
for step in range(self.max_steps):
action = self.agent.act(obs, info)
next_obs, reward, terminated, truncated, next_info = self.env.step(action)
transition = Transition(
obs=obs,
action=action,
reward=reward,
terminated=terminated,
truncated=truncated,
next_obs=next_obs,
info=next_info,
)
self.agent.observe(transition)
total_reward += reward
obs, info = next_obs, next_info
steps = step + 1
if terminated or truncated:
break
return EpisodeResult(
total_reward=total_reward,
steps=steps,
terminated=terminated,
truncated=truncated,
)
def run_training(env, agent, episodes: int = 100) -> List[EpisodeResult]:
runner = EpisodeRunner(env, agent)
results: List[EpisodeResult] = []
for _ in range(episodes):
results.append(runner.run_episode())
return results