feat: initialize majiang-rl project
This commit is contained in:
277
majiang_rl/rules.py
Normal file
277
majiang_rl/rules.py
Normal file
@@ -0,0 +1,277 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, Iterable, List, Tuple
|
||||
|
||||
from .tiles import FLOWER_RANGE, HONOR_RANGE, SUITED_RANGE
|
||||
|
||||
ORPHAN_INDICES = (
|
||||
0,
|
||||
8,
|
||||
9,
|
||||
17,
|
||||
18,
|
||||
26,
|
||||
27,
|
||||
28,
|
||||
29,
|
||||
30,
|
||||
31,
|
||||
32,
|
||||
33,
|
||||
)
|
||||
|
||||
|
||||
def strip_flowers(tiles: Iterable[int]) -> List[int]:
|
||||
return [tile for tile in tiles if tile not in FLOWER_RANGE]
|
||||
|
||||
|
||||
def is_seven_pairs(counts: List[int]) -> bool:
|
||||
return sum(counts) == 14 and all(c % 2 == 0 for c in counts)
|
||||
|
||||
|
||||
def is_thirteen_orphans(counts: List[int]) -> bool:
|
||||
if sum(counts) != 14:
|
||||
return False
|
||||
if any(counts[i] == 0 for i in ORPHAN_INDICES):
|
||||
return False
|
||||
if any(counts[i] > 0 for i in range(34) if i not in ORPHAN_INDICES):
|
||||
return False
|
||||
return sum(counts[i] for i in ORPHAN_INDICES) == 14
|
||||
|
||||
|
||||
def is_all_pungs(counts: List[int]) -> bool:
|
||||
if sum(counts) != 14:
|
||||
return False
|
||||
for idx, count in enumerate(counts):
|
||||
if count >= 2:
|
||||
counts[idx] -= 2
|
||||
if _can_form_pungs(counts, 4):
|
||||
counts[idx] += 2
|
||||
return True
|
||||
counts[idx] += 2
|
||||
return False
|
||||
|
||||
|
||||
def _can_form_pungs(counts: List[int], needed: int) -> bool:
|
||||
if needed == 0:
|
||||
return sum(counts) == 0
|
||||
for idx, count in enumerate(counts):
|
||||
if count > 0:
|
||||
break
|
||||
else:
|
||||
return False
|
||||
if counts[idx] >= 3:
|
||||
counts[idx] -= 3
|
||||
if _can_form_pungs(counts, needed - 1):
|
||||
counts[idx] += 3
|
||||
return True
|
||||
counts[idx] += 3
|
||||
return False
|
||||
|
||||
|
||||
def is_all_honors(counts: List[int]) -> bool:
|
||||
return sum(counts[i] for i in SUITED_RANGE) == 0 and sum(counts) == 14
|
||||
|
||||
|
||||
def _suit_presence(counts: List[int]) -> Tuple[bool, Tuple[bool, bool, bool]]:
|
||||
suit_flags = [False, False, False]
|
||||
for idx in SUITED_RANGE:
|
||||
if counts[idx] > 0:
|
||||
suit_flags[idx // 9] = True
|
||||
honors = sum(counts[i] for i in HONOR_RANGE) > 0
|
||||
return honors, tuple(suit_flags)
|
||||
|
||||
|
||||
def is_pure_one_suit(counts: List[int]) -> bool:
|
||||
honors, suits = _suit_presence(counts)
|
||||
return not honors and sum(1 for flag in suits if flag) == 1
|
||||
|
||||
|
||||
def is_half_flush(counts: List[int]) -> bool:
|
||||
honors, suits = _suit_presence(counts)
|
||||
return honors and sum(1 for flag in suits if flag) == 1
|
||||
|
||||
|
||||
def is_all_terminals_or_honors(counts: List[int]) -> bool:
|
||||
if sum(counts) != 14:
|
||||
return False
|
||||
terminals = {0, 8, 9, 17, 18, 26}
|
||||
for idx, count in enumerate(counts):
|
||||
if count == 0:
|
||||
continue
|
||||
if idx in HONOR_RANGE:
|
||||
continue
|
||||
if idx not in terminals:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def can_form_melds(counts: List[int], needed: int) -> bool:
|
||||
if needed == 0:
|
||||
return sum(counts) == 0
|
||||
for idx, count in enumerate(counts):
|
||||
if count > 0:
|
||||
break
|
||||
else:
|
||||
return False
|
||||
|
||||
# Pong
|
||||
if counts[idx] >= 3:
|
||||
counts[idx] -= 3
|
||||
if can_form_melds(counts, needed - 1):
|
||||
counts[idx] += 3
|
||||
return True
|
||||
counts[idx] += 3
|
||||
|
||||
# Chow
|
||||
if idx in SUITED_RANGE and idx % 9 <= 6:
|
||||
if counts[idx + 1] > 0 and counts[idx + 2] > 0:
|
||||
counts[idx] -= 1
|
||||
counts[idx + 1] -= 1
|
||||
counts[idx + 2] -= 1
|
||||
if can_form_melds(counts, needed - 1):
|
||||
counts[idx] += 1
|
||||
counts[idx + 1] += 1
|
||||
counts[idx + 2] += 1
|
||||
return True
|
||||
counts[idx] += 1
|
||||
counts[idx + 1] += 1
|
||||
counts[idx + 2] += 1
|
||||
return False
|
||||
|
||||
|
||||
def is_standard_hand(counts: List[int], open_melds: int) -> bool:
|
||||
if sum(counts) != (4 - open_melds) * 3 + 2:
|
||||
return False
|
||||
for idx, count in enumerate(counts):
|
||||
if count >= 2:
|
||||
counts[idx] -= 2
|
||||
if can_form_melds(counts, 4 - open_melds):
|
||||
counts[idx] += 2
|
||||
return True
|
||||
counts[idx] += 2
|
||||
return False
|
||||
|
||||
|
||||
def is_win(tiles: Iterable[int], open_melds: int) -> bool:
|
||||
counts = [0] * 34
|
||||
for tile in tiles:
|
||||
if tile in FLOWER_RANGE:
|
||||
continue
|
||||
counts[tile] += 1
|
||||
|
||||
if is_seven_pairs(counts):
|
||||
return True
|
||||
if is_thirteen_orphans(counts):
|
||||
return True
|
||||
return is_standard_hand(counts, open_melds)
|
||||
|
||||
|
||||
def fan_breakdown(tiles: Iterable[int], melds: List[List[int]]) -> Dict[str, int]:
|
||||
counts = [0] * 34
|
||||
for tile in tiles:
|
||||
if tile in FLOWER_RANGE:
|
||||
continue
|
||||
counts[tile] += 1
|
||||
for meld in melds:
|
||||
for tile in meld:
|
||||
if tile in FLOWER_RANGE:
|
||||
continue
|
||||
counts[tile] += 1
|
||||
|
||||
fans: Dict[str, int] = {}
|
||||
if is_thirteen_orphans(counts):
|
||||
fans["thirteen_orphans"] = 88
|
||||
if is_all_honors(counts):
|
||||
fans["all_honors"] = 64
|
||||
if is_pure_one_suit(counts):
|
||||
fans["pure_one_suit"] = 24
|
||||
if is_seven_pairs(counts):
|
||||
fans["seven_pairs"] = 24
|
||||
if is_all_terminals_or_honors(counts):
|
||||
fans["all_terminals_or_honors"] = 32
|
||||
if is_all_pungs(counts.copy()):
|
||||
fans["all_pungs"] = 6
|
||||
if is_half_flush(counts):
|
||||
fans["half_flush"] = 6
|
||||
if not fans:
|
||||
fans["standard"] = 1
|
||||
return fans
|
||||
|
||||
|
||||
def total_fan(tiles: Iterable[int], melds: List[List[int]]) -> int:
|
||||
return sum(fan_breakdown(tiles, melds).values())
|
||||
|
||||
|
||||
def hand_distance_to_win(hand_tiles: Iterable[int], melds: List[List[int]]) -> int:
|
||||
counts = [0] * 34
|
||||
for tile in hand_tiles:
|
||||
if tile in FLOWER_RANGE:
|
||||
continue
|
||||
counts[tile] += 1
|
||||
|
||||
melds_done = min(len(melds), 4)
|
||||
needed_melds = max(0, 4 - melds_done)
|
||||
max_used = _max_used_tiles(counts, needed_melds)
|
||||
used_tiles = min(14, 3 * melds_done + max_used)
|
||||
return max(0, 14 - used_tiles)
|
||||
|
||||
|
||||
def _max_used_tiles(counts: List[int], needed_melds: int) -> int:
|
||||
best = 3 * _max_melds(counts, needed_melds)
|
||||
for idx, count in enumerate(counts):
|
||||
if count >= 2:
|
||||
counts[idx] -= 2
|
||||
melds = _max_melds(counts, needed_melds)
|
||||
best = max(best, 2 + 3 * melds)
|
||||
counts[idx] += 2
|
||||
return best
|
||||
|
||||
|
||||
def _max_melds(counts: List[int], limit: int) -> int:
|
||||
if limit == 0:
|
||||
return 0
|
||||
for idx, count in enumerate(counts):
|
||||
if count > 0:
|
||||
break
|
||||
else:
|
||||
return 0
|
||||
|
||||
best = 0
|
||||
counts[idx] -= 1
|
||||
best = max(best, _max_melds(counts, limit))
|
||||
counts[idx] += 1
|
||||
|
||||
if counts[idx] >= 3:
|
||||
counts[idx] -= 3
|
||||
best = max(best, 1 + _max_melds(counts, limit - 1))
|
||||
counts[idx] += 3
|
||||
|
||||
if idx in SUITED_RANGE and idx % 9 <= 6:
|
||||
if counts[idx + 1] > 0 and counts[idx + 2] > 0:
|
||||
counts[idx] -= 1
|
||||
counts[idx + 1] -= 1
|
||||
counts[idx + 2] -= 1
|
||||
best = max(best, 1 + _max_melds(counts, limit - 1))
|
||||
counts[idx] += 1
|
||||
counts[idx + 1] += 1
|
||||
counts[idx + 2] += 1
|
||||
|
||||
return best
|
||||
|
||||
|
||||
def can_chi_options(counts: List[int], tile_id: int) -> List[int]:
|
||||
if tile_id not in SUITED_RANGE:
|
||||
return []
|
||||
options: List[int] = []
|
||||
rank = tile_id % 9
|
||||
if rank >= 2:
|
||||
if counts[tile_id - 1] > 0 and counts[tile_id - 2] > 0:
|
||||
options.append(0)
|
||||
if 1 <= rank <= 7:
|
||||
if counts[tile_id - 1] > 0 and counts[tile_id + 1] > 0:
|
||||
options.append(1)
|
||||
if rank <= 6:
|
||||
if counts[tile_id + 1] > 0 and counts[tile_id + 2] > 0:
|
||||
options.append(2)
|
||||
return options
|
||||
Reference in New Issue
Block a user