from __future__ import annotations import argparse import copy from dataclasses import asdict, dataclass from typing import Dict, List, Tuple import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from ..envs.self_play_env import ( ACTION_DECLARE_CHI, ACTION_DECLARE_KONG, ACTION_DECLARE_PONG, ACTION_DISCARD, SelfPlayMahjong, ) from ..rules import hand_distance_to_win, total_fan @dataclass class GRPOConfig: group_size: int = 16 updates: int = 50 max_steps: int = 0 lr: float = 3e-4 beta_kl: float = 0.02 seed: int = 1 hidden_size: int = 256 device: str = "auto" use_swanlab: bool = False swanlab_project: str = "majiang-rl" swanlab_run_name: str = "" pong_reward: float = 0.1 closest_bonus: float = 1.0 @dataclass class StepSample: player: int obs: torch.Tensor mask: Dict[str, torch.Tensor] action: Dict[str, int] reward: float = 0.0 def obs_to_tensor(obs: Dict[str, np.ndarray], device: str) -> torch.Tensor: parts = [ obs["hand"].astype(np.float32), obs["melds"].astype(np.float32).reshape(-1), obs["discards"].astype(np.float32).reshape(-1), obs["flowers"].astype(np.float32), obs["wall_count"].astype(np.float32), obs["current_player"].astype(np.float32), obs["phase"].astype(np.float32), obs["pending_discard"].astype(np.float32), ] flat = np.concatenate(parts, axis=0) return torch.tensor(flat, dtype=torch.float32, device=device) def mask_to_torch(mask: Dict[str, np.ndarray], device: str) -> Dict[str, torch.Tensor]: return { "type": torch.tensor(mask["type"], dtype=torch.bool, device=device), "discard": torch.tensor(mask["discard"], dtype=torch.bool, device=device), "pong": torch.tensor(mask["pong"], dtype=torch.bool, device=device), "kong": torch.tensor(mask["kong"], dtype=torch.bool, device=device), "chi": torch.tensor(mask["chi"], dtype=torch.bool, device=device), } class PolicyNet(nn.Module): def __init__(self, input_dim: int, hidden_size: int = 256): super().__init__() self.net = nn.Sequential( nn.Linear(input_dim, hidden_size), nn.ReLU(), nn.Linear(hidden_size, hidden_size), nn.ReLU(), ) self.type_head = nn.Linear(hidden_size, 6) self.tile_head = nn.Linear(hidden_size, 42) self.chi_head = nn.Linear(hidden_size, 3) def forward(self, obs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: x = self.net(obs) return self.type_head(x), self.tile_head(x), self.chi_head(x) def sample_action(self, obs: torch.Tensor, mask: Dict[str, torch.Tensor]) -> Dict[str, int]: logits_type, logits_tile, logits_chi = self(obs) type_mask = mask["type"] masked_type = masked_logits(logits_type, type_mask) type_dist = torch.distributions.Categorical(logits=masked_type) action_type = int(type_dist.sample().item()) tile_id = 0 chi_choice = 0 if action_type == ACTION_DISCARD: tile_id = sample_from_mask(logits_tile, mask["discard"]) elif action_type == ACTION_DECLARE_PONG: tile_id = sample_from_mask(logits_tile, mask["pong"]) elif action_type == ACTION_DECLARE_KONG: tile_id = sample_from_mask(logits_tile, mask["kong"]) elif action_type == ACTION_DECLARE_CHI: chi_choice = sample_from_mask(logits_chi, mask["chi"]) return {"type": action_type, "tile": tile_id, "chi": chi_choice} def masked_logits(logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: masked = logits.clone() masked[~mask] = -1e9 return masked def sample_from_mask(logits: torch.Tensor, mask: torch.Tensor) -> int: if mask.sum().item() == 0: return 0 masked = masked_logits(logits, mask) dist = torch.distributions.Categorical(logits=masked) return int(dist.sample().item()) def logprob_action( logits_type: torch.Tensor, logits_tile: torch.Tensor, logits_chi: torch.Tensor, action: Dict[str, int], mask: Dict[str, torch.Tensor], ) -> torch.Tensor: action_type = int(action["type"]) logp = masked_log_softmax(logits_type, mask["type"])[action_type] if action_type == ACTION_DISCARD: logp = logp + masked_log_softmax(logits_tile, mask["discard"])[int(action["tile"])] elif action_type == ACTION_DECLARE_PONG: logp = logp + masked_log_softmax(logits_tile, mask["pong"])[int(action["tile"])] elif action_type == ACTION_DECLARE_KONG: logp = logp + masked_log_softmax(logits_tile, mask["kong"])[int(action["tile"])] elif action_type == ACTION_DECLARE_CHI: logp = logp + masked_log_softmax(logits_chi, mask["chi"])[int(action["chi"])] return logp def masked_log_softmax(logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: masked = masked_logits(logits, mask) return F.log_softmax(masked, dim=-1) def masked_softmax(logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: masked = masked_logits(logits, mask) return F.softmax(masked, dim=-1) def kl_divergence(logits: torch.Tensor, ref_logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: if mask.sum().item() == 0: return torch.tensor(0.0, device=logits.device) p = masked_softmax(logits, mask) logp = torch.log(p + 1e-12) logp_ref = masked_log_softmax(ref_logits, mask) return torch.sum(p * (logp - logp_ref)) def compute_rewards(game: SelfPlayMahjong, config: GRPOConfig) -> List[float]: rewards = [0.0, 0.0, 0.0, 0.0] if game.winner is not None: fan = float(total_fan(game.winning_tiles, game.winning_melds)) rewards[game.winner] = fan split = -fan / 3.0 for idx in range(4): if idx != game.winner: rewards[idx] = split distances = [] for idx, player in enumerate(game.players): tiles = ( game.winning_tiles if game.winner is not None and idx == game.winner else player.hand ) distances.append(hand_distance_to_win(tiles, player.melds)) min_distance = min(distances) if distances else 0 for idx, distance in enumerate(distances): if distance == min_distance: rewards[idx] += config.closest_bonus return rewards def collect_rollouts( policy: PolicyNet, config: GRPOConfig, seed_offset: int ) -> Tuple[List[StepSample], List[float], Dict[str, float]]: all_steps: List[StepSample] = [] episode_rewards: List[float] = [] final_rewards_all: List[float] = [] shaping_rewards_all: List[float] = [] for episode in range(config.group_size): game = SelfPlayMahjong( seed=config.seed + seed_offset + episode, pong_reward=config.pong_reward ) game.reset() episode_steps: List[StepSample] = [] shaping_rewards = [0.0, 0.0, 0.0, 0.0] max_steps = _auto_max_steps(config.max_steps, len(game.wall)) for _ in range(max_steps): if game.done: break player = game.current_player obs = game.observe(player) mask = game.action_mask(player) obs_tensor = obs_to_tensor(obs, config.device) mask_tensor = mask_to_torch(mask, config.device) with torch.no_grad(): action = policy.sample_action(obs_tensor, mask_tensor) episode_steps.append( StepSample(player=player, obs=obs_tensor, mask=mask_tensor, action=action) ) reward = game.step(player, action) shaping_rewards[player] += reward episode_steps[-1].reward = reward if not game.done: game._end_game(None, "max_steps", []) rewards = compute_rewards(game, config) total_rewards = [rewards[i] + shaping_rewards[i] for i in range(4)] episode_rewards.extend(total_rewards) final_rewards_all.extend(rewards) shaping_rewards_all.extend(shaping_rewards) for step in episode_steps: step.reward += rewards[step.player] all_steps.append(step) stats = { "avg_final_reward": float(np.mean(final_rewards_all)) if final_rewards_all else 0.0, "avg_shaping_reward": float(np.mean(shaping_rewards_all)) if shaping_rewards_all else 0.0, } return all_steps, episode_rewards, stats def train_grpo(config: GRPOConfig) -> PolicyNet: config.device = resolve_device(config.device) torch.manual_seed(config.seed) np.random.seed(config.seed) print(f"Training on device: {config.device}") dummy_game = SelfPlayMahjong(seed=config.seed) dummy_game.reset() input_dim = obs_to_tensor(dummy_game.observe(0), config.device).shape[0] policy = PolicyNet(input_dim, hidden_size=config.hidden_size).to(config.device) optimizer = torch.optim.Adam(policy.parameters(), lr=config.lr) logger = SwanlabLogger(config) for update in range(config.updates): policy.eval() steps, episode_rewards, reward_stats = collect_rollouts( policy, config, update * config.group_size ) if not steps: continue rewards = np.array(episode_rewards, dtype=np.float32) mean = float(rewards.mean()) std = float(rewards.std() + 1e-6) ref_policy = copy.deepcopy(policy).eval() policy.train() losses = [] for step in steps: logits_type, logits_tile, logits_chi = policy(step.obs) with torch.no_grad(): ref_type, ref_tile, ref_chi = ref_policy(step.obs) advantage = torch.tensor((step.reward - mean) / std, device=logits_type.device) logp = logprob_action(logits_type, logits_tile, logits_chi, step.action, step.mask) tile_union = step.mask["discard"] | step.mask["pong"] | step.mask["kong"] kl = kl_divergence(logits_type, ref_type, step.mask["type"]) kl += kl_divergence(logits_tile, ref_tile, tile_union) kl += kl_divergence(logits_chi, ref_chi, step.mask["chi"]) loss = -(advantage * logp) + config.beta_kl * kl losses.append(loss) loss_value = torch.stack(losses).mean() optimizer.zero_grad() loss_value.backward() optimizer.step() logger.log( { "loss": float(loss_value.item()), "avg_reward": mean, "reward_std": std, "avg_final_reward": reward_stats["avg_final_reward"], "avg_shaping_reward": reward_stats["avg_shaping_reward"], }, step=update + 1, ) print(f"Update {update + 1}/{config.updates} loss={loss_value.item():.4f} avg_reward={mean:.2f}") logger.finish() return policy def main() -> None: parser = argparse.ArgumentParser(description="GRPO self-play training for Mahjong") parser.add_argument("--updates", type=int, default=50) parser.add_argument("--group-size", type=int, default=16) parser.add_argument("--max-steps", type=int, default=0) parser.add_argument("--lr", type=float, default=3e-4) parser.add_argument("--beta-kl", type=float, default=0.02) parser.add_argument("--seed", type=int, default=1) parser.add_argument("--hidden-size", type=int, default=256) parser.add_argument("--device", type=str, default="auto") parser.add_argument("--swanlab", action="store_true") parser.add_argument("--swanlab-project", type=str, default="majiang-rl") parser.add_argument("--swanlab-run-name", type=str, default="") parser.add_argument("--pong-reward", type=float, default=0.1) parser.add_argument("--closest-bonus", type=float, default=1.0) args = parser.parse_args() config = GRPOConfig( group_size=args.group_size, updates=args.updates, max_steps=args.max_steps, lr=args.lr, beta_kl=args.beta_kl, seed=args.seed, hidden_size=args.hidden_size, device=args.device, use_swanlab=args.swanlab, swanlab_project=args.swanlab_project, swanlab_run_name=args.swanlab_run_name, pong_reward=args.pong_reward, closest_bonus=args.closest_bonus, ) train_grpo(config) def resolve_device(requested: str) -> str: if requested != "auto": return requested return "cuda" if torch.cuda.is_available() else "cpu" def _auto_max_steps(configured: int, wall_len: int) -> int: if configured > 0: return configured return wall_len * 4 + 50 class SwanlabLogger: def __init__(self, config: GRPOConfig): self.enabled = config.use_swanlab self._swanlab = None if not self.enabled: return try: import swanlab except ImportError: print("swanlab not installed, logging disabled") self.enabled = False return self._swanlab = swanlab self._init_run(config) def _init_run(self, config: GRPOConfig) -> None: if not self._swanlab: return base_cfg = asdict(config) base_cfg.pop("use_swanlab", None) kwargs = {"project": config.swanlab_project, "config": base_cfg} if config.swanlab_run_name: kwargs["name"] = config.swanlab_run_name init = getattr(self._swanlab, "init", None) if init is None: self.enabled = False return try: init(**kwargs) return except TypeError: pass if config.swanlab_run_name: kwargs_alt = {"project": config.swanlab_project, "config": base_cfg, "experiment_name": config.swanlab_run_name} else: kwargs_alt = {"project": config.swanlab_project, "config": base_cfg} try: init(**kwargs_alt) except TypeError: init(project=config.swanlab_project) def log(self, metrics: Dict[str, float], step: int) -> None: if not self.enabled or not self._swanlab: return log_fn = getattr(self._swanlab, "log", None) if log_fn is None: return try: log_fn(metrics, step=step) except TypeError: log_fn(metrics) def finish(self) -> None: if not self.enabled or not self._swanlab: return finish_fn = getattr(self._swanlab, "finish", None) if finish_fn: finish_fn() if __name__ == "__main__": main()