feat: initialize majiang-rl project

This commit is contained in:
game-loader
2026-01-14 10:49:00 +08:00
commit b29a18b459
21 changed files with 18895 additions and 0 deletions

10
.gitignore vendored Normal file
View File

@@ -0,0 +1,10 @@
# Python-generated files
__pycache__/
*.py[oc]
build/
dist/
wheels/
*.egg-info
# Virtual environments
.venv

1
.python-version Normal file
View File

@@ -0,0 +1 @@
3.12

64
README.md Normal file
View File

@@ -0,0 +1,64 @@
# majiang-rl
A minimal Mahjong (Guobiao) simulation environment and reinforcement learning scaffold built on gymnasium.
## Features
- 4-player Guobiao tile set (144 tiles, including flowers)
- Draw/discard turn loop with flower replacement
- Basic calls: win (ron/tsumo), pong, chi, kong
- Win checking for standard hands, seven pairs, and thirteen orphans
- Gymnasium-style environment API with action masks (type/discard/pong/kong/chi)
- Simple RL loop and random agent
## Limitations
- No scoring or 8-fan enforcement
- NPC players use simple greedy claims and random discards
- No detailed round rules (winds/seat rotation, riichi, etc.)
## Quick start (uv)
```bash
uv venv
uv pip install -e .
uv run python main.py
```
## Environment API
```python
from majiang_rl import MahjongEnv
env = MahjongEnv()
obs, info = env.reset()
# action format
# type: 0 discard, 1 declare win, 2 declare kong, 3 pass, 4 declare pong, 5 declare chi
# tile: tile id (0-41)
# chi: 0 left, 1 middle, 2 right
action = {"type": 0, "tile": 0, "chi": 0}
obs, reward, terminated, truncated, info = env.step(action)
```
## RL scaffold
```python
from majiang_rl import MahjongEnv
from majiang_rl.rl import RandomAgent, run_training
env = MahjongEnv()
agent = RandomAgent()
results = run_training(env, agent, episodes=10)
print(results[0])
```
## GRPO self-play training
```bash
uv run python -m majiang_rl.rl.grpo --updates 20 --group-size 16 --device auto
uv run python -m majiang_rl.rl.grpo --updates 20 --group-size 16 --device auto --pong-reward 0.1 --closest-bonus 1.0
uv run python -m majiang_rl.rl.grpo --updates 20 --group-size 16 --device auto --swanlab --swanlab-project majiang-rl --swanlab-run-name grpo-demo
```
Reward uses a simplified fan breakdown (thirteen orphans, seven pairs, pure/half flush, all pungs, all honors).
## Simple web UI
```bash
uv run python -m majiang_rl.ui.web --port 8000
```
Then open `http://localhost:8000/index.html` to watch the playback.

14
main.py Normal file
View File

@@ -0,0 +1,14 @@
from majiang_rl import MahjongEnv
from majiang_rl.rl import RandomAgent, run_training
def main():
env = MahjongEnv()
agent = RandomAgent()
results = run_training(env, agent, episodes=5)
for idx, result in enumerate(results, start=1):
print(f"Episode {idx}: reward={result.total_reward}, steps={result.steps}")
if __name__ == "__main__":
main()

3
majiang_rl/__init__.py Normal file
View File

@@ -0,0 +1,3 @@
from .envs import MahjongEnv, SelfPlayMahjong
__all__ = ["MahjongEnv", "SelfPlayMahjong"]

View File

@@ -0,0 +1,4 @@
from .mahjong_env import MahjongEnv
from .self_play_env import SelfPlayMahjong
__all__ = ["MahjongEnv", "SelfPlayMahjong"]

View File

