from __future__ import annotations import random from dataclasses import dataclass, field from typing import Dict, List, Optional, Tuple import numpy as np from ..rules import can_chi_options, is_win from ..tiles import TILE_TYPES, build_wall, is_flower PHASE_TURN = 0 PHASE_RESPONSE = 1 ACTION_DISCARD = 0 ACTION_DECLARE_WIN = 1 ACTION_DECLARE_KONG = 2 ACTION_PASS = 3 ACTION_DECLARE_PONG = 4 ACTION_DECLARE_CHI = 5 @dataclass class PlayerState: hand: List[int] = field(default_factory=list) melds: List[List[int]] = field(default_factory=list) discards: List[int] = field(default_factory=list) flowers: List[int] = field(default_factory=list) class SelfPlayMahjong: def __init__(self, seed: Optional[int] = None, pong_reward: float = 0.0): self._random = random.Random(seed) self.pong_reward = pong_reward self.players: List[PlayerState] = [] self.wall: List[int] = [] self.current_player = 0 self.phase = PHASE_TURN self.pending_discard: Optional[Tuple[int, int]] = None self.response_queue: List[int] = [] self.turn_started = False self.skip_draw_for_player: Optional[int] = None self.done = False self.winner: Optional[int] = None self.terminal_reason: Optional[str] = None self.winning_tiles: List[int] = [] self.winning_melds: List[List[int]] = [] def reset(self, *, seed: Optional[int] = None) -> None: if seed is not None: self._random.seed(seed) self.players = [PlayerState() for _ in range(4)] self.wall = build_wall(self._random) self.current_player = 0 self.phase = PHASE_TURN self.pending_discard = None self.response_queue = [] self.turn_started = False self.skip_draw_for_player = None self.done = False self.winner = None self.terminal_reason = None self.winning_tiles = [] self.winning_melds = [] for _ in range(13): for player in range(4): self._draw_tile(player) def observe(self, player_idx: int) -> Dict[str, np.ndarray]: hand_counts = self._counts(self.players[player_idx].hand) melds = np.zeros((4, len(TILE_TYPES)), dtype=np.int8) discards = np.zeros((4, len(TILE_TYPES)), dtype=np.int8) flowers = np.zeros(4, dtype=np.int8) for idx, player in enumerate(self.players): for meld in player.melds: for tile in meld: melds[idx, tile] += 1 for tile in player.discards: discards[idx, tile] += 1 flowers[idx] = len(player.flowers) pending = np.zeros(len(TILE_TYPES), dtype=np.int8) if self.pending_discard is not None: pending[self.pending_discard[1]] = 1 return { "hand": hand_counts, "melds": melds, "discards": discards, "flowers": flowers, "wall_count": np.array([len(self.wall)], dtype=np.int16), "current_player": np.array([self.current_player], dtype=np.int8), "phase": np.array([self.phase], dtype=np.int8), "pending_discard": pending, } def action_mask(self, player_idx: int) -> Dict[str, np.ndarray]: mask_type = np.zeros(6, dtype=bool) mask_discard = np.zeros(len(TILE_TYPES), dtype=bool) mask_pong = np.zeros(len(TILE_TYPES), dtype=bool) mask_kong = np.zeros(len(TILE_TYPES), dtype=bool) mask_chi = np.zeros(3, dtype=bool) player = self.players[player_idx] counts = self._counts(player.hand) if self.phase == PHASE_TURN: mask_type[ACTION_DISCARD] = True for tile in player.hand: mask_discard[tile] = True if is_win(player.hand, len(player.melds)): mask_type[ACTION_DECLARE_WIN] = True for tile_id, count in enumerate(counts): if count >= 4: mask_type[ACTION_DECLARE_KONG] = True mask_kong[tile_id] = True elif self.phase == PHASE_RESPONSE and self.pending_discard is not None: mask_type[ACTION_PASS] = True discarder, tile = self.pending_discard virtual_hand = player.hand + [tile] if is_win(virtual_hand, len(player.melds)): mask_type[ACTION_DECLARE_WIN] = True if counts[tile] >= 2: mask_type[ACTION_DECLARE_PONG] = True mask_pong[tile] = True if counts[tile] >= 3: mask_type[ACTION_DECLARE_KONG] = True mask_kong[tile] = True if player_idx == (discarder + 1) % 4: options = can_chi_options(counts, tile) if options: mask_type[ACTION_DECLARE_CHI] = True for option in options: mask_chi[option] = True return { "type": mask_type, "discard": mask_discard, "pong": mask_pong, "kong": mask_kong, "chi": mask_chi, } def step(self, player_idx: int, action: Dict[str, int]) -> float: if self.done: return 0.0 if player_idx != self.current_player: raise ValueError("Action from non-active player") action = self._sanitize_action(player_idx, action) if self.phase == PHASE_TURN: if not self.turn_started: self._start_turn(player_idx) if self.done: return 0.0 self.turn_started = True return self._handle_turn_action(player_idx, action) return self._handle_response_action(player_idx, action) def _handle_turn_action(self, player_idx: int, action: Dict[str, int]) -> float: player = self.players[player_idx] action_type = int(action.get("type", ACTION_DISCARD)) tile_id = int(action.get("tile", 0)) if action_type == ACTION_DECLARE_WIN and is_win(player.hand, len(player.melds)): self._end_game(player_idx, "self_draw", player.hand) return 0.0 if action_type == ACTION_DECLARE_KONG and player.hand.count(tile_id) >= 4: for _ in range(4): player.hand.remove(tile_id) player.melds.append([tile_id] * 4) self._draw_tile(player_idx) self.turn_started = True return 0.0 if action_type == ACTION_DISCARD and tile_id in player.hand: player.hand.remove(tile_id) player.discards.append(tile_id) self.pending_discard = (player_idx, tile_id) self.response_queue = [(player_idx + offset) % 4 for offset in range(1, 4)] self.phase = PHASE_RESPONSE self.current_player = self.response_queue.pop(0) self.turn_started = False return 0.0 return 0.0 def _handle_response_action(self, player_idx: int, action: Dict[str, int]) -> float: if self.pending_discard is None: self.phase = PHASE_TURN self.turn_started = False return 0.0 discarder, tile = self.pending_discard player = self.players[player_idx] counts = self._counts(player.hand) action_type = int(action.get("type", ACTION_PASS)) chi_choice = int(action.get("chi", 0)) if action_type == ACTION_DECLARE_WIN and is_win(player.hand + [tile], len(player.melds)): self._end_game(player_idx, "ron", player.hand + [tile]) return 0.0 if action_type == ACTION_DECLARE_PONG and counts[tile] >= 2: self._consume_discard(discarder, tile) player.hand.remove(tile) player.hand.remove(tile) player.melds.append([tile, tile, tile]) self.current_player = player_idx self.pending_discard = None self.response_queue = [] self.phase = PHASE_TURN self.turn_started = False self.skip_draw_for_player = player_idx return self.pong_reward if action_type == ACTION_DECLARE_KONG and counts[tile] >= 3: self._consume_discard(discarder, tile) for _ in range(3): player.hand.remove(tile) player.melds.append([tile, tile, tile, tile]) self.current_player = player_idx self.pending_discard = None self.response_queue = [] self.phase = PHASE_TURN self.turn_started = True self._draw_tile(player_idx) return 0.0 if action_type == ACTION_DECLARE_CHI and player_idx == (discarder + 1) % 4: options = can_chi_options(counts, tile) if chi_choice in options: self._consume_discard(discarder, tile) if chi_choice == 0: chi_tiles = [tile - 2, tile - 1, tile] elif chi_choice == 1: chi_tiles = [tile - 1, tile, tile + 1] else: chi_tiles = [tile, tile + 1, tile + 2] for chi_tile in chi_tiles: if chi_tile != tile: player.hand.remove(chi_tile) player.melds.append(chi_tiles) self.current_player = player_idx self.pending_discard = None self.response_queue = [] self.phase = PHASE_TURN self.turn_started = False self.skip_draw_for_player = player_idx return 0.0 self._advance_response_queue(discarder) return 0.0 def _sanitize_action(self, player_idx: int, action: Dict[str, int]) -> Dict[str, int]: mask = self.action_mask(player_idx) action_type = int(action.get("type", ACTION_PASS)) tile_id = int(action.get("tile", 0)) chi_choice = int(action.get("chi", 0)) if action_type < 0 or action_type >= 6 or not mask["type"][action_type]: return self._sample_valid_action(mask) if action_type == ACTION_DISCARD and not mask["discard"][tile_id]: return self._sample_valid_action(mask) if action_type == ACTION_DECLARE_PONG and not mask["pong"][tile_id]: return self._sample_valid_action(mask) if action_type == ACTION_DECLARE_KONG and not mask["kong"][tile_id]: return self._sample_valid_action(mask) if action_type == ACTION_DECLARE_CHI and not mask["chi"][chi_choice]: return self._sample_valid_action(mask) return {"type": action_type, "tile": tile_id, "chi": chi_choice} def _sample_valid_action(self, mask: Dict[str, np.ndarray]) -> Dict[str, int]: valid_types = np.flatnonzero(mask["type"]) if len(valid_types) == 0: return {"type": ACTION_PASS, "tile": 0, "chi": 0} action_type = int(self._random.choice(valid_types)) tile_id = 0 chi_choice = 0 if action_type == ACTION_DISCARD: valid_tiles = np.flatnonzero(mask["discard"]) tile_id = int(self._random.choice(valid_tiles)) if len(valid_tiles) else 0 if action_type == ACTION_DECLARE_PONG: valid_tiles = np.flatnonzero(mask["pong"]) tile_id = int(self._random.choice(valid_tiles)) if len(valid_tiles) else 0 if action_type == ACTION_DECLARE_KONG: valid_tiles = np.flatnonzero(mask["kong"]) tile_id = int(self._random.choice(valid_tiles)) if len(valid_tiles) else 0 if action_type == ACTION_DECLARE_CHI: valid_chi = np.flatnonzero(mask["chi"]) chi_choice = int(self._random.choice(valid_chi)) if len(valid_chi) else 0 return {"type": action_type, "tile": tile_id, "chi": chi_choice} def _advance_response_queue(self, discarder: int): if self.response_queue: self.current_player = self.response_queue.pop(0) return self.pending_discard = None self.phase = PHASE_TURN self.current_player = (discarder + 1) % 4 self.turn_started = False def _start_turn(self, player_idx: int): if self.skip_draw_for_player == player_idx: self.skip_draw_for_player = None return self._draw_tile(player_idx) def _draw_tile(self, player_idx: int): while self.wall: tile = self.wall.pop() if is_flower(tile): self.players[player_idx].flowers.append(tile) continue self.players[player_idx].hand.append(tile) return tile self._end_game(None, "wall_empty", []) return None def _consume_discard(self, discarder: int, tile: int): discards = self.players[discarder].discards if discards and discards[-1] == tile: discards.pop() def _end_game(self, winner: Optional[int], reason: str, winning_tiles: List[int]): if self.done: return self.done = True self.winner = winner self.terminal_reason = reason self.winning_tiles = winning_tiles self.winning_melds = [] if winner is None else self.players[winner].melds def _counts(self, tiles: List[int]) -> np.ndarray: counts = np.zeros(len(TILE_TYPES), dtype=np.int8) for tile in tiles: counts[tile] += 1 return counts