338 lines
13 KiB
Python
338 lines
13 KiB
Python
from __future__ import annotations
|
|
|
|
import random
|
|
from dataclasses import dataclass, field
|
|
from typing import Dict, List, Optional, Tuple
|
|
|
|
import numpy as np
|
|
|
|
from ..rules import can_chi_options, is_win
|
|
from ..tiles import TILE_TYPES, build_wall, is_flower
|
|
|
|
PHASE_TURN = 0
|
|
PHASE_RESPONSE = 1
|
|
|
|
ACTION_DISCARD = 0
|
|
ACTION_DECLARE_WIN = 1
|
|
ACTION_DECLARE_KONG = 2
|
|
ACTION_PASS = 3
|
|
ACTION_DECLARE_PONG = 4
|
|
ACTION_DECLARE_CHI = 5
|
|
|
|
|
|
@dataclass
|
|
class PlayerState:
|
|
hand: List[int] = field(default_factory=list)
|
|
melds: List[List[int]] = field(default_factory=list)
|
|
discards: List[int] = field(default_factory=list)
|
|
flowers: List[int] = field(default_factory=list)
|
|
|
|
|
|
class SelfPlayMahjong:
|
|
def __init__(self, seed: Optional[int] = None, pong_reward: float = 0.0):
|
|
self._random = random.Random(seed)
|
|
self.pong_reward = pong_reward
|
|
self.players: List[PlayerState] = []
|
|
self.wall: List[int] = []
|
|
self.current_player = 0
|
|
self.phase = PHASE_TURN
|
|
self.pending_discard: Optional[Tuple[int, int]] = None
|
|
self.response_queue: List[int] = []
|
|
self.turn_started = False
|
|
self.skip_draw_for_player: Optional[int] = None
|
|
self.done = False
|
|
self.winner: Optional[int] = None
|
|
self.terminal_reason: Optional[str] = None
|
|
self.winning_tiles: List[int] = []
|
|
self.winning_melds: List[List[int]] = []
|
|
|
|
def reset(self, *, seed: Optional[int] = None) -> None:
|
|
if seed is not None:
|
|
self._random.seed(seed)
|
|
self.players = [PlayerState() for _ in range(4)]
|
|
self.wall = build_wall(self._random)
|
|
self.current_player = 0
|
|
self.phase = PHASE_TURN
|
|
self.pending_discard = None
|
|
self.response_queue = []
|
|
self.turn_started = False
|
|
self.skip_draw_for_player = None
|
|
self.done = False
|
|
self.winner = None
|
|
self.terminal_reason = None
|
|
self.winning_tiles = []
|
|
self.winning_melds = []
|
|
|
|
for _ in range(13):
|
|
for player in range(4):
|
|
self._draw_tile(player)
|
|
|
|
def observe(self, player_idx: int) -> Dict[str, np.ndarray]:
|
|
hand_counts = self._counts(self.players[player_idx].hand)
|
|
melds = np.zeros((4, len(TILE_TYPES)), dtype=np.int8)
|
|
discards = np.zeros((4, len(TILE_TYPES)), dtype=np.int8)
|
|
flowers = np.zeros(4, dtype=np.int8)
|
|
for idx, player in enumerate(self.players):
|
|
for meld in player.melds:
|
|
for tile in meld:
|
|
melds[idx, tile] += 1
|
|
for tile in player.discards:
|
|
discards[idx, tile] += 1
|
|
flowers[idx] = len(player.flowers)
|
|
|
|
pending = np.zeros(len(TILE_TYPES), dtype=np.int8)
|
|
if self.pending_discard is not None:
|
|
pending[self.pending_discard[1]] = 1
|
|
|
|
return {
|
|
"hand": hand_counts,
|
|
"melds": melds,
|
|
"discards": discards,
|
|
"flowers": flowers,
|
|
"wall_count": np.array([len(self.wall)], dtype=np.int16),
|
|
"current_player": np.array([self.current_player], dtype=np.int8),
|
|
"phase": np.array([self.phase], dtype=np.int8),
|
|
"pending_discard": pending,
|
|
}
|
|
|
|
def action_mask(self, player_idx: int) -> Dict[str, np.ndarray]:
|
|
mask_type = np.zeros(6, dtype=bool)
|
|
mask_discard = np.zeros(len(TILE_TYPES), dtype=bool)
|
|
mask_pong = np.zeros(len(TILE_TYPES), dtype=bool)
|
|
mask_kong = np.zeros(len(TILE_TYPES), dtype=bool)
|
|
mask_chi = np.zeros(3, dtype=bool)
|
|
player = self.players[player_idx]
|
|
counts = self._counts(player.hand)
|
|
|
|
if self.phase == PHASE_TURN:
|
|
mask_type[ACTION_DISCARD] = True
|
|
for tile in player.hand:
|
|
mask_discard[tile] = True
|
|
if is_win(player.hand, len(player.melds)):
|
|
mask_type[ACTION_DECLARE_WIN] = True
|
|
for tile_id, count in enumerate(counts):
|
|
if count >= 4:
|
|
mask_type[ACTION_DECLARE_KONG] = True
|
|
mask_kong[tile_id] = True
|
|
elif self.phase == PHASE_RESPONSE and self.pending_discard is not None:
|
|
mask_type[ACTION_PASS] = True
|
|
discarder, tile = self.pending_discard
|
|
virtual_hand = player.hand + [tile]
|
|
if is_win(virtual_hand, len(player.melds)):
|
|
mask_type[ACTION_DECLARE_WIN] = True
|
|
if counts[tile] >= 2:
|
|
mask_type[ACTION_DECLARE_PONG] = True
|
|
mask_pong[tile] = True
|
|
if counts[tile] >= 3:
|
|
mask_type[ACTION_DECLARE_KONG] = True
|
|
mask_kong[tile] = True
|
|
if player_idx == (discarder + 1) % 4:
|
|
options = can_chi_options(counts, tile)
|
|
if options:
|
|
mask_type[ACTION_DECLARE_CHI] = True
|
|
for option in options:
|
|
mask_chi[option] = True
|
|
|
|
return {
|
|
"type": mask_type,
|
|
"discard": mask_discard,
|
|
"pong": mask_pong,
|
|
"kong": mask_kong,
|
|
"chi": mask_chi,
|
|
}
|
|
|
|
def step(self, player_idx: int, action: Dict[str, int]) -> float:
|
|
if self.done:
|
|
return 0.0
|
|
if player_idx != self.current_player:
|
|
raise ValueError("Action from non-active player")
|
|
|
|
action = self._sanitize_action(player_idx, action)
|
|
if self.phase == PHASE_TURN:
|
|
if not self.turn_started:
|
|
self._start_turn(player_idx)
|
|
if self.done:
|
|
return 0.0
|
|
self.turn_started = True
|
|
return self._handle_turn_action(player_idx, action)
|
|
return self._handle_response_action(player_idx, action)
|
|
|
|
def _handle_turn_action(self, player_idx: int, action: Dict[str, int]) -> float:
|
|
player = self.players[player_idx]
|
|
action_type = int(action.get("type", ACTION_DISCARD))
|
|
tile_id = int(action.get("tile", 0))
|
|
|
|
if action_type == ACTION_DECLARE_WIN and is_win(player.hand, len(player.melds)):
|
|
self._end_game(player_idx, "self_draw", player.hand)
|
|
return 0.0
|
|
|
|
if action_type == ACTION_DECLARE_KONG and player.hand.count(tile_id) >= 4:
|
|
for _ in range(4):
|
|
player.hand.remove(tile_id)
|
|
player.melds.append([tile_id] * 4)
|
|
self._draw_tile(player_idx)
|
|
self.turn_started = True
|
|
return 0.0
|
|
|
|
if action_type == ACTION_DISCARD and tile_id in player.hand:
|
|
player.hand.remove(tile_id)
|
|
player.discards.append(tile_id)
|
|
self.pending_discard = (player_idx, tile_id)
|
|
self.response_queue = [(player_idx + offset) % 4 for offset in range(1, 4)]
|
|
self.phase = PHASE_RESPONSE
|
|
self.current_player = self.response_queue.pop(0)
|
|
self.turn_started = False
|
|
return 0.0
|
|
return 0.0
|
|
|
|
def _handle_response_action(self, player_idx: int, action: Dict[str, int]) -> float:
|
|
if self.pending_discard is None:
|
|
self.phase = PHASE_TURN
|
|
self.turn_started = False
|
|
return 0.0
|
|
discarder, tile = self.pending_discard
|
|
player = self.players[player_idx]
|
|
counts = self._counts(player.hand)
|
|
action_type = int(action.get("type", ACTION_PASS))
|
|
chi_choice = int(action.get("chi", 0))
|
|
|
|
if action_type == ACTION_DECLARE_WIN and is_win(player.hand + [tile], len(player.melds)):
|
|
self._end_game(player_idx, "ron", player.hand + [tile])
|
|
return 0.0
|
|
|
|
if action_type == ACTION_DECLARE_PONG and counts[tile] >= 2:
|
|
self._consume_discard(discarder, tile)
|
|
player.hand.remove(tile)
|
|
player.hand.remove(tile)
|
|
player.melds.append([tile, tile, tile])
|
|
self.current_player = player_idx
|
|
self.pending_discard = None
|
|
self.response_queue = []
|
|
self.phase = PHASE_TURN
|
|
self.turn_started = False
|
|
self.skip_draw_for_player = player_idx
|
|
return self.pong_reward
|
|
|
|
if action_type == ACTION_DECLARE_KONG and counts[tile] >= 3:
|
|
self._consume_discard(discarder, tile)
|
|
for _ in range(3):
|
|
player.hand.remove(tile)
|
|
player.melds.append([tile, tile, tile, tile])
|
|
self.current_player = player_idx
|
|
self.pending_discard = None
|
|
self.response_queue = []
|
|
self.phase = PHASE_TURN
|
|
self.turn_started = True
|
|
self._draw_tile(player_idx)
|
|
return 0.0
|
|
|
|
if action_type == ACTION_DECLARE_CHI and player_idx == (discarder + 1) % 4:
|
|
options = can_chi_options(counts, tile)
|
|
if chi_choice in options:
|
|
self._consume_discard(discarder, tile)
|
|
if chi_choice == 0:
|
|
chi_tiles = [tile - 2, tile - 1, tile]
|
|
elif chi_choice == 1:
|
|
chi_tiles = [tile - 1, tile, tile + 1]
|
|
else:
|
|
chi_tiles = [tile, tile + 1, tile + 2]
|
|
for chi_tile in chi_tiles:
|
|
if chi_tile != tile:
|
|
player.hand.remove(chi_tile)
|
|
player.melds.append(chi_tiles)
|
|
self.current_player = player_idx
|
|
self.pending_discard = None
|
|
self.response_queue = []
|
|
self.phase = PHASE_TURN
|
|
self.turn_started = False
|
|
self.skip_draw_for_player = player_idx
|
|
return 0.0
|
|
|
|
self._advance_response_queue(discarder)
|
|
return 0.0
|
|
|
|
def _sanitize_action(self, player_idx: int, action: Dict[str, int]) -> Dict[str, int]:
|
|
mask = self.action_mask(player_idx)
|
|
action_type = int(action.get("type", ACTION_PASS))
|
|
tile_id = int(action.get("tile", 0))
|
|
chi_choice = int(action.get("chi", 0))
|
|
|
|
if action_type < 0 or action_type >= 6 or not mask["type"][action_type]:
|
|
return self._sample_valid_action(mask)
|
|
if action_type == ACTION_DISCARD and not mask["discard"][tile_id]:
|
|
return self._sample_valid_action(mask)
|
|
if action_type == ACTION_DECLARE_PONG and not mask["pong"][tile_id]:
|
|
return self._sample_valid_action(mask)
|
|
if action_type == ACTION_DECLARE_KONG and not mask["kong"][tile_id]:
|
|
return self._sample_valid_action(mask)
|
|
if action_type == ACTION_DECLARE_CHI and not mask["chi"][chi_choice]:
|
|
return self._sample_valid_action(mask)
|
|
return {"type": action_type, "tile": tile_id, "chi": chi_choice}
|
|
|
|
def _sample_valid_action(self, mask: Dict[str, np.ndarray]) -> Dict[str, int]:
|
|
valid_types = np.flatnonzero(mask["type"])
|
|
if len(valid_types) == 0:
|
|
return {"type": ACTION_PASS, "tile": 0, "chi": 0}
|
|
action_type = int(self._random.choice(valid_types))
|
|
tile_id = 0
|
|
chi_choice = 0
|
|
if action_type == ACTION_DISCARD:
|
|
valid_tiles = np.flatnonzero(mask["discard"])
|
|
tile_id = int(self._random.choice(valid_tiles)) if len(valid_tiles) else 0
|
|
if action_type == ACTION_DECLARE_PONG:
|
|
valid_tiles = np.flatnonzero(mask["pong"])
|
|
tile_id = int(self._random.choice(valid_tiles)) if len(valid_tiles) else 0
|
|
if action_type == ACTION_DECLARE_KONG:
|
|
valid_tiles = np.flatnonzero(mask["kong"])
|
|
tile_id = int(self._random.choice(valid_tiles)) if len(valid_tiles) else 0
|
|
if action_type == ACTION_DECLARE_CHI:
|
|
valid_chi = np.flatnonzero(mask["chi"])
|
|
chi_choice = int(self._random.choice(valid_chi)) if len(valid_chi) else 0
|
|
return {"type": action_type, "tile": tile_id, "chi": chi_choice}
|
|
|
|
def _advance_response_queue(self, discarder: int):
|
|
if self.response_queue:
|
|
self.current_player = self.response_queue.pop(0)
|
|
return
|
|
self.pending_discard = None
|
|
self.phase = PHASE_TURN
|
|
self.current_player = (discarder + 1) % 4
|
|
self.turn_started = False
|
|
|
|
def _start_turn(self, player_idx: int):
|
|
if self.skip_draw_for_player == player_idx:
|
|
self.skip_draw_for_player = None
|
|
return
|
|
self._draw_tile(player_idx)
|
|
|
|
def _draw_tile(self, player_idx: int):
|
|
while self.wall:
|
|
tile = self.wall.pop()
|
|
if is_flower(tile):
|
|
self.players[player_idx].flowers.append(tile)
|
|
continue
|
|
self.players[player_idx].hand.append(tile)
|
|
return tile
|
|
self._end_game(None, "wall_empty", [])
|
|
return None
|
|
|
|
def _consume_discard(self, discarder: int, tile: int):
|
|
discards = self.players[discarder].discards
|
|
if discards and discards[-1] == tile:
|
|
discards.pop()
|
|
|
|
def _end_game(self, winner: Optional[int], reason: str, winning_tiles: List[int]):
|
|
if self.done:
|
|
return
|
|
self.done = True
|
|
self.winner = winner
|
|
self.terminal_reason = reason
|
|
self.winning_tiles = winning_tiles
|
|
self.winning_melds = [] if winner is None else self.players[winner].melds
|
|
|
|
def _counts(self, tiles: List[int]) -> np.ndarray:
|
|
counts = np.zeros(len(TILE_TYPES), dtype=np.int8)
|
|
for tile in tiles:
|
|
counts[tile] += 1
|
|
return counts
|