@@ -0,0 +1,578 @@
from __future__ import annotations
import random
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple
import gymnasium as gym
import numpy as np
from gymnasium import spaces
from ..rules import can_chi_options, is_win
from ..tiles import TILE_TYPES, build_wall, is_flower, tile_name
PHASE_AGENT_TURN = 0
PHASE_AGENT_RESPONSE = 1
ACTION_DISCARD = 0
ACTION_DECLARE_WIN = 1
ACTION_DECLARE_KONG = 2
ACTION_PASS = 3
ACTION_DECLARE_PONG = 4
ACTION_DECLARE_CHI = 5
@dataclass
class PlayerState:
hand: List[int] = field(default_factory=list)
melds: List[List[int]] = field(default_factory=list)
discards: List[int] = field(default_factory=list)
flowers: List[int] = field(default_factory=list)
class MahjongEnv(gym.Env):
metadata = {"render_modes": ["human"]}
def __init__(self, seed: Optional[int] = None, max_rounds: int = 1):
super().__init__()
self.max_rounds = max_rounds
self.action_space = spaces.Dict(
{
"type": spaces.Discrete(6),
"tile": spaces.Discrete(len(TILE_TYPES)),
"chi": spaces.Discrete(3),
}
)
self.observation_space = spaces.Dict(
{
"hand": spaces.Box(0, 4, shape=(len(TILE_TYPES),), dtype=np.int8),
"melds": spaces.Box(0, 4, shape=(4, len(TILE_TYPES)), dtype=np.int8),
"discards": spaces.Box(0, 4, shape=(4, len(TILE_TYPES)), dtype=np.int8),
"flowers": spaces.Box(0, 8, shape=(4,), dtype=np.int8),
"wall_count": spaces.Box(0, 144, shape=(1,), dtype=np.int16),
"current_player": spaces.Box(0, 3, shape=(1,), dtype=np.int8),
"phase": spaces.Box(0, 1, shape=(1,), dtype=np.int8),
}
)
self._random = random.Random(seed)
self._round = 0
self.agent_index = 0
self.players: List[PlayerState] = []
self.wall: List[int] = []
self.current_player = 0
self.phase = PHASE_AGENT_TURN
self.pending_discard: Optional[Tuple[int, int]] = None
self.skip_draw_for_player: Optional[int] = None
self.done = False
self.winner: Optional[int] = None
self.terminal_reason: Optional[str] = None
self.event_log: List[Dict[str, object]] = []
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
super().reset(seed=seed)
if seed is not None:
self._random.seed(seed)
self._round += 1
self.players = [PlayerState() for _ in range(4)]
self.wall = build_wall(self._random)
self.current_player = 0
self.pending_discard = None
self.skip_draw_for_player = None
self.done = False
self.winner = None
self.terminal_reason = None
self.event_log = []
for _ in range(13):
for player in range(4):
self._draw_tile(player)
self._draw_tile(0)
self.skip_draw_for_player = 0
self.phase = PHASE_AGENT_TURN
self._advance_to_agent()
return self._get_obs(), self._get_info()
def step(self, action: Dict[str, int]):
if self.done:
return self._get_obs(), 0.0, True, False, self._get_info()
reward = 0.0
action = self._sanitize_action(action)
action_type = int(action["type"])
tile_id = int(action["tile"])
chi_choice = int(action["chi"])
if self.phase == PHASE_AGENT_TURN:
reward += self._handle_agent_turn(action_type, tile_id)
elif self.phase == PHASE_AGENT_RESPONSE:
reward += self._handle_agent_response(action_type, tile_id, chi_choice)
if not self.done:
self._advance_to_agent()
if self.done:
if self.winner == self.agent_index:
reward += 1.0
elif self.winner is not None:
reward -= 1.0
return self._get_obs(), reward, self.done, False, self._get_info()
def render(self):
if self.done:
print(f"Winner: {self.winner}, reason: {self.terminal_reason}")
else:
print(f"Player {self.current_player} to act, wall {len(self.wall)}")
def _get_obs(self):
hand_counts = self._counts(self.players[self.agent_index].hand)
melds = np.zeros((4, len(TILE_TYPES)), dtype=np.int8)
discards = np.zeros((4, len(TILE_TYPES)), dtype=np.int8)
flowers = np.zeros(4, dtype=np.int8)
for idx, player in enumerate(self.players):
for meld in player.melds:
for tile in meld:
melds[idx, tile] += 1
for tile in player.discards:
discards[idx, tile] += 1
flowers[idx] = len(player.flowers)
return {
"hand": hand_counts,
"melds": melds,
"discards": discards,
"flowers": flowers,
"wall_count": np.array([len(self.wall)], dtype=np.int16),
"current_player": np.array([self.current_player], dtype=np.int8),
"phase": np.array([self.phase], dtype=np.int8),
}
def _get_info(self):
action_mask = self._action_mask()
return {
"winner": self.winner,
"terminal_reason": self.terminal_reason,
"pending_discard": self.pending_discard,
"action_mask": action_mask,
}
def _sanitize_action(self, action: Dict[str, int]) -> Dict[str, int]:
if not isinstance(action, dict):
return self._sample_valid_action()
if "type" not in action or "tile" not in action or "chi" not in action:
return self._sample_valid_action()
mask = self._action_mask()
action_type = int(action["type"])
tile_id = int(action["tile"])
chi_choice = int(action["chi"])
if action_type < 0 or action_type >= 6:
return self._sample_valid_action()
if not mask["type"][action_type]:
return self._sample_valid_action()
if action_type == ACTION_DISCARD:
if not mask["discard"][tile_id]:
return self._sample_valid_action()
if action_type == ACTION_DECLARE_PONG:
if not mask["pong"][tile_id]:
return self._sample_valid_action()
if action_type == ACTION_DECLARE_KONG:
if not mask["kong"][tile_id]:
return self._sample_valid_action()
if action_type == ACTION_DECLARE_CHI:
if not mask["chi"][chi_choice]:
return self._sample_valid_action()
return {"type": action_type, "tile": tile_id, "chi": chi_choice}
def _sample_valid_action(self) -> Dict[str, int]:
mask = self._action_mask()
valid_types = np.flatnonzero(mask["type"])
if len(valid_types) == 0:
return {"type": ACTION_PASS, "tile": 0, "chi": 0}
action_type = int(self._random.choice(valid_types))
tile_id = 0
chi_choice = 0
if action_type == ACTION_DISCARD:
valid_tiles = np.flatnonzero(mask["discard"])
tile_id = int(self._random.choice(valid_tiles)) if len(valid_tiles) else 0
if action_type == ACTION_DECLARE_PONG:
valid_tiles = np.flatnonzero(mask["pong"])
tile_id = int(self._random.choice(valid_tiles)) if len(valid_tiles) else 0
if action_type == ACTION_DECLARE_KONG:
valid_tiles = np.flatnonzero(mask["kong"])
tile_id = int(self._random.choice(valid_tiles)) if len(valid_tiles) else 0
if action_type == ACTION_DECLARE_CHI:
valid_chi = np.flatnonzero(mask["chi"])
chi_choice = int(self._random.choice(valid_chi)) if len(valid_chi) else 0
return {"type": action_type, "tile": tile_id, "chi": chi_choice}
def _action_mask(self) -> Dict[str, np.ndarray]:
mask_type = np.zeros(6, dtype=bool)
mask_discard = np.zeros(len(TILE_TYPES), dtype=bool)
mask_pong = np.zeros(len(TILE_TYPES), dtype=bool)
mask_kong = np.zeros(len(TILE_TYPES), dtype=bool)
mask_chi = np.zeros(3, dtype=bool)
agent = self.players[self.agent_index]
counts = self._counts(agent.hand)
if self.phase == PHASE_AGENT_TURN:
mask_type[ACTION_DISCARD] = True
for tile in agent.hand:
mask_discard[tile] = True
if is_win(agent.hand, len(agent.melds)):
mask_type[ACTION_DECLARE_WIN] = True
for tile_id, count in enumerate(counts):
if count >= 4:
mask_type[ACTION_DECLARE_KONG] = True
mask_kong[tile_id] = True
elif self.phase == PHASE_AGENT_RESPONSE and self.pending_discard is not None:
mask_type[ACTION_PASS] = True
_, tile = self.pending_discard
virtual_hand = agent.hand + [tile]
if is_win(virtual_hand, len(agent.melds)):
mask_type[ACTION_DECLARE_WIN] = True
if counts[tile] >= 2:
mask_type[ACTION_DECLARE_PONG] = True
mask_pong[tile] = True
if counts[tile] >= 3:
mask_type[ACTION_DECLARE_KONG] = True
mask_kong[tile] = True
if self._is_next_player(self.agent_index, self.pending_discard[0]):
options = can_chi_options(counts, tile)
if options:
mask_type[ACTION_DECLARE_CHI] = True
for option in options:
mask_chi[option] = True
return {
"type": mask_type,
"discard": mask_discard,
"pong": mask_pong,
"kong": mask_kong,
"chi": mask_chi,
}
def _handle_agent_turn(self, action_type: int, tile_id: int) -> float:
agent = self.players[self.agent_index]
if action_type == ACTION_DECLARE_WIN:
if is_win(agent.hand, len(agent.melds)):
self._end_game(self.agent_index, "self_draw")
return 0.0
if action_type == ACTION_DECLARE_KONG:
if self._can_concealed_kong(agent, tile_id):
self._make_kong(self.agent_index, tile_id, concealed=True)
self._draw_tile(self.agent_index)
self.skip_draw_for_player = self.agent_index
self.phase = PHASE_AGENT_TURN
return 0.0
if action_type == ACTION_DISCARD:
if tile_id in agent.hand:
agent.hand.remove(tile_id)
agent.discards.append(tile_id)
self.pending_discard = (self.agent_index, tile_id)
self._log_event(
"discard",
player=self.agent_index,
tile_id=tile_id,
)
self._resolve_discard()
return 0.0
return 0.0
def _handle_agent_response(self, action_type: int, tile_id: int, chi_choice: int) -> float:
if self.pending_discard is None:
self.phase = PHASE_AGENT_TURN
return 0.0
discarder, tile = self.pending_discard
agent = self.players[self.agent_index]
counts = self._counts(agent.hand)
if action_type == ACTION_DECLARE_WIN:
if is_win(agent.hand + [tile], len(agent.melds)):
self._end_game(self.agent_index, "ron")
return 0.0
if action_type == ACTION_DECLARE_PONG:
if counts[tile] >= 2:
self._consume_discard(discarder, tile)
agent.hand.remove(tile)
agent.hand.remove(tile)
agent.melds.append([tile, tile, tile])
self.current_player = self.agent_index
self.skip_draw_for_player = self.agent_index
self.pending_discard = None
self._log_event(
"claim",
player=self.agent_index,
claim="pong",
tile_id=tile,
from_player=discarder,
)
self.phase = PHASE_AGENT_TURN
return 0.0
if action_type == ACTION_DECLARE_KONG:
if counts[tile] >= 3:
self._consume_discard(discarder, tile)
for _ in range(3):
agent.hand.remove(tile)
agent.melds.append([tile, tile, tile, tile])
self.current_player = self.agent_index
self.skip_draw_for_player = self.agent_index
self.pending_discard = None
self._draw_tile(self.agent_index)
self._log_event(
"claim",
player=self.agent_index,
claim="kong",
tile_id=tile,
from_player=discarder,
)
self.phase = PHASE_AGENT_TURN
return 0.0
if action_type == ACTION_DECLARE_CHI:
if self._is_next_player(self.agent_index, discarder):
options = can_chi_options(counts, tile)
if chi_choice in options:
self._consume_discard(discarder, tile)
if chi_choice == 0:
chi_tiles = [tile - 2, tile - 1, tile]
elif chi_choice == 1:
chi_tiles = [tile - 1, tile, tile + 1]
else:
chi_tiles = [tile, tile + 1, tile + 2]
for chi_tile in chi_tiles:
if chi_tile != tile:
agent.hand.remove(chi_tile)
agent.melds.append(chi_tiles)
self.current_player = self.agent_index
self.skip_draw_for_player = self.agent_index
self.pending_discard = None
self._log_event(
"claim",
player=self.agent_index,
claim="chi",
tile_id=tile,
from_player=discarder,
)
self.phase = PHASE_AGENT_TURN
return 0.0
self.pending_discard = None
self.phase = PHASE_AGENT_TURN
self._resolve_discard(after_pass=True, original_discarder=discarder, tile=tile)
return 0.0
def _advance_to_agent(self):
while not self.done:
if self.phase == PHASE_AGENT_RESPONSE:
return
if self.current_player == self.agent_index:
self._start_turn(self.agent_index)
if self.done:
return
self.phase = PHASE_AGENT_TURN
return
self._npc_turn(self.current_player)
if self.done or self.phase == PHASE_AGENT_RESPONSE:
return
def _npc_turn(self, player_idx: int):
self._start_turn(player_idx)
if self.done:
return
player = self.players[player_idx]
if is_win(player.hand, len(player.melds)):
self._end_game(player_idx, "self_draw")
return
tile = self._random.choice(player.hand)
player.hand.remove(tile)
player.discards.append(tile)
self.pending_discard = (player_idx, tile)
self._log_event(
"discard",
player=player_idx,
tile_id=tile,
)
self._resolve_discard()
def _resolve_discard(self, after_pass: bool = False, original_discarder: Optional[int] = None, tile: Optional[int] = None):
if self.pending_discard is None and not after_pass:
return
discarder, discarded_tile = self.pending_discard if self.pending_discard is not None else (original_discarder, tile)
if discarder is None or discarded_tile is None:
return
if discarder != self.agent_index and not after_pass:
if self._agent_can_respond(discarder, discarded_tile):
self.pending_discard = (discarder, discarded_tile)
self.phase = PHASE_AGENT_RESPONSE
return
winner = self._npc_win_on_discard(discarder, discarded_tile)
if winner is not None:
self._end_game(winner, "ron")
return
claim = self._npc_claim(discarder, discarded_tile)
if claim is not None:
self.pending_discard = None
self._apply_npc_claim(*claim)
return
self.pending_discard = None
self.current_player = (discarder + 1) % 4
def _agent_can_respond(self, discarder: int, tile: int) -> bool:
agent = self.players[self.agent_index]
counts = self._counts(agent.hand)
if is_win(agent.hand + [tile], len(agent.melds)):
return True
if counts[tile] >= 2:
return True
if counts[tile] >= 3:
return True
if self._is_next_player(self.agent_index, discarder):
return bool(can_chi_options(counts, tile))
return False
def _npc_win_on_discard(self, discarder: int, tile: int) -> Optional[int]:
for offset in range(1, 4):
player_idx = (discarder + offset) % 4
if player_idx == self.agent_index:
continue
player = self.players[player_idx]
if is_win(player.hand + [tile], len(player.melds)):
return player_idx
return None
def _npc_claim(self, discarder: int, tile: int):
for offset in range(1, 4):
player_idx = (discarder + offset) % 4
if player_idx == self.agent_index:
continue
player = self.players[player_idx]
counts = self._counts(player.hand)
if counts[tile] >= 3:
return (player_idx, "kong", tile, None, discarder)
if counts[tile] >= 2:
return (player_idx, "pong", tile, None, discarder)
next_player = (discarder + 1) % 4
if next_player != self.agent_index:
player = self.players[next_player]
options = can_chi_options(self._counts(player.hand), tile)
if options:
return (next_player, "chi", tile, options[0], discarder)
return None
def _apply_npc_claim(
self,
player_idx: int,
claim_type: str,
tile: int,
chi_choice: Optional[int],
discarder: int,
):
player = self.players[player_idx]
self._consume_discard(discarder, tile)
if claim_type == "pong":
player.hand.remove(tile)
player.hand.remove(tile)
player.melds.append([tile, tile, tile])
self.current_player = player_idx
self.skip_draw_for_player = player_idx
elif claim_type == "kong":
for _ in range(3):
player.hand.remove(tile)
player.melds.append([tile, tile, tile, tile])
self.current_player = player_idx
self.skip_draw_for_player = player_idx
self._draw_tile(player_idx)
elif claim_type == "chi" and chi_choice is not None:
if chi_choice == 0:
chi_tiles = [tile - 2, tile - 1, tile]
elif chi_choice == 1:
chi_tiles = [tile - 1, tile, tile + 1]
else:
chi_tiles = [tile, tile + 1, tile + 2]
for chi_tile in chi_tiles:
if chi_tile != tile:
player.hand.remove(chi_tile)
player.melds.append(chi_tiles)
self.current_player = player_idx
self.skip_draw_for_player = player_idx
self._log_event(
"claim",
player=player_idx,
claim=claim_type,
tile_id=tile,
from_player=discarder,
)
def _draw_tile(self, player_idx: int):
while self.wall:
tile = self.wall.pop()
if is_flower(tile):
self.players[player_idx].flowers.append(tile)
continue
self.players[player_idx].hand.append(tile)
return tile
self._end_game(None, "wall_empty")
return None
def _start_turn(self, player_idx: int):
if self.skip_draw_for_player == player_idx:
self.skip_draw_for_player = None
return
self._draw_tile(player_idx)
def _make_kong(self, player_idx: int, tile_id: int, concealed: bool):
player = self.players[player_idx]
for _ in range(4):
player.hand.remove(tile_id)
player.melds.append([tile_id] * 4)
def _can_concealed_kong(self, player: PlayerState, tile_id: int) -> bool:
return player.hand.count(tile_id) >= 4
def _counts(self, tiles: List[int]) -> np.ndarray:
counts = np.zeros(len(TILE_TYPES), dtype=np.int8)
for tile in tiles:
counts[tile] += 1
return counts
def _is_next_player(self, player: int, discarder: int) -> bool:
return player == (discarder + 1) % 4
def _consume_discard(self, discarder: int, tile: int):
discards = self.players[discarder].discards
if discards and discards[-1] == tile:
discards.pop()
def _end_game(self, winner: Optional[int], reason: str):
if self.done:
return
self.done = True
self.winner = winner
self.terminal_reason = reason
self._log_event(
"win" if winner is not None else "draw",
player=winner,
reason=reason,
)
def _snapshot(self) -> Dict[str, object]:
return {
"hands": [[tile_name(tile) for tile in sorted(p.hand)] for p in self.players],
"discards": [[tile_name(tile) for tile in p.discards] for p in self.players],
"melds": [[[tile_name(tile) for tile in meld] for meld in p.melds] for p in self.players],
"flowers": [[tile_name(tile) for tile in p.flowers] for p in self.players],
"wall_count": len(self.wall),
"current_player": self.current_player,
}
def _log_event(self, event_type: str, **payload: object):
event = {
"id": len(self.event_log),
"type": event_type,
"snapshot": self._snapshot(),
}
event.update(payload)
if "tile_id" in event:
event["tile"] = tile_name(int(event["tile_id"]))
self.event_log.append(event)

