feat: initialize majiang-rl project
This commit is contained in:
61
majiang_rl/rl/trainer.py
Normal file
61
majiang_rl/rl/trainer.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
|
||||
from .agents import Transition
|
||||
|
||||
|
||||
@dataclass
|
||||
class EpisodeResult:
|
||||
total_reward: float
|
||||
steps: int
|
||||
terminated: bool
|
||||
truncated: bool
|
||||
|
||||
|
||||
class EpisodeRunner:
|
||||
def __init__(self, env, agent, max_steps: int = 1000):
|
||||
self.env = env
|
||||
self.agent = agent
|
||||
self.max_steps = max_steps
|
||||
|
||||
def run_episode(self) -> EpisodeResult:
|
||||
obs, info = self.env.reset()
|
||||
self.agent.reset()
|
||||
total_reward = 0.0
|
||||
terminated = False
|
||||
truncated = False
|
||||
steps = 0
|
||||
for step in range(self.max_steps):
|
||||
action = self.agent.act(obs, info)
|
||||
next_obs, reward, terminated, truncated, next_info = self.env.step(action)
|
||||
transition = Transition(
|
||||
obs=obs,
|
||||
action=action,
|
||||
reward=reward,
|
||||
terminated=terminated,
|
||||
truncated=truncated,
|
||||
next_obs=next_obs,
|
||||
info=next_info,
|
||||
)
|
||||
self.agent.observe(transition)
|
||||
total_reward += reward
|
||||
obs, info = next_obs, next_info
|
||||
steps = step + 1
|
||||
if terminated or truncated:
|
||||
break
|
||||
return EpisodeResult(
|
||||
total_reward=total_reward,
|
||||
steps=steps,
|
||||
terminated=terminated,
|
||||
truncated=truncated,
|
||||
)
|
||||
|
||||
|
||||
def run_training(env, agent, episodes: int = 100) -> List[EpisodeResult]:
|
||||
runner = EpisodeRunner(env, agent)
|
||||
results: List[EpisodeResult] = []
|
||||
for _ in range(episodes):
|
||||
results.append(runner.run_episode())
|
||||
return results
|
||||
Reference in New Issue
Block a user