77 lines
2.3 KiB
Python
77 lines
2.3 KiB
Python
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass
|
|
from typing import Any, Dict, Protocol
|
|
|
|
import numpy as np
|
|
|
|
|
|
class Agent(Protocol):
|
|
def act(self, obs: Dict[str, Any], info: Dict[str, Any]) -> Dict[str, int]:
|
|
...
|
|
|
|
def observe(self, transition: "Transition") -> None:
|
|
...
|
|
|
|
def reset(self) -> None:
|
|
...
|
|
|
|
|
|
@dataclass
|
|
class Transition:
|
|
obs: Dict[str, Any]
|
|
action: Dict[str, int]
|
|
reward: float
|
|
terminated: bool
|
|
truncated: bool
|
|
next_obs: Dict[str, Any]
|
|
info: Dict[str, Any]
|
|
|
|
|
|
class RandomAgent:
|
|
def __init__(self, seed: int | None = None):
|
|
self.rng = np.random.default_rng(seed)
|
|
|
|
def reset(self) -> None:
|
|
return None
|
|
|
|
def act(self, obs: Dict[str, Any], info: Dict[str, Any]) -> Dict[str, int]:
|
|
mask = info.get("action_mask")
|
|
if mask is None:
|
|
return {"type": 0, "tile": 0, "chi": 0}
|
|
return sample_action_from_mask(mask, self.rng)
|
|
|
|
def observe(self, transition: Transition) -> None:
|
|
return None
|
|
|
|
|
|
def sample_action_from_mask(mask: Dict[str, np.ndarray], rng: np.random.Generator) -> Dict[str, int]:
|
|
type_mask = np.asarray(mask["type"], dtype=bool)
|
|
discard_mask = np.asarray(mask["discard"], dtype=bool)
|
|
pong_mask = np.asarray(mask["pong"], dtype=bool)
|
|
kong_mask = np.asarray(mask["kong"], dtype=bool)
|
|
chi_mask = np.asarray(mask["chi"], dtype=bool)
|
|
valid_types = np.flatnonzero(type_mask)
|
|
if len(valid_types) == 0:
|
|
return {"type": 3, "tile": 0, "chi": 0}
|
|
action_type = int(rng.choice(valid_types))
|
|
tile_id = 0
|
|
chi_choice = 0
|
|
if action_type == 0:
|
|
valid_tiles = np.flatnonzero(discard_mask)
|
|
if len(valid_tiles) > 0:
|
|
tile_id = int(rng.choice(valid_tiles))
|
|
if action_type == 2:
|
|
valid_tiles = np.flatnonzero(kong_mask)
|
|
if len(valid_tiles) > 0:
|
|
tile_id = int(rng.choice(valid_tiles))
|
|
if action_type == 4:
|
|
valid_tiles = np.flatnonzero(pong_mask)
|
|
if len(valid_tiles) > 0:
|
|
tile_id = int(rng.choice(valid_tiles))
|
|
if action_type == 5:
|
|
valid_chi = np.flatnonzero(chi_mask)
|
|
if len(valid_chi) > 0:
|
|
chi_choice = int(rng.choice(valid_chi))
|
|
return {"type": action_type, "tile": tile_id, "chi": chi_choice}
|