View File

@@ -0,0 +1,337 @@
from __future__ import annotations
import random
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple
import numpy as np
from ..rules import can_chi_options, is_win
from ..tiles import TILE_TYPES, build_wall, is_flower
PHASE_TURN = 0
PHASE_RESPONSE = 1
ACTION_DISCARD = 0
ACTION_DECLARE_WIN = 1
ACTION_DECLARE_KONG = 2
ACTION_PASS = 3
ACTION_DECLARE_PONG = 4
ACTION_DECLARE_CHI = 5
@dataclass
class PlayerState:
hand: List[int] = field(default_factory=list)
melds: List[List[int]] = field(default_factory=list)
discards: List[int] = field(default_factory=list)
flowers: List[int] = field(default_factory=list)
class SelfPlayMahjong:
def __init__(self, seed: Optional[int] = None, pong_reward: float = 0.0):
self._random = random.Random(seed)
self.pong_reward = pong_reward
self.players: List[PlayerState] = []
self.wall: List[int] = []
self.current_player = 0
self.phase = PHASE_TURN
self.pending_discard: Optional[Tuple[int, int]] = None
self.response_queue: List[int] = []
self.turn_started = False
self.skip_draw_for_player: Optional[int] = None
self.done = False
self.winner: Optional[int] = None
self.terminal_reason: Optional[str] = None
self.winning_tiles: List[int] = []
self.winning_melds: List[List[int]] = []
def reset(self, *, seed: Optional[int] = None) -> None:
if seed is not None:
self._random.seed(seed)
self.players = [PlayerState() for _ in range(4)]
self.wall = build_wall(self._random)
self.current_player = 0
self.phase = PHASE_TURN
self.pending_discard = None
self.response_queue = []
self.turn_started = False
self.skip_draw_for_player = None
self.done = False
self.winner = None
self.terminal_reason = None
self.winning_tiles = []
self.winning_melds = []
for _ in range(13):
for player in range(4):
self._draw_tile(player)
def observe(self, player_idx: int) -> Dict[str, np.ndarray]:
hand_counts = self._counts(self.players[player_idx].hand)
melds = np.zeros((4, len(TILE_TYPES)), dtype=np.int8)
discards = np.zeros((4, len(TILE_TYPES)), dtype=np.int8)
flowers = np.zeros(4, dtype=np.int8)
for idx, player in enumerate(self.players):
for meld in player.melds:
for tile in meld:
melds[idx, tile] += 1
for tile in player.discards:
discards[idx, tile] += 1
flowers[idx] = len(player.flowers)
pending = np.zeros(len(TILE_TYPES), dtype=np.int8)
if self.pending_discard is not None:
pending[self.pending_discard[1]] = 1
return {
"hand": hand_counts,
"melds": melds,
"discards": discards,
"flowers": flowers,
"wall_count": np.array([len(self.wall)], dtype=np.int16),
"current_player": np.array([self.current_player], dtype=np.int8),
"phase": np.array([self.phase], dtype=np.int8),
"pending_discard": pending,
}
def action_mask(self, player_idx: int) -> Dict[str, np.ndarray]:
mask_type = np.zeros(6, dtype=bool)
mask_discard = np.zeros(len(TILE_TYPES), dtype=bool)
mask_pong = np.zeros(len(TILE_TYPES), dtype=bool)
mask_kong = np.zeros(len(TILE_TYPES), dtype=bool)
mask_chi = np.zeros(3, dtype=bool)
player = self.players[player_idx]
counts = self._counts(player.hand)
if self.phase == PHASE_TURN:
mask_type[ACTION_DISCARD] = True
for tile in player.hand:
mask_discard[tile] = True
if is_win(player.hand, len(player.melds)):
mask_type[ACTION_DECLARE_WIN] = True
for tile_id, count in enumerate(counts):
if count >= 4:
mask_type[ACTION_DECLARE_KONG] = True
mask_kong[tile_id] = True
elif self.phase == PHASE_RESPONSE and self.pending_discard is not None:
mask_type[ACTION_PASS] = True
discarder, tile = self.pending_discard
virtual_hand = player.hand + [tile]
if is_win(virtual_hand, len(player.melds)):
mask_type[ACTION_DECLARE_WIN] = True
if counts[tile] >= 2:
mask_type[ACTION_DECLARE_PONG] = True
mask_pong[tile] = True
if counts[tile] >= 3:
mask_type[ACTION_DECLARE_KONG] = True
mask_kong[tile] = True
if player_idx == (discarder + 1) % 4:
options = can_chi_options(counts, tile)
if options:
mask_type[ACTION_DECLARE_CHI] = True
for option in options:
mask_chi[option] = True
return {
"type": mask_type,
"discard": mask_discard,
"pong": mask_pong,
"kong": mask_kong,
"chi": mask_chi,
}
def step(self, player_idx: int, action: Dict[str, int]) -> float:
if self.done:
return 0.0
if player_idx != self.current_player:
raise ValueError("Action from non-active player")
action = self._sanitize_action(player_idx, action)
if self.phase == PHASE_TURN:
if not self.turn_started:
self._start_turn(player_idx)
if self.done:
return 0.0
self.turn_started = True
return self._handle_turn_action(player_idx, action)
return self._handle_response_action(player_idx, action)
def _handle_turn_action(self, player_idx: int, action: Dict[str, int]) -> float:
player = self.players[player_idx]
action_type = int(action.get("type", ACTION_DISCARD))
tile_id = int(action.get("tile", 0))
if action_type == ACTION_DECLARE_WIN and is_win(player.hand, len(player.melds)):
self._end_game(player_idx, "self_draw", player.hand)
return 0.0
if action_type == ACTION_DECLARE_KONG and player.hand.count(tile_id) >= 4:
for _ in range(4):
player.hand.remove(tile_id)
player.melds.append([tile_id] * 4)
self._draw_tile(player_idx)
self.turn_started = True
return 0.0
if action_type == ACTION_DISCARD and tile_id in player.hand:
player.hand.remove(tile_id)
player.discards.append(tile_id)
self.pending_discard = (player_idx, tile_id)
self.response_queue = [(player_idx + offset) % 4 for offset in range(1, 4)]
self.phase = PHASE_RESPONSE
self.current_player = self.response_queue.pop(0)
self.turn_started = False
return 0.0
return 0.0
def _handle_response_action(self, player_idx: int, action: Dict[str, int]) -> float:
if self.pending_discard is None:
self.phase = PHASE_TURN
self.turn_started = False
return 0.0
discarder, tile = self.pending_discard
player = self.players[player_idx]
counts = self._counts(player.hand)
action_type = int(action.get("type", ACTION_PASS))
chi_choice = int(action.get("chi", 0))
if action_type == ACTION_DECLARE_WIN and is_win(player.hand + [tile], len(player.melds)):
self._end_game(player_idx, "ron", player.hand + [tile])
return 0.0
if action_type == ACTION_DECLARE_PONG and counts[tile] >= 2:
self._consume_discard(discarder, tile)
player.hand.remove(tile)
player.hand.remove(tile)
player.melds.append([tile, tile, tile])
self.current_player = player_idx
self.pending_discard = None
self.response_queue = []
self.phase = PHASE_TURN
self.turn_started = False
self.skip_draw_for_player = player_idx
return self.pong_reward
if action_type == ACTION_DECLARE_KONG and counts[tile] >= 3:
self._consume_discard(discarder, tile)
for _ in range(3):
player.hand.remove(tile)
player.melds.append([tile, tile, tile, tile])
self.current_player = player_idx
self.pending_discard = None
self.response_queue = []
self.phase = PHASE_TURN
self.turn_started = True
self._draw_tile(player_idx)
return 0.0
if action_type == ACTION_DECLARE_CHI and player_idx == (discarder + 1) % 4:
options = can_chi_options(counts, tile)
if chi_choice in options:
self._consume_discard(discarder, tile)
if chi_choice == 0:
chi_tiles = [tile - 2, tile - 1, tile]
elif chi_choice == 1:
chi_tiles = [tile - 1, tile, tile + 1]
else:
chi_tiles = [tile, tile + 1, tile + 2]
for chi_tile in chi_tiles:
if chi_tile != tile:
player.hand.remove(chi_tile)
player.melds.append(chi_tiles)
self.current_player = player_idx
self.pending_discard = None
self.response_queue = []
self.phase = PHASE_TURN
self.turn_started = False
self.skip_draw_for_player = player_idx
return 0.0
self._advance_response_queue(discarder)
return 0.0
def _sanitize_action(self, player_idx: int, action: Dict[str, int]) -> Dict[str, int]:
mask = self.action_mask(player_idx)
action_type = int(action.get("type", ACTION_PASS))
tile_id = int(action.get("tile", 0))
chi_choice = int(action.get("chi", 0))
if action_type < 0 or action_type >= 6 or not mask["type"][action_type]:
return self._sample_valid_action(mask)
if action_type == ACTION_DISCARD and not mask["discard"][tile_id]:
return self._sample_valid_action(mask)
if action_type == ACTION_DECLARE_PONG and not mask["pong"][tile_id]:
return self._sample_valid_action(mask)
if action_type == ACTION_DECLARE_KONG and not mask["kong"][tile_id]:
return self._sample_valid_action(mask)
if action_type == ACTION_DECLARE_CHI and not mask["chi"][chi_choice]:
return self._sample_valid_action(mask)
return {"type": action_type, "tile": tile_id, "chi": chi_choice}
def _sample_valid_action(self, mask: Dict[str, np.ndarray]) -> Dict[str, int]:
valid_types = np.flatnonzero(mask["type"])
if len(valid_types) == 0:
return {"type": ACTION_PASS, "tile": 0, "chi": 0}
action_type = int(self._random.choice(valid_types))
tile_id = 0
chi_choice = 0
if action_type == ACTION_DISCARD:
valid_tiles = np.flatnonzero(mask["discard"])
tile_id = int(self._random.choice(valid_tiles)) if len(valid_tiles) else 0
if action_type == ACTION_DECLARE_PONG:
valid_tiles = np.flatnonzero(mask["pong"])
tile_id = int(self._random.choice(valid_tiles)) if len(valid_tiles) else 0
if action_type == ACTION_DECLARE_KONG:
valid_tiles = np.flatnonzero(mask["kong"])
tile_id = int(self._random.choice(valid_tiles)) if len(valid_tiles) else 0
if action_type == ACTION_DECLARE_CHI:
valid_chi = np.flatnonzero(mask["chi"])
chi_choice = int(self._random.choice(valid_chi)) if len(valid_chi) else 0
return {"type": action_type, "tile": tile_id, "chi": chi_choice}
def _advance_response_queue(self, discarder: int):
if self.response_queue:
self.current_player = self.response_queue.pop(0)
return
self.pending_discard = None
self.phase = PHASE_TURN
self.current_player = (discarder + 1) % 4
self.turn_started = False
def _start_turn(self, player_idx: int):
if self.skip_draw_for_player == player_idx:
self.skip_draw_for_player = None
return
self._draw_tile(player_idx)
def _draw_tile(self, player_idx: int):
while self.wall:
tile = self.wall.pop()
if is_flower(tile):
self.players[player_idx].flowers.append(tile)
continue
self.players[player_idx].hand.append(tile)
return tile
self._end_game(None, "wall_empty", [])
return None
def _consume_discard(self, discarder: int, tile: int):
discards = self.players[discarder].discards
if discards and discards[-1] == tile:
discards.pop()
def _end_game(self, winner: Optional[int], reason: str, winning_tiles: List[int]):
if self.done:
return
self.done = True
self.winner = winner
self.terminal_reason = reason
self.winning_tiles = winning_tiles
self.winning_melds = [] if winner is None else self.players[winner].melds
def _counts(self, tiles: List[int]) -> np.ndarray:
counts = np.zeros(len(TILE_TYPES), dtype=np.int8)
for tile in tiles:
counts[tile] += 1
return counts

