from __future__ import annotations import random from dataclasses import dataclass, field from typing import Dict, List, Optional, Tuple import gymnasium as gym import numpy as np from gymnasium import spaces from ..rules import can_chi_options, is_win from ..tiles import TILE_TYPES, build_wall, is_flower, tile_name PHASE_AGENT_TURN = 0 PHASE_AGENT_RESPONSE = 1 ACTION_DISCARD = 0 ACTION_DECLARE_WIN = 1 ACTION_DECLARE_KONG = 2 ACTION_PASS = 3 ACTION_DECLARE_PONG = 4 ACTION_DECLARE_CHI = 5 @dataclass class PlayerState: hand: List[int] = field(default_factory=list) melds: List[List[int]] = field(default_factory=list) discards: List[int] = field(default_factory=list) flowers: List[int] = field(default_factory=list) class MahjongEnv(gym.Env): metadata = {"render_modes": ["human"]} def __init__(self, seed: Optional[int] = None, max_rounds: int = 1): super().__init__() self.max_rounds = max_rounds self.action_space = spaces.Dict( { "type": spaces.Discrete(6), "tile": spaces.Discrete(len(TILE_TYPES)), "chi": spaces.Discrete(3), } ) self.observation_space = spaces.Dict( { "hand": spaces.Box(0, 4, shape=(len(TILE_TYPES),), dtype=np.int8), "melds": spaces.Box(0, 4, shape=(4, len(TILE_TYPES)), dtype=np.int8), "discards": spaces.Box(0, 4, shape=(4, len(TILE_TYPES)), dtype=np.int8), "flowers": spaces.Box(0, 8, shape=(4,), dtype=np.int8), "wall_count": spaces.Box(0, 144, shape=(1,), dtype=np.int16), "current_player": spaces.Box(0, 3, shape=(1,), dtype=np.int8), "phase": spaces.Box(0, 1, shape=(1,), dtype=np.int8), } ) self._random = random.Random(seed) self._round = 0 self.agent_index = 0 self.players: List[PlayerState] = [] self.wall: List[int] = [] self.current_player = 0 self.phase = PHASE_AGENT_TURN self.pending_discard: Optional[Tuple[int, int]] = None self.skip_draw_for_player: Optional[int] = None self.done = False self.winner: Optional[int] = None self.terminal_reason: Optional[str] = None self.event_log: List[Dict[str, object]] = [] def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): super().reset(seed=seed) if seed is not None: self._random.seed(seed) self._round += 1 self.players = [PlayerState() for _ in range(4)] self.wall = build_wall(self._random) self.current_player = 0 self.pending_discard = None self.skip_draw_for_player = None self.done = False self.winner = None self.terminal_reason = None self.event_log = [] for _ in range(13): for player in range(4): self._draw_tile(player) self._draw_tile(0) self.skip_draw_for_player = 0 self.phase = PHASE_AGENT_TURN self._advance_to_agent() return self._get_obs(), self._get_info() def step(self, action: Dict[str, int]): if self.done: return self._get_obs(), 0.0, True, False, self._get_info() reward = 0.0 action = self._sanitize_action(action) action_type = int(action["type"]) tile_id = int(action["tile"]) chi_choice = int(action["chi"]) if self.phase == PHASE_AGENT_TURN: reward += self._handle_agent_turn(action_type, tile_id) elif self.phase == PHASE_AGENT_RESPONSE: reward += self._handle_agent_response(action_type, tile_id, chi_choice) if not self.done: self._advance_to_agent() if self.done: if self.winner == self.agent_index: reward += 1.0 elif self.winner is not None: reward -= 1.0 return self._get_obs(), reward, self.done, False, self._get_info() def render(self): if self.done: print(f"Winner: {self.winner}, reason: {self.terminal_reason}") else: print(f"Player {self.current_player} to act, wall {len(self.wall)}") def _get_obs(self): hand_counts = self._counts(self.players[self.agent_index].hand) melds = np.zeros((4, len(TILE_TYPES)), dtype=np.int8) discards = np.zeros((4, len(TILE_TYPES)), dtype=np.int8) flowers = np.zeros(4, dtype=np.int8) for idx, player in enumerate(self.players): for meld in player.melds: for tile in meld: melds[idx, tile] += 1 for tile in player.discards: discards[idx, tile] += 1 flowers[idx] = len(player.flowers) return { "hand": hand_counts, "melds": melds, "discards": discards, "flowers": flowers, "wall_count": np.array([len(self.wall)], dtype=np.int16), "current_player": np.array([self.current_player], dtype=np.int8), "phase": np.array([self.phase], dtype=np.int8), } def _get_info(self): action_mask = self._action_mask() return { "winner": self.winner, "terminal_reason": self.terminal_reason, "pending_discard": self.pending_discard, "action_mask": action_mask, } def _sanitize_action(self, action: Dict[str, int]) -> Dict[str, int]: if not isinstance(action, dict): return self._sample_valid_action() if "type" not in action or "tile" not in action or "chi" not in action: return self._sample_valid_action() mask = self._action_mask() action_type = int(action["type"]) tile_id = int(action["tile"]) chi_choice = int(action["chi"]) if action_type < 0 or action_type >= 6: return self._sample_valid_action() if not mask["type"][action_type]: return self._sample_valid_action() if action_type == ACTION_DISCARD: if not mask["discard"][tile_id]: return self._sample_valid_action() if action_type == ACTION_DECLARE_PONG: if not mask["pong"][tile_id]: return self._sample_valid_action() if action_type == ACTION_DECLARE_KONG: if not mask["kong"][tile_id]: return self._sample_valid_action() if action_type == ACTION_DECLARE_CHI: if not mask["chi"][chi_choice]: return self._sample_valid_action() return {"type": action_type, "tile": tile_id, "chi": chi_choice} def _sample_valid_action(self) -> Dict[str, int]: mask = self._action_mask() valid_types = np.flatnonzero(mask["type"]) if len(valid_types) == 0: return {"type": ACTION_PASS, "tile": 0, "chi": 0} action_type = int(self._random.choice(valid_types)) tile_id = 0 chi_choice = 0 if action_type == ACTION_DISCARD: valid_tiles = np.flatnonzero(mask["discard"]) tile_id = int(self._random.choice(valid_tiles)) if len(valid_tiles) else 0 if action_type == ACTION_DECLARE_PONG: valid_tiles = np.flatnonzero(mask["pong"]) tile_id = int(self._random.choice(valid_tiles)) if len(valid_tiles) else 0 if action_type == ACTION_DECLARE_KONG: valid_tiles = np.flatnonzero(mask["kong"]) tile_id = int(self._random.choice(valid_tiles)) if len(valid_tiles) else 0 if action_type == ACTION_DECLARE_CHI: valid_chi = np.flatnonzero(mask["chi"]) chi_choice = int(self._random.choice(valid_chi)) if len(valid_chi) else 0 return {"type": action_type, "tile": tile_id, "chi": chi_choice} def _action_mask(self) -> Dict[str, np.ndarray]: mask_type = np.zeros(6, dtype=bool) mask_discard = np.zeros(len(TILE_TYPES), dtype=bool) mask_pong = np.zeros(len(TILE_TYPES), dtype=bool) mask_kong = np.zeros(len(TILE_TYPES), dtype=bool) mask_chi = np.zeros(3, dtype=bool) agent = self.players[self.agent_index] counts = self._counts(agent.hand) if self.phase == PHASE_AGENT_TURN: mask_type[ACTION_DISCARD] = True for tile in agent.hand: mask_discard[tile] = True if is_win(agent.hand, len(agent.melds)): mask_type[ACTION_DECLARE_WIN] = True for tile_id, count in enumerate(counts): if count >= 4: mask_type[ACTION_DECLARE_KONG] = True mask_kong[tile_id] = True elif self.phase == PHASE_AGENT_RESPONSE and self.pending_discard is not None: mask_type[ACTION_PASS] = True _, tile = self.pending_discard virtual_hand = agent.hand + [tile] if is_win(virtual_hand, len(agent.melds)): mask_type[ACTION_DECLARE_WIN] = True if counts[tile] >= 2: mask_type[ACTION_DECLARE_PONG] = True mask_pong[tile] = True if counts[tile] >= 3: mask_type[ACTION_DECLARE_KONG] = True mask_kong[tile] = True if self._is_next_player(self.agent_index, self.pending_discard[0]): options = can_chi_options(counts, tile) if options: mask_type[ACTION_DECLARE_CHI] = True for option in options: mask_chi[option] = True return { "type": mask_type, "discard": mask_discard, "pong": mask_pong, "kong": mask_kong, "chi": mask_chi, } def _handle_agent_turn(self, action_type: int, tile_id: int) -> float: agent = self.players[self.agent_index] if action_type == ACTION_DECLARE_WIN: if is_win(agent.hand, len(agent.melds)): self._end_game(self.agent_index, "self_draw") return 0.0 if action_type == ACTION_DECLARE_KONG: if self._can_concealed_kong(agent, tile_id): self._make_kong(self.agent_index, tile_id, concealed=True) self._draw_tile(self.agent_index) self.skip_draw_for_player = self.agent_index self.phase = PHASE_AGENT_TURN return 0.0 if action_type == ACTION_DISCARD: if tile_id in agent.hand: agent.hand.remove(tile_id) agent.discards.append(tile_id) self.pending_discard = (self.agent_index, tile_id) self._log_event( "discard", player=self.agent_index, tile_id=tile_id, ) self._resolve_discard() return 0.0 return 0.0 def _handle_agent_response(self, action_type: int, tile_id: int, chi_choice: int) -> float: if self.pending_discard is None: self.phase = PHASE_AGENT_TURN return 0.0 discarder, tile = self.pending_discard agent = self.players[self.agent_index] counts = self._counts(agent.hand) if action_type == ACTION_DECLARE_WIN: if is_win(agent.hand + [tile], len(agent.melds)): self._end_game(self.agent_index, "ron") return 0.0 if action_type == ACTION_DECLARE_PONG: if counts[tile] >= 2: self._consume_discard(discarder, tile) agent.hand.remove(tile) agent.hand.remove(tile) agent.melds.append([tile, tile, tile]) self.current_player = self.agent_index self.skip_draw_for_player = self.agent_index self.pending_discard = None self._log_event( "claim", player=self.agent_index, claim="pong", tile_id=tile, from_player=discarder, ) self.phase = PHASE_AGENT_TURN return 0.0 if action_type == ACTION_DECLARE_KONG: if counts[tile] >= 3: self._consume_discard(discarder, tile) for _ in range(3): agent.hand.remove(tile) agent.melds.append([tile, tile, tile, tile]) self.current_player = self.agent_index self.skip_draw_for_player = self.agent_index self.pending_discard = None self._draw_tile(self.agent_index) self._log_event( "claim", player=self.agent_index, claim="kong", tile_id=tile, from_player=discarder, ) self.phase = PHASE_AGENT_TURN return 0.0 if action_type == ACTION_DECLARE_CHI: if self._is_next_player(self.agent_index, discarder): options = can_chi_options(counts, tile) if chi_choice in options: self._consume_discard(discarder, tile) if chi_choice == 0: chi_tiles = [tile - 2, tile - 1, tile] elif chi_choice == 1: chi_tiles = [tile - 1, tile, tile + 1] else: chi_tiles = [tile, tile + 1, tile + 2] for chi_tile in chi_tiles: if chi_tile != tile: agent.hand.remove(chi_tile) agent.melds.append(chi_tiles) self.current_player = self.agent_index self.skip_draw_for_player = self.agent_index self.pending_discard = None self._log_event( "claim", player=self.agent_index, claim="chi", tile_id=tile, from_player=discarder, ) self.phase = PHASE_AGENT_TURN return 0.0 self.pending_discard = None self.phase = PHASE_AGENT_TURN self._resolve_discard(after_pass=True, original_discarder=discarder, tile=tile) return 0.0 def _advance_to_agent(self): while not self.done: if self.phase == PHASE_AGENT_RESPONSE: return if self.current_player == self.agent_index: self._start_turn(self.agent_index) if self.done: return self.phase = PHASE_AGENT_TURN return self._npc_turn(self.current_player) if self.done or self.phase == PHASE_AGENT_RESPONSE: return def _npc_turn(self, player_idx: int): self._start_turn(player_idx) if self.done: return player = self.players[player_idx] if is_win(player.hand, len(player.melds)): self._end_game(player_idx, "self_draw") return tile = self._random.choice(player.hand) player.hand.remove(tile) player.discards.append(tile) self.pending_discard = (player_idx, tile) self._log_event( "discard", player=player_idx, tile_id=tile, ) self._resolve_discard() def _resolve_discard(self, after_pass: bool = False, original_discarder: Optional[int] = None, tile: Optional[int] = None): if self.pending_discard is None and not after_pass: return discarder, discarded_tile = self.pending_discard if self.pending_discard is not None else (original_discarder, tile) if discarder is None or discarded_tile is None: return if discarder != self.agent_index and not after_pass: if self._agent_can_respond(discarder, discarded_tile): self.pending_discard = (discarder, discarded_tile) self.phase = PHASE_AGENT_RESPONSE return winner = self._npc_win_on_discard(discarder, discarded_tile) if winner is not None: self._end_game(winner, "ron") return claim = self._npc_claim(discarder, discarded_tile) if claim is not None: self.pending_discard = None self._apply_npc_claim(*claim) return self.pending_discard = None self.current_player = (discarder + 1) % 4 def _agent_can_respond(self, discarder: int, tile: int) -> bool: agent = self.players[self.agent_index] counts = self._counts(agent.hand) if is_win(agent.hand + [tile], len(agent.melds)): return True if counts[tile] >= 2: return True if counts[tile] >= 3: return True if self._is_next_player(self.agent_index, discarder): return bool(can_chi_options(counts, tile)) return False def _npc_win_on_discard(self, discarder: int, tile: int) -> Optional[int]: for offset in range(1, 4): player_idx = (discarder + offset) % 4 if player_idx == self.agent_index: continue player = self.players[player_idx] if is_win(player.hand + [tile], len(player.melds)): return player_idx return None def _npc_claim(self, discarder: int, tile: int): for offset in range(1, 4): player_idx = (discarder + offset) % 4 if player_idx == self.agent_index: continue player = self.players[player_idx] counts = self._counts(player.hand) if counts[tile] >= 3: return (player_idx, "kong", tile, None, discarder) if counts[tile] >= 2: return (player_idx, "pong", tile, None, discarder) next_player = (discarder + 1) % 4 if next_player != self.agent_index: player = self.players[next_player] options = can_chi_options(self._counts(player.hand), tile) if options: return (next_player, "chi", tile, options[0], discarder) return None def _apply_npc_claim( self, player_idx: int, claim_type: str, tile: int, chi_choice: Optional[int], discarder: int, ): player = self.players[player_idx] self._consume_discard(discarder, tile) if claim_type == "pong": player.hand.remove(tile) player.hand.remove(tile) player.melds.append([tile, tile, tile]) self.current_player = player_idx self.skip_draw_for_player = player_idx elif claim_type == "kong": for _ in range(3): player.hand.remove(tile) player.melds.append([tile, tile, tile, tile]) self.current_player = player_idx self.skip_draw_for_player = player_idx self._draw_tile(player_idx) elif claim_type == "chi" and chi_choice is not None: if chi_choice == 0: chi_tiles = [tile - 2, tile - 1, tile] elif chi_choice == 1: chi_tiles = [tile - 1, tile, tile + 1] else: chi_tiles = [tile, tile + 1, tile + 2] for chi_tile in chi_tiles: if chi_tile != tile: player.hand.remove(chi_tile) player.melds.append(chi_tiles) self.current_player = player_idx self.skip_draw_for_player = player_idx self._log_event( "claim", player=player_idx, claim=claim_type, tile_id=tile, from_player=discarder, ) def _draw_tile(self, player_idx: int): while self.wall: tile = self.wall.pop() if is_flower(tile): self.players[player_idx].flowers.append(tile) continue self.players[player_idx].hand.append(tile) return tile self._end_game(None, "wall_empty") return None def _start_turn(self, player_idx: int): if self.skip_draw_for_player == player_idx: self.skip_draw_for_player = None return self._draw_tile(player_idx) def _make_kong(self, player_idx: int, tile_id: int, concealed: bool): player = self.players[player_idx] for _ in range(4): player.hand.remove(tile_id) player.melds.append([tile_id] * 4) def _can_concealed_kong(self, player: PlayerState, tile_id: int) -> bool: return player.hand.count(tile_id) >= 4 def _counts(self, tiles: List[int]) -> np.ndarray: counts = np.zeros(len(TILE_TYPES), dtype=np.int8) for tile in tiles: counts[tile] += 1 return counts def _is_next_player(self, player: int, discarder: int) -> bool: return player == (discarder + 1) % 4 def _consume_discard(self, discarder: int, tile: int): discards = self.players[discarder].discards if discards and discards[-1] == tile: discards.pop() def _end_game(self, winner: Optional[int], reason: str): if self.done: return self.done = True self.winner = winner self.terminal_reason = reason self._log_event( "win" if winner is not None else "draw", player=winner, reason=reason, ) def _snapshot(self) -> Dict[str, object]: return { "hands": [[tile_name(tile) for tile in sorted(p.hand)] for p in self.players], "discards": [[tile_name(tile) for tile in p.discards] for p in self.players], "melds": [[[tile_name(tile) for tile in meld] for meld in p.melds] for p in self.players], "flowers": [[tile_name(tile) for tile in p.flowers] for p in self.players], "wall_count": len(self.wall), "current_player": self.current_player, } def _log_event(self, event_type: str, **payload: object): event = { "id": len(self.event_log), "type": event_type, "snapshot": self._snapshot(), } event.update(payload) if "tile_id" in event: event["tile"] = tile_name(int(event["tile_id"])) self.event_log.append(event)