Files
majiang-rl/majiang_rl/envs/self_play_env.py
2026-01-14 10:49:00 +08:00

338 lines
13 KiB
Python

from __future__ import annotations
import random
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple
import numpy as np
from ..rules import can_chi_options, is_win
from ..tiles import TILE_TYPES, build_wall, is_flower
PHASE_TURN = 0
PHASE_RESPONSE = 1
ACTION_DISCARD = 0
ACTION_DECLARE_WIN = 1
ACTION_DECLARE_KONG = 2
ACTION_PASS = 3
ACTION_DECLARE_PONG = 4
ACTION_DECLARE_CHI = 5
@dataclass
class PlayerState:
hand: List[int] = field(default_factory=list)
melds: List[List[int]] = field(default_factory=list)
discards: List[int] = field(default_factory=list)
flowers: List[int] = field(default_factory=list)
class SelfPlayMahjong:
def __init__(self, seed: Optional[int] = None, pong_reward: float = 0.0):
self._random = random.Random(seed)
self.pong_reward = pong_reward
self.players: List[PlayerState] = []
self.wall: List[int] = []
self.current_player = 0
self.phase = PHASE_TURN
self.pending_discard: Optional[Tuple[int, int]] = None
self.response_queue: List[int] = []
self.turn_started = False
self.skip_draw_for_player: Optional[int] = None
self.done = False
self.winner: Optional[int] = None
self.terminal_reason: Optional[str] = None
self.winning_tiles: List[int] = []
self.winning_melds: List[List[int]] = []
def reset(self, *, seed: Optional[int] = None) -> None:
if seed is not None:
self._random.seed(seed)
self.players = [PlayerState() for _ in range(4)]
self.wall = build_wall(self._random)
self.current_player = 0
self.phase = PHASE_TURN
self.pending_discard = None
self.response_queue = []
self.turn_started = False
self.skip_draw_for_player = None
self.done = False
self.winner = None
self.terminal_reason = None
self.winning_tiles = []
self.winning_melds = []
for _ in range(13):
for player in range(4):
self._draw_tile(player)
def observe(self, player_idx: int) -> Dict[str, np.ndarray]:
hand_counts = self._counts(self.players[player_idx].hand)
melds = np.zeros((4, len(TILE_TYPES)), dtype=np.int8)
discards = np.zeros((4, len(TILE_TYPES)), dtype=np.int8)
flowers = np.zeros(4, dtype=np.int8)
for idx, player in enumerate(self.players):
for meld in player.melds:
for tile in meld:
melds[idx, tile] += 1
for tile in player.discards:
discards[idx, tile] += 1
flowers[idx] = len(player.flowers)
pending = np.zeros(len(TILE_TYPES), dtype=np.int8)
if self.pending_discard is not None:
pending[self.pending_discard[1]] = 1
return {
"hand": hand_counts,
"melds": melds,
"discards": discards,
"flowers": flowers,
"wall_count": np.array([len(self.wall)], dtype=np.int16),
"current_player": np.array([self.current_player], dtype=np.int8),
"phase": np.array([self.phase], dtype=np.int8),
"pending_discard": pending,
}
def action_mask(self, player_idx: int) -> Dict[str, np.ndarray]:
mask_type = np.zeros(6, dtype=bool)
mask_discard = np.zeros(len(TILE_TYPES), dtype=bool)
mask_pong = np.zeros(len(TILE_TYPES), dtype=bool)
mask_kong = np.zeros(len(TILE_TYPES), dtype=bool)
mask_chi = np.zeros(3, dtype=bool)
player = self.players[player_idx]
counts = self._counts(player.hand)
if self.phase == PHASE_TURN:
mask_type[ACTION_DISCARD] = True
for tile in player.hand:
mask_discard[tile] = True
if is_win(player.hand, len(player.melds)):
mask_type[ACTION_DECLARE_WIN] = True
for tile_id, count in enumerate(counts):
if count >= 4:
mask_type[ACTION_DECLARE_KONG] = True
mask_kong[tile_id] = True
elif self.phase == PHASE_RESPONSE and self.pending_discard is not None:
mask_type[ACTION_PASS] = True
discarder, tile = self.pending_discard
virtual_hand = player.hand + [tile]
if is_win(virtual_hand, len(player.melds)):
mask_type[ACTION_DECLARE_WIN] = True
if counts[tile] >= 2:
mask_type[ACTION_DECLARE_PONG] = True
mask_pong[tile] = True
if counts[tile] >= 3:
mask_type[ACTION_DECLARE_KONG] = True
mask_kong[tile] = True
if player_idx == (discarder + 1) % 4:
options = can_chi_options(counts, tile)
if options:
mask_type[ACTION_DECLARE_CHI] = True
for option in options:
mask_chi[option] = True
return {
"type": mask_type,
"discard": mask_discard,
"pong": mask_pong,
"kong": mask_kong,
"chi": mask_chi,
}
def step(self, player_idx: int, action: Dict[str, int]) -> float:
if self.done:
return 0.0
if player_idx != self.current_player:
raise ValueError("Action from non-active player")
action = self._sanitize_action(player_idx, action)
if self.phase == PHASE_TURN:
if not self.turn_started:
self._start_turn(player_idx)
if self.done:
return 0.0
self.turn_started = True
return self._handle_turn_action(player_idx, action)
return self._handle_response_action(player_idx, action)
def _handle_turn_action(self, player_idx: int, action: Dict[str, int]) -> float:
player = self.players[player_idx]
action_type = int(action.get("type", ACTION_DISCARD))
tile_id = int(action.get("tile", 0))
if action_type == ACTION_DECLARE_WIN and is_win(player.hand, len(player.melds)):
self._end_game(player_idx, "self_draw", player.hand)
return 0.0
if action_type == ACTION_DECLARE_KONG and player.hand.count(tile_id) >= 4:
for _ in range(4):
player.hand.remove(tile_id)
player.melds.append([tile_id] * 4)
self._draw_tile(player_idx)
self.turn_started = True
return 0.0
if action_type == ACTION_DISCARD and tile_id in player.hand:
player.hand.remove(tile_id)
player.discards.append(tile_id)
self.pending_discard = (player_idx, tile_id)
self.response_queue = [(player_idx + offset) % 4 for offset in range(1, 4)]
self.phase = PHASE_RESPONSE
self.current_player = self.response_queue.pop(0)
self.turn_started = False
return 0.0
return 0.0
def _handle_response_action(self, player_idx: int, action: Dict[str, int]) -> float:
if self.pending_discard is None:
self.phase = PHASE_TURN
self.turn_started = False
return 0.0
discarder, tile = self.pending_discard
player = self.players[player_idx]
counts = self._counts(player.hand)
action_type = int(action.get("type", ACTION_PASS))
chi_choice = int(action.get("chi", 0))
if action_type == ACTION_DECLARE_WIN and is_win(player.hand + [tile], len(player.melds)):
self._end_game(player_idx, "ron", player.hand + [tile])
return 0.0
if action_type == ACTION_DECLARE_PONG and counts[tile] >= 2:
self._consume_discard(discarder, tile)
player.hand.remove(tile)
player.hand.remove(tile)
player.melds.append([tile, tile, tile])
self.current_player = player_idx
self.pending_discard = None
self.response_queue = []
self.phase = PHASE_TURN
self.turn_started = False
self.skip_draw_for_player = player_idx
return self.pong_reward
if action_type == ACTION_DECLARE_KONG and counts[tile] >= 3:
self._consume_discard(discarder, tile)
for _ in range(3):
player.hand.remove(tile)
player.melds.append([tile, tile, tile, tile])
self.current_player = player_idx
self.pending_discard = None
self.response_queue = []
self.phase = PHASE_TURN
self.turn_started = True
self._draw_tile(player_idx)
return 0.0
if action_type == ACTION_DECLARE_CHI and player_idx == (discarder + 1) % 4:
options = can_chi_options(counts, tile)
if chi_choice in options:
self._consume_discard(discarder, tile)
if chi_choice == 0:
chi_tiles = [tile - 2, tile - 1, tile]
elif chi_choice == 1:
chi_tiles = [tile - 1, tile, tile + 1]
else:
chi_tiles = [tile, tile + 1, tile + 2]
for chi_tile in chi_tiles:
if chi_tile != tile:
player.hand.remove(chi_tile)
player.melds.append(chi_tiles)
self.current_player = player_idx
self.pending_discard = None
self.response_queue = []
self.phase = PHASE_TURN
self.turn_started = False
self.skip_draw_for_player = player_idx
return 0.0
self._advance_response_queue(discarder)
return 0.0
def _sanitize_action(self, player_idx: int, action: Dict[str, int]) -> Dict[str, int]:
mask = self.action_mask(player_idx)
action_type = int(action.get("type", ACTION_PASS))
tile_id = int(action.get("tile", 0))
chi_choice = int(action.get("chi", 0))
if action_type < 0 or action_type >= 6 or not mask["type"][action_type]:
return self._sample_valid_action(mask)
if action_type == ACTION_DISCARD and not mask["discard"][tile_id]:
return self._sample_valid_action(mask)
if action_type == ACTION_DECLARE_PONG and not mask["pong"][tile_id]:
return self._sample_valid_action(mask)
if action_type == ACTION_DECLARE_KONG and not mask["kong"][tile_id]:
return self._sample_valid_action(mask)
if action_type == ACTION_DECLARE_CHI and not mask["chi"][chi_choice]:
return self._sample_valid_action(mask)
return {"type": action_type, "tile": tile_id, "chi": chi_choice}
def _sample_valid_action(self, mask: Dict[str, np.ndarray]) -> Dict[str, int]:
valid_types = np.flatnonzero(mask["type"])
if len(valid_types) == 0:
return {"type": ACTION_PASS, "tile": 0, "chi": 0}
action_type = int(self._random.choice(valid_types))
tile_id = 0
chi_choice = 0
if action_type == ACTION_DISCARD:
valid_tiles = np.flatnonzero(mask["discard"])
tile_id = int(self._random.choice(valid_tiles)) if len(valid_tiles) else 0
if action_type == ACTION_DECLARE_PONG:
valid_tiles = np.flatnonzero(mask["pong"])
tile_id = int(self._random.choice(valid_tiles)) if len(valid_tiles) else 0
if action_type == ACTION_DECLARE_KONG:
valid_tiles = np.flatnonzero(mask["kong"])
tile_id = int(self._random.choice(valid_tiles)) if len(valid_tiles) else 0
if action_type == ACTION_DECLARE_CHI:
valid_chi = np.flatnonzero(mask["chi"])
chi_choice = int(self._random.choice(valid_chi)) if len(valid_chi) else 0
return {"type": action_type, "tile": tile_id, "chi": chi_choice}
def _advance_response_queue(self, discarder: int):
if self.response_queue:
self.current_player = self.response_queue.pop(0)
return
self.pending_discard = None
self.phase = PHASE_TURN
self.current_player = (discarder + 1) % 4
self.turn_started = False
def _start_turn(self, player_idx: int):
if self.skip_draw_for_player == player_idx:
self.skip_draw_for_player = None
return
self._draw_tile(player_idx)
def _draw_tile(self, player_idx: int):
while self.wall:
tile = self.wall.pop()
if is_flower(tile):
self.players[player_idx].flowers.append(tile)
continue
self.players[player_idx].hand.append(tile)
return tile
self._end_game(None, "wall_empty", [])
return None
def _consume_discard(self, discarder: int, tile: int):
discards = self.players[discarder].discards
if discards and discards[-1] == tile:
discards.pop()
def _end_game(self, winner: Optional[int], reason: str, winning_tiles: List[int]):
if self.done:
return
self.done = True
self.winner = winner
self.terminal_reason = reason
self.winning_tiles = winning_tiles
self.winning_melds = [] if winner is None else self.players[winner].melds
def _counts(self, tiles: List[int]) -> np.ndarray:
counts = np.zeros(len(TILE_TYPES), dtype=np.int8)
for tile in tiles:
counts[tile] += 1
return counts