from __future__ import annotations from dataclasses import dataclass from typing import List from .agents import Transition @dataclass class EpisodeResult: total_reward: float steps: int terminated: bool truncated: bool class EpisodeRunner: def __init__(self, env, agent, max_steps: int = 1000): self.env = env self.agent = agent self.max_steps = max_steps def run_episode(self) -> EpisodeResult: obs, info = self.env.reset() self.agent.reset() total_reward = 0.0 terminated = False truncated = False steps = 0 for step in range(self.max_steps): action = self.agent.act(obs, info) next_obs, reward, terminated, truncated, next_info = self.env.step(action) transition = Transition( obs=obs, action=action, reward=reward, terminated=terminated, truncated=truncated, next_obs=next_obs, info=next_info, ) self.agent.observe(transition) total_reward += reward obs, info = next_obs, next_info steps = step + 1 if terminated or truncated: break return EpisodeResult( total_reward=total_reward, steps=steps, terminated=terminated, truncated=truncated, ) def run_training(env, agent, episodes: int = 100) -> List[EpisodeResult]: runner = EpisodeRunner(env, agent) results: List[EpisodeResult] = [] for _ in range(episodes): results.append(runner.run_episode()) return results