from __future__ import annotations from dataclasses import dataclass from typing import Any, Dict, Protocol import numpy as np class Agent(Protocol): def act(self, obs: Dict[str, Any], info: Dict[str, Any]) -> Dict[str, int]: ... def observe(self, transition: "Transition") -> None: ... def reset(self) -> None: ... @dataclass class Transition: obs: Dict[str, Any] action: Dict[str, int] reward: float terminated: bool truncated: bool next_obs: Dict[str, Any] info: Dict[str, Any] class RandomAgent: def __init__(self, seed: int | None = None): self.rng = np.random.default_rng(seed) def reset(self) -> None: return None def act(self, obs: Dict[str, Any], info: Dict[str, Any]) -> Dict[str, int]: mask = info.get("action_mask") if mask is None: return {"type": 0, "tile": 0, "chi": 0} return sample_action_from_mask(mask, self.rng) def observe(self, transition: Transition) -> None: return None def sample_action_from_mask(mask: Dict[str, np.ndarray], rng: np.random.Generator) -> Dict[str, int]: type_mask = np.asarray(mask["type"], dtype=bool) discard_mask = np.asarray(mask["discard"], dtype=bool) pong_mask = np.asarray(mask["pong"], dtype=bool) kong_mask = np.asarray(mask["kong"], dtype=bool) chi_mask = np.asarray(mask["chi"], dtype=bool) valid_types = np.flatnonzero(type_mask) if len(valid_types) == 0: return {"type": 3, "tile": 0, "chi": 0} action_type = int(rng.choice(valid_types)) tile_id = 0 chi_choice = 0 if action_type == 0: valid_tiles = np.flatnonzero(discard_mask) if len(valid_tiles) > 0: tile_id = int(rng.choice(valid_tiles)) if action_type == 2: valid_tiles = np.flatnonzero(kong_mask) if len(valid_tiles) > 0: tile_id = int(rng.choice(valid_tiles)) if action_type == 4: valid_tiles = np.flatnonzero(pong_mask) if len(valid_tiles) > 0: tile_id = int(rng.choice(valid_tiles)) if action_type == 5: valid_chi = np.flatnonzero(chi_mask) if len(valid_chi) > 0: chi_choice = int(rng.choice(valid_chi)) return {"type": action_type, "tile": tile_id, "chi": chi_choice}