13
majiang_rl/rl/__init__.py Normal file
View File

@@ -0,0 +1,13 @@
from .agents import RandomAgent, Transition
from .grpo import GRPOConfig, train_grpo
from .trainer import EpisodeResult, EpisodeRunner, run_training
__all__ = [
"RandomAgent",
"Transition",
"EpisodeResult",
"EpisodeRunner",
"run_training",
"GRPOConfig",
"train_grpo",
]

76
majiang_rl/rl/agents.py Normal file
View File

@@ -0,0 +1,76 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Dict, Protocol
import numpy as np
class Agent(Protocol):
def act(self, obs: Dict[str, Any], info: Dict[str, Any]) -> Dict[str, int]:
...
def observe(self, transition: "Transition") -> None:
...
def reset(self) -> None:
...
@dataclass
class Transition:
obs: Dict[str, Any]
action: Dict[str, int]
reward: float
terminated: bool
truncated: bool
next_obs: Dict[str, Any]
info: Dict[str, Any]
class RandomAgent:
def __init__(self, seed: int | None = None):
self.rng = np.random.default_rng(seed)
def reset(self) -> None:
return None
def act(self, obs: Dict[str, Any], info: Dict[str, Any]) -> Dict[str, int]:
mask = info.get("action_mask")
if mask is None:
return {"type": 0, "tile": 0, "chi": 0}
return sample_action_from_mask(mask, self.rng)
def observe(self, transition: Transition) -> None:
return None
def sample_action_from_mask(mask: Dict[str, np.ndarray], rng: np.random.Generator) -> Dict[str, int]:
type_mask = np.asarray(mask["type"], dtype=bool)
discard_mask = np.asarray(mask["discard"], dtype=bool)
pong_mask = np.asarray(mask["pong"], dtype=bool)
kong_mask = np.asarray(mask["kong"], dtype=bool)
chi_mask = np.asarray(mask["chi"], dtype=bool)
valid_types = np.flatnonzero(type_mask)
if len(valid_types) == 0:
return {"type": 3, "tile": 0, "chi": 0}
action_type = int(rng.choice(valid_types))
tile_id = 0
chi_choice = 0
if action_type == 0:
valid_tiles = np.flatnonzero(discard_mask)
if len(valid_tiles) > 0:
tile_id = int(rng.choice(valid_tiles))
if action_type == 2:
valid_tiles = np.flatnonzero(kong_mask)
if len(valid_tiles) > 0:
tile_id = int(rng.choice(valid_tiles))
if action_type == 4:
valid_tiles = np.flatnonzero(pong_mask)
if len(valid_tiles) > 0:
tile_id = int(rng.choice(valid_tiles))
if action_type == 5:
valid_chi = np.flatnonzero(chi_mask)
if len(valid_chi) > 0:
chi_choice = int(rng.choice(valid_chi))
return {"type": action_type, "tile": tile_id, "chi": chi_choice}

