feat: initialize majiang-rl project

This commit is contained in:
game-loader
2026-01-14 10:49:00 +08:00
commit b29a18b459
21 changed files with 18895 additions and 0 deletions

View File

@@ -0,0 +1,578 @@
from __future__ import annotations
import random
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple
import gymnasium as gym
import numpy as np
from gymnasium import spaces
from ..rules import can_chi_options, is_win
from ..tiles import TILE_TYPES, build_wall, is_flower, tile_name
PHASE_AGENT_TURN = 0
PHASE_AGENT_RESPONSE = 1
ACTION_DISCARD = 0
ACTION_DECLARE_WIN = 1
ACTION_DECLARE_KONG = 2
ACTION_PASS = 3
ACTION_DECLARE_PONG = 4
ACTION_DECLARE_CHI = 5
@dataclass
class PlayerState:
hand: List[int] = field(default_factory=list)
melds: List[List[int]] = field(default_factory=list)
discards: List[int] = field(default_factory=list)
flowers: List[int] = field(default_factory=list)
class MahjongEnv(gym.Env):
metadata = {"render_modes": ["human"]}
def __init__(self, seed: Optional[int] = None, max_rounds: int = 1):
super().__init__()
self.max_rounds = max_rounds
self.action_space = spaces.Dict(
{
"type": spaces.Discrete(6),
"tile": spaces.Discrete(len(TILE_TYPES)),
"chi": spaces.Discrete(3),
}
)
self.observation_space = spaces.Dict(
{
"hand": spaces.Box(0, 4, shape=(len(TILE_TYPES),), dtype=np.int8),
"melds": spaces.Box(0, 4, shape=(4, len(TILE_TYPES)), dtype=np.int8),
"discards": spaces.Box(0, 4, shape=(4, len(TILE_TYPES)), dtype=np.int8),
"flowers": spaces.Box(0, 8, shape=(4,), dtype=np.int8),
"wall_count": spaces.Box(0, 144, shape=(1,), dtype=np.int16),
"current_player": spaces.Box(0, 3, shape=(1,), dtype=np.int8),
"phase": spaces.Box(0, 1, shape=(1,), dtype=np.int8),
}
)
self._random = random.Random(seed)
self._round = 0
self.agent_index = 0
self.players: List[PlayerState] = []
self.wall: List[int] = []
self.current_player = 0
self.phase = PHASE_AGENT_TURN
self.pending_discard: Optional[Tuple[int, int]] = None
self.skip_draw_for_player: Optional[int] = None
self.done = False
self.winner: Optional[int] = None
self.terminal_reason: Optional[str] = None
self.event_log: List[Dict[str, object]] = []
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
super().reset(seed=seed)
if seed is not None:
self._random.seed(seed)
self._round += 1
self.players = [PlayerState() for _ in range(4)]
self.wall = build_wall(self._random)
self.current_player = 0
self.pending_discard = None
self.skip_draw_for_player = None
self.done = False
self.winner = None
self.terminal_reason = None
self.event_log = []
for _ in range(13):
for player in range(4):
self._draw_tile(player)
self._draw_tile(0)
self.skip_draw_for_player = 0
self.phase = PHASE_AGENT_TURN
self._advance_to_agent()
return self._get_obs(), self._get_info()
def step(self, action: Dict[str, int]):
if self.done:
return self._get_obs(), 0.0, True, False, self._get_info()
reward = 0.0
action = self._sanitize_action(action)
action_type = int(action["type"])
tile_id = int(action["tile"])
chi_choice = int(action["chi"])
if self.phase == PHASE_AGENT_TURN:
reward += self._handle_agent_turn(action_type, tile_id)
elif self.phase == PHASE_AGENT_RESPONSE:
reward += self._handle_agent_response(action_type, tile_id, chi_choice)
if not self.done:
self._advance_to_agent()
if self.done:
if self.winner == self.agent_index:
reward += 1.0
elif self.winner is not None:
reward -= 1.0
return self._get_obs(), reward, self.done, False, self._get_info()
def render(self):
if self.done:
print(f"Winner: {self.winner}, reason: {self.terminal_reason}")
else:
print(f"Player {self.current_player} to act, wall {len(self.wall)}")
def _get_obs(self):
hand_counts = self._counts(self.players[self.agent_index].hand)
melds = np.zeros((4, len(TILE_TYPES)), dtype=np.int8)
discards = np.zeros((4, len(TILE_TYPES)), dtype=np.int8)
flowers = np.zeros(4, dtype=np.int8)
for idx, player in enumerate(self.players):
for meld in player.melds:
for tile in meld:
melds[idx, tile] += 1
for tile in player.discards:
discards[idx, tile] += 1
flowers[idx] = len(player.flowers)
return {
"hand": hand_counts,
"melds": melds,
"discards": discards,
"flowers": flowers,
"wall_count": np.array([len(self.wall)], dtype=np.int16),
"current_player": np.array([self.current_player], dtype=np.int8),
"phase": np.array([self.phase], dtype=np.int8),
}
def _get_info(self):
action_mask = self._action_mask()
return {
"winner": self.winner,
"terminal_reason": self.terminal_reason,
"pending_discard": self.pending_discard,
"action_mask": action_mask,
}
def _sanitize_action(self, action: Dict[str, int]) -> Dict[str, int]:
if not isinstance(action, dict):
return self._sample_valid_action()
if "type" not in action or "tile" not in action or "chi" not in action:
return self._sample_valid_action()
mask = self._action_mask()
action_type = int(action["type"])
tile_id = int(action["tile"])
chi_choice = int(action["chi"])
if action_type < 0 or action_type >= 6:
return self._sample_valid_action()
if not mask["type"][action_type]:
return self._sample_valid_action()
if action_type == ACTION_DISCARD:
if not mask["discard"][tile_id]:
return self._sample_valid_action()
if action_type == ACTION_DECLARE_PONG:
if not mask["pong"][tile_id]:
return self._sample_valid_action()
if action_type == ACTION_DECLARE_KONG:
if not mask["kong"][tile_id]:
return self._sample_valid_action()
if action_type == ACTION_DECLARE_CHI:
if not mask["chi"][chi_choice]:
return self._sample_valid_action()
return {"type": action_type, "tile": tile_id, "chi": chi_choice}
def _sample_valid_action(self) -> Dict[str, int]:
mask = self._action_mask()
valid_types = np.flatnonzero(mask["type"])
if len(valid_types) == 0:
return {"type": ACTION_PASS, "tile": 0, "chi": 0}
action_type = int(self._random.choice(valid_types))
tile_id = 0
chi_choice = 0
if action_type == ACTION_DISCARD:
valid_tiles = np.flatnonzero(mask["discard"])
tile_id = int(self._random.choice(valid_tiles)) if len(valid_tiles) else 0
if action_type == ACTION_DECLARE_PONG:
valid_tiles = np.flatnonzero(mask["pong"])
tile_id = int(self._random.choice(valid_tiles)) if len(valid_tiles) else 0
if action_type == ACTION_DECLARE_KONG:
valid_tiles = np.flatnonzero(mask["kong"])
tile_id = int(self._random.choice(valid_tiles)) if len(valid_tiles) else 0
if action_type == ACTION_DECLARE_CHI:
valid_chi = np.flatnonzero(mask["chi"])
chi_choice = int(self._random.choice(valid_chi)) if len(valid_chi) else 0
return {"type": action_type, "tile": tile_id, "chi": chi_choice}
def _action_mask(self) -> Dict[str, np.ndarray]:
mask_type = np.zeros(6, dtype=bool)
mask_discard = np.zeros(len(TILE_TYPES), dtype=bool)
mask_pong = np.zeros(len(TILE_TYPES), dtype=bool)
mask_kong = np.zeros(len(TILE_TYPES), dtype=bool)
mask_chi = np.zeros(3, dtype=bool)
agent = self.players[self.agent_index]
counts = self._counts(agent.hand)
if self.phase == PHASE_AGENT_TURN:
mask_type[ACTION_DISCARD] = True
for tile in agent.hand:
mask_discard[tile] = True
if is_win(agent.hand, len(agent.melds)):
mask_type[ACTION_DECLARE_WIN] = True
for tile_id, count in enumerate(counts):
if count >= 4:
mask_type[ACTION_DECLARE_KONG] = True
mask_kong[tile_id] = True
elif self.phase == PHASE_AGENT_RESPONSE and self.pending_discard is not None:
mask_type[ACTION_PASS] = True
_, tile = self.pending_discard
virtual_hand = agent.hand + [tile]
if is_win(virtual_hand, len(agent.melds)):
mask_type[ACTION_DECLARE_WIN] = True
if counts[tile] >= 2:
mask_type[ACTION_DECLARE_PONG] = True
mask_pong[tile] = True
if counts[tile] >= 3:
mask_type[ACTION_DECLARE_KONG] = True
mask_kong[tile] = True
if self._is_next_player(self.agent_index, self.pending_discard[0]):
options = can_chi_options(counts, tile)
if options:
mask_type[ACTION_DECLARE_CHI] = True
for option in options:
mask_chi[option] = True
return {
"type": mask_type,
"discard": mask_discard,
"pong": mask_pong,
"kong": mask_kong,
"chi": mask_chi,
}
def _handle_agent_turn(self, action_type: int, tile_id: int) -> float:
agent = self.players[self.agent_index]
if action_type == ACTION_DECLARE_WIN:
if is_win(agent.hand, len(agent.melds)):
self._end_game(self.agent_index, "self_draw")
return 0.0
if action_type == ACTION_DECLARE_KONG:
if self._can_concealed_kong(agent, tile_id):
self._make_kong(self.agent_index, tile_id, concealed=True)
self._draw_tile(self.agent_index)
self.skip_draw_for_player = self.agent_index
self.phase = PHASE_AGENT_TURN
return 0.0
if action_type == ACTION_DISCARD:
if tile_id in agent.hand:
agent.hand.remove(tile_id)
agent.discards.append(tile_id)
self.pending_discard = (self.agent_index, tile_id)
self._log_event(
"discard",
player=self.agent_index,
tile_id=tile_id,
)
self._resolve_discard()
return 0.0
return 0.0
def _handle_agent_response(self, action_type: int, tile_id: int, chi_choice: int) -> float:
if self.pending_discard is None:
self.phase = PHASE_AGENT_TURN
return 0.0
discarder, tile = self.pending_discard
agent = self.players[self.agent_index]
counts = self._counts(agent.hand)
if action_type == ACTION_DECLARE_WIN:
if is_win(agent.hand + [tile], len(agent.melds)):
self._end_game(self.agent_index, "ron")
return 0.0
if action_type == ACTION_DECLARE_PONG:
if counts[tile] >= 2:
self._consume_discard(discarder, tile)
agent.hand.remove(tile)
agent.hand.remove(tile)
agent.melds.append([tile, tile, tile])
self.current_player = self.agent_index
self.skip_draw_for_player = self.agent_index
self.pending_discard = None
self._log_event(
"claim",
player=self.agent_index,
claim="pong",
tile_id=tile,
from_player=discarder,
)
self.phase = PHASE_AGENT_TURN
return 0.0
if action_type == ACTION_DECLARE_KONG:
if counts[tile] >= 3:
self._consume_discard(discarder, tile)
for _ in range(3):
agent.hand.remove(tile)
agent.melds.append([tile, tile, tile, tile])
self.current_player = self.agent_index
self.skip_draw_for_player = self.agent_index
self.pending_discard = None
self._draw_tile(self.agent_index)
self._log_event(
"claim",
player=self.agent_index,
claim="kong",
tile_id=tile,
from_player=discarder,
)
self.phase = PHASE_AGENT_TURN
return 0.0
if action_type == ACTION_DECLARE_CHI:
if self._is_next_player(self.agent_index, discarder):
options = can_chi_options(counts, tile)
if chi_choice in options:
self._consume_discard(discarder, tile)
if chi_choice == 0:
chi_tiles = [tile - 2, tile - 1, tile]
elif chi_choice == 1:
chi_tiles = [tile - 1, tile, tile + 1]
else:
chi_tiles = [tile, tile + 1, tile + 2]
for chi_tile in chi_tiles:
if chi_tile != tile:
agent.hand.remove(chi_tile)
agent.melds.append(chi_tiles)
self.current_player = self.agent_index
self.skip_draw_for_player = self.agent_index
self.pending_discard = None
self._log_event(
"claim",
player=self.agent_index,
claim="chi",
tile_id=tile,
from_player=discarder,
)
self.phase = PHASE_AGENT_TURN
return 0.0
self.pending_discard = None
self.phase = PHASE_AGENT_TURN
self._resolve_discard(after_pass=True, original_discarder=discarder, tile=tile)
return 0.0
def _advance_to_agent(self):
while not self.done:
if self.phase == PHASE_AGENT_RESPONSE:
return
if self.current_player == self.agent_index:
self._start_turn(self.agent_index)
if self.done:
return
self.phase = PHASE_AGENT_TURN
return
self._npc_turn(self.current_player)
if self.done or self.phase == PHASE_AGENT_RESPONSE:
return
def _npc_turn(self, player_idx: int):
self._start_turn(player_idx)
if self.done:
return
player = self.players[player_idx]
if is_win(player.hand, len(player.melds)):
self._end_game(player_idx, "self_draw")
return
tile = self._random.choice(player.hand)
player.hand.remove(tile)
player.discards.append(tile)
self.pending_discard = (player_idx, tile)
self._log_event(
"discard",
player=player_idx,
tile_id=tile,
)
self._resolve_discard()
def _resolve_discard(self, after_pass: bool = False, original_discarder: Optional[int] = None, tile: Optional[int] = None):
if self.pending_discard is None and not after_pass:
return
discarder, discarded_tile = self.pending_discard if self.pending_discard is not None else (original_discarder, tile)
if discarder is None or discarded_tile is None:
return
if discarder != self.agent_index and not after_pass:
if self._agent_can_respond(discarder, discarded_tile):
self.pending_discard = (discarder, discarded_tile)
self.phase = PHASE_AGENT_RESPONSE
return
winner = self._npc_win_on_discard(discarder, discarded_tile)
if winner is not None:
self._end_game(winner, "ron")
return
claim = self._npc_claim(discarder, discarded_tile)
if claim is not None:
self.pending_discard = None
self._apply_npc_claim(*claim)
return
self.pending_discard = None
self.current_player = (discarder + 1) % 4
def _agent_can_respond(self, discarder: int, tile: int) -> bool:
agent = self.players[self.agent_index]
counts = self._counts(agent.hand)
if is_win(agent.hand + [tile], len(agent.melds)):
return True
if counts[tile] >= 2:
return True
if counts[tile] >= 3:
return True
if self._is_next_player(self.agent_index, discarder):
return bool(can_chi_options(counts, tile))
return False
def _npc_win_on_discard(self, discarder: int, tile: int) -> Optional[int]:
for offset in range(1, 4):
player_idx = (discarder + offset) % 4
if player_idx == self.agent_index:
continue
player = self.players[player_idx]
if is_win(player.hand + [tile], len(player.melds)):
return player_idx
return None
def _npc_claim(self, discarder: int, tile: int):
for offset in range(1, 4):
player_idx = (discarder + offset) % 4
if player_idx == self.agent_index:
continue
player = self.players[player_idx]
counts = self._counts(player.hand)
if counts[tile] >= 3:
return (player_idx, "kong", tile, None, discarder)
if counts[tile] >= 2:
return (player_idx, "pong", tile, None, discarder)
next_player = (discarder + 1) % 4
if next_player != self.agent_index:
player = self.players[next_player]
options = can_chi_options(self._counts(player.hand), tile)
if options:
return (next_player, "chi", tile, options[0], discarder)
return None
def _apply_npc_claim(
self,
player_idx: int,
claim_type: str,
tile: int,
chi_choice: Optional[int],
discarder: int,
):
player = self.players[player_idx]
self._consume_discard(discarder, tile)
if claim_type == "pong":
player.hand.remove(tile)
player.hand.remove(tile)
player.melds.append([tile, tile, tile])
self.current_player = player_idx
self.skip_draw_for_player = player_idx
elif claim_type == "kong":
for _ in range(3):
player.hand.remove(tile)
player.melds.append([tile, tile, tile, tile])
self.current_player = player_idx
self.skip_draw_for_player = player_idx
self._draw_tile(player_idx)
elif claim_type == "chi" and chi_choice is not None:
if chi_choice == 0:
chi_tiles = [tile - 2, tile - 1, tile]
elif chi_choice == 1:
chi_tiles = [tile - 1, tile, tile + 1]
else:
chi_tiles = [tile, tile + 1, tile + 2]
for chi_tile in chi_tiles:
if chi_tile != tile:
player.hand.remove(chi_tile)
player.melds.append(chi_tiles)
self.current_player = player_idx
self.skip_draw_for_player = player_idx
self._log_event(
"claim",
player=player_idx,
claim=claim_type,
tile_id=tile,
from_player=discarder,
)
def _draw_tile(self, player_idx: int):
while self.wall:
tile = self.wall.pop()
if is_flower(tile):
self.players[player_idx].flowers.append(tile)
continue
self.players[player_idx].hand.append(tile)
return tile
self._end_game(None, "wall_empty")
return None
def _start_turn(self, player_idx: int):
if self.skip_draw_for_player == player_idx:
self.skip_draw_for_player = None
return
self._draw_tile(player_idx)
def _make_kong(self, player_idx: int, tile_id: int, concealed: bool):
player = self.players[player_idx]
for _ in range(4):
player.hand.remove(tile_id)
player.melds.append([tile_id] * 4)
def _can_concealed_kong(self, player: PlayerState, tile_id: int) -> bool:
return player.hand.count(tile_id) >= 4
def _counts(self, tiles: List[int]) -> np.ndarray:
counts = np.zeros(len(TILE_TYPES), dtype=np.int8)
for tile in tiles:
counts[tile] += 1
return counts
def _is_next_player(self, player: int, discarder: int) -> bool:
return player == (discarder + 1) % 4
def _consume_discard(self, discarder: int, tile: int):
discards = self.players[discarder].discards
if discards and discards[-1] == tile:
discards.pop()
def _end_game(self, winner: Optional[int], reason: str):
if self.done:
return
self.done = True
self.winner = winner
self.terminal_reason = reason
self._log_event(
"win" if winner is not None else "draw",
player=winner,
reason=reason,
)
def _snapshot(self) -> Dict[str, object]:
return {
"hands": [[tile_name(tile) for tile in sorted(p.hand)] for p in self.players],
"discards": [[tile_name(tile) for tile in p.discards] for p in self.players],
"melds": [[[tile_name(tile) for tile in meld] for meld in p.melds] for p in self.players],
"flowers": [[tile_name(tile) for tile in p.flowers] for p in self.players],
"wall_count": len(self.wall),
"current_player": self.current_player,
}
def _log_event(self, event_type: str, **payload: object):
event = {
"id": len(self.event_log),
"type": event_type,
"snapshot": self._snapshot(),
}
event.update(payload)
if "tile_id" in event:
event["tile"] = tile_name(int(event["tile_id"]))
self.event_log.append(event)