feat: initialize majiang-rl project
This commit is contained in:
34
majiang_rl/ui/record.py
Normal file
34
majiang_rl/ui/record.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from .. import MahjongEnv
|
||||
from ..rl import RandomAgent
|
||||
|
||||
|
||||
def record_episode(seed: Optional[int] = None, max_steps: int = 200) -> List[Dict[str, object]]:
|
||||
env = MahjongEnv(seed=seed)
|
||||
agent = RandomAgent(seed=seed)
|
||||
obs, info = env.reset()
|
||||
events: List[Dict[str, object]] = [{"id": 0, "type": "start", "snapshot": env._snapshot()}]
|
||||
|
||||
for _ in range(max_steps):
|
||||
action = agent.act(obs, info)
|
||||
obs, _, terminated, truncated, info = env.step(action)
|
||||
if env.event_log:
|
||||
for event in env.event_log:
|
||||
event_copy = dict(event)
|
||||
event_copy["id"] = len(events)
|
||||
events.append(event_copy)
|
||||
env.event_log = []
|
||||
if terminated or truncated:
|
||||
break
|
||||
|
||||
return events
|
||||
|
||||
|
||||
def write_record(path: Path, events: List[Dict[str, object]]) -> None:
|
||||
payload = {"events": events}
|
||||
path.write_text(json.dumps(payload, indent=2, ensure_ascii=True))
|
||||
Reference in New Issue
Block a user