feat: initialize majiang-rl project
This commit is contained in:
76
majiang_rl/rl/agents.py
Normal file
76
majiang_rl/rl/agents.py
Normal file
@@ -0,0 +1,76 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Protocol
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Agent(Protocol):
|
||||
def act(self, obs: Dict[str, Any], info: Dict[str, Any]) -> Dict[str, int]:
|
||||
...
|
||||
|
||||
def observe(self, transition: "Transition") -> None:
|
||||
...
|
||||
|
||||
def reset(self) -> None:
|
||||
...
|
||||
|
||||
|
||||
@dataclass
|
||||
class Transition:
|
||||
obs: Dict[str, Any]
|
||||
action: Dict[str, int]
|
||||
reward: float
|
||||
terminated: bool
|
||||
truncated: bool
|
||||
next_obs: Dict[str, Any]
|
||||
info: Dict[str, Any]
|
||||
|
||||
|
||||
class RandomAgent:
|
||||
def __init__(self, seed: int | None = None):
|
||||
self.rng = np.random.default_rng(seed)
|
||||
|
||||
def reset(self) -> None:
|
||||
return None
|
||||
|
||||
def act(self, obs: Dict[str, Any], info: Dict[str, Any]) -> Dict[str, int]:
|
||||
mask = info.get("action_mask")
|
||||
if mask is None:
|
||||
return {"type": 0, "tile": 0, "chi": 0}
|
||||
return sample_action_from_mask(mask, self.rng)
|
||||
|
||||
def observe(self, transition: Transition) -> None:
|
||||
return None
|
||||
|
||||
|
||||
def sample_action_from_mask(mask: Dict[str, np.ndarray], rng: np.random.Generator) -> Dict[str, int]:
|
||||
type_mask = np.asarray(mask["type"], dtype=bool)
|
||||
discard_mask = np.asarray(mask["discard"], dtype=bool)
|
||||
pong_mask = np.asarray(mask["pong"], dtype=bool)
|
||||
kong_mask = np.asarray(mask["kong"], dtype=bool)
|
||||
chi_mask = np.asarray(mask["chi"], dtype=bool)
|
||||
valid_types = np.flatnonzero(type_mask)
|
||||
if len(valid_types) == 0:
|
||||
return {"type": 3, "tile": 0, "chi": 0}
|
||||
action_type = int(rng.choice(valid_types))
|
||||
tile_id = 0
|
||||
chi_choice = 0
|
||||
if action_type == 0:
|
||||
valid_tiles = np.flatnonzero(discard_mask)
|
||||
if len(valid_tiles) > 0:
|
||||
tile_id = int(rng.choice(valid_tiles))
|
||||
if action_type == 2:
|
||||
valid_tiles = np.flatnonzero(kong_mask)
|
||||
if len(valid_tiles) > 0:
|
||||
tile_id = int(rng.choice(valid_tiles))
|
||||
if action_type == 4:
|
||||
valid_tiles = np.flatnonzero(pong_mask)
|
||||
if len(valid_tiles) > 0:
|
||||
tile_id = int(rng.choice(valid_tiles))
|
||||
if action_type == 5:
|
||||
valid_chi = np.flatnonzero(chi_mask)
|
||||
if len(valid_chi) > 0:
|
||||
chi_choice = int(rng.choice(valid_chi))
|
||||
return {"type": action_type, "tile": tile_id, "chi": chi_choice}
|
||||
Reference in New Issue
Block a user