feat: initialize majiang-rl project

This commit is contained in:
game-loader
2026-01-14 10:49:00 +08:00
commit b29a18b459
21 changed files with 18895 additions and 0 deletions

13
majiang_rl/rl/__init__.py Normal file
View File

@@ -0,0 +1,13 @@
from .agents import RandomAgent, Transition
from .grpo import GRPOConfig, train_grpo
from .trainer import EpisodeResult, EpisodeRunner, run_training
__all__ = [
"RandomAgent",
"Transition",
"EpisodeResult",
"EpisodeRunner",
"run_training",
"GRPOConfig",
"train_grpo",
]

76
majiang_rl/rl/agents.py Normal file
View File

@@ -0,0 +1,76 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Dict, Protocol
import numpy as np
class Agent(Protocol):
def act(self, obs: Dict[str, Any], info: Dict[str, Any]) -> Dict[str, int]:
...
def observe(self, transition: "Transition") -> None:
...
def reset(self) -> None:
...
@dataclass
class Transition:
obs: Dict[str, Any]
action: Dict[str, int]
reward: float
terminated: bool
truncated: bool
next_obs: Dict[str, Any]
info: Dict[str, Any]
class RandomAgent:
def __init__(self, seed: int | None = None):
self.rng = np.random.default_rng(seed)
def reset(self) -> None:
return None
def act(self, obs: Dict[str, Any], info: Dict[str, Any]) -> Dict[str, int]:
mask = info.get("action_mask")
if mask is None:
return {"type": 0, "tile": 0, "chi": 0}
return sample_action_from_mask(mask, self.rng)
def observe(self, transition: Transition) -> None:
return None
def sample_action_from_mask(mask: Dict[str, np.ndarray], rng: np.random.Generator) -> Dict[str, int]:
type_mask = np.asarray(mask["type"], dtype=bool)
discard_mask = np.asarray(mask["discard"], dtype=bool)
pong_mask = np.asarray(mask["pong"], dtype=bool)
kong_mask = np.asarray(mask["kong"], dtype=bool)
chi_mask = np.asarray(mask["chi"], dtype=bool)
valid_types = np.flatnonzero(type_mask)
if len(valid_types) == 0:
return {"type": 3, "tile": 0, "chi": 0}
action_type = int(rng.choice(valid_types))
tile_id = 0
chi_choice = 0
if action_type == 0:
valid_tiles = np.flatnonzero(discard_mask)
if len(valid_tiles) > 0:
tile_id = int(rng.choice(valid_tiles))
if action_type == 2:
valid_tiles = np.flatnonzero(kong_mask)
if len(valid_tiles) > 0:
tile_id = int(rng.choice(valid_tiles))
if action_type == 4:
valid_tiles = np.flatnonzero(pong_mask)
if len(valid_tiles) > 0:
tile_id = int(rng.choice(valid_tiles))
if action_type == 5:
valid_chi = np.flatnonzero(chi_mask)
if len(valid_chi) > 0:
chi_choice = int(rng.choice(valid_chi))
return {"type": action_type, "tile": tile_id, "chi": chi_choice}

415
majiang_rl/rl/grpo.py Normal file
View File

