feat: initialize majiang-rl project
This commit is contained in:
10
.gitignore
vendored
Normal file
10
.gitignore
vendored
Normal file
@@ -0,0 +1,10 @@
|
||||
# Python-generated files
|
||||
__pycache__/
|
||||
*.py[oc]
|
||||
build/
|
||||
dist/
|
||||
wheels/
|
||||
*.egg-info
|
||||
|
||||
# Virtual environments
|
||||
.venv
|
||||
1
.python-version
Normal file
1
.python-version
Normal file
@@ -0,0 +1 @@
|
||||
3.12
|
||||
64
README.md
Normal file
64
README.md
Normal file
@@ -0,0 +1,64 @@
|
||||
# majiang-rl
|
||||
|
||||
A minimal Mahjong (Guobiao) simulation environment and reinforcement learning scaffold built on gymnasium.
|
||||
|
||||
## Features
|
||||
- 4-player Guobiao tile set (144 tiles, including flowers)
|
||||
- Draw/discard turn loop with flower replacement
|
||||
- Basic calls: win (ron/tsumo), pong, chi, kong
|
||||
- Win checking for standard hands, seven pairs, and thirteen orphans
|
||||
- Gymnasium-style environment API with action masks (type/discard/pong/kong/chi)
|
||||
- Simple RL loop and random agent
|
||||
|
||||
## Limitations
|
||||
- No scoring or 8-fan enforcement
|
||||
- NPC players use simple greedy claims and random discards
|
||||
- No detailed round rules (winds/seat rotation, riichi, etc.)
|
||||
|
||||
## Quick start (uv)
|
||||
```bash
|
||||
uv venv
|
||||
uv pip install -e .
|
||||
uv run python main.py
|
||||
```
|
||||
|
||||
## Environment API
|
||||
```python
|
||||
from majiang_rl import MahjongEnv
|
||||
|
||||
env = MahjongEnv()
|
||||
obs, info = env.reset()
|
||||
|
||||
# action format
|
||||
# type: 0 discard, 1 declare win, 2 declare kong, 3 pass, 4 declare pong, 5 declare chi
|
||||
# tile: tile id (0-41)
|
||||
# chi: 0 left, 1 middle, 2 right
|
||||
action = {"type": 0, "tile": 0, "chi": 0}
|
||||
obs, reward, terminated, truncated, info = env.step(action)
|
||||
```
|
||||
|
||||
## RL scaffold
|
||||
```python
|
||||
from majiang_rl import MahjongEnv
|
||||
from majiang_rl.rl import RandomAgent, run_training
|
||||
|
||||
env = MahjongEnv()
|
||||
agent = RandomAgent()
|
||||
results = run_training(env, agent, episodes=10)
|
||||
print(results[0])
|
||||
```
|
||||
|
||||
## GRPO self-play training
|
||||
```bash
|
||||
uv run python -m majiang_rl.rl.grpo --updates 20 --group-size 16 --device auto
|
||||
uv run python -m majiang_rl.rl.grpo --updates 20 --group-size 16 --device auto --pong-reward 0.1 --closest-bonus 1.0
|
||||
uv run python -m majiang_rl.rl.grpo --updates 20 --group-size 16 --device auto --swanlab --swanlab-project majiang-rl --swanlab-run-name grpo-demo
|
||||
```
|
||||
|
||||
Reward uses a simplified fan breakdown (thirteen orphans, seven pairs, pure/half flush, all pungs, all honors).
|
||||
|
||||
## Simple web UI
|
||||
```bash
|
||||
uv run python -m majiang_rl.ui.web --port 8000
|
||||
```
|
||||
Then open `http://localhost:8000/index.html` to watch the playback.
|
||||
14
main.py
Normal file
14
main.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from majiang_rl import MahjongEnv
|
||||
from majiang_rl.rl import RandomAgent, run_training
|
||||
|
||||
|
||||
def main():
|
||||
env = MahjongEnv()
|
||||
agent = RandomAgent()
|
||||
results = run_training(env, agent, episodes=5)
|
||||
for idx, result in enumerate(results, start=1):
|
||||
print(f"Episode {idx}: reward={result.total_reward}, steps={result.steps}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
3
majiang_rl/__init__.py
Normal file
3
majiang_rl/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .envs import MahjongEnv, SelfPlayMahjong
|
||||
|
||||
__all__ = ["MahjongEnv", "SelfPlayMahjong"]
|
||||
4
majiang_rl/envs/__init__.py
Normal file
4
majiang_rl/envs/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .mahjong_env import MahjongEnv
|
||||
from .self_play_env import SelfPlayMahjong
|
||||
|
||||
__all__ = ["MahjongEnv", "SelfPlayMahjong"]
|
||||
578
majiang_rl/envs/mahjong_env.py
Normal file
578
majiang_rl/envs/mahjong_env.py
Normal file
@@ -0,0 +1,578 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
from gymnasium import spaces
|
||||
|
||||
from ..rules import can_chi_options, is_win
|
||||
from ..tiles import TILE_TYPES, build_wall, is_flower, tile_name
|
||||
|
||||
PHASE_AGENT_TURN = 0
|
||||
PHASE_AGENT_RESPONSE = 1
|
||||
|
||||
ACTION_DISCARD = 0
|
||||
ACTION_DECLARE_WIN = 1
|
||||
ACTION_DECLARE_KONG = 2
|
||||
ACTION_PASS = 3
|
||||
ACTION_DECLARE_PONG = 4
|
||||
ACTION_DECLARE_CHI = 5
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlayerState:
|
||||
hand: List[int] = field(default_factory=list)
|
||||
melds: List[List[int]] = field(default_factory=list)
|
||||
discards: List[int] = field(default_factory=list)
|
||||
flowers: List[int] = field(default_factory=list)
|
||||
|
||||
|
||||
class MahjongEnv(gym.Env):
|
||||
metadata = {"render_modes": ["human"]}
|
||||
|
||||
def __init__(self, seed: Optional[int] = None, max_rounds: int = 1):
|
||||
super().__init__()
|
||||
self.max_rounds = max_rounds
|
||||
self.action_space = spaces.Dict(
|
||||
{
|
||||
"type": spaces.Discrete(6),
|
||||
"tile": spaces.Discrete(len(TILE_TYPES)),
|
||||
"chi": spaces.Discrete(3),
|
||||
}
|
||||
)
|
||||
self.observation_space = spaces.Dict(
|
||||
{
|
||||
"hand": spaces.Box(0, 4, shape=(len(TILE_TYPES),), dtype=np.int8),
|
||||
"melds": spaces.Box(0, 4, shape=(4, len(TILE_TYPES)), dtype=np.int8),
|
||||
"discards": spaces.Box(0, 4, shape=(4, len(TILE_TYPES)), dtype=np.int8),
|
||||
"flowers": spaces.Box(0, 8, shape=(4,), dtype=np.int8),
|
||||
"wall_count": spaces.Box(0, 144, shape=(1,), dtype=np.int16),
|
||||
"current_player": spaces.Box(0, 3, shape=(1,), dtype=np.int8),
|
||||
"phase": spaces.Box(0, 1, shape=(1,), dtype=np.int8),
|
||||
}
|
||||
)
|
||||
self._random = random.Random(seed)
|
||||
self._round = 0
|
||||
self.agent_index = 0
|
||||
self.players: List[PlayerState] = []
|
||||
self.wall: List[int] = []
|
||||
self.current_player = 0
|
||||
self.phase = PHASE_AGENT_TURN
|
||||
self.pending_discard: Optional[Tuple[int, int]] = None
|
||||
self.skip_draw_for_player: Optional[int] = None
|
||||
self.done = False
|
||||
self.winner: Optional[int] = None
|
||||
self.terminal_reason: Optional[str] = None
|
||||
self.event_log: List[Dict[str, object]] = []
|
||||
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
||||
super().reset(seed=seed)
|
||||
if seed is not None:
|
||||
self._random.seed(seed)
|
||||
self._round += 1
|
||||
self.players = [PlayerState() for _ in range(4)]
|
||||
self.wall = build_wall(self._random)
|
||||
self.current_player = 0
|
||||
self.pending_discard = None
|
||||
self.skip_draw_for_player = None
|
||||
self.done = False
|
||||
self.winner = None
|
||||
self.terminal_reason = None
|
||||
self.event_log = []
|
||||
|
||||
for _ in range(13):
|
||||
for player in range(4):
|
||||
self._draw_tile(player)
|
||||
self._draw_tile(0)
|
||||
self.skip_draw_for_player = 0
|
||||
self.phase = PHASE_AGENT_TURN
|
||||
|
||||
self._advance_to_agent()
|
||||
return self._get_obs(), self._get_info()
|
||||
|
||||
def step(self, action: Dict[str, int]):
|
||||
if self.done:
|
||||
return self._get_obs(), 0.0, True, False, self._get_info()
|
||||
|
||||
reward = 0.0
|
||||
action = self._sanitize_action(action)
|
||||
action_type = int(action["type"])
|
||||
tile_id = int(action["tile"])
|
||||
chi_choice = int(action["chi"])
|
||||
|
||||
if self.phase == PHASE_AGENT_TURN:
|
||||
reward += self._handle_agent_turn(action_type, tile_id)
|
||||
elif self.phase == PHASE_AGENT_RESPONSE:
|
||||
reward += self._handle_agent_response(action_type, tile_id, chi_choice)
|
||||
|
||||
if not self.done:
|
||||
self._advance_to_agent()
|
||||
|
||||
if self.done:
|
||||
if self.winner == self.agent_index:
|
||||
reward += 1.0
|
||||
elif self.winner is not None:
|
||||
reward -= 1.0
|
||||
return self._get_obs(), reward, self.done, False, self._get_info()
|
||||
|
||||
def render(self):
|
||||
if self.done:
|
||||
print(f"Winner: {self.winner}, reason: {self.terminal_reason}")
|
||||
else:
|
||||
print(f"Player {self.current_player} to act, wall {len(self.wall)}")
|
||||
|
||||
def _get_obs(self):
|
||||
hand_counts = self._counts(self.players[self.agent_index].hand)
|
||||
melds = np.zeros((4, len(TILE_TYPES)), dtype=np.int8)
|
||||
discards = np.zeros((4, len(TILE_TYPES)), dtype=np.int8)
|
||||
flowers = np.zeros(4, dtype=np.int8)
|
||||
for idx, player in enumerate(self.players):
|
||||
for meld in player.melds:
|
||||
for tile in meld:
|
||||
melds[idx, tile] += 1
|
||||
for tile in player.discards:
|
||||
discards[idx, tile] += 1
|
||||
flowers[idx] = len(player.flowers)
|
||||
return {
|
||||
"hand": hand_counts,
|
||||
"melds": melds,
|
||||
"discards": discards,
|
||||
"flowers": flowers,
|
||||
"wall_count": np.array([len(self.wall)], dtype=np.int16),
|
||||
"current_player": np.array([self.current_player], dtype=np.int8),
|
||||
"phase": np.array([self.phase], dtype=np.int8),
|
||||
}
|
||||
|
||||
def _get_info(self):
|
||||
action_mask = self._action_mask()
|
||||
return {
|
||||
"winner": self.winner,
|
||||
"terminal_reason": self.terminal_reason,
|
||||
"pending_discard": self.pending_discard,
|
||||
"action_mask": action_mask,
|
||||
}
|
||||
|
||||
def _sanitize_action(self, action: Dict[str, int]) -> Dict[str, int]:
|
||||
if not isinstance(action, dict):
|
||||
return self._sample_valid_action()
|
||||
if "type" not in action or "tile" not in action or "chi" not in action:
|
||||
return self._sample_valid_action()
|
||||
mask = self._action_mask()
|
||||
action_type = int(action["type"])
|
||||
tile_id = int(action["tile"])
|
||||
chi_choice = int(action["chi"])
|
||||
if action_type < 0 or action_type >= 6:
|
||||
return self._sample_valid_action()
|
||||
if not mask["type"][action_type]:
|
||||
return self._sample_valid_action()
|
||||
if action_type == ACTION_DISCARD:
|
||||
if not mask["discard"][tile_id]:
|
||||
return self._sample_valid_action()
|
||||
if action_type == ACTION_DECLARE_PONG:
|
||||
if not mask["pong"][tile_id]:
|
||||
return self._sample_valid_action()
|
||||
if action_type == ACTION_DECLARE_KONG:
|
||||
if not mask["kong"][tile_id]:
|
||||
return self._sample_valid_action()
|
||||
if action_type == ACTION_DECLARE_CHI:
|
||||
if not mask["chi"][chi_choice]:
|
||||
return self._sample_valid_action()
|
||||
return {"type": action_type, "tile": tile_id, "chi": chi_choice}
|
||||
|
||||
def _sample_valid_action(self) -> Dict[str, int]:
|
||||
mask = self._action_mask()
|
||||
valid_types = np.flatnonzero(mask["type"])
|
||||
if len(valid_types) == 0:
|
||||
return {"type": ACTION_PASS, "tile": 0, "chi": 0}
|
||||
action_type = int(self._random.choice(valid_types))
|
||||
tile_id = 0
|
||||
chi_choice = 0
|
||||
if action_type == ACTION_DISCARD:
|
||||
valid_tiles = np.flatnonzero(mask["discard"])
|
||||
tile_id = int(self._random.choice(valid_tiles)) if len(valid_tiles) else 0
|
||||
if action_type == ACTION_DECLARE_PONG:
|
||||
valid_tiles = np.flatnonzero(mask["pong"])
|
||||
tile_id = int(self._random.choice(valid_tiles)) if len(valid_tiles) else 0
|
||||
if action_type == ACTION_DECLARE_KONG:
|
||||
valid_tiles = np.flatnonzero(mask["kong"])
|
||||
tile_id = int(self._random.choice(valid_tiles)) if len(valid_tiles) else 0
|
||||
if action_type == ACTION_DECLARE_CHI:
|
||||
valid_chi = np.flatnonzero(mask["chi"])
|
||||
chi_choice = int(self._random.choice(valid_chi)) if len(valid_chi) else 0
|
||||
return {"type": action_type, "tile": tile_id, "chi": chi_choice}
|
||||
|
||||
def _action_mask(self) -> Dict[str, np.ndarray]:
|
||||
mask_type = np.zeros(6, dtype=bool)
|
||||
mask_discard = np.zeros(len(TILE_TYPES), dtype=bool)
|
||||
mask_pong = np.zeros(len(TILE_TYPES), dtype=bool)
|
||||
mask_kong = np.zeros(len(TILE_TYPES), dtype=bool)
|
||||
mask_chi = np.zeros(3, dtype=bool)
|
||||
agent = self.players[self.agent_index]
|
||||
counts = self._counts(agent.hand)
|
||||
|
||||
if self.phase == PHASE_AGENT_TURN:
|
||||
mask_type[ACTION_DISCARD] = True
|
||||
for tile in agent.hand:
|
||||
mask_discard[tile] = True
|
||||
if is_win(agent.hand, len(agent.melds)):
|
||||
mask_type[ACTION_DECLARE_WIN] = True
|
||||
for tile_id, count in enumerate(counts):
|
||||
if count >= 4:
|
||||
mask_type[ACTION_DECLARE_KONG] = True
|
||||
mask_kong[tile_id] = True
|
||||
elif self.phase == PHASE_AGENT_RESPONSE and self.pending_discard is not None:
|
||||
mask_type[ACTION_PASS] = True
|
||||
_, tile = self.pending_discard
|
||||
virtual_hand = agent.hand + [tile]
|
||||
if is_win(virtual_hand, len(agent.melds)):
|
||||
mask_type[ACTION_DECLARE_WIN] = True
|
||||
if counts[tile] >= 2:
|
||||
mask_type[ACTION_DECLARE_PONG] = True
|
||||
mask_pong[tile] = True
|
||||
if counts[tile] >= 3:
|
||||
mask_type[ACTION_DECLARE_KONG] = True
|
||||
mask_kong[tile] = True
|
||||
if self._is_next_player(self.agent_index, self.pending_discard[0]):
|
||||
options = can_chi_options(counts, tile)
|
||||
if options:
|
||||
mask_type[ACTION_DECLARE_CHI] = True
|
||||
for option in options:
|
||||
mask_chi[option] = True
|
||||
|
||||
return {
|
||||
"type": mask_type,
|
||||
"discard": mask_discard,
|
||||
"pong": mask_pong,
|
||||
"kong": mask_kong,
|
||||
"chi": mask_chi,
|
||||
}
|
||||
|
||||
def _handle_agent_turn(self, action_type: int, tile_id: int) -> float:
|
||||
agent = self.players[self.agent_index]
|
||||
if action_type == ACTION_DECLARE_WIN:
|
||||
if is_win(agent.hand, len(agent.melds)):
|
||||
self._end_game(self.agent_index, "self_draw")
|
||||
return 0.0
|
||||
if action_type == ACTION_DECLARE_KONG:
|
||||
if self._can_concealed_kong(agent, tile_id):
|
||||
self._make_kong(self.agent_index, tile_id, concealed=True)
|
||||
self._draw_tile(self.agent_index)
|
||||
self.skip_draw_for_player = self.agent_index
|
||||
self.phase = PHASE_AGENT_TURN
|
||||
return 0.0
|
||||
if action_type == ACTION_DISCARD:
|
||||
if tile_id in agent.hand:
|
||||
agent.hand.remove(tile_id)
|
||||
agent.discards.append(tile_id)
|
||||
self.pending_discard = (self.agent_index, tile_id)
|
||||
self._log_event(
|
||||
"discard",
|
||||
player=self.agent_index,
|
||||
tile_id=tile_id,
|
||||
)
|
||||
self._resolve_discard()
|
||||
return 0.0
|
||||
return 0.0
|
||||
|
||||
def _handle_agent_response(self, action_type: int, tile_id: int, chi_choice: int) -> float:
|
||||
if self.pending_discard is None:
|
||||
self.phase = PHASE_AGENT_TURN
|
||||
return 0.0
|
||||
discarder, tile = self.pending_discard
|
||||
agent = self.players[self.agent_index]
|
||||
counts = self._counts(agent.hand)
|
||||
|
||||
if action_type == ACTION_DECLARE_WIN:
|
||||
if is_win(agent.hand + [tile], len(agent.melds)):
|
||||
self._end_game(self.agent_index, "ron")
|
||||
return 0.0
|
||||
if action_type == ACTION_DECLARE_PONG:
|
||||
if counts[tile] >= 2:
|
||||
self._consume_discard(discarder, tile)
|
||||
agent.hand.remove(tile)
|
||||
agent.hand.remove(tile)
|
||||
agent.melds.append([tile, tile, tile])
|
||||
self.current_player = self.agent_index
|
||||
self.skip_draw_for_player = self.agent_index
|
||||
self.pending_discard = None
|
||||
self._log_event(
|
||||
"claim",
|
||||
player=self.agent_index,
|
||||
claim="pong",
|
||||
tile_id=tile,
|
||||
from_player=discarder,
|
||||
)
|
||||
self.phase = PHASE_AGENT_TURN
|
||||
return 0.0
|
||||
if action_type == ACTION_DECLARE_KONG:
|
||||
if counts[tile] >= 3:
|
||||
self._consume_discard(discarder, tile)
|
||||
for _ in range(3):
|
||||
agent.hand.remove(tile)
|
||||
agent.melds.append([tile, tile, tile, tile])
|
||||
self.current_player = self.agent_index
|
||||
self.skip_draw_for_player = self.agent_index
|
||||
self.pending_discard = None
|
||||
self._draw_tile(self.agent_index)
|
||||
self._log_event(
|
||||
"claim",
|
||||
player=self.agent_index,
|
||||
claim="kong",
|
||||
tile_id=tile,
|
||||
from_player=discarder,
|
||||
)
|
||||
self.phase = PHASE_AGENT_TURN
|
||||
return 0.0
|
||||
if action_type == ACTION_DECLARE_CHI:
|
||||
if self._is_next_player(self.agent_index, discarder):
|
||||
options = can_chi_options(counts, tile)
|
||||
if chi_choice in options:
|
||||
self._consume_discard(discarder, tile)
|
||||
if chi_choice == 0:
|
||||
chi_tiles = [tile - 2, tile - 1, tile]
|
||||
elif chi_choice == 1:
|
||||
chi_tiles = [tile - 1, tile, tile + 1]
|
||||
else:
|
||||
chi_tiles = [tile, tile + 1, tile + 2]
|
||||
for chi_tile in chi_tiles:
|
||||
if chi_tile != tile:
|
||||
agent.hand.remove(chi_tile)
|
||||
agent.melds.append(chi_tiles)
|
||||
self.current_player = self.agent_index
|
||||
self.skip_draw_for_player = self.agent_index
|
||||
self.pending_discard = None
|
||||
self._log_event(
|
||||
"claim",
|
||||
player=self.agent_index,
|
||||
claim="chi",
|
||||
tile_id=tile,
|
||||
from_player=discarder,
|
||||
)
|
||||
self.phase = PHASE_AGENT_TURN
|
||||
return 0.0
|
||||
|
||||
self.pending_discard = None
|
||||
self.phase = PHASE_AGENT_TURN
|
||||
self._resolve_discard(after_pass=True, original_discarder=discarder, tile=tile)
|
||||
return 0.0
|
||||
|
||||
def _advance_to_agent(self):
|
||||
while not self.done:
|
||||
if self.phase == PHASE_AGENT_RESPONSE:
|
||||
return
|
||||
if self.current_player == self.agent_index:
|
||||
self._start_turn(self.agent_index)
|
||||
if self.done:
|
||||
return
|
||||
self.phase = PHASE_AGENT_TURN
|
||||
return
|
||||
self._npc_turn(self.current_player)
|
||||
if self.done or self.phase == PHASE_AGENT_RESPONSE:
|
||||
return
|
||||
|
||||
def _npc_turn(self, player_idx: int):
|
||||
self._start_turn(player_idx)
|
||||
if self.done:
|
||||
return
|
||||
player = self.players[player_idx]
|
||||
if is_win(player.hand, len(player.melds)):
|
||||
self._end_game(player_idx, "self_draw")
|
||||
return
|
||||
tile = self._random.choice(player.hand)
|
||||
player.hand.remove(tile)
|
||||
player.discards.append(tile)
|
||||
self.pending_discard = (player_idx, tile)
|
||||
self._log_event(
|
||||
"discard",
|
||||
player=player_idx,
|
||||
tile_id=tile,
|
||||
)
|
||||
self._resolve_discard()
|
||||
|
||||
def _resolve_discard(self, after_pass: bool = False, original_discarder: Optional[int] = None, tile: Optional[int] = None):
|
||||
if self.pending_discard is None and not after_pass:
|
||||
return
|
||||
discarder, discarded_tile = self.pending_discard if self.pending_discard is not None else (original_discarder, tile)
|
||||
if discarder is None or discarded_tile is None:
|
||||
return
|
||||
|
||||
if discarder != self.agent_index and not after_pass:
|
||||
if self._agent_can_respond(discarder, discarded_tile):
|
||||
self.pending_discard = (discarder, discarded_tile)
|
||||
self.phase = PHASE_AGENT_RESPONSE
|
||||
return
|
||||
|
||||
winner = self._npc_win_on_discard(discarder, discarded_tile)
|
||||
if winner is not None:
|
||||
self._end_game(winner, "ron")
|
||||
return
|
||||
claim = self._npc_claim(discarder, discarded_tile)
|
||||
if claim is not None:
|
||||
self.pending_discard = None
|
||||
self._apply_npc_claim(*claim)
|
||||
return
|
||||
|
||||
self.pending_discard = None
|
||||
self.current_player = (discarder + 1) % 4
|
||||
|
||||
def _agent_can_respond(self, discarder: int, tile: int) -> bool:
|
||||
agent = self.players[self.agent_index]
|
||||
counts = self._counts(agent.hand)
|
||||
if is_win(agent.hand + [tile], len(agent.melds)):
|
||||
return True
|
||||
if counts[tile] >= 2:
|
||||
return True
|
||||
if counts[tile] >= 3:
|
||||
return True
|
||||
if self._is_next_player(self.agent_index, discarder):
|
||||
return bool(can_chi_options(counts, tile))
|
||||
return False
|
||||
|
||||
def _npc_win_on_discard(self, discarder: int, tile: int) -> Optional[int]:
|
||||
for offset in range(1, 4):
|
||||
player_idx = (discarder + offset) % 4
|
||||
if player_idx == self.agent_index:
|
||||
continue
|
||||
player = self.players[player_idx]
|
||||
if is_win(player.hand + [tile], len(player.melds)):
|
||||
return player_idx
|
||||
return None
|
||||
|
||||
def _npc_claim(self, discarder: int, tile: int):
|
||||
for offset in range(1, 4):
|
||||
player_idx = (discarder + offset) % 4
|
||||
if player_idx == self.agent_index:
|
||||
continue
|
||||
player = self.players[player_idx]
|
||||
counts = self._counts(player.hand)
|
||||
if counts[tile] >= 3:
|
||||
return (player_idx, "kong", tile, None, discarder)
|
||||
if counts[tile] >= 2:
|
||||
return (player_idx, "pong", tile, None, discarder)
|
||||
next_player = (discarder + 1) % 4
|
||||
if next_player != self.agent_index:
|
||||
player = self.players[next_player]
|
||||
options = can_chi_options(self._counts(player.hand), tile)
|
||||
if options:
|
||||
return (next_player, "chi", tile, options[0], discarder)
|
||||
return None
|
||||
|
||||
def _apply_npc_claim(
|
||||
self,
|
||||
player_idx: int,
|
||||
claim_type: str,
|
||||
tile: int,
|
||||
chi_choice: Optional[int],
|
||||
discarder: int,
|
||||
):
|
||||
player = self.players[player_idx]
|
||||
self._consume_discard(discarder, tile)
|
||||
if claim_type == "pong":
|
||||
player.hand.remove(tile)
|
||||
player.hand.remove(tile)
|
||||
player.melds.append([tile, tile, tile])
|
||||
self.current_player = player_idx
|
||||
self.skip_draw_for_player = player_idx
|
||||
elif claim_type == "kong":
|
||||
for _ in range(3):
|
||||
player.hand.remove(tile)
|
||||
player.melds.append([tile, tile, tile, tile])
|
||||
self.current_player = player_idx
|
||||
self.skip_draw_for_player = player_idx
|
||||
self._draw_tile(player_idx)
|
||||
elif claim_type == "chi" and chi_choice is not None:
|
||||
if chi_choice == 0:
|
||||
chi_tiles = [tile - 2, tile - 1, tile]
|
||||
elif chi_choice == 1:
|
||||
chi_tiles = [tile - 1, tile, tile + 1]
|
||||
else:
|
||||
chi_tiles = [tile, tile + 1, tile + 2]
|
||||
for chi_tile in chi_tiles:
|
||||
if chi_tile != tile:
|
||||
player.hand.remove(chi_tile)
|
||||
player.melds.append(chi_tiles)
|
||||
self.current_player = player_idx
|
||||
self.skip_draw_for_player = player_idx
|
||||
self._log_event(
|
||||
"claim",
|
||||
player=player_idx,
|
||||
claim=claim_type,
|
||||
tile_id=tile,
|
||||
from_player=discarder,
|
||||
)
|
||||
|
||||
def _draw_tile(self, player_idx: int):
|
||||
while self.wall:
|
||||
tile = self.wall.pop()
|
||||
if is_flower(tile):
|
||||
self.players[player_idx].flowers.append(tile)
|
||||
continue
|
||||
self.players[player_idx].hand.append(tile)
|
||||
return tile
|
||||
self._end_game(None, "wall_empty")
|
||||
return None
|
||||
|
||||
def _start_turn(self, player_idx: int):
|
||||
if self.skip_draw_for_player == player_idx:
|
||||
self.skip_draw_for_player = None
|
||||
return
|
||||
self._draw_tile(player_idx)
|
||||
|
||||
def _make_kong(self, player_idx: int, tile_id: int, concealed: bool):
|
||||
player = self.players[player_idx]
|
||||
for _ in range(4):
|
||||
player.hand.remove(tile_id)
|
||||
player.melds.append([tile_id] * 4)
|
||||
|
||||
def _can_concealed_kong(self, player: PlayerState, tile_id: int) -> bool:
|
||||
return player.hand.count(tile_id) >= 4
|
||||
|
||||
def _counts(self, tiles: List[int]) -> np.ndarray:
|
||||
counts = np.zeros(len(TILE_TYPES), dtype=np.int8)
|
||||
for tile in tiles:
|
||||
counts[tile] += 1
|
||||
return counts
|
||||
|
||||
def _is_next_player(self, player: int, discarder: int) -> bool:
|
||||
return player == (discarder + 1) % 4
|
||||
|
||||
def _consume_discard(self, discarder: int, tile: int):
|
||||
discards = self.players[discarder].discards
|
||||
if discards and discards[-1] == tile:
|
||||
discards.pop()
|
||||
|
||||
def _end_game(self, winner: Optional[int], reason: str):
|
||||
if self.done:
|
||||
return
|
||||
self.done = True
|
||||
self.winner = winner
|
||||
self.terminal_reason = reason
|
||||
self._log_event(
|
||||
"win" if winner is not None else "draw",
|
||||
player=winner,
|
||||
reason=reason,
|
||||
)
|
||||
|
||||
def _snapshot(self) -> Dict[str, object]:
|
||||
return {
|
||||
"hands": [[tile_name(tile) for tile in sorted(p.hand)] for p in self.players],
|
||||
"discards": [[tile_name(tile) for tile in p.discards] for p in self.players],
|
||||
"melds": [[[tile_name(tile) for tile in meld] for meld in p.melds] for p in self.players],
|
||||
"flowers": [[tile_name(tile) for tile in p.flowers] for p in self.players],
|
||||
"wall_count": len(self.wall),
|
||||
"current_player": self.current_player,
|
||||
}
|
||||
|
||||
def _log_event(self, event_type: str, **payload: object):
|
||||
event = {
|
||||
"id": len(self.event_log),
|
||||
"type": event_type,
|
||||
"snapshot": self._snapshot(),
|
||||
}
|
||||
event.update(payload)
|
||||
if "tile_id" in event:
|
||||
event["tile"] = tile_name(int(event["tile_id"]))
|
||||
self.event_log.append(event)
|
||||
337
majiang_rl/envs/self_play_env.py
Normal file
337
majiang_rl/envs/self_play_env.py
Normal file
@@ -0,0 +1,337 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..rules import can_chi_options, is_win
|
||||
from ..tiles import TILE_TYPES, build_wall, is_flower
|
||||
|
||||
PHASE_TURN = 0
|
||||
PHASE_RESPONSE = 1
|
||||
|
||||
ACTION_DISCARD = 0
|
||||
ACTION_DECLARE_WIN = 1
|
||||
ACTION_DECLARE_KONG = 2
|
||||
ACTION_PASS = 3
|
||||
ACTION_DECLARE_PONG = 4
|
||||
ACTION_DECLARE_CHI = 5
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlayerState:
|
||||
hand: List[int] = field(default_factory=list)
|
||||
melds: List[List[int]] = field(default_factory=list)
|
||||
discards: List[int] = field(default_factory=list)
|
||||
flowers: List[int] = field(default_factory=list)
|
||||
|
||||
|
||||
class SelfPlayMahjong:
|
||||
def __init__(self, seed: Optional[int] = None, pong_reward: float = 0.0):
|
||||
self._random = random.Random(seed)
|
||||
self.pong_reward = pong_reward
|
||||
self.players: List[PlayerState] = []
|
||||
self.wall: List[int] = []
|
||||
self.current_player = 0
|
||||
self.phase = PHASE_TURN
|
||||
self.pending_discard: Optional[Tuple[int, int]] = None
|
||||
self.response_queue: List[int] = []
|
||||
self.turn_started = False
|
||||
self.skip_draw_for_player: Optional[int] = None
|
||||
self.done = False
|
||||
self.winner: Optional[int] = None
|
||||
self.terminal_reason: Optional[str] = None
|
||||
self.winning_tiles: List[int] = []
|
||||
self.winning_melds: List[List[int]] = []
|
||||
|
||||
def reset(self, *, seed: Optional[int] = None) -> None:
|
||||
if seed is not None:
|
||||
self._random.seed(seed)
|
||||
self.players = [PlayerState() for _ in range(4)]
|
||||
self.wall = build_wall(self._random)
|
||||
self.current_player = 0
|
||||
self.phase = PHASE_TURN
|
||||
self.pending_discard = None
|
||||
self.response_queue = []
|
||||
self.turn_started = False
|
||||
self.skip_draw_for_player = None
|
||||
self.done = False
|
||||
self.winner = None
|
||||
self.terminal_reason = None
|
||||
self.winning_tiles = []
|
||||
self.winning_melds = []
|
||||
|
||||
for _ in range(13):
|
||||
for player in range(4):
|
||||
self._draw_tile(player)
|
||||
|
||||
def observe(self, player_idx: int) -> Dict[str, np.ndarray]:
|
||||
hand_counts = self._counts(self.players[player_idx].hand)
|
||||
melds = np.zeros((4, len(TILE_TYPES)), dtype=np.int8)
|
||||
discards = np.zeros((4, len(TILE_TYPES)), dtype=np.int8)
|
||||
flowers = np.zeros(4, dtype=np.int8)
|
||||
for idx, player in enumerate(self.players):
|
||||
for meld in player.melds:
|
||||
for tile in meld:
|
||||
melds[idx, tile] += 1
|
||||
for tile in player.discards:
|
||||
discards[idx, tile] += 1
|
||||
flowers[idx] = len(player.flowers)
|
||||
|
||||
pending = np.zeros(len(TILE_TYPES), dtype=np.int8)
|
||||
if self.pending_discard is not None:
|
||||
pending[self.pending_discard[1]] = 1
|
||||
|
||||
return {
|
||||
"hand": hand_counts,
|
||||
"melds": melds,
|
||||
"discards": discards,
|
||||
"flowers": flowers,
|
||||
"wall_count": np.array([len(self.wall)], dtype=np.int16),
|
||||
"current_player": np.array([self.current_player], dtype=np.int8),
|
||||
"phase": np.array([self.phase], dtype=np.int8),
|
||||
"pending_discard": pending,
|
||||
}
|
||||
|
||||
def action_mask(self, player_idx: int) -> Dict[str, np.ndarray]:
|
||||
mask_type = np.zeros(6, dtype=bool)
|
||||
mask_discard = np.zeros(len(TILE_TYPES), dtype=bool)
|
||||
mask_pong = np.zeros(len(TILE_TYPES), dtype=bool)
|
||||
mask_kong = np.zeros(len(TILE_TYPES), dtype=bool)
|
||||
mask_chi = np.zeros(3, dtype=bool)
|
||||
player = self.players[player_idx]
|
||||
counts = self._counts(player.hand)
|
||||
|
||||
if self.phase == PHASE_TURN:
|
||||
mask_type[ACTION_DISCARD] = True
|
||||
for tile in player.hand:
|
||||
mask_discard[tile] = True
|
||||
if is_win(player.hand, len(player.melds)):
|
||||
mask_type[ACTION_DECLARE_WIN] = True
|
||||
for tile_id, count in enumerate(counts):
|
||||
if count >= 4:
|
||||
mask_type[ACTION_DECLARE_KONG] = True
|
||||
mask_kong[tile_id] = True
|
||||
elif self.phase == PHASE_RESPONSE and self.pending_discard is not None:
|
||||
mask_type[ACTION_PASS] = True
|
||||
discarder, tile = self.pending_discard
|
||||
virtual_hand = player.hand + [tile]
|
||||
if is_win(virtual_hand, len(player.melds)):
|
||||
mask_type[ACTION_DECLARE_WIN] = True
|
||||
if counts[tile] >= 2:
|
||||
mask_type[ACTION_DECLARE_PONG] = True
|
||||
mask_pong[tile] = True
|
||||
if counts[tile] >= 3:
|
||||
mask_type[ACTION_DECLARE_KONG] = True
|
||||
mask_kong[tile] = True
|
||||
if player_idx == (discarder + 1) % 4:
|
||||
options = can_chi_options(counts, tile)
|
||||
if options:
|
||||
mask_type[ACTION_DECLARE_CHI] = True
|
||||
for option in options:
|
||||
mask_chi[option] = True
|
||||
|
||||
return {
|
||||
"type": mask_type,
|
||||
"discard": mask_discard,
|
||||
"pong": mask_pong,
|
||||
"kong": mask_kong,
|
||||
"chi": mask_chi,
|
||||
}
|
||||
|
||||
def step(self, player_idx: int, action: Dict[str, int]) -> float:
|
||||
if self.done:
|
||||
return 0.0
|
||||
if player_idx != self.current_player:
|
||||
raise ValueError("Action from non-active player")
|
||||
|
||||
action = self._sanitize_action(player_idx, action)
|
||||
if self.phase == PHASE_TURN:
|
||||
if not self.turn_started:
|
||||
self._start_turn(player_idx)
|
||||
if self.done:
|
||||
return 0.0
|
||||
self.turn_started = True
|
||||
return self._handle_turn_action(player_idx, action)
|
||||
return self._handle_response_action(player_idx, action)
|
||||
|
||||
def _handle_turn_action(self, player_idx: int, action: Dict[str, int]) -> float:
|
||||
player = self.players[player_idx]
|
||||
action_type = int(action.get("type", ACTION_DISCARD))
|
||||
tile_id = int(action.get("tile", 0))
|
||||
|
||||
if action_type == ACTION_DECLARE_WIN and is_win(player.hand, len(player.melds)):
|
||||
self._end_game(player_idx, "self_draw", player.hand)
|
||||
return 0.0
|
||||
|
||||
if action_type == ACTION_DECLARE_KONG and player.hand.count(tile_id) >= 4:
|
||||
for _ in range(4):
|
||||
player.hand.remove(tile_id)
|
||||
player.melds.append([tile_id] * 4)
|
||||
self._draw_tile(player_idx)
|
||||
self.turn_started = True
|
||||
return 0.0
|
||||
|
||||
if action_type == ACTION_DISCARD and tile_id in player.hand:
|
||||
player.hand.remove(tile_id)
|
||||
player.discards.append(tile_id)
|
||||
self.pending_discard = (player_idx, tile_id)
|
||||
self.response_queue = [(player_idx + offset) % 4 for offset in range(1, 4)]
|
||||
self.phase = PHASE_RESPONSE
|
||||
self.current_player = self.response_queue.pop(0)
|
||||
self.turn_started = False
|
||||
return 0.0
|
||||
return 0.0
|
||||
|
||||
def _handle_response_action(self, player_idx: int, action: Dict[str, int]) -> float:
|
||||
if self.pending_discard is None:
|
||||
self.phase = PHASE_TURN
|
||||
self.turn_started = False
|
||||
return 0.0
|
||||
discarder, tile = self.pending_discard
|
||||
player = self.players[player_idx]
|
||||
counts = self._counts(player.hand)
|
||||
action_type = int(action.get("type", ACTION_PASS))
|
||||
chi_choice = int(action.get("chi", 0))
|
||||
|
||||
if action_type == ACTION_DECLARE_WIN and is_win(player.hand + [tile], len(player.melds)):
|
||||
self._end_game(player_idx, "ron", player.hand + [tile])
|
||||
return 0.0
|
||||
|
||||
if action_type == ACTION_DECLARE_PONG and counts[tile] >= 2:
|
||||
self._consume_discard(discarder, tile)
|
||||
player.hand.remove(tile)
|
||||
player.hand.remove(tile)
|
||||
player.melds.append([tile, tile, tile])
|
||||
self.current_player = player_idx
|
||||
self.pending_discard = None
|
||||
self.response_queue = []
|
||||
self.phase = PHASE_TURN
|
||||
self.turn_started = False
|
||||
self.skip_draw_for_player = player_idx
|
||||
return self.pong_reward
|
||||
|
||||
if action_type == ACTION_DECLARE_KONG and counts[tile] >= 3:
|
||||
self._consume_discard(discarder, tile)
|
||||
for _ in range(3):
|
||||
player.hand.remove(tile)
|
||||
player.melds.append([tile, tile, tile, tile])
|
||||
self.current_player = player_idx
|
||||
self.pending_discard = None
|
||||
self.response_queue = []
|
||||
self.phase = PHASE_TURN
|
||||
self.turn_started = True
|
||||
self._draw_tile(player_idx)
|
||||
return 0.0
|
||||
|
||||
if action_type == ACTION_DECLARE_CHI and player_idx == (discarder + 1) % 4:
|
||||
options = can_chi_options(counts, tile)
|
||||
if chi_choice in options:
|
||||
self._consume_discard(discarder, tile)
|
||||
if chi_choice == 0:
|
||||
chi_tiles = [tile - 2, tile - 1, tile]
|
||||
elif chi_choice == 1:
|
||||
chi_tiles = [tile - 1, tile, tile + 1]
|
||||
else:
|
||||
chi_tiles = [tile, tile + 1, tile + 2]
|
||||
for chi_tile in chi_tiles:
|
||||
if chi_tile != tile:
|
||||
player.hand.remove(chi_tile)
|
||||
player.melds.append(chi_tiles)
|
||||
self.current_player = player_idx
|
||||
self.pending_discard = None
|
||||
self.response_queue = []
|
||||
self.phase = PHASE_TURN
|
||||
self.turn_started = False
|
||||
self.skip_draw_for_player = player_idx
|
||||
return 0.0
|
||||
|
||||
self._advance_response_queue(discarder)
|
||||
return 0.0
|
||||
|
||||
def _sanitize_action(self, player_idx: int, action: Dict[str, int]) -> Dict[str, int]:
|
||||
mask = self.action_mask(player_idx)
|
||||
action_type = int(action.get("type", ACTION_PASS))
|
||||
tile_id = int(action.get("tile", 0))
|
||||
chi_choice = int(action.get("chi", 0))
|
||||
|
||||
if action_type < 0 or action_type >= 6 or not mask["type"][action_type]:
|
||||
return self._sample_valid_action(mask)
|
||||
if action_type == ACTION_DISCARD and not mask["discard"][tile_id]:
|
||||
return self._sample_valid_action(mask)
|
||||
if action_type == ACTION_DECLARE_PONG and not mask["pong"][tile_id]:
|
||||
return self._sample_valid_action(mask)
|
||||
if action_type == ACTION_DECLARE_KONG and not mask["kong"][tile_id]:
|
||||
return self._sample_valid_action(mask)
|
||||
if action_type == ACTION_DECLARE_CHI and not mask["chi"][chi_choice]:
|
||||
return self._sample_valid_action(mask)
|
||||
return {"type": action_type, "tile": tile_id, "chi": chi_choice}
|
||||
|
||||
def _sample_valid_action(self, mask: Dict[str, np.ndarray]) -> Dict[str, int]:
|
||||
valid_types = np.flatnonzero(mask["type"])
|
||||
if len(valid_types) == 0:
|
||||
return {"type": ACTION_PASS, "tile": 0, "chi": 0}
|
||||
action_type = int(self._random.choice(valid_types))
|
||||
tile_id = 0
|
||||
chi_choice = 0
|
||||
if action_type == ACTION_DISCARD:
|
||||
valid_tiles = np.flatnonzero(mask["discard"])
|
||||
tile_id = int(self._random.choice(valid_tiles)) if len(valid_tiles) else 0
|
||||
if action_type == ACTION_DECLARE_PONG:
|
||||
valid_tiles = np.flatnonzero(mask["pong"])
|
||||
tile_id = int(self._random.choice(valid_tiles)) if len(valid_tiles) else 0
|
||||
if action_type == ACTION_DECLARE_KONG:
|
||||
valid_tiles = np.flatnonzero(mask["kong"])
|
||||
tile_id = int(self._random.choice(valid_tiles)) if len(valid_tiles) else 0
|
||||
if action_type == ACTION_DECLARE_CHI:
|
||||
valid_chi = np.flatnonzero(mask["chi"])
|
||||
chi_choice = int(self._random.choice(valid_chi)) if len(valid_chi) else 0
|
||||
return {"type": action_type, "tile": tile_id, "chi": chi_choice}
|
||||
|
||||
def _advance_response_queue(self, discarder: int):
|
||||
if self.response_queue:
|
||||
self.current_player = self.response_queue.pop(0)
|
||||
return
|
||||
self.pending_discard = None
|
||||
self.phase = PHASE_TURN
|
||||
self.current_player = (discarder + 1) % 4
|
||||
self.turn_started = False
|
||||
|
||||
def _start_turn(self, player_idx: int):
|
||||
if self.skip_draw_for_player == player_idx:
|
||||
self.skip_draw_for_player = None
|
||||
return
|
||||
self._draw_tile(player_idx)
|
||||
|
||||
def _draw_tile(self, player_idx: int):
|
||||
while self.wall:
|
||||
tile = self.wall.pop()
|
||||
if is_flower(tile):
|
||||
self.players[player_idx].flowers.append(tile)
|
||||
continue
|
||||
self.players[player_idx].hand.append(tile)
|
||||
return tile
|
||||
self._end_game(None, "wall_empty", [])
|
||||
return None
|
||||
|
||||
def _consume_discard(self, discarder: int, tile: int):
|
||||
discards = self.players[discarder].discards
|
||||
if discards and discards[-1] == tile:
|
||||
discards.pop()
|
||||
|
||||
def _end_game(self, winner: Optional[int], reason: str, winning_tiles: List[int]):
|
||||
if self.done:
|
||||
return
|
||||
self.done = True
|
||||
self.winner = winner
|
||||
self.terminal_reason = reason
|
||||
self.winning_tiles = winning_tiles
|
||||
self.winning_melds = [] if winner is None else self.players[winner].melds
|
||||
|
||||
def _counts(self, tiles: List[int]) -> np.ndarray:
|
||||
counts = np.zeros(len(TILE_TYPES), dtype=np.int8)
|
||||
for tile in tiles:
|
||||
counts[tile] += 1
|
||||
return counts
|
||||
13
majiang_rl/rl/__init__.py
Normal file
13
majiang_rl/rl/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from .agents import RandomAgent, Transition
|
||||
from .grpo import GRPOConfig, train_grpo
|
||||
from .trainer import EpisodeResult, EpisodeRunner, run_training
|
||||
|
||||
__all__ = [
|
||||
"RandomAgent",
|
||||
"Transition",
|
||||
"EpisodeResult",
|
||||
"EpisodeRunner",
|
||||
"run_training",
|
||||
"GRPOConfig",
|
||||
"train_grpo",
|
||||
]
|
||||
76
majiang_rl/rl/agents.py
Normal file
76
majiang_rl/rl/agents.py
Normal file
@@ -0,0 +1,76 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Protocol
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Agent(Protocol):
|
||||
def act(self, obs: Dict[str, Any], info: Dict[str, Any]) -> Dict[str, int]:
|
||||
...
|
||||
|
||||
def observe(self, transition: "Transition") -> None:
|
||||
...
|
||||
|
||||
def reset(self) -> None:
|
||||
...
|
||||
|
||||
|
||||
@dataclass
|
||||
class Transition:
|
||||
obs: Dict[str, Any]
|
||||
action: Dict[str, int]
|
||||
reward: float
|
||||
terminated: bool
|
||||
truncated: bool
|
||||
next_obs: Dict[str, Any]
|
||||
info: Dict[str, Any]
|
||||
|
||||
|
||||
class RandomAgent:
|
||||
def __init__(self, seed: int | None = None):
|
||||
self.rng = np.random.default_rng(seed)
|
||||
|
||||
def reset(self) -> None:
|
||||
return None
|
||||
|
||||
def act(self, obs: Dict[str, Any], info: Dict[str, Any]) -> Dict[str, int]:
|
||||
mask = info.get("action_mask")
|
||||
if mask is None:
|
||||
return {"type": 0, "tile": 0, "chi": 0}
|
||||
return sample_action_from_mask(mask, self.rng)
|
||||
|
||||
def observe(self, transition: Transition) -> None:
|
||||
return None
|
||||
|
||||
|
||||
def sample_action_from_mask(mask: Dict[str, np.ndarray], rng: np.random.Generator) -> Dict[str, int]:
|
||||
type_mask = np.asarray(mask["type"], dtype=bool)
|
||||
discard_mask = np.asarray(mask["discard"], dtype=bool)
|
||||
pong_mask = np.asarray(mask["pong"], dtype=bool)
|
||||
kong_mask = np.asarray(mask["kong"], dtype=bool)
|
||||
chi_mask = np.asarray(mask["chi"], dtype=bool)
|
||||
valid_types = np.flatnonzero(type_mask)
|
||||
if len(valid_types) == 0:
|
||||
return {"type": 3, "tile": 0, "chi": 0}
|
||||
action_type = int(rng.choice(valid_types))
|
||||
tile_id = 0
|
||||
chi_choice = 0
|
||||
if action_type == 0:
|
||||
valid_tiles = np.flatnonzero(discard_mask)
|
||||
if len(valid_tiles) > 0:
|
||||
tile_id = int(rng.choice(valid_tiles))
|
||||
if action_type == 2:
|
||||
valid_tiles = np.flatnonzero(kong_mask)
|
||||
if len(valid_tiles) > 0:
|
||||
tile_id = int(rng.choice(valid_tiles))
|
||||
if action_type == 4:
|
||||
valid_tiles = np.flatnonzero(pong_mask)
|
||||
if len(valid_tiles) > 0:
|
||||
tile_id = int(rng.choice(valid_tiles))
|
||||
if action_type == 5:
|
||||
valid_chi = np.flatnonzero(chi_mask)
|
||||
if len(valid_chi) > 0:
|
||||
chi_choice = int(rng.choice(valid_chi))
|
||||
return {"type": action_type, "tile": tile_id, "chi": chi_choice}
|
||||
415
majiang_rl/rl/grpo.py
Normal file
415
majiang_rl/rl/grpo.py
Normal file
@@ -0,0 +1,415 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..envs.self_play_env import (
|
||||
ACTION_DECLARE_CHI,
|
||||
ACTION_DECLARE_KONG,
|
||||
ACTION_DECLARE_PONG,
|
||||
ACTION_DISCARD,
|
||||
SelfPlayMahjong,
|
||||
)
|
||||
from ..rules import hand_distance_to_win, total_fan
|
||||
|
||||
|
||||
@dataclass
|
||||
class GRPOConfig:
|
||||
group_size: int = 16
|
||||
updates: int = 50
|
||||
max_steps: int = 0
|
||||
lr: float = 3e-4
|
||||
beta_kl: float = 0.02
|
||||
seed: int = 1
|
||||
hidden_size: int = 256
|
||||
device: str = "auto"
|
||||
use_swanlab: bool = False
|
||||
swanlab_project: str = "majiang-rl"
|
||||
swanlab_run_name: str = ""
|
||||
pong_reward: float = 0.1
|
||||
closest_bonus: float = 1.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class StepSample:
|
||||
player: int
|
||||
obs: torch.Tensor
|
||||
mask: Dict[str, torch.Tensor]
|
||||
action: Dict[str, int]
|
||||
reward: float = 0.0
|
||||
|
||||
|
||||
def obs_to_tensor(obs: Dict[str, np.ndarray], device: str) -> torch.Tensor:
|
||||
parts = [
|
||||
obs["hand"].astype(np.float32),
|
||||
obs["melds"].astype(np.float32).reshape(-1),
|
||||
obs["discards"].astype(np.float32).reshape(-1),
|
||||
obs["flowers"].astype(np.float32),
|
||||
obs["wall_count"].astype(np.float32),
|
||||
obs["current_player"].astype(np.float32),
|
||||
obs["phase"].astype(np.float32),
|
||||
obs["pending_discard"].astype(np.float32),
|
||||
]
|
||||
flat = np.concatenate(parts, axis=0)
|
||||
return torch.tensor(flat, dtype=torch.float32, device=device)
|
||||
|
||||
|
||||
def mask_to_torch(mask: Dict[str, np.ndarray], device: str) -> Dict[str, torch.Tensor]:
|
||||
return {
|
||||
"type": torch.tensor(mask["type"], dtype=torch.bool, device=device),
|
||||
"discard": torch.tensor(mask["discard"], dtype=torch.bool, device=device),
|
||||
"pong": torch.tensor(mask["pong"], dtype=torch.bool, device=device),
|
||||
"kong": torch.tensor(mask["kong"], dtype=torch.bool, device=device),
|
||||
"chi": torch.tensor(mask["chi"], dtype=torch.bool, device=device),
|
||||
}
|
||||
|
||||
|
||||
class PolicyNet(nn.Module):
|
||||
def __init__(self, input_dim: int, hidden_size: int = 256):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(input_dim, hidden_size),
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_size, hidden_size),
|
||||
nn.ReLU(),
|
||||
)
|
||||
self.type_head = nn.Linear(hidden_size, 6)
|
||||
self.tile_head = nn.Linear(hidden_size, 42)
|
||||
self.chi_head = nn.Linear(hidden_size, 3)
|
||||
|
||||
def forward(self, obs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
x = self.net(obs)
|
||||
return self.type_head(x), self.tile_head(x), self.chi_head(x)
|
||||
|
||||
def sample_action(self, obs: torch.Tensor, mask: Dict[str, torch.Tensor]) -> Dict[str, int]:
|
||||
logits_type, logits_tile, logits_chi = self(obs)
|
||||
type_mask = mask["type"]
|
||||
masked_type = masked_logits(logits_type, type_mask)
|
||||
type_dist = torch.distributions.Categorical(logits=masked_type)
|
||||
action_type = int(type_dist.sample().item())
|
||||
|
||||
tile_id = 0
|
||||
chi_choice = 0
|
||||
if action_type == ACTION_DISCARD:
|
||||
tile_id = sample_from_mask(logits_tile, mask["discard"])
|
||||
elif action_type == ACTION_DECLARE_PONG:
|
||||
tile_id = sample_from_mask(logits_tile, mask["pong"])
|
||||
elif action_type == ACTION_DECLARE_KONG:
|
||||
tile_id = sample_from_mask(logits_tile, mask["kong"])
|
||||
elif action_type == ACTION_DECLARE_CHI:
|
||||
chi_choice = sample_from_mask(logits_chi, mask["chi"])
|
||||
|
||||
return {"type": action_type, "tile": tile_id, "chi": chi_choice}
|
||||
|
||||
|
||||
def masked_logits(logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
||||
masked = logits.clone()
|
||||
masked[~mask] = -1e9
|
||||
return masked
|
||||
|
||||
|
||||
def sample_from_mask(logits: torch.Tensor, mask: torch.Tensor) -> int:
|
||||
if mask.sum().item() == 0:
|
||||
return 0
|
||||
masked = masked_logits(logits, mask)
|
||||
dist = torch.distributions.Categorical(logits=masked)
|
||||
return int(dist.sample().item())
|
||||
|
||||
|
||||
def logprob_action(
|
||||
logits_type: torch.Tensor,
|
||||
logits_tile: torch.Tensor,
|
||||
logits_chi: torch.Tensor,
|
||||
action: Dict[str, int],
|
||||
mask: Dict[str, torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
action_type = int(action["type"])
|
||||
logp = masked_log_softmax(logits_type, mask["type"])[action_type]
|
||||
|
||||
if action_type == ACTION_DISCARD:
|
||||
logp = logp + masked_log_softmax(logits_tile, mask["discard"])[int(action["tile"])]
|
||||
elif action_type == ACTION_DECLARE_PONG:
|
||||
logp = logp + masked_log_softmax(logits_tile, mask["pong"])[int(action["tile"])]
|
||||
elif action_type == ACTION_DECLARE_KONG:
|
||||
logp = logp + masked_log_softmax(logits_tile, mask["kong"])[int(action["tile"])]
|
||||
elif action_type == ACTION_DECLARE_CHI:
|
||||
logp = logp + masked_log_softmax(logits_chi, mask["chi"])[int(action["chi"])]
|
||||
|
||||
return logp
|
||||
|
||||
|
||||
def masked_log_softmax(logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
||||
masked = masked_logits(logits, mask)
|
||||
return F.log_softmax(masked, dim=-1)
|
||||
|
||||
|
||||
def masked_softmax(logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
||||
masked = masked_logits(logits, mask)
|
||||
return F.softmax(masked, dim=-1)
|
||||
|
||||
|
||||
def kl_divergence(logits: torch.Tensor, ref_logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
||||
if mask.sum().item() == 0:
|
||||
return torch.tensor(0.0, device=logits.device)
|
||||
p = masked_softmax(logits, mask)
|
||||
logp = torch.log(p + 1e-12)
|
||||
logp_ref = masked_log_softmax(ref_logits, mask)
|
||||
return torch.sum(p * (logp - logp_ref))
|
||||
|
||||
|
||||
def compute_rewards(game: SelfPlayMahjong, config: GRPOConfig) -> List[float]:
|
||||
rewards = [0.0, 0.0, 0.0, 0.0]
|
||||
if game.winner is not None:
|
||||
fan = float(total_fan(game.winning_tiles, game.winning_melds))
|
||||
rewards[game.winner] = fan
|
||||
split = -fan / 3.0
|
||||
for idx in range(4):
|
||||
if idx != game.winner:
|
||||
rewards[idx] = split
|
||||
|
||||
distances = []
|
||||
for idx, player in enumerate(game.players):
|
||||
tiles = (
|
||||
game.winning_tiles if game.winner is not None and idx == game.winner else player.hand
|
||||
)
|
||||
distances.append(hand_distance_to_win(tiles, player.melds))
|
||||
min_distance = min(distances) if distances else 0
|
||||
for idx, distance in enumerate(distances):
|
||||
if distance == min_distance:
|
||||
rewards[idx] += config.closest_bonus
|
||||
return rewards
|
||||
|
||||
|
||||
def collect_rollouts(
|
||||
policy: PolicyNet, config: GRPOConfig, seed_offset: int
|
||||
) -> Tuple[List[StepSample], List[float], Dict[str, float]]:
|
||||
all_steps: List[StepSample] = []
|
||||
episode_rewards: List[float] = []
|
||||
final_rewards_all: List[float] = []
|
||||
shaping_rewards_all: List[float] = []
|
||||
|
||||
for episode in range(config.group_size):
|
||||
game = SelfPlayMahjong(
|
||||
seed=config.seed + seed_offset + episode, pong_reward=config.pong_reward
|
||||
)
|
||||
game.reset()
|
||||
episode_steps: List[StepSample] = []
|
||||
shaping_rewards = [0.0, 0.0, 0.0, 0.0]
|
||||
max_steps = _auto_max_steps(config.max_steps, len(game.wall))
|
||||
for _ in range(max_steps):
|
||||
if game.done:
|
||||
break
|
||||
player = game.current_player
|
||||
obs = game.observe(player)
|
||||
mask = game.action_mask(player)
|
||||
obs_tensor = obs_to_tensor(obs, config.device)
|
||||
mask_tensor = mask_to_torch(mask, config.device)
|
||||
with torch.no_grad():
|
||||
action = policy.sample_action(obs_tensor, mask_tensor)
|
||||
episode_steps.append(
|
||||
StepSample(player=player, obs=obs_tensor, mask=mask_tensor, action=action)
|
||||
)
|
||||
reward = game.step(player, action)
|
||||
shaping_rewards[player] += reward
|
||||
episode_steps[-1].reward = reward
|
||||
|
||||
if not game.done:
|
||||
game._end_game(None, "max_steps", [])
|
||||
rewards = compute_rewards(game, config)
|
||||
total_rewards = [rewards[i] + shaping_rewards[i] for i in range(4)]
|
||||
episode_rewards.extend(total_rewards)
|
||||
final_rewards_all.extend(rewards)
|
||||
shaping_rewards_all.extend(shaping_rewards)
|
||||
for step in episode_steps:
|
||||
step.reward += rewards[step.player]
|
||||
all_steps.append(step)
|
||||
|
||||
stats = {
|
||||
"avg_final_reward": float(np.mean(final_rewards_all)) if final_rewards_all else 0.0,
|
||||
"avg_shaping_reward": float(np.mean(shaping_rewards_all)) if shaping_rewards_all else 0.0,
|
||||
}
|
||||
return all_steps, episode_rewards, stats
|
||||
|
||||
|
||||
def train_grpo(config: GRPOConfig) -> PolicyNet:
|
||||
config.device = resolve_device(config.device)
|
||||
torch.manual_seed(config.seed)
|
||||
np.random.seed(config.seed)
|
||||
print(f"Training on device: {config.device}")
|
||||
|
||||
dummy_game = SelfPlayMahjong(seed=config.seed)
|
||||
dummy_game.reset()
|
||||
input_dim = obs_to_tensor(dummy_game.observe(0), config.device).shape[0]
|
||||
|
||||
policy = PolicyNet(input_dim, hidden_size=config.hidden_size).to(config.device)
|
||||
optimizer = torch.optim.Adam(policy.parameters(), lr=config.lr)
|
||||
|
||||
logger = SwanlabLogger(config)
|
||||
for update in range(config.updates):
|
||||
policy.eval()
|
||||
steps, episode_rewards, reward_stats = collect_rollouts(
|
||||
policy, config, update * config.group_size
|
||||
)
|
||||
if not steps:
|
||||
continue
|
||||
rewards = np.array(episode_rewards, dtype=np.float32)
|
||||
mean = float(rewards.mean())
|
||||
std = float(rewards.std() + 1e-6)
|
||||
|
||||
ref_policy = copy.deepcopy(policy).eval()
|
||||
policy.train()
|
||||
losses = []
|
||||
|
||||
for step in steps:
|
||||
logits_type, logits_tile, logits_chi = policy(step.obs)
|
||||
with torch.no_grad():
|
||||
ref_type, ref_tile, ref_chi = ref_policy(step.obs)
|
||||
|
||||
advantage = torch.tensor((step.reward - mean) / std, device=logits_type.device)
|
||||
logp = logprob_action(logits_type, logits_tile, logits_chi, step.action, step.mask)
|
||||
|
||||
tile_union = step.mask["discard"] | step.mask["pong"] | step.mask["kong"]
|
||||
kl = kl_divergence(logits_type, ref_type, step.mask["type"])
|
||||
kl += kl_divergence(logits_tile, ref_tile, tile_union)
|
||||
kl += kl_divergence(logits_chi, ref_chi, step.mask["chi"])
|
||||
|
||||
loss = -(advantage * logp) + config.beta_kl * kl
|
||||
losses.append(loss)
|
||||
|
||||
loss_value = torch.stack(losses).mean()
|
||||
optimizer.zero_grad()
|
||||
loss_value.backward()
|
||||
optimizer.step()
|
||||
|
||||
logger.log(
|
||||
{
|
||||
"loss": float(loss_value.item()),
|
||||
"avg_reward": mean,
|
||||
"reward_std": std,
|
||||
"avg_final_reward": reward_stats["avg_final_reward"],
|
||||
"avg_shaping_reward": reward_stats["avg_shaping_reward"],
|
||||
},
|
||||
step=update + 1,
|
||||
)
|
||||
print(f"Update {update + 1}/{config.updates} loss={loss_value.item():.4f} avg_reward={mean:.2f}")
|
||||
|
||||
logger.finish()
|
||||
return policy
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="GRPO self-play training for Mahjong")
|
||||
parser.add_argument("--updates", type=int, default=50)
|
||||
parser.add_argument("--group-size", type=int, default=16)
|
||||
parser.add_argument("--max-steps", type=int, default=0)
|
||||
parser.add_argument("--lr", type=float, default=3e-4)
|
||||
parser.add_argument("--beta-kl", type=float, default=0.02)
|
||||
parser.add_argument("--seed", type=int, default=1)
|
||||
parser.add_argument("--hidden-size", type=int, default=256)
|
||||
parser.add_argument("--device", type=str, default="auto")
|
||||
parser.add_argument("--swanlab", action="store_true")
|
||||
parser.add_argument("--swanlab-project", type=str, default="majiang-rl")
|
||||
parser.add_argument("--swanlab-run-name", type=str, default="")
|
||||
parser.add_argument("--pong-reward", type=float, default=0.1)
|
||||
parser.add_argument("--closest-bonus", type=float, default=1.0)
|
||||
args = parser.parse_args()
|
||||
|
||||
config = GRPOConfig(
|
||||
group_size=args.group_size,
|
||||
updates=args.updates,
|
||||
max_steps=args.max_steps,
|
||||
lr=args.lr,
|
||||
beta_kl=args.beta_kl,
|
||||
seed=args.seed,
|
||||
hidden_size=args.hidden_size,
|
||||
device=args.device,
|
||||
use_swanlab=args.swanlab,
|
||||
swanlab_project=args.swanlab_project,
|
||||
swanlab_run_name=args.swanlab_run_name,
|
||||
pong_reward=args.pong_reward,
|
||||
closest_bonus=args.closest_bonus,
|
||||
)
|
||||
train_grpo(config)
|
||||
|
||||
|
||||
def resolve_device(requested: str) -> str:
|
||||
if requested != "auto":
|
||||
return requested
|
||||
return "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
||||
def _auto_max_steps(configured: int, wall_len: int) -> int:
|
||||
if configured > 0:
|
||||
return configured
|
||||
return wall_len * 4 + 50
|
||||
|
||||
|
||||
class SwanlabLogger:
|
||||
def __init__(self, config: GRPOConfig):
|
||||
self.enabled = config.use_swanlab
|
||||
self._swanlab = None
|
||||
if not self.enabled:
|
||||
return
|
||||
try:
|
||||
import swanlab
|
||||
except ImportError:
|
||||
print("swanlab not installed, logging disabled")
|
||||
self.enabled = False
|
||||
return
|
||||
self._swanlab = swanlab
|
||||
self._init_run(config)
|
||||
|
||||
def _init_run(self, config: GRPOConfig) -> None:
|
||||
if not self._swanlab:
|
||||
return
|
||||
base_cfg = asdict(config)
|
||||
base_cfg.pop("use_swanlab", None)
|
||||
kwargs = {"project": config.swanlab_project, "config": base_cfg}
|
||||
if config.swanlab_run_name:
|
||||
kwargs["name"] = config.swanlab_run_name
|
||||
init = getattr(self._swanlab, "init", None)
|
||||
if init is None:
|
||||
self.enabled = False
|
||||
return
|
||||
try:
|
||||
init(**kwargs)
|
||||
return
|
||||
except TypeError:
|
||||
pass
|
||||
if config.swanlab_run_name:
|
||||
kwargs_alt = {"project": config.swanlab_project, "config": base_cfg, "experiment_name": config.swanlab_run_name}
|
||||
else:
|
||||
kwargs_alt = {"project": config.swanlab_project, "config": base_cfg}
|
||||
try:
|
||||
init(**kwargs_alt)
|
||||
except TypeError:
|
||||
init(project=config.swanlab_project)
|
||||
|
||||
def log(self, metrics: Dict[str, float], step: int) -> None:
|
||||
if not self.enabled or not self._swanlab:
|
||||
return
|
||||
log_fn = getattr(self._swanlab, "log", None)
|
||||
if log_fn is None:
|
||||
return
|
||||
try:
|
||||
log_fn(metrics, step=step)
|
||||
except TypeError:
|
||||
log_fn(metrics)
|
||||
|
||||
def finish(self) -> None:
|
||||
if not self.enabled or not self._swanlab:
|
||||
return
|
||||
finish_fn = getattr(self._swanlab, "finish", None)
|
||||
if finish_fn:
|
||||
finish_fn()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
61
majiang_rl/rl/trainer.py
Normal file
61
majiang_rl/rl/trainer.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
|
||||
from .agents import Transition
|
||||
|
||||
|
||||
@dataclass
|
||||
class EpisodeResult:
|
||||
total_reward: float
|
||||
steps: int
|
||||
terminated: bool
|
||||
truncated: bool
|
||||
|
||||
|
||||
class EpisodeRunner:
|
||||
def __init__(self, env, agent, max_steps: int = 1000):
|
||||
self.env = env
|
||||
self.agent = agent
|
||||
self.max_steps = max_steps
|
||||
|
||||
def run_episode(self) -> EpisodeResult:
|
||||
obs, info = self.env.reset()
|
||||
self.agent.reset()
|
||||
total_reward = 0.0
|
||||
terminated = False
|
||||
truncated = False
|
||||
steps = 0
|
||||
for step in range(self.max_steps):
|
||||
action = self.agent.act(obs, info)
|
||||
next_obs, reward, terminated, truncated, next_info = self.env.step(action)
|
||||
transition = Transition(
|
||||
obs=obs,
|
||||
action=action,
|
||||
reward=reward,
|
||||
terminated=terminated,
|
||||
truncated=truncated,
|
||||
next_obs=next_obs,
|
||||
info=next_info,
|
||||
)
|
||||
self.agent.observe(transition)
|
||||
total_reward += reward
|
||||
obs, info = next_obs, next_info
|
||||
steps = step + 1
|
||||
if terminated or truncated:
|
||||
break
|
||||
return EpisodeResult(
|
||||
total_reward=total_reward,
|
||||
steps=steps,
|
||||
terminated=terminated,
|
||||
truncated=truncated,
|
||||
)
|
||||
|
||||
|
||||
def run_training(env, agent, episodes: int = 100) -> List[EpisodeResult]:
|
||||
runner = EpisodeRunner(env, agent)
|
||||
results: List[EpisodeResult] = []
|
||||
for _ in range(episodes):
|
||||
results.append(runner.run_episode())
|
||||
return results
|
||||
277
majiang_rl/rules.py
Normal file
277
majiang_rl/rules.py
Normal file
@@ -0,0 +1,277 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, Iterable, List, Tuple
|
||||
|
||||
from .tiles import FLOWER_RANGE, HONOR_RANGE, SUITED_RANGE
|
||||
|
||||
ORPHAN_INDICES = (
|
||||
0,
|
||||
8,
|
||||
9,
|
||||
17,
|
||||
18,
|
||||
26,
|
||||
27,
|
||||
28,
|
||||
29,
|
||||
30,
|
||||
31,
|
||||
32,
|
||||
33,
|
||||
)
|
||||
|
||||
|
||||
def strip_flowers(tiles: Iterable[int]) -> List[int]:
|
||||
return [tile for tile in tiles if tile not in FLOWER_RANGE]
|
||||
|
||||
|
||||
def is_seven_pairs(counts: List[int]) -> bool:
|
||||
return sum(counts) == 14 and all(c % 2 == 0 for c in counts)
|
||||
|
||||
|
||||
def is_thirteen_orphans(counts: List[int]) -> bool:
|
||||
if sum(counts) != 14:
|
||||
return False
|
||||
if any(counts[i] == 0 for i in ORPHAN_INDICES):
|
||||
return False
|
||||
if any(counts[i] > 0 for i in range(34) if i not in ORPHAN_INDICES):
|
||||
return False
|
||||
return sum(counts[i] for i in ORPHAN_INDICES) == 14
|
||||
|
||||
|
||||
def is_all_pungs(counts: List[int]) -> bool:
|
||||
if sum(counts) != 14:
|
||||
return False
|
||||
for idx, count in enumerate(counts):
|
||||
if count >= 2:
|
||||
counts[idx] -= 2
|
||||
if _can_form_pungs(counts, 4):
|
||||
counts[idx] += 2
|
||||
return True
|
||||
counts[idx] += 2
|
||||
return False
|
||||
|
||||
|
||||
def _can_form_pungs(counts: List[int], needed: int) -> bool:
|
||||
if needed == 0:
|
||||
return sum(counts) == 0
|
||||
for idx, count in enumerate(counts):
|
||||
if count > 0:
|
||||
break
|
||||
else:
|
||||
return False
|
||||
if counts[idx] >= 3:
|
||||
counts[idx] -= 3
|
||||
if _can_form_pungs(counts, needed - 1):
|
||||
counts[idx] += 3
|
||||
return True
|
||||
counts[idx] += 3
|
||||
return False
|
||||
|
||||
|
||||
def is_all_honors(counts: List[int]) -> bool:
|
||||
return sum(counts[i] for i in SUITED_RANGE) == 0 and sum(counts) == 14
|
||||
|
||||
|
||||
def _suit_presence(counts: List[int]) -> Tuple[bool, Tuple[bool, bool, bool]]:
|
||||
suit_flags = [False, False, False]
|
||||
for idx in SUITED_RANGE:
|
||||
if counts[idx] > 0:
|
||||
suit_flags[idx // 9] = True
|
||||
honors = sum(counts[i] for i in HONOR_RANGE) > 0
|
||||
return honors, tuple(suit_flags)
|
||||
|
||||
|
||||
def is_pure_one_suit(counts: List[int]) -> bool:
|
||||
honors, suits = _suit_presence(counts)
|
||||
return not honors and sum(1 for flag in suits if flag) == 1
|
||||
|
||||
|
||||
def is_half_flush(counts: List[int]) -> bool:
|
||||
honors, suits = _suit_presence(counts)
|
||||
return honors and sum(1 for flag in suits if flag) == 1
|
||||
|
||||
|
||||
def is_all_terminals_or_honors(counts: List[int]) -> bool:
|
||||
if sum(counts) != 14:
|
||||
return False
|
||||
terminals = {0, 8, 9, 17, 18, 26}
|
||||
for idx, count in enumerate(counts):
|
||||
if count == 0:
|
||||
continue
|
||||
if idx in HONOR_RANGE:
|
||||
continue
|
||||
if idx not in terminals:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def can_form_melds(counts: List[int], needed: int) -> bool:
|
||||
if needed == 0:
|
||||
return sum(counts) == 0
|
||||
for idx, count in enumerate(counts):
|
||||
if count > 0:
|
||||
break
|
||||
else:
|
||||
return False
|
||||
|
||||
# Pong
|
||||
if counts[idx] >= 3:
|
||||
counts[idx] -= 3
|
||||
if can_form_melds(counts, needed - 1):
|
||||
counts[idx] += 3
|
||||
return True
|
||||
counts[idx] += 3
|
||||
|
||||
# Chow
|
||||
if idx in SUITED_RANGE and idx % 9 <= 6:
|
||||
if counts[idx + 1] > 0 and counts[idx + 2] > 0:
|
||||
counts[idx] -= 1
|
||||
counts[idx + 1] -= 1
|
||||
counts[idx + 2] -= 1
|
||||
if can_form_melds(counts, needed - 1):
|
||||
counts[idx] += 1
|
||||
counts[idx + 1] += 1
|
||||
counts[idx + 2] += 1
|
||||
return True
|
||||
counts[idx] += 1
|
||||
counts[idx + 1] += 1
|
||||
counts[idx + 2] += 1
|
||||
return False
|
||||
|
||||
|
||||
def is_standard_hand(counts: List[int], open_melds: int) -> bool:
|
||||
if sum(counts) != (4 - open_melds) * 3 + 2:
|
||||
return False
|
||||
for idx, count in enumerate(counts):
|
||||
if count >= 2:
|
||||
counts[idx] -= 2
|
||||
if can_form_melds(counts, 4 - open_melds):
|
||||
counts[idx] += 2
|
||||
return True
|
||||
counts[idx] += 2
|
||||
return False
|
||||
|
||||
|
||||
def is_win(tiles: Iterable[int], open_melds: int) -> bool:
|
||||
counts = [0] * 34
|
||||
for tile in tiles:
|
||||
if tile in FLOWER_RANGE:
|
||||
continue
|
||||
counts[tile] += 1
|
||||
|
||||
if is_seven_pairs(counts):
|
||||
return True
|
||||
if is_thirteen_orphans(counts):
|
||||
return True
|
||||
return is_standard_hand(counts, open_melds)
|
||||
|
||||
|
||||
def fan_breakdown(tiles: Iterable[int], melds: List[List[int]]) -> Dict[str, int]:
|
||||
counts = [0] * 34
|
||||
for tile in tiles:
|
||||
if tile in FLOWER_RANGE:
|
||||
continue
|
||||
counts[tile] += 1
|
||||
for meld in melds:
|
||||
for tile in meld:
|
||||
if tile in FLOWER_RANGE:
|
||||
continue
|
||||
counts[tile] += 1
|
||||
|
||||
fans: Dict[str, int] = {}
|
||||
if is_thirteen_orphans(counts):
|
||||
fans["thirteen_orphans"] = 88
|
||||
if is_all_honors(counts):
|
||||
fans["all_honors"] = 64
|
||||
if is_pure_one_suit(counts):
|
||||
fans["pure_one_suit"] = 24
|
||||
if is_seven_pairs(counts):
|
||||
fans["seven_pairs"] = 24
|
||||
if is_all_terminals_or_honors(counts):
|
||||
fans["all_terminals_or_honors"] = 32
|
||||
if is_all_pungs(counts.copy()):
|
||||
fans["all_pungs"] = 6
|
||||
if is_half_flush(counts):
|
||||
fans["half_flush"] = 6
|
||||
if not fans:
|
||||
fans["standard"] = 1
|
||||
return fans
|
||||
|
||||
|
||||
def total_fan(tiles: Iterable[int], melds: List[List[int]]) -> int:
|
||||
return sum(fan_breakdown(tiles, melds).values())
|
||||
|
||||
|
||||
def hand_distance_to_win(hand_tiles: Iterable[int], melds: List[List[int]]) -> int:
|
||||
counts = [0] * 34
|
||||
for tile in hand_tiles:
|
||||
if tile in FLOWER_RANGE:
|
||||
continue
|
||||
counts[tile] += 1
|
||||
|
||||
melds_done = min(len(melds), 4)
|
||||
needed_melds = max(0, 4 - melds_done)
|
||||
max_used = _max_used_tiles(counts, needed_melds)
|
||||
used_tiles = min(14, 3 * melds_done + max_used)
|
||||
return max(0, 14 - used_tiles)
|
||||
|
||||
|
||||
def _max_used_tiles(counts: List[int], needed_melds: int) -> int:
|
||||
best = 3 * _max_melds(counts, needed_melds)
|
||||
for idx, count in enumerate(counts):
|
||||
if count >= 2:
|
||||
counts[idx] -= 2
|
||||
melds = _max_melds(counts, needed_melds)
|
||||
best = max(best, 2 + 3 * melds)
|
||||
counts[idx] += 2
|
||||
return best
|
||||
|
||||
|
||||
def _max_melds(counts: List[int], limit: int) -> int:
|
||||
if limit == 0:
|
||||
return 0
|
||||
for idx, count in enumerate(counts):
|
||||
if count > 0:
|
||||
break
|
||||
else:
|
||||
return 0
|
||||
|
||||
best = 0
|
||||
counts[idx] -= 1
|
||||
best = max(best, _max_melds(counts, limit))
|
||||
counts[idx] += 1
|
||||
|
||||
if counts[idx] >= 3:
|
||||
counts[idx] -= 3
|
||||
best = max(best, 1 + _max_melds(counts, limit - 1))
|
||||
counts[idx] += 3
|
||||
|
||||
if idx in SUITED_RANGE and idx % 9 <= 6:
|
||||
if counts[idx + 1] > 0 and counts[idx + 2] > 0:
|
||||
counts[idx] -= 1
|
||||
counts[idx + 1] -= 1
|
||||
counts[idx + 2] -= 1
|
||||
best = max(best, 1 + _max_melds(counts, limit - 1))
|
||||
counts[idx] += 1
|
||||
counts[idx + 1] += 1
|
||||
counts[idx + 2] += 1
|
||||
|
||||
return best
|
||||
|
||||
|
||||
def can_chi_options(counts: List[int], tile_id: int) -> List[int]:
|
||||
if tile_id not in SUITED_RANGE:
|
||||
return []
|
||||
options: List[int] = []
|
||||
rank = tile_id % 9
|
||||
if rank >= 2:
|
||||
if counts[tile_id - 1] > 0 and counts[tile_id - 2] > 0:
|
||||
options.append(0)
|
||||
if 1 <= rank <= 7:
|
||||
if counts[tile_id - 1] > 0 and counts[tile_id + 1] > 0:
|
||||
options.append(1)
|
||||
if rank <= 6:
|
||||
if counts[tile_id + 1] > 0 and counts[tile_id + 2] > 0:
|
||||
options.append(2)
|
||||
return options
|
||||
65
majiang_rl/tiles.py
Normal file
65
majiang_rl/tiles.py
Normal file
@@ -0,0 +1,65 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Iterable, List, Tuple
|
||||
|
||||
SUITS = ("wan", "tong", "suo")
|
||||
WINDS = ("east", "south", "west", "north")
|
||||
DRAGONS = ("red", "green", "white")
|
||||
FLOWERS = (
|
||||
"plum",
|
||||
"orchid",
|
||||
"chrysanthemum",
|
||||
"bamboo",
|
||||
"spring",
|
||||
"summer",
|
||||
"autumn",
|
||||
"winter",
|
||||
)
|
||||
|
||||
# 34 suited + honors, plus 8 flowers = 42
|
||||
TILE_TYPES: List[str] = []
|
||||
for suit in SUITS:
|
||||
for rank in range(1, 10):
|
||||
TILE_TYPES.append(f"{suit}_{rank}")
|
||||
for wind in WINDS:
|
||||
TILE_TYPES.append(f"wind_{wind}")
|
||||
for dragon in DRAGONS:
|
||||
TILE_TYPES.append(f"dragon_{dragon}")
|
||||
for flower in FLOWERS:
|
||||
TILE_TYPES.append(f"flower_{flower}")
|
||||
|
||||
SUITED_RANGE = range(0, 27)
|
||||
HONOR_RANGE = range(27, 34)
|
||||
FLOWER_RANGE = range(34, 42)
|
||||
|
||||
def build_wall(rng) -> List[int]:
|
||||
wall: List[int] = []
|
||||
for tile_id in range(34):
|
||||
wall.extend([tile_id] * 4)
|
||||
for tile_id in FLOWER_RANGE:
|
||||
wall.append(tile_id)
|
||||
rng.shuffle(wall)
|
||||
return wall
|
||||
|
||||
|
||||
def is_flower(tile_id: int) -> bool:
|
||||
return tile_id in FLOWER_RANGE
|
||||
|
||||
|
||||
def tile_name(tile_id: int) -> str:
|
||||
return TILE_TYPES[tile_id]
|
||||
|
||||
|
||||
def counts_from_tiles(tiles: Iterable[int], size: int = 42) -> List[int]:
|
||||
counts = [0] * size
|
||||
for tile_id in tiles:
|
||||
counts[tile_id] += 1
|
||||
return counts
|
||||
|
||||
|
||||
def split_suit_index(tile_id: int) -> Tuple[int, int] | None:
|
||||
if tile_id not in SUITED_RANGE:
|
||||
return None
|
||||
suit = tile_id // 9
|
||||
rank = tile_id % 9
|
||||
return suit, rank
|
||||
3
majiang_rl/ui/__init__.py
Normal file
3
majiang_rl/ui/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .record import record_episode, write_record
|
||||
|
||||
__all__ = ["record_episode", "write_record"]
|
||||
34
majiang_rl/ui/record.py
Normal file
34
majiang_rl/ui/record.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from .. import MahjongEnv
|
||||
from ..rl import RandomAgent
|
||||
|
||||
|
||||
def record_episode(seed: Optional[int] = None, max_steps: int = 200) -> List[Dict[str, object]]:
|
||||
env = MahjongEnv(seed=seed)
|
||||
agent = RandomAgent(seed=seed)
|
||||
obs, info = env.reset()
|
||||
events: List[Dict[str, object]] = [{"id": 0, "type": "start", "snapshot": env._snapshot()}]
|
||||
|
||||
for _ in range(max_steps):
|
||||
action = agent.act(obs, info)
|
||||
obs, _, terminated, truncated, info = env.step(action)
|
||||
if env.event_log:
|
||||
for event in env.event_log:
|
||||
event_copy = dict(event)
|
||||
event_copy["id"] = len(events)
|
||||
events.append(event_copy)
|
||||
env.event_log = []
|
||||
if terminated or truncated:
|
||||
break
|
||||
|
||||
return events
|
||||
|
||||
|
||||
def write_record(path: Path, events: List[Dict[str, object]]) -> None:
|
||||
payload = {"events": events}
|
||||
path.write_text(json.dumps(payload, indent=2, ensure_ascii=True))
|
||||
230
majiang_rl/ui/static/index.html
Normal file
230
majiang_rl/ui/static/index.html
Normal file
@@ -0,0 +1,230 @@
|
||||
<!doctype html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1">
|
||||
<title>Mahjong RL Viewer</title>
|
||||
<style>
|
||||
:root {
|
||||
--bg: #f7f4ef;
|
||||
--ink: #1f1f1f;
|
||||
--accent: #b4572a;
|
||||
--panel: #ffffff;
|
||||
--muted: #6b6b6b;
|
||||
}
|
||||
body {
|
||||
margin: 0;
|
||||
font-family: "IBM Plex Mono", "Courier New", monospace;
|
||||
background: radial-gradient(circle at 10% 10%, #fdf6e9, var(--bg));
|
||||
color: var(--ink);
|
||||
}
|
||||
header {
|
||||
padding: 24px;
|
||||
border-bottom: 2px solid #e5ddd1;
|
||||
}
|
||||
h1 {
|
||||
margin: 0 0 6px 0;
|
||||
font-size: 22px;
|
||||
letter-spacing: 1px;
|
||||
text-transform: uppercase;
|
||||
}
|
||||
.subtitle {
|
||||
color: var(--muted);
|
||||
font-size: 13px;
|
||||
}
|
||||
.layout {
|
||||
display: grid;
|
||||
grid-template-columns: 1fr 2fr;
|
||||
gap: 20px;
|
||||
padding: 20px;
|
||||
}
|
||||
.panel {
|
||||
background: var(--panel);
|
||||
border: 2px solid #e5ddd1;
|
||||
border-radius: 12px;
|
||||
padding: 16px;
|
||||
box-shadow: 4px 6px 0 rgba(0, 0, 0, 0.05);
|
||||
}
|
||||
.controls {
|
||||
display: flex;
|
||||
gap: 8px;
|
||||
align-items: center;
|
||||
margin-bottom: 12px;
|
||||
}
|
||||
button {
|
||||
background: var(--accent);
|
||||
color: #fff;
|
||||
border: none;
|
||||
padding: 8px 12px;
|
||||
border-radius: 8px;
|
||||
cursor: pointer;
|
||||
font-weight: 600;
|
||||
}
|
||||
button.secondary {
|
||||
background: #3f3f3f;
|
||||
}
|
||||
.status {
|
||||
font-size: 14px;
|
||||
line-height: 1.4;
|
||||
}
|
||||
.board {
|
||||
display: grid;
|
||||
grid-template-columns: repeat(2, 1fr);
|
||||
gap: 12px;
|
||||
}
|
||||
.player {
|
||||
border: 1px dashed #d2c4b2;
|
||||
padding: 10px;
|
||||
border-radius: 10px;
|
||||
}
|
||||
.player h3 {
|
||||
margin: 0 0 6px 0;
|
||||
font-size: 14px;
|
||||
color: var(--accent);
|
||||
}
|
||||
.tiles {
|
||||
font-size: 12px;
|
||||
line-height: 1.6;
|
||||
}
|
||||
.label {
|
||||
display: inline-block;
|
||||
min-width: 72px;
|
||||
color: var(--muted);
|
||||
}
|
||||
@media (max-width: 900px) {
|
||||
.layout {
|
||||
grid-template-columns: 1fr;
|
||||
}
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<header>
|
||||
<h1>Mahjong RL Viewer</h1>
|
||||
<div class="subtitle">Playback of a simulated Guobiao round</div>
|
||||
</header>
|
||||
|
||||
<div class="layout">
|
||||
<div class="panel">
|
||||
<div class="controls">
|
||||
<button id="prev">Prev</button>
|
||||
<button id="next">Next</button>
|
||||
<button id="play" class="secondary">Play</button>
|
||||
<span id="counter" class="subtitle"></span>
|
||||
</div>
|
||||
<div id="status" class="status"></div>
|
||||
</div>
|
||||
<div class="panel">
|
||||
<div id="board" class="board"></div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
let events = [];
|
||||
let idx = 0;
|
||||
let timer = null;
|
||||
|
||||
function describe(event) {
|
||||
if (!event) return "";
|
||||
const type = event.type;
|
||||
if (type === "start") {
|
||||
return "Start of round";
|
||||
}
|
||||
if (type === "discard") {
|
||||
return `Player ${event.player} discarded ${event.tile}`;
|
||||
}
|
||||
if (type === "claim") {
|
||||
return `Player ${event.player} claimed ${event.claim} on ${event.tile} from Player ${event.from_player}`;
|
||||
}
|
||||
if (type === "win") {
|
||||
return `Player ${event.player} wins (${event.reason})`;
|
||||
}
|
||||
if (type === "draw") {
|
||||
return `Round ends in draw (${event.reason})`;
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
function tilesToText(tiles) {
|
||||
if (!tiles || tiles.length === 0) return "-";
|
||||
return tiles.join(" ");
|
||||
}
|
||||
|
||||
function renderBoard(snapshot) {
|
||||
const board = document.getElementById("board");
|
||||
board.innerHTML = "";
|
||||
for (let i = 0; i < snapshot.hands.length; i += 1) {
|
||||
const player = document.createElement("div");
|
||||
player.className = "player";
|
||||
const title = document.createElement("h3");
|
||||
title.textContent = `Player ${i}`;
|
||||
const hand = document.createElement("div");
|
||||
hand.className = "tiles";
|
||||
hand.innerHTML = `<span class="label">Hand</span>${tilesToText(snapshot.hands[i])}`;
|
||||
const discards = document.createElement("div");
|
||||
discards.className = "tiles";
|
||||
discards.innerHTML = `<span class="label">Discards</span>${tilesToText(snapshot.discards[i])}`;
|
||||
const melds = document.createElement("div");
|
||||
melds.className = "tiles";
|
||||
melds.innerHTML = `<span class="label">Melds</span>${tilesToText(snapshot.melds[i].flat())}`;
|
||||
const flowers = document.createElement("div");
|
||||
flowers.className = "tiles";
|
||||
flowers.innerHTML = `<span class="label">Flowers</span>${tilesToText(snapshot.flowers[i])}`;
|
||||
player.appendChild(title);
|
||||
player.appendChild(hand);
|
||||
player.appendChild(discards);
|
||||
player.appendChild(melds);
|
||||
player.appendChild(flowers);
|
||||
board.appendChild(player);
|
||||
}
|
||||
}
|
||||
|
||||
function render() {
|
||||
const event = events[idx];
|
||||
if (!event) return;
|
||||
document.getElementById("counter").textContent = `Step ${idx + 1} / ${events.length}`;
|
||||
document.getElementById("status").textContent = `${describe(event)} | Wall: ${event.snapshot.wall_count}`;
|
||||
renderBoard(event.snapshot);
|
||||
}
|
||||
|
||||
function togglePlay() {
|
||||
const button = document.getElementById("play");
|
||||
if (timer) {
|
||||
clearInterval(timer);
|
||||
timer = null;
|
||||
button.textContent = "Play";
|
||||
return;
|
||||
}
|
||||
button.textContent = "Pause";
|
||||
timer = setInterval(() => {
|
||||
idx = Math.min(idx + 1, events.length - 1);
|
||||
render();
|
||||
if (idx >= events.length - 1) {
|
||||
togglePlay();
|
||||
}
|
||||
}, 700);
|
||||
}
|
||||
|
||||
document.getElementById("prev").addEventListener("click", () => {
|
||||
idx = Math.max(0, idx - 1);
|
||||
render();
|
||||
});
|
||||
document.getElementById("next").addEventListener("click", () => {
|
||||
idx = Math.min(events.length - 1, idx + 1);
|
||||
render();
|
||||
});
|
||||
document.getElementById("play").addEventListener("click", togglePlay);
|
||||
|
||||
fetch("steps.json")
|
||||
.then((response) => response.json())
|
||||
.then((data) => {
|
||||
events = data.events || [];
|
||||
idx = 0;
|
||||
render();
|
||||
})
|
||||
.catch(() => {
|
||||
document.getElementById("status").textContent = "Failed to load steps.json";
|
||||
});
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
15562
majiang_rl/ui/static/steps.json
Normal file
15562
majiang_rl/ui/static/steps.json
Normal file
File diff suppressed because it is too large
Load Diff
33
majiang_rl/ui/web.py
Normal file
33
majiang_rl/ui/web.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import functools
|
||||
import http.server
|
||||
import socketserver
|
||||
from pathlib import Path
|
||||
|
||||
from .record import record_episode, write_record
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="Mahjong RL web viewer")
|
||||
parser.add_argument("--port", type=int, default=8000)
|
||||
parser.add_argument("--seed", type=int, default=None)
|
||||
parser.add_argument("--max-steps", type=int, default=200)
|
||||
args = parser.parse_args()
|
||||
|
||||
static_dir = Path(__file__).parent / "static"
|
||||
static_dir.mkdir(parents=True, exist_ok=True)
|
||||
record_path = static_dir / "steps.json"
|
||||
|
||||
events = record_episode(seed=args.seed, max_steps=args.max_steps)
|
||||
write_record(record_path, events)
|
||||
|
||||
handler = functools.partial(http.server.SimpleHTTPRequestHandler, directory=str(static_dir))
|
||||
with socketserver.TCPServer(("", args.port), handler) as httpd:
|
||||
print(f"Serving UI at http://localhost:{args.port}/index.html")
|
||||
httpd.serve_forever()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
13
pyproject.toml
Normal file
13
pyproject.toml
Normal file
@@ -0,0 +1,13 @@
|
||||
[project]
|
||||
name = "majiang-rl"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"gymnasium>=0.29.0",
|
||||
"numpy>=1.26.0",
|
||||
"torch>=2.2.0",
|
||||
"torchvision>=0.24.1",
|
||||
"swanlab>=0.3.0",
|
||||
]
|
||||
Reference in New Issue
Block a user