Files
majiang-rl/majiang_rl/rl/agents.py
2026-01-14 10:49:00 +08:00

77 lines
2.3 KiB
Python

from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Dict, Protocol
import numpy as np
class Agent(Protocol):
def act(self, obs: Dict[str, Any], info: Dict[str, Any]) -> Dict[str, int]:
...
def observe(self, transition: "Transition") -> None:
...
def reset(self) -> None:
...
@dataclass
class Transition:
obs: Dict[str, Any]
action: Dict[str, int]
reward: float
terminated: bool
truncated: bool
next_obs: Dict[str, Any]
info: Dict[str, Any]
class RandomAgent:
def __init__(self, seed: int | None = None):
self.rng = np.random.default_rng(seed)
def reset(self) -> None:
return None
def act(self, obs: Dict[str, Any], info: Dict[str, Any]) -> Dict[str, int]:
mask = info.get("action_mask")
if mask is None:
return {"type": 0, "tile": 0, "chi": 0}
return sample_action_from_mask(mask, self.rng)
def observe(self, transition: Transition) -> None:
return None
def sample_action_from_mask(mask: Dict[str, np.ndarray], rng: np.random.Generator) -> Dict[str, int]:
type_mask = np.asarray(mask["type"], dtype=bool)
discard_mask = np.asarray(mask["discard"], dtype=bool)
pong_mask = np.asarray(mask["pong"], dtype=bool)
kong_mask = np.asarray(mask["kong"], dtype=bool)
chi_mask = np.asarray(mask["chi"], dtype=bool)
valid_types = np.flatnonzero(type_mask)
if len(valid_types) == 0:
return {"type": 3, "tile": 0, "chi": 0}
action_type = int(rng.choice(valid_types))
tile_id = 0
chi_choice = 0
if action_type == 0:
valid_tiles = np.flatnonzero(discard_mask)
if len(valid_tiles) > 0:
tile_id = int(rng.choice(valid_tiles))
if action_type == 2:
valid_tiles = np.flatnonzero(kong_mask)
if len(valid_tiles) > 0:
tile_id = int(rng.choice(valid_tiles))
if action_type == 4:
valid_tiles = np.flatnonzero(pong_mask)
if len(valid_tiles) > 0:
tile_id = int(rng.choice(valid_tiles))
if action_type == 5:
valid_chi = np.flatnonzero(chi_mask)
if len(valid_chi) > 0:
chi_choice = int(rng.choice(valid_chi))
return {"type": action_type, "tile": tile_id, "chi": chi_choice}