415
majiang_rl/rl/grpo.py Normal file
View File

@@ -0,0 +1,415 @@
from __future__ import annotations
import argparse
import copy
from dataclasses import asdict, dataclass
from typing import Dict, List, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..envs.self_play_env import (
ACTION_DECLARE_CHI,
ACTION_DECLARE_KONG,
ACTION_DECLARE_PONG,
ACTION_DISCARD,
SelfPlayMahjong,
)
from ..rules import hand_distance_to_win, total_fan
@dataclass
class GRPOConfig:
group_size: int = 16
updates: int = 50
max_steps: int = 0
lr: float = 3e-4
beta_kl: float = 0.02
seed: int = 1
hidden_size: int = 256
device: str = "auto"
use_swanlab: bool = False
swanlab_project: str = "majiang-rl"
swanlab_run_name: str = ""
pong_reward: float = 0.1
closest_bonus: float = 1.0
@dataclass
class StepSample:
player: int
obs: torch.Tensor
mask: Dict[str, torch.Tensor]
action: Dict[str, int]
reward: float = 0.0
def obs_to_tensor(obs: Dict[str, np.ndarray], device: str) -> torch.Tensor:
parts = [
obs["hand"].astype(np.float32),
obs["melds"].astype(np.float32).reshape(-1),
obs["discards"].astype(np.float32).reshape(-1),
obs["flowers"].astype(np.float32),
obs["wall_count"].astype(np.float32),
obs["current_player"].astype(np.float32),
obs["phase"].astype(np.float32),
obs["pending_discard"].astype(np.float32),
]
flat = np.concatenate(parts, axis=0)
return torch.tensor(flat, dtype=torch.float32, device=device)
def mask_to_torch(mask: Dict[str, np.ndarray], device: str) -> Dict[str, torch.Tensor]:
return {
"type": torch.tensor(mask["type"], dtype=torch.bool, device=device),
"discard": torch.tensor(mask["discard"], dtype=torch.bool, device=device),
"pong": torch.tensor(mask["pong"], dtype=torch.bool, device=device),
"kong": torch.tensor(mask["kong"], dtype=torch.bool, device=device),
"chi": torch.tensor(mask["chi"], dtype=torch.bool, device=device),
}
class PolicyNet(nn.Module):
def __init__(self, input_dim: int, hidden_size: int = 256):
super().__init__()
self.net = nn.Sequential(
nn.Linear(input_dim, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, hidden_size),
nn.ReLU(),
)
self.type_head = nn.Linear(hidden_size, 6)
self.tile_head = nn.Linear(hidden_size, 42)
self.chi_head = nn.Linear(hidden_size, 3)
def forward(self, obs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
x = self.net(obs)
return self.type_head(x), self.tile_head(x), self.chi_head(x)
def sample_action(self, obs: torch.Tensor, mask: Dict[str, torch.Tensor]) -> Dict[str, int]:
logits_type, logits_tile, logits_chi = self(obs)
type_mask = mask["type"]
masked_type = masked_logits(logits_type, type_mask)
type_dist = torch.distributions.Categorical(logits=masked_type)
action_type = int(type_dist.sample().item())
tile_id = 0
chi_choice = 0
if action_type == ACTION_DISCARD:
tile_id = sample_from_mask(logits_tile, mask["discard"])
elif action_type == ACTION_DECLARE_PONG:
tile_id = sample_from_mask(logits_tile, mask["pong"])
elif action_type == ACTION_DECLARE_KONG:
tile_id = sample_from_mask(logits_tile, mask["kong"])
elif action_type == ACTION_DECLARE_CHI:
chi_choice = sample_from_mask(logits_chi, mask["chi"])
return {"type": action_type, "tile": tile_id, "chi": chi_choice}
def masked_logits(logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
masked = logits.clone()
masked[~mask] = -1e9
return masked
def sample_from_mask(logits: torch.Tensor, mask: torch.Tensor) -> int:
if mask.sum().item() == 0:
return 0
masked = masked_logits(logits, mask)
dist = torch.distributions.Categorical(logits=masked)
return int(dist.sample().item())
def logprob_action(
logits_type: torch.Tensor,
logits_tile: torch.Tensor,
logits_chi: torch.Tensor,
action: Dict[str, int],
mask: Dict[str, torch.Tensor],
) -> torch.Tensor:
action_type = int(action["type"])
logp = masked_log_softmax(logits_type, mask["type"])[action_type]
if action_type == ACTION_DISCARD:
logp = logp + masked_log_softmax(logits_tile, mask["discard"])[int(action["tile"])]
elif action_type == ACTION_DECLARE_PONG:
logp = logp + masked_log_softmax(logits_tile, mask["pong"])[int(action["tile"])]
elif action_type == ACTION_DECLARE_KONG:
logp = logp + masked_log_softmax(logits_tile, mask["kong"])[int(action["tile"])]
elif action_type == ACTION_DECLARE_CHI:
logp = logp + masked_log_softmax(logits_chi, mask["chi"])[int(action["chi"])]
return logp
def masked_log_softmax(logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
masked = masked_logits(logits, mask)
return F.log_softmax(masked, dim=-1)
def masked_softmax(logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
masked = masked_logits(logits, mask)
return F.softmax(masked, dim=-1)
def kl_divergence(logits: torch.Tensor, ref_logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
if mask.sum().item() == 0:
return torch.tensor(0.0, device=logits.device)
p = masked_softmax(logits, mask)
logp = torch.log(p + 1e-12)
logp_ref = masked_log_softmax(ref_logits, mask)
return torch.sum(p * (logp - logp_ref))
def compute_rewards(game: SelfPlayMahjong, config: GRPOConfig) -> List[float]:
rewards = [0.0, 0.0, 0.0, 0.0]
if game.winner is not None:
fan = float(total_fan(game.winning_tiles, game.winning_melds))
rewards[game.winner] = fan
split = -fan / 3.0
for idx in range(4):
if idx != game.winner:
rewards[idx] = split
distances = []
for idx, player in enumerate(game.players):
tiles = (
game.winning_tiles if game.winner is not None and idx == game.winner else player.hand
)
distances.append(hand_distance_to_win(tiles, player.melds))
min_distance = min(distances) if distances else 0
for idx, distance in enumerate(distances):
if distance == min_distance:
rewards[idx] += config.closest_bonus
return rewards
def collect_rollouts(
policy: PolicyNet, config: GRPOConfig, seed_offset: int
) -> Tuple[List[StepSample], List[float], Dict[str, float]]:
all_steps: List[StepSample] = []
episode_rewards: List[float] = []
final_rewards_all: List[float] = []
shaping_rewards_all: List[float] = []
for episode in range(config.group_size):
game = SelfPlayMahjong(
seed=config.seed + seed_offset + episode, pong_reward=config.pong_reward
)
game.reset()
episode_steps: List[StepSample] = []
shaping_rewards = [0.0, 0.0, 0.0, 0.0]
max_steps = _auto_max_steps(config.max_steps, len(game.wall))
for _ in range(max_steps):
if game.done:
break
player = game.current_player
obs = game.observe(player)
mask = game.action_mask(player)
obs_tensor = obs_to_tensor(obs, config.device)
mask_tensor = mask_to_torch(mask, config.device)
with torch.no_grad():
action = policy.sample_action(obs_tensor, mask_tensor)
episode_steps.append(
StepSample(player=player, obs=obs_tensor, mask=mask_tensor, action=action)
)
reward = game.step(player, action)
shaping_rewards[player] += reward
episode_steps[-1].reward = reward
if not game.done:
game._end_game(None, "max_steps", [])
rewards = compute_rewards(game, config)
total_rewards = [rewards[i] + shaping_rewards[i] for i in range(4)]
episode_rewards.extend(total_rewards)
final_rewards_all.extend(rewards)
shaping_rewards_all.extend(shaping_rewards)
for step in episode_steps:
step.reward += rewards[step.player]
all_steps.append(step)
stats = {
"avg_final_reward": float(np.mean(final_rewards_all)) if final_rewards_all else 0.0,
"avg_shaping_reward": float(np.mean(shaping_rewards_all)) if shaping_rewards_all else 0.0,
}
return all_steps, episode_rewards, stats
def train_grpo(config: GRPOConfig) -> PolicyNet:
config.device = resolve_device(config.device)
torch.manual_seed(config.seed)
np.random.seed(config.seed)
print(f"Training on device: {config.device}")
dummy_game = SelfPlayMahjong(seed=config.seed)
dummy_game.reset()
input_dim = obs_to_tensor(dummy_game.observe(0), config.device).shape[0]
policy = PolicyNet(input_dim, hidden_size=config.hidden_size).to(config.device)
optimizer = torch.optim.Adam(policy.parameters(), lr=config.lr)
logger = SwanlabLogger(config)
for update in range(config.updates):
policy.eval()
steps, episode_rewards, reward_stats = collect_rollouts(
policy, config, update * config.group_size
)
if not steps:
continue
rewards = np.array(episode_rewards, dtype=np.float32)
mean = float(rewards.mean())
std = float(rewards.std() + 1e-6)
ref_policy = copy.deepcopy(policy).eval()
policy.train()
losses = []
for step in steps:
logits_type, logits_tile, logits_chi = policy(step.obs)
with torch.no_grad():
ref_type, ref_tile, ref_chi = ref_policy(step.obs)
advantage = torch.tensor((step.reward - mean) / std, device=logits_type.device)
logp = logprob_action(logits_type, logits_tile, logits_chi, step.action, step.mask)
tile_union = step.mask["discard"] | step.mask["pong"] | step.mask["kong"]
kl = kl_divergence(logits_type, ref_type, step.mask["type"])
kl += kl_divergence(logits_tile, ref_tile, tile_union)
kl += kl_divergence(logits_chi, ref_chi, step.mask["chi"])
loss = -(advantage * logp) + config.beta_kl * kl
losses.append(loss)
loss_value = torch.stack(losses).mean()
optimizer.zero_grad()
loss_value.backward()
optimizer.step()
logger.log(
{
"loss": float(loss_value.item()),
"avg_reward": mean,
"reward_std": std,
"avg_final_reward": reward_stats["avg_final_reward"],
"avg_shaping_reward": reward_stats["avg_shaping_reward"],
},
step=update + 1,
)
print(f"Update {update + 1}/{config.updates} loss={loss_value.item():.4f} avg_reward={mean:.2f}")
logger.finish()
return policy
def main() -> None:
parser = argparse.ArgumentParser(description="GRPO self-play training for Mahjong")
parser.add_argument("--updates", type=int, default=50)
parser.add_argument("--group-size", type=int, default=16)
parser.add_argument("--max-steps", type=int, default=0)
parser.add_argument("--lr", type=float, default=3e-4)
parser.add_argument("--beta-kl", type=float, default=0.02)
parser.add_argument("--seed", type=int, default=1)
parser.add_argument("--hidden-size", type=int, default=256)
parser.add_argument("--device", type=str, default="auto")
parser.add_argument("--swanlab", action="store_true")
parser.add_argument("--swanlab-project", type=str, default="majiang-rl")
parser.add_argument("--swanlab-run-name", type=str, default="")
parser.add_argument("--pong-reward", type=float, default=0.1)
parser.add_argument("--closest-bonus", type=float, default=1.0)
args = parser.parse_args()
config = GRPOConfig(
group_size=args.group_size,
updates=args.updates,
max_steps=args.max_steps,
lr=args.lr,
beta_kl=args.beta_kl,
seed=args.seed,
hidden_size=args.hidden_size,
device=args.device,
use_swanlab=args.swanlab,
swanlab_project=args.swanlab_project,
swanlab_run_name=args.swanlab_run_name,
pong_reward=args.pong_reward,
closest_bonus=args.closest_bonus,
)
train_grpo(config)
def resolve_device(requested: str) -> str:
if requested != "auto":
return requested
return "cuda" if torch.cuda.is_available() else "cpu"
def _auto_max_steps(configured: int, wall_len: int) -> int:
if configured > 0:
return configured
return wall_len * 4 + 50
class SwanlabLogger:
def __init__(self, config: GRPOConfig):
self.enabled = config.use_swanlab
self._swanlab = None
if not self.enabled:
return
try:
import swanlab
except ImportError:
print("swanlab not installed, logging disabled")
self.enabled = False
return
self._swanlab = swanlab
self._init_run(config)
def _init_run(self, config: GRPOConfig) -> None:
if not self._swanlab:
return
base_cfg = asdict(config)
base_cfg.pop("use_swanlab", None)
kwargs = {"project": config.swanlab_project, "config": base_cfg}
if config.swanlab_run_name:
kwargs["name"] = config.swanlab_run_name
init = getattr(self._swanlab, "init", None)
if init is None:
self.enabled = False
return
try:
init(**kwargs)
return
except TypeError:
pass
if config.swanlab_run_name:
kwargs_alt = {"project": config.swanlab_project, "config": base_cfg, "experiment_name": config.swanlab_run_name}
else:
kwargs_alt = {"project": config.swanlab_project, "config": base_cfg}
try:
init(**kwargs_alt)
except TypeError:
init(project=config.swanlab_project)
def log(self, metrics: Dict[str, float], step: int) -> None:
if not self.enabled or not self._swanlab:
return
log_fn = getattr(self._swanlab, "log", None)
if log_fn is None:
return
try:
log_fn(metrics, step=step)
except TypeError:
log_fn(metrics)
def finish(self) -> None:
if not self.enabled or not self._swanlab:
return
finish_fn = getattr(self._swanlab, "finish", None)
if finish_fn:
finish_fn()
if __name__ == "__main__":
main()

61
majiang_rl/rl/trainer.py Normal file
View File

@@ -0,0 +1,61 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import List
from .agents import Transition
@dataclass
class EpisodeResult:
total_reward: float
steps: int
terminated: bool
truncated: bool
class EpisodeRunner:
def __init__(self, env, agent, max_steps: int = 1000):
self.env = env
self.agent = agent
self.max_steps = max_steps
def run_episode(self) -> EpisodeResult:
obs, info = self.env.reset()
self.agent.reset()
total_reward = 0.0
terminated = False
truncated = False
steps = 0
for step in range(self.max_steps):
action = self.agent.act(obs, info)
next_obs, reward, terminated, truncated, next_info = self.env.step(action)
transition = Transition(
obs=obs,
action=action,
reward=reward,
terminated=terminated,
truncated=truncated,
next_obs=next_obs,
info=next_info,
)
self.agent.observe(transition)
total_reward += reward
obs, info = next_obs, next_info
steps = step + 1
if terminated or truncated:
break
return EpisodeResult(
total_reward=total_reward,
steps=steps,
terminated=terminated,
truncated=truncated,
)
def run_training(env, agent, episodes: int = 100) -> List[EpisodeResult]:
runner = EpisodeRunner(env, agent)
results: List[EpisodeResult] = []
for _ in range(episodes):
results.append(runner.run_episode())
return results

277
majiang_rl/rules.py Normal file
View File

@@ -0,0 +1,277 @@
from __future__ import annotations
from typing import Dict, Iterable, List, Tuple
from .tiles import FLOWER_RANGE, HONOR_RANGE, SUITED_RANGE
ORPHAN_INDICES = (
0,
8,
9,
17,
18,
26,
27,
28,
29,
30,
31,
32,
33,
)
def strip_flowers(tiles: Iterable[int]) -> List[int]:
return [tile for tile in tiles if tile not in FLOWER_RANGE]
def is_seven_pairs(counts: List[int]) -> bool:
return sum(counts) == 14 and all(c % 2 == 0 for c in counts)
def is_thirteen_orphans(counts: List[int]) -> bool:
if sum(counts) != 14:
return False
if any(counts[i] == 0 for i in ORPHAN_INDICES):
return False
if any(counts[i] > 0 for i in range(34) if i not in ORPHAN_INDICES):
return False
return sum(counts[i] for i in ORPHAN_INDICES) == 14
def is_all_pungs(counts: List[int]) -> bool:
if sum(counts) != 14:
return False
for idx, count in enumerate(counts):
if count >= 2:
counts[idx] -= 2
if _can_form_pungs(counts, 4):
counts[idx] += 2
return True
counts[idx] += 2
return False
def _can_form_pungs(counts: List[int], needed: int) -> bool:
if needed == 0:
return sum(counts) == 0
for idx, count in enumerate(counts):
if count > 0:
break
else:
return False
if counts[idx] >= 3:
counts[idx] -= 3
if _can_form_pungs(counts, needed - 1):
counts[idx] += 3
return True
counts[idx] += 3
return False
def is_all_honors(counts: List[int]) -> bool:
return sum(counts[i] for i in SUITED_RANGE) == 0 and sum(counts) == 14
def _suit_presence(counts: List[int]) -> Tuple[bool, Tuple[bool, bool, bool]]:
suit_flags = [False, False, False]
for idx in SUITED_RANGE:
if counts[idx] > 0:
suit_flags[idx // 9] = True
honors = sum(counts[i] for i in HONOR_RANGE) > 0
return honors, tuple(suit_flags)
def is_pure_one_suit(counts: List[int]) -> bool:
honors, suits = _suit_presence(counts)
return not honors and sum(1 for flag in suits if flag) == 1
def is_half_flush(counts: List[int]) -> bool:
honors, suits = _suit_presence(counts)
return honors and sum(1 for flag in suits if flag) == 1
def is_all_terminals_or_honors(counts: List[int]) -> bool:
if sum(counts) != 14:
return False
terminals = {0, 8, 9, 17, 18, 26}
for idx, count in enumerate(counts):
if count == 0:
continue
if idx in HONOR_RANGE:
continue
if idx not in terminals:
return False
return True
def can_form_melds(counts: List[int], needed: int) -> bool:
if needed == 0:
return sum(counts) == 0
for idx, count in enumerate(counts):
if count > 0:
break
else:
return False
# Pong
if counts[idx] >= 3:
counts[idx] -= 3
if can_form_melds(counts, needed - 1):
counts[idx] += 3
return True
counts[idx] += 3
# Chow
if idx in SUITED_RANGE and idx % 9 <= 6:
if counts[idx + 1] > 0 and counts[idx + 2] > 0:
counts[idx] -= 1
counts[idx + 1] -= 1
counts[idx + 2] -= 1
if can_form_melds(counts, needed - 1):
counts[idx] += 1
counts[idx + 1] += 1
counts[idx + 2] += 1
return True
counts[idx] += 1
counts[idx + 1] += 1
counts[idx + 2] += 1
return False
def is_standard_hand(counts: List[int], open_melds: int) -> bool:
if sum(counts) != (4 - open_melds) * 3 + 2:
return False
for idx, count in enumerate(counts):
if count >= 2:
counts[idx] -= 2
if can_form_melds(counts, 4 - open_melds):
counts[idx] += 2
return True
counts[idx] += 2
return False
def is_win(tiles: Iterable[int], open_melds: int) -> bool:
counts = [0] * 34
for tile in tiles:
if tile in FLOWER_RANGE:
continue
counts[tile] += 1
if is_seven_pairs(counts):
return True
if is_thirteen_orphans(counts):
return True
return is_standard_hand(counts, open_melds)
def fan_breakdown(tiles: Iterable[int], melds: List[List[int]]) -> Dict[str, int]:
counts = [0] * 34
for tile in tiles:
if tile in FLOWER_RANGE:
continue
counts[tile] += 1
for meld in melds:
for tile in meld:
if tile in FLOWER_RANGE:
continue
counts[tile] += 1
fans: Dict[str, int] = {}
if is_thirteen_orphans(counts):
fans["thirteen_orphans"] = 88
if is_all_honors(counts):
fans["all_honors"] = 64
if is_pure_one_suit(counts):
fans["pure_one_suit"] = 24
if is_seven_pairs(counts):
fans["seven_pairs"] = 24
if is_all_terminals_or_honors(counts):
fans["all_terminals_or_honors"] = 32
if is_all_pungs(counts.copy()):
fans["all_pungs"] = 6
if is_half_flush(counts):
fans["half_flush"] = 6
if not fans:
fans["standard"] = 1
return fans
def total_fan(tiles: Iterable[int], melds: List[List[int]]) -> int:
return sum(fan_breakdown(tiles, melds).values())
def hand_distance_to_win(hand_tiles: Iterable[int], melds: List[List[int]]) -> int:
counts = [0] * 34
for tile in hand_tiles:
if tile in FLOWER_RANGE:
continue
counts[tile] += 1
melds_done = min(len(melds), 4)
needed_melds = max(0, 4 - melds_done)
max_used = _max_used_tiles(counts, needed_melds)
used_tiles = min(14, 3 * melds_done + max_used)
return max(0, 14 - used_tiles)
def _max_used_tiles(counts: List[int], needed_melds: int) -> int:
best = 3 * _max_melds(counts, needed_melds)
for idx, count in enumerate(counts):
if count >= 2:
counts[idx] -= 2
melds = _max_melds(counts, needed_melds)
best = max(best, 2 + 3 * melds)
counts[idx] += 2
return best
def _max_melds(counts: List[int], limit: int) -> int:
if limit == 0:
return 0
for idx, count in enumerate(counts):
if count > 0:
break
else:
return 0
best = 0
counts[idx] -= 1
best = max(best, _max_melds(counts, limit))
counts[idx] += 1
if counts[idx] >= 3:
counts[idx] -= 3
best = max(best, 1 + _max_melds(counts, limit - 1))
counts[idx] += 3
if idx in SUITED_RANGE and idx % 9 <= 6:
if counts[idx + 1] > 0 and counts[idx + 2] > 0:
counts[idx] -= 1
counts[idx + 1] -= 1
counts[idx + 2] -= 1
best = max(best, 1 + _max_melds(counts, limit - 1))
counts[idx] += 1
counts[idx + 1] += 1
counts[idx + 2] += 1
return best
def can_chi_options(counts: List[int], tile_id: int) -> List[int]:
if tile_id not in SUITED_RANGE:
return []
options: List[int] = []
rank = tile_id % 9
if rank >= 2:
if counts[tile_id - 1] > 0 and counts[tile_id - 2] > 0:
options.append(0)
if 1 <= rank <= 7:
if counts[tile_id - 1] > 0 and counts[tile_id + 1] > 0:
options.append(1)
if rank <= 6:
if counts[tile_id + 1] > 0 and counts[tile_id + 2] > 0:
options.append(2)
return options

65
majiang_rl/tiles.py Normal file
View File

@@ -0,0 +1,65 @@
from __future__ import annotations
from typing import Iterable, List, Tuple
SUITS = ("wan", "tong", "suo")
WINDS = ("east", "south", "west", "north")
DRAGONS = ("red", "green", "white")
FLOWERS = (
"plum",
"orchid",
"chrysanthemum",
"bamboo",
"spring",
"summer",
"autumn",
"winter",
)
# 34 suited + honors, plus 8 flowers = 42
TILE_TYPES: List[str] = []
for suit in SUITS:
for rank in range(1, 10):
TILE_TYPES.append(f"{suit}_{rank}")
for wind in WINDS:
TILE_TYPES.append(f"wind_{wind}")
for dragon in DRAGONS:
TILE_TYPES.append(f"dragon_{dragon}")
for flower in FLOWERS:
TILE_TYPES.append(f"flower_{flower}")
SUITED_RANGE = range(0, 27)
HONOR_RANGE = range(27, 34)
FLOWER_RANGE = range(34, 42)
def build_wall(rng) -> List[int]:
wall: List[int] = []
for tile_id in range(34):
wall.extend([tile_id] * 4)
for tile_id in FLOWER_RANGE:
wall.append(tile_id)
rng.shuffle(wall)
return wall
def is_flower(tile_id: int) -> bool:
return tile_id in FLOWER_RANGE
def tile_name(tile_id: int) -> str:
return TILE_TYPES[tile_id]
def counts_from_tiles(tiles: Iterable[int], size: int = 42) -> List[int]:
counts = [0] * size
for tile_id in tiles:
counts[tile_id] += 1
return counts
def split_suit_index(tile_id: int) -> Tuple[int, int] | None:
if tile_id not in SUITED_RANGE:
return None
suit = tile_id // 9
rank = tile_id % 9
return suit, rank

View File

@@ -0,0 +1,3 @@
from .record import record_episode, write_record
__all__ = ["record_episode", "write_record"]

34
majiang_rl/ui/record.py Normal file
View File

@@ -0,0 +1,34 @@
from __future__ import annotations
import json
from pathlib import Path
from typing import Dict, List, Optional
from .. import MahjongEnv
from ..rl import RandomAgent
def record_episode(seed: Optional[int] = None, max_steps: int = 200) -> List[Dict[str, object]]:
env = MahjongEnv(seed=seed)
agent = RandomAgent(seed=seed)
obs, info = env.reset()
events: List[Dict[str, object]] = [{"id": 0, "type": "start", "snapshot": env._snapshot()}]
for _ in range(max_steps):
action = agent.act(obs, info)
obs, _, terminated, truncated, info = env.step(action)
if env.event_log:
for event in env.event_log:
event_copy = dict(event)
event_copy["id"] = len(events)
events.append(event_copy)
env.event_log = []
if terminated or truncated:
break
return events
def write_record(path: Path, events: List[Dict[str, object]]) -> None:
payload = {"events": events}
path.write_text(json.dumps(payload, indent=2, ensure_ascii=True))

View File

@@ -0,0 +1,230 @@
<!doctype html>
<html lang="en">
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1">
<title>Mahjong RL Viewer</title>
<style>
:root {
--bg: #f7f4ef;
--ink: #1f1f1f;
--accent: #b4572a;
--panel: #ffffff;
--muted: #6b6b6b;
}
body {
margin: 0;
font-family: "IBM Plex Mono", "Courier New", monospace;
background: radial-gradient(circle at 10% 10%, #fdf6e9, var(--bg));
color: var(--ink);
}
header {
padding: 24px;
border-bottom: 2px solid #e5ddd1;
}
h1 {
margin: 0 0 6px 0;
font-size: 22px;
letter-spacing: 1px;
text-transform: uppercase;
}
.subtitle {
color: var(--muted);
font-size: 13px;
}
.layout {
display: grid;
grid-template-columns: 1fr 2fr;
gap: 20px;
padding: 20px;
}
.panel {
background: var(--panel);
border: 2px solid #e5ddd1;
border-radius: 12px;
padding: 16px;
box-shadow: 4px 6px 0 rgba(0, 0, 0, 0.05);
}
.controls {
display: flex;
gap: 8px;
align-items: center;
margin-bottom: 12px;
}
button {
background: var(--accent);
color: #fff;
border: none;
padding: 8px 12px;
border-radius: 8px;
cursor: pointer;
font-weight: 600;
}
button.secondary {
background: #3f3f3f;
}
.status {
font-size: 14px;
line-height: 1.4;
}
.board {
display: grid;
grid-template-columns: repeat(2, 1fr);
gap: 12px;
}
.player {
border: 1px dashed #d2c4b2;
padding: 10px;
border-radius: 10px;
}
.player h3 {
margin: 0 0 6px 0;
font-size: 14px;
color: var(--accent);
}
.tiles {
font-size: 12px;
line-height: 1.6;
}
.label {
display: inline-block;
min-width: 72px;
color: var(--muted);
}
@media (max-width: 900px) {
.layout {
grid-template-columns: 1fr;
}
}
</style>
</head>
<body>
<header>
<h1>Mahjong RL Viewer</h1>
<div class="subtitle">Playback of a simulated Guobiao round</div>
</header>
<div class="layout">
<div class="panel">
<div class="controls">
<button id="prev">Prev</button>
<button id="next">Next</button>
<button id="play" class="secondary">Play</button>
<span id="counter" class="subtitle"></span>
</div>
<div id="status" class="status"></div>
</div>
<div class="panel">
<div id="board" class="board"></div>
</div>
</div>
<script>
let events = [];
let idx = 0;
let timer = null;
function describe(event) {
if (!event) return "";
const type = event.type;
if (type === "start") {
return "Start of round";
}
if (type === "discard") {
return `Player ${event.player} discarded ${event.tile}`;
}
if (type === "claim") {
return `Player ${event.player} claimed ${event.claim} on ${event.tile} from Player ${event.from_player}`;
}
if (type === "win") {
return `Player ${event.player} wins (${event.reason})`;
}
if (type === "draw") {
return `Round ends in draw (${event.reason})`;
}
return "";
}
function tilesToText(tiles) {
if (!tiles || tiles.length === 0) return "-";
return tiles.join(" ");
}
function renderBoard(snapshot) {
const board = document.getElementById("board");
board.innerHTML = "";
for (let i = 0; i < snapshot.hands.length; i += 1) {
const player = document.createElement("div");
player.className = "player";
const title = document.createElement("h3");
title.textContent = `Player ${i}`;
const hand = document.createElement("div");
hand.className = "tiles";
hand.innerHTML = `<span class="label">Hand</span>${tilesToText(snapshot.hands[i])}`;
const discards = document.createElement("div");
discards.className = "tiles";
discards.innerHTML = `<span class="label">Discards</span>${tilesToText(snapshot.discards[i])}`;
const melds = document.createElement("div");
melds.className = "tiles";
melds.innerHTML = `<span class="label">Melds</span>${tilesToText(snapshot.melds[i].flat())}`;
const flowers = document.createElement("div");
flowers.className = "tiles";
flowers.innerHTML = `<span class="label">Flowers</span>${tilesToText(snapshot.flowers[i])}`;
player.appendChild(title);
player.appendChild(hand);
player.appendChild(discards);
player.appendChild(melds);
player.appendChild(flowers);
board.appendChild(player);
}
}
function render() {
const event = events[idx];
if (!event) return;
document.getElementById("counter").textContent = `Step ${idx + 1} / ${events.length}`;
document.getElementById("status").textContent = `${describe(event)} | Wall: ${event.snapshot.wall_count}`;
renderBoard(event.snapshot);
}
function togglePlay() {
const button = document.getElementById("play");
if (timer) {
clearInterval(timer);
timer = null;
button.textContent = "Play";
return;
}
button.textContent = "Pause";
timer = setInterval(() => {
idx = Math.min(idx + 1, events.length - 1);
render();
if (idx >= events.length - 1) {
togglePlay();
}
}, 700);
}
document.getElementById("prev").addEventListener("click", () => {
idx = Math.max(0, idx - 1);
render();
});
document.getElementById("next").addEventListener("click", () => {
idx = Math.min(events.length - 1, idx + 1);
render();
});
document.getElementById("play").addEventListener("click", togglePlay);
fetch("steps.json")
.then((response) => response.json())
.then((data) => {
events = data.events || [];
idx = 0;
render();
})
.catch(() => {
document.getElementById("status").textContent = "Failed to load steps.json";
});
</script>
</body>
</html>

File diff suppressed because it is too large Load Diff

33
majiang_rl/ui/web.py Normal file
View File

@@ -0,0 +1,33 @@
from __future__ import annotations
import argparse
import functools
import http.server
import socketserver
from pathlib import Path
from .record import record_episode, write_record
def main() -> None:
parser = argparse.ArgumentParser(description="Mahjong RL web viewer")
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--seed", type=int, default=None)
parser.add_argument("--max-steps", type=int, default=200)
args = parser.parse_args()
static_dir = Path(__file__).parent / "static"
static_dir.mkdir(parents=True, exist_ok=True)
record_path = static_dir / "steps.json"
events = record_episode(seed=args.seed, max_steps=args.max_steps)
write_record(record_path, events)
handler = functools.partial(http.server.SimpleHTTPRequestHandler, directory=str(static_dir))
with socketserver.TCPServer(("", args.port), handler) as httpd:
print(f"Serving UI at http://localhost:{args.port}/index.html")
httpd.serve_forever()
if __name__ == "__main__":
main()

13
pyproject.toml Normal file
View File

@@ -0,0 +1,13 @@
[project]
name = "majiang-rl"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.12"
dependencies = [
"gymnasium>=0.29.0",
"numpy>=1.26.0",
"torch>=2.2.0",
"torchvision>=0.24.1",
"swanlab>=0.3.0",
]

1102
uv.lock generated Normal file

File diff suppressed because it is too large Load Diff