62 lines
1.7 KiB
Python
62 lines
1.7 KiB
Python
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass
|
|
from typing import List
|
|
|
|
from .agents import Transition
|
|
|
|
|
|
@dataclass
|
|
class EpisodeResult:
|
|
total_reward: float
|
|
steps: int
|
|
terminated: bool
|
|
truncated: bool
|
|
|
|
|
|
class EpisodeRunner:
|
|
def __init__(self, env, agent, max_steps: int = 1000):
|
|
self.env = env
|
|
self.agent = agent
|
|
self.max_steps = max_steps
|
|
|
|
def run_episode(self) -> EpisodeResult:
|
|
obs, info = self.env.reset()
|
|
self.agent.reset()
|
|
total_reward = 0.0
|
|
terminated = False
|
|
truncated = False
|
|
steps = 0
|
|
for step in range(self.max_steps):
|
|
action = self.agent.act(obs, info)
|
|
next_obs, reward, terminated, truncated, next_info = self.env.step(action)
|
|
transition = Transition(
|
|
obs=obs,
|
|
action=action,
|
|
reward=reward,
|
|
terminated=terminated,
|
|
truncated=truncated,
|
|
next_obs=next_obs,
|
|
info=next_info,
|
|
)
|
|
self.agent.observe(transition)
|
|
total_reward += reward
|
|
obs, info = next_obs, next_info
|
|
steps = step + 1
|
|
if terminated or truncated:
|
|
break
|
|
return EpisodeResult(
|
|
total_reward=total_reward,
|
|
steps=steps,
|
|
terminated=terminated,
|
|
truncated=truncated,
|
|
)
|
|
|
|
|
|
def run_training(env, agent, episodes: int = 100) -> List[EpisodeResult]:
|
|
runner = EpisodeRunner(env, agent)
|
|
results: List[EpisodeResult] = []
|
|
for _ in range(episodes):
|
|
results.append(runner.run_episode())
|
|
return results
|