@@ -0,0 +1,415 @@
from __future__ import annotations
import argparse
import copy
from dataclasses import asdict, dataclass
from typing import Dict, List, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..envs.self_play_env import (
ACTION_DECLARE_CHI,
ACTION_DECLARE_KONG,
ACTION_DECLARE_PONG,
ACTION_DISCARD,
SelfPlayMahjong,
)
from ..rules import hand_distance_to_win, total_fan
@dataclass
class GRPOConfig:
group_size: int = 16
updates: int = 50
max_steps: int = 0
lr: float = 3e-4
beta_kl: float = 0.02
seed: int = 1
hidden_size: int = 256
device: str = "auto"
use_swanlab: bool = False
swanlab_project: str = "majiang-rl"
swanlab_run_name: str = ""
pong_reward: float = 0.1
closest_bonus: float = 1.0
@dataclass
class StepSample:
player: int
obs: torch.Tensor
mask: Dict[str, torch.Tensor]
action: Dict[str, int]
reward: float = 0.0
def obs_to_tensor(obs: Dict[str, np.ndarray], device: str) -> torch.Tensor:
parts = [
obs["hand"].astype(np.float32),
obs["melds"].astype(np.float32).reshape(-1),
obs["discards"].astype(np.float32).reshape(-1),
obs["flowers"].astype(np.float32),
obs["wall_count"].astype(np.float32),
obs["current_player"].astype(np.float32),
obs["phase"].astype(np.float32),
obs["pending_discard"].astype(np.float32),
]
flat = np.concatenate(parts, axis=0)
return torch.tensor(flat, dtype=torch.float32, device=device)
def mask_to_torch(mask: Dict[str, np.ndarray], device: str) -> Dict[str, torch.Tensor]:
return {
"type": torch.tensor(mask["type"], dtype=torch.bool, device=device),
"discard": torch.tensor(mask["discard"], dtype=torch.bool, device=device),
"pong": torch.tensor(mask["pong"], dtype=torch.bool, device=device),
"kong": torch.tensor(mask["kong"], dtype=torch.bool, device=device),
"chi": torch.tensor(mask["chi"], dtype=torch.bool, device=device),
}
class PolicyNet(nn.Module):
def __init__(self, input_dim: int, hidden_size: int = 256):
super().__init__()
self.net = nn.Sequential(
nn.Linear(input_dim, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, hidden_size),
nn.ReLU(),
)
self.type_head = nn.Linear(hidden_size, 6)
self.tile_head = nn.Linear(hidden_size, 42)
self.chi_head = nn.Linear(hidden_size, 3)
def forward(self, obs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
x = self.net(obs)
return self.type_head(x), self.tile_head(x), self.chi_head(x)
def sample_action(self, obs: torch.Tensor, mask: Dict[str, torch.Tensor]) -> Dict[str, int]:
logits_type, logits_tile, logits_chi = self(obs)
type_mask = mask["type"]
masked_type = masked_logits(logits_type, type_mask)
type_dist = torch.distributions.Categorical(logits=masked_type)
action_type = int(type_dist.sample().item())
tile_id = 0
chi_choice = 0
if action_type == ACTION_DISCARD:
tile_id = sample_from_mask(logits_tile, mask["discard"])
elif action_type == ACTION_DECLARE_PONG:
tile_id = sample_from_mask(logits_tile, mask["pong"])
elif action_type == ACTION_DECLARE_KONG:
tile_id = sample_from_mask(logits_tile, mask["kong"])
elif action_type == ACTION_DECLARE_CHI:
chi_choice = sample_from_mask(logits_chi, mask["chi"])
return {"type": action_type, "tile": tile_id, "chi": chi_choice}
def masked_logits(logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
masked = logits.clone()
masked[~mask] = -1e9
return masked
def sample_from_mask(logits: torch.Tensor, mask: torch.Tensor) -> int:
if mask.sum().item() == 0:
return 0
masked = masked_logits(logits, mask)
dist = torch.distributions.Categorical(logits=masked)
return int(dist.sample().item())
def logprob_action(
logits_type: torch.Tensor,
logits_tile: torch.Tensor,
logits_chi: torch.Tensor,
action: Dict[str, int],
mask: Dict[str, torch.Tensor],
) -> torch.Tensor:
action_type = int(action["type"])
logp = masked_log_softmax(logits_type, mask["type"])[action_type]
if action_type == ACTION_DISCARD:
logp = logp + masked_log_softmax(logits_tile, mask["discard"])[int(action["tile"])]
elif action_type == ACTION_DECLARE_PONG:
logp = logp + masked_log_softmax(logits_tile, mask["pong"])[int(action["tile"])]
elif action_type == ACTION_DECLARE_KONG:
logp = logp + masked_log_softmax(logits_tile, mask["kong"])[int(action["tile"])]
elif action_type == ACTION_DECLARE_CHI:
logp = logp + masked_log_softmax(logits_chi, mask["chi"])[int(action["chi"])]
return logp
def masked_log_softmax(logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
masked = masked_logits(logits, mask)
return F.log_softmax(masked, dim=-1)
def masked_softmax(logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
masked = masked_logits(logits, mask)
return F.softmax(masked, dim=-1)
def kl_divergence(logits: torch.Tensor, ref_logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
if mask.sum().item() == 0:
return torch.tensor(0.0, device=logits.device)
p = masked_softmax(logits, mask)
logp = torch.log(p + 1e-12)
logp_ref = masked_log_softmax(ref_logits, mask)
return torch.sum(p * (logp - logp_ref))
def compute_rewards(game: SelfPlayMahjong, config: GRPOConfig) -> List[float]:
rewards = [0.0, 0.0, 0.0, 0.0]
if game.winner is not None:
fan = float(total_fan(game.winning_tiles, game.winning_melds))
rewards[game.winner] = fan
split = -fan / 3.0
for idx in range(4):
if idx != game.winner:
rewards[idx] = split
distances = []
for idx, player in enumerate(game.players):
tiles = (
game.winning_tiles if game.winner is not None and idx == game.winner else player.hand
)
distances.append(hand_distance_to_win(tiles, player.melds))
min_distance = min(distances) if distances else 0
for idx, distance in enumerate(distances):
if distance == min_distance:
rewards[idx] += config.closest_bonus
return rewards
def collect_rollouts(
policy: PolicyNet, config: GRPOConfig, seed_offset: int
) -> Tuple[List[StepSample], List[float], Dict[str, float]]:
all_steps: List[StepSample] = []
episode_rewards: List[float] = []
final_rewards_all: List[float] = []
shaping_rewards_all: List[float] = []
for episode in range(config.group_size):
game = SelfPlayMahjong(
seed=config.seed + seed_offset + episode, pong_reward=config.pong_reward
)
game.reset()
episode_steps: List[StepSample] = []
shaping_rewards = [0.0, 0.0, 0.0, 0.0]
max_steps = _auto_max_steps(config.max_steps, len(game.wall))
for _ in range(max_steps):
if game.done:
break
player = game.current_player
obs = game.observe(player)
mask = game.action_mask(player)
obs_tensor = obs_to_tensor(obs, config.device)
mask_tensor = mask_to_torch(mask, config.device)
with torch.no_grad():
action = policy.sample_action(obs_tensor, mask_tensor)
episode_steps.append(
StepSample(player=player, obs=obs_tensor, mask=mask_tensor, action=action)
)
reward = game.step(player, action)
shaping_rewards[player] += reward
episode_steps[-1].reward = reward
if not game.done:
game._end_game(None, "max_steps", [])
rewards = compute_rewards(game, config)
total_rewards = [rewards[i] + shaping_rewards[i] for i in range(4)]
episode_rewards.extend(total_rewards)
final_rewards_all.extend(rewards)
shaping_rewards_all.extend(shaping_rewards)
for step in episode_steps:
step.reward += rewards[step.player]
all_steps.append(step)
stats = {
"avg_final_reward": float(np.mean(final_rewards_all)) if final_rewards_all else 0.0,
"avg_shaping_reward": float(np.mean(shaping_rewards_all)) if shaping_rewards_all else 0.0,
}
return all_steps, episode_rewards, stats
def train_grpo(config: GRPOConfig) -> PolicyNet:
config.device = resolve_device(config.device)
torch.manual_seed(config.seed)
np.random.seed(config.seed)
print(f"Training on device: {config.device}")
dummy_game = SelfPlayMahjong(seed=config.seed)
dummy_game.reset()
input_dim = obs_to_tensor(dummy_game.observe(0), config.device).shape[0]
policy = PolicyNet(input_dim, hidden_size=config.hidden_size).to(config.device)
optimizer = torch.optim.Adam(policy.parameters(), lr=config.lr)
logger = SwanlabLogger(config)
for update in range(config.updates):
policy.eval()
steps, episode_rewards, reward_stats = collect_rollouts(
policy, config, update * config.group_size
)
if not steps:
continue
rewards = np.array(episode_rewards, dtype=np.float32)
mean = float(rewards.mean())
std = float(rewards.std() + 1e-6)
ref_policy = copy.deepcopy(policy).eval()
policy.train()
losses = []
for step in steps:
logits_type, logits_tile, logits_chi = policy(step.obs)
with torch.no_grad():
ref_type, ref_tile, ref_chi = ref_policy(step.obs)
advantage = torch.tensor((step.reward - mean) / std, device=logits_type.device)
logp = logprob_action(logits_type, logits_tile, logits_chi, step.action, step.mask)
tile_union = step.mask["discard"] | step.mask["pong"] | step.mask["kong"]
kl = kl_divergence(logits_type, ref_type, step.mask["type"])
kl += kl_divergence(logits_tile, ref_tile, tile_union)
kl += kl_divergence(logits_chi, ref_chi, step.mask["chi"])
loss = -(advantage * logp) + config.beta_kl * kl
losses.append(loss)
loss_value = torch.stack(losses).mean()
optimizer.zero_grad()
loss_value.backward()
optimizer.step()
logger.log(
{
"loss": float(loss_value.item()),
"avg_reward": mean,
"reward_std": std,
"avg_final_reward": reward_stats["avg_final_reward"],
"avg_shaping_reward": reward_stats["avg_shaping_reward"],
},
step=update + 1,
)
print(f"Update {update + 1}/{config.updates} loss={loss_value.item():.4f} avg_reward={mean:.2f}")
logger.finish()
return policy
def main() -> None:
parser = argparse.ArgumentParser(description="GRPO self-play training for Mahjong")
parser.add_argument("--updates", type=int, default=50)
parser.add_argument("--group-size", type=int, default=16)
parser.add_argument("--max-steps", type=int, default=0)
parser.add_argument("--lr", type=float, default=3e-4)
parser.add_argument("--beta-kl", type=float, default=0.02)
parser.add_argument("--seed", type=int, default=1)
parser.add_argument("--hidden-size", type=int, default=256)
parser.add_argument("--device", type=str, default="auto")
parser.add_argument("--swanlab", action="store_true")
parser.add_argument("--swanlab-project", type=str, default="majiang-rl")
parser.add_argument("--swanlab-run-name", type=str, default="")
parser.add_argument("--pong-reward", type=float, default=0.1)
parser.add_argument("--closest-bonus", type=float, default=1.0)
args = parser.parse_args()
config = GRPOConfig(
group_size=args.group_size,
updates=args.updates,
max_steps=args.max_steps,
lr=args.lr,
beta_kl=args.beta_kl,
seed=args.seed,
hidden_size=args.hidden_size,
device=args.device,
use_swanlab=args.swanlab,
swanlab_project=args.swanlab_project,
swanlab_run_name=args.swanlab_run_name,
pong_reward=args.pong_reward,
closest_bonus=args.closest_bonus,
)
train_grpo(config)
def resolve_device(requested: str) -> str:
if requested != "auto":
return requested
return "cuda" if torch.cuda.is_available() else "cpu"
def _auto_max_steps(configured: int, wall_len: int) -> int:
if configured > 0:
return configured
return wall_len * 4 + 50
class SwanlabLogger:
def __init__(self, config: GRPOConfig):
self.enabled = config.use_swanlab
self._swanlab = None
if not self.enabled:
return
try:
import swanlab
except ImportError:
print("swanlab not installed, logging disabled")
self.enabled = False
return
self._swanlab = swanlab
self._init_run(config)
def _init_run(self, config: GRPOConfig) -> None:
if not self._swanlab:
return
base_cfg = asdict(config)
base_cfg.pop("use_swanlab", None)
kwargs = {"project": config.swanlab_project, "config": base_cfg}
if config.swanlab_run_name:
kwargs["name"] = config.swanlab_run_name
init = getattr(self._swanlab, "init", None)
if init is None:
self.enabled = False
return
try:
init(**kwargs)
return
except TypeError:
pass
if config.swanlab_run_name:
kwargs_alt = {"project": config.swanlab_project, "config": base_cfg, "experiment_name": config.swanlab_run_name}
else:
kwargs_alt = {"project": config.swanlab_project, "config": base_cfg}
try:
init(**kwargs_alt)
except TypeError:
init(project=config.swanlab_project)
def log(self, metrics: Dict[str, float], step: int) -> None:
if not self.enabled or not self._swanlab:
return
log_fn = getattr(self._swanlab, "log", None)
if log_fn is None:
return
try:
log_fn(metrics, step=step)
except TypeError:
log_fn(metrics)
def finish(self) -> None:
if not self.enabled or not self._swanlab:
return
finish_fn = getattr(self._swanlab, "finish", None)
if finish_fn:
finish_fn()
if __name__ == "__main__":
main()

61
majiang_rl/rl/trainer.py Normal file
View File

@@ -0,0 +1,61 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import List
from .agents import Transition
@dataclass
class EpisodeResult:
total_reward: float
steps: int
terminated: bool
truncated: bool
class EpisodeRunner:
def __init__(self, env, agent, max_steps: int = 1000):
self.env = env
self.agent = agent
self.max_steps = max_steps
def run_episode(self) -> EpisodeResult:
obs, info = self.env.reset()
self.agent.reset()
total_reward = 0.0
terminated = False
truncated = False
steps = 0
for step in range(self.max_steps):
action = self.agent.act(obs, info)
next_obs, reward, terminated, truncated, next_info = self.env.step(action)
transition = Transition(
obs=obs,
action=action,
reward=reward,
terminated=terminated,
truncated=truncated,
next_obs=next_obs,
info=next_info,
)
self.agent.observe(transition)
total_reward += reward
obs, info = next_obs, next_info
steps = step + 1
if terminated or truncated:
break
return EpisodeResult(
total_reward=total_reward,
steps=steps,
terminated=terminated,
truncated=truncated,
)
def run_training(env, agent, episodes: int = 100) -> List[EpisodeResult]:
runner = EpisodeRunner(env, agent)
results: List[EpisodeResult] = []
for _ in range(episodes):
results.append(runner.run_episode())
return results