416 lines
14 KiB
Python
416 lines
14 KiB
Python
from __future__ import annotations
|
|
|
|
import argparse
|
|
import copy
|
|
from dataclasses import asdict, dataclass
|
|
from typing import Dict, List, Tuple
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from ..envs.self_play_env import (
|
|
ACTION_DECLARE_CHI,
|
|
ACTION_DECLARE_KONG,
|
|
ACTION_DECLARE_PONG,
|
|
ACTION_DISCARD,
|
|
SelfPlayMahjong,
|
|
)
|
|
from ..rules import hand_distance_to_win, total_fan
|
|
|
|
|
|
@dataclass
|
|
class GRPOConfig:
|
|
group_size: int = 16
|
|
updates: int = 50
|
|
max_steps: int = 0
|
|
lr: float = 3e-4
|
|
beta_kl: float = 0.02
|
|
seed: int = 1
|
|
hidden_size: int = 256
|
|
device: str = "auto"
|
|
use_swanlab: bool = False
|
|
swanlab_project: str = "majiang-rl"
|
|
swanlab_run_name: str = ""
|
|
pong_reward: float = 0.1
|
|
closest_bonus: float = 1.0
|
|
|
|
|
|
@dataclass
|
|
class StepSample:
|
|
player: int
|
|
obs: torch.Tensor
|
|
mask: Dict[str, torch.Tensor]
|
|
action: Dict[str, int]
|
|
reward: float = 0.0
|
|
|
|
|
|
def obs_to_tensor(obs: Dict[str, np.ndarray], device: str) -> torch.Tensor:
|
|
parts = [
|
|
obs["hand"].astype(np.float32),
|
|
obs["melds"].astype(np.float32).reshape(-1),
|
|
obs["discards"].astype(np.float32).reshape(-1),
|
|
obs["flowers"].astype(np.float32),
|
|
obs["wall_count"].astype(np.float32),
|
|
obs["current_player"].astype(np.float32),
|
|
obs["phase"].astype(np.float32),
|
|
obs["pending_discard"].astype(np.float32),
|
|
]
|
|
flat = np.concatenate(parts, axis=0)
|
|
return torch.tensor(flat, dtype=torch.float32, device=device)
|
|
|
|
|
|
def mask_to_torch(mask: Dict[str, np.ndarray], device: str) -> Dict[str, torch.Tensor]:
|
|
return {
|
|
"type": torch.tensor(mask["type"], dtype=torch.bool, device=device),
|
|
"discard": torch.tensor(mask["discard"], dtype=torch.bool, device=device),
|
|
"pong": torch.tensor(mask["pong"], dtype=torch.bool, device=device),
|
|
"kong": torch.tensor(mask["kong"], dtype=torch.bool, device=device),
|
|
"chi": torch.tensor(mask["chi"], dtype=torch.bool, device=device),
|
|
}
|
|
|
|
|
|
class PolicyNet(nn.Module):
|
|
def __init__(self, input_dim: int, hidden_size: int = 256):
|
|
super().__init__()
|
|
self.net = nn.Sequential(
|
|
nn.Linear(input_dim, hidden_size),
|
|
nn.ReLU(),
|
|
nn.Linear(hidden_size, hidden_size),
|
|
nn.ReLU(),
|
|
)
|
|
self.type_head = nn.Linear(hidden_size, 6)
|
|
self.tile_head = nn.Linear(hidden_size, 42)
|
|
self.chi_head = nn.Linear(hidden_size, 3)
|
|
|
|
def forward(self, obs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
x = self.net(obs)
|
|
return self.type_head(x), self.tile_head(x), self.chi_head(x)
|
|
|
|
def sample_action(self, obs: torch.Tensor, mask: Dict[str, torch.Tensor]) -> Dict[str, int]:
|
|
logits_type, logits_tile, logits_chi = self(obs)
|
|
type_mask = mask["type"]
|
|
masked_type = masked_logits(logits_type, type_mask)
|
|
type_dist = torch.distributions.Categorical(logits=masked_type)
|
|
action_type = int(type_dist.sample().item())
|
|
|
|
tile_id = 0
|
|
chi_choice = 0
|
|
if action_type == ACTION_DISCARD:
|
|
tile_id = sample_from_mask(logits_tile, mask["discard"])
|
|
elif action_type == ACTION_DECLARE_PONG:
|
|
tile_id = sample_from_mask(logits_tile, mask["pong"])
|
|
elif action_type == ACTION_DECLARE_KONG:
|
|
tile_id = sample_from_mask(logits_tile, mask["kong"])
|
|
elif action_type == ACTION_DECLARE_CHI:
|
|
chi_choice = sample_from_mask(logits_chi, mask["chi"])
|
|
|
|
return {"type": action_type, "tile": tile_id, "chi": chi_choice}
|
|
|
|
|
|
def masked_logits(logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
|
masked = logits.clone()
|
|
masked[~mask] = -1e9
|
|
return masked
|
|
|
|
|
|
def sample_from_mask(logits: torch.Tensor, mask: torch.Tensor) -> int:
|
|
if mask.sum().item() == 0:
|
|
return 0
|
|
masked = masked_logits(logits, mask)
|
|
dist = torch.distributions.Categorical(logits=masked)
|
|
return int(dist.sample().item())
|
|
|
|
|
|
def logprob_action(
|
|
logits_type: torch.Tensor,
|
|
logits_tile: torch.Tensor,
|
|
logits_chi: torch.Tensor,
|
|
action: Dict[str, int],
|
|
mask: Dict[str, torch.Tensor],
|
|
) -> torch.Tensor:
|
|
action_type = int(action["type"])
|
|
logp = masked_log_softmax(logits_type, mask["type"])[action_type]
|
|
|
|
if action_type == ACTION_DISCARD:
|
|
logp = logp + masked_log_softmax(logits_tile, mask["discard"])[int(action["tile"])]
|
|
elif action_type == ACTION_DECLARE_PONG:
|
|
logp = logp + masked_log_softmax(logits_tile, mask["pong"])[int(action["tile"])]
|
|
elif action_type == ACTION_DECLARE_KONG:
|
|
logp = logp + masked_log_softmax(logits_tile, mask["kong"])[int(action["tile"])]
|
|
elif action_type == ACTION_DECLARE_CHI:
|
|
logp = logp + masked_log_softmax(logits_chi, mask["chi"])[int(action["chi"])]
|
|
|
|
return logp
|
|
|
|
|
|
def masked_log_softmax(logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
|
masked = masked_logits(logits, mask)
|
|
return F.log_softmax(masked, dim=-1)
|
|
|
|
|
|
def masked_softmax(logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
|
masked = masked_logits(logits, mask)
|
|
return F.softmax(masked, dim=-1)
|
|
|
|
|
|
def kl_divergence(logits: torch.Tensor, ref_logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
|
if mask.sum().item() == 0:
|
|
return torch.tensor(0.0, device=logits.device)
|
|
p = masked_softmax(logits, mask)
|
|
logp = torch.log(p + 1e-12)
|
|
logp_ref = masked_log_softmax(ref_logits, mask)
|
|
return torch.sum(p * (logp - logp_ref))
|
|
|
|
|
|
def compute_rewards(game: SelfPlayMahjong, config: GRPOConfig) -> List[float]:
|
|
rewards = [0.0, 0.0, 0.0, 0.0]
|
|
if game.winner is not None:
|
|
fan = float(total_fan(game.winning_tiles, game.winning_melds))
|
|
rewards[game.winner] = fan
|
|
split = -fan / 3.0
|
|
for idx in range(4):
|
|
if idx != game.winner:
|
|
rewards[idx] = split
|
|
|
|
distances = []
|
|
for idx, player in enumerate(game.players):
|
|
tiles = (
|
|
game.winning_tiles if game.winner is not None and idx == game.winner else player.hand
|
|
)
|
|
distances.append(hand_distance_to_win(tiles, player.melds))
|
|
min_distance = min(distances) if distances else 0
|
|
for idx, distance in enumerate(distances):
|
|
if distance == min_distance:
|
|
rewards[idx] += config.closest_bonus
|
|
return rewards
|
|
|
|
|
|
def collect_rollouts(
|
|
policy: PolicyNet, config: GRPOConfig, seed_offset: int
|
|
) -> Tuple[List[StepSample], List[float], Dict[str, float]]:
|
|
all_steps: List[StepSample] = []
|
|
episode_rewards: List[float] = []
|
|
final_rewards_all: List[float] = []
|
|
shaping_rewards_all: List[float] = []
|
|
|
|
for episode in range(config.group_size):
|
|
game = SelfPlayMahjong(
|
|
seed=config.seed + seed_offset + episode, pong_reward=config.pong_reward
|
|
)
|
|
game.reset()
|
|
episode_steps: List[StepSample] = []
|
|
shaping_rewards = [0.0, 0.0, 0.0, 0.0]
|
|
max_steps = _auto_max_steps(config.max_steps, len(game.wall))
|
|
for _ in range(max_steps):
|
|
if game.done:
|
|
break
|
|
player = game.current_player
|
|
obs = game.observe(player)
|
|
mask = game.action_mask(player)
|
|
obs_tensor = obs_to_tensor(obs, config.device)
|
|
mask_tensor = mask_to_torch(mask, config.device)
|
|
with torch.no_grad():
|
|
action = policy.sample_action(obs_tensor, mask_tensor)
|
|
episode_steps.append(
|
|
StepSample(player=player, obs=obs_tensor, mask=mask_tensor, action=action)
|
|
)
|
|
reward = game.step(player, action)
|
|
shaping_rewards[player] += reward
|
|
episode_steps[-1].reward = reward
|
|
|
|
if not game.done:
|
|
game._end_game(None, "max_steps", [])
|
|
rewards = compute_rewards(game, config)
|
|
total_rewards = [rewards[i] + shaping_rewards[i] for i in range(4)]
|
|
episode_rewards.extend(total_rewards)
|
|
final_rewards_all.extend(rewards)
|
|
shaping_rewards_all.extend(shaping_rewards)
|
|
for step in episode_steps:
|
|
step.reward += rewards[step.player]
|
|
all_steps.append(step)
|
|
|
|
stats = {
|
|
"avg_final_reward": float(np.mean(final_rewards_all)) if final_rewards_all else 0.0,
|
|
"avg_shaping_reward": float(np.mean(shaping_rewards_all)) if shaping_rewards_all else 0.0,
|
|
}
|
|
return all_steps, episode_rewards, stats
|
|
|
|
|
|
def train_grpo(config: GRPOConfig) -> PolicyNet:
|
|
config.device = resolve_device(config.device)
|
|
torch.manual_seed(config.seed)
|
|
np.random.seed(config.seed)
|
|
print(f"Training on device: {config.device}")
|
|
|
|
dummy_game = SelfPlayMahjong(seed=config.seed)
|
|
dummy_game.reset()
|
|
input_dim = obs_to_tensor(dummy_game.observe(0), config.device).shape[0]
|
|
|
|
policy = PolicyNet(input_dim, hidden_size=config.hidden_size).to(config.device)
|
|
optimizer = torch.optim.Adam(policy.parameters(), lr=config.lr)
|
|
|
|
logger = SwanlabLogger(config)
|
|
for update in range(config.updates):
|
|
policy.eval()
|
|
steps, episode_rewards, reward_stats = collect_rollouts(
|
|
policy, config, update * config.group_size
|
|
)
|
|
if not steps:
|
|
continue
|
|
rewards = np.array(episode_rewards, dtype=np.float32)
|
|
mean = float(rewards.mean())
|
|
std = float(rewards.std() + 1e-6)
|
|
|
|
ref_policy = copy.deepcopy(policy).eval()
|
|
policy.train()
|
|
losses = []
|
|
|
|
for step in steps:
|
|
logits_type, logits_tile, logits_chi = policy(step.obs)
|
|
with torch.no_grad():
|
|
ref_type, ref_tile, ref_chi = ref_policy(step.obs)
|
|
|
|
advantage = torch.tensor((step.reward - mean) / std, device=logits_type.device)
|
|
logp = logprob_action(logits_type, logits_tile, logits_chi, step.action, step.mask)
|
|
|
|
tile_union = step.mask["discard"] | step.mask["pong"] | step.mask["kong"]
|
|
kl = kl_divergence(logits_type, ref_type, step.mask["type"])
|
|
kl += kl_divergence(logits_tile, ref_tile, tile_union)
|
|
kl += kl_divergence(logits_chi, ref_chi, step.mask["chi"])
|
|
|
|
loss = -(advantage * logp) + config.beta_kl * kl
|
|
losses.append(loss)
|
|
|
|
loss_value = torch.stack(losses).mean()
|
|
optimizer.zero_grad()
|
|
loss_value.backward()
|
|
optimizer.step()
|
|
|
|
logger.log(
|
|
{
|
|
"loss": float(loss_value.item()),
|
|
"avg_reward": mean,
|
|
"reward_std": std,
|
|
"avg_final_reward": reward_stats["avg_final_reward"],
|
|
"avg_shaping_reward": reward_stats["avg_shaping_reward"],
|
|
},
|
|
step=update + 1,
|
|
)
|
|
print(f"Update {update + 1}/{config.updates} loss={loss_value.item():.4f} avg_reward={mean:.2f}")
|
|
|
|
logger.finish()
|
|
return policy
|
|
|
|
|
|
def main() -> None:
|
|
parser = argparse.ArgumentParser(description="GRPO self-play training for Mahjong")
|
|
parser.add_argument("--updates", type=int, default=50)
|
|
parser.add_argument("--group-size", type=int, default=16)
|
|
parser.add_argument("--max-steps", type=int, default=0)
|
|
parser.add_argument("--lr", type=float, default=3e-4)
|
|
parser.add_argument("--beta-kl", type=float, default=0.02)
|
|
parser.add_argument("--seed", type=int, default=1)
|
|
parser.add_argument("--hidden-size", type=int, default=256)
|
|
parser.add_argument("--device", type=str, default="auto")
|
|
parser.add_argument("--swanlab", action="store_true")
|
|
parser.add_argument("--swanlab-project", type=str, default="majiang-rl")
|
|
parser.add_argument("--swanlab-run-name", type=str, default="")
|
|
parser.add_argument("--pong-reward", type=float, default=0.1)
|
|
parser.add_argument("--closest-bonus", type=float, default=1.0)
|
|
args = parser.parse_args()
|
|
|
|
config = GRPOConfig(
|
|
group_size=args.group_size,
|
|
updates=args.updates,
|
|
max_steps=args.max_steps,
|
|
lr=args.lr,
|
|
beta_kl=args.beta_kl,
|
|
seed=args.seed,
|
|
hidden_size=args.hidden_size,
|
|
device=args.device,
|
|
use_swanlab=args.swanlab,
|
|
swanlab_project=args.swanlab_project,
|
|
swanlab_run_name=args.swanlab_run_name,
|
|
pong_reward=args.pong_reward,
|
|
closest_bonus=args.closest_bonus,
|
|
)
|
|
train_grpo(config)
|
|
|
|
|
|
def resolve_device(requested: str) -> str:
|
|
if requested != "auto":
|
|
return requested
|
|
return "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
|
def _auto_max_steps(configured: int, wall_len: int) -> int:
|
|
if configured > 0:
|
|
return configured
|
|
return wall_len * 4 + 50
|
|
|
|
|
|
class SwanlabLogger:
|
|
def __init__(self, config: GRPOConfig):
|
|
self.enabled = config.use_swanlab
|
|
self._swanlab = None
|
|
if not self.enabled:
|
|
return
|
|
try:
|
|
import swanlab
|
|
except ImportError:
|
|
print("swanlab not installed, logging disabled")
|
|
self.enabled = False
|
|
return
|
|
self._swanlab = swanlab
|
|
self._init_run(config)
|
|
|
|
def _init_run(self, config: GRPOConfig) -> None:
|
|
if not self._swanlab:
|
|
return
|
|
base_cfg = asdict(config)
|
|
base_cfg.pop("use_swanlab", None)
|
|
kwargs = {"project": config.swanlab_project, "config": base_cfg}
|
|
if config.swanlab_run_name:
|
|
kwargs["name"] = config.swanlab_run_name
|
|
init = getattr(self._swanlab, "init", None)
|
|
if init is None:
|
|
self.enabled = False
|
|
return
|
|
try:
|
|
init(**kwargs)
|
|
return
|
|
except TypeError:
|
|
pass
|
|
if config.swanlab_run_name:
|
|
kwargs_alt = {"project": config.swanlab_project, "config": base_cfg, "experiment_name": config.swanlab_run_name}
|
|
else:
|
|
kwargs_alt = {"project": config.swanlab_project, "config": base_cfg}
|
|
try:
|
|
init(**kwargs_alt)
|
|
except TypeError:
|
|
init(project=config.swanlab_project)
|
|
|
|
def log(self, metrics: Dict[str, float], step: int) -> None:
|
|
if not self.enabled or not self._swanlab:
|
|
return
|
|
log_fn = getattr(self._swanlab, "log", None)
|
|
if log_fn is None:
|
|
return
|
|
try:
|
|
log_fn(metrics, step=step)
|
|
except TypeError:
|
|
log_fn(metrics)
|
|
|
|
def finish(self) -> None:
|
|
if not self.enabled or not self._swanlab:
|
|
return
|
|
finish_fn = getattr(self._swanlab, "finish", None)
|
|
if finish_fn:
|
|
finish_fn()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|