35 lines
1.0 KiB
Python
35 lines
1.0 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
from pathlib import Path
|
|
from typing import Dict, List, Optional
|
|
|
|
from .. import MahjongEnv
|
|
from ..rl import RandomAgent
|
|
|
|
|
|
def record_episode(seed: Optional[int] = None, max_steps: int = 200) -> List[Dict[str, object]]:
|
|
env = MahjongEnv(seed=seed)
|
|
agent = RandomAgent(seed=seed)
|
|
obs, info = env.reset()
|
|
events: List[Dict[str, object]] = [{"id": 0, "type": "start", "snapshot": env._snapshot()}]
|
|
|
|
for _ in range(max_steps):
|
|
action = agent.act(obs, info)
|
|
obs, _, terminated, truncated, info = env.step(action)
|
|
if env.event_log:
|
|
for event in env.event_log:
|
|
event_copy = dict(event)
|
|
event_copy["id"] = len(events)
|
|
events.append(event_copy)
|
|
env.event_log = []
|
|
if terminated or truncated:
|
|
break
|
|
|
|
return events
|
|
|
|
|
|
def write_record(path: Path, events: List[Dict[str, object]]) -> None:
|
|
payload = {"events": events}
|
|
path.write_text(json.dumps(payload, indent=2, ensure_ascii=True))
|