Mahjax:基于JAX的GPU加速麻将模拟器,用于强化学习
摘要
本文介绍了Mahjax,一个完全向量化的立直麻将模拟器,基于JAX实现,用于GPU加速的强化学习,具有高吞吐量,并支持从零开始训练。
arXiv:2605.20577v1 公告类型:新
摘要:立直麻将是一种多人、不完全信息博弈,特点是随机性和高维状态空间。这些属性呈现了独特的挑战组合,反映了强化学习中复杂现实决策问题。虽然先前研究严重依赖从人类游戏日志进行监督学习以预训练策略,但能够从零开始学习的算法(即从\textit{tabula rasa})具有更大的通用潜力,如AlphaZero系列所示。为了促进此类研究,我们介绍了\textbf{Mahjax},一个完全向量化的立直麻将环境,基于JAX实现,以便在图形处理单元(GPU)上进行大规模并行化展开。我们还提供了一个高质量的可视化工具,以简化调试和与训练代理的交互。实验结果表明,在无红规则和有红规则下,Mahjax在八块NVIDIA A100 GPU上分别达到每秒高达\textbf{200万}和\textbf{100万}步的吞吐量。此外,我们通过展示代理可以有效训练以提高其相对于基线策略的排名,验证了该环境在强化学习中的实用性。
查看缓存全文
缓存时间: 2026/05/22 08:47
# Mahjax:JAX 中用于强化学习的 GPU 加速麻将模拟器 来源:https://arxiv.org/html/2605.20577 Soichiro Nishimori1,2, Shinri Okano3, Keigo Habara4, Sotetsu Koyamada5,6,7, Eason Yu8, and Masashi Sugiyama2,11东京大学,日本东京。2RIKEN AIP,日本东京。3奈良科学技术大学,日本奈良。4独立研究员。5神户大学,日本神户。6京都大学,日本京都。7ATR,日本京都。8悉尼大学,澳大利亚悉尼。通讯作者:Soichiro Nishimori。邮箱:[email protected]。本作品已提交IEEE考虑发表。版权可能在通知后转移,之后此版本可能不再可访问。 ###### 摘要 立直麻将是一款多玩家、信息不完全的游戏,其特征是随机性和高维状态空间。这些属性呈现出一系列独特的挑战,反映了强化学习中复杂的现实世界决策问题。以往的研究严重依赖于从人类对局日志进行监督学习来预训练策略,而能够从头开始学习的算法则具有更大的通用潜力,正如 AlphaZero 系列所证明的那样。为了促进这类研究,我们引入了 **Mahjax**,一个在 JAX 中实现的完全向量化立直麻将环境,旨在支持在图形处理器(GPU)上进行大规模对局并行化。我们还提供了一个高质量的视觉化工具,以简化与训练后智能体的调试和交互。实验结果表明,在无赤牌和有赤牌规则下,Mahjax 在八块 NVIDIA A100 GPU 上分别实现了高达每秒 **200 万步** 和 **100 万步** 的吞吐量。此外,我们通过证明智能体可以针对基线策略有效训练以提升其排名,验证了该环境对强化学习的实用性。代码可在 https://github.com/nissymori/mahjax 获取。 ## I. 引言 立直麻将是一款流行的牌类游戏,玩家在信息不完全的情况下竞争组成获胜手牌 [11 (https://arxiv.org/html/2605.20577#bib.bib6)]。该游戏体现了复杂的现实世界决策问题,其特点是多智能体交互、高维状态空间和随机性。因此,它在强化学习(RL)领域得到了广泛研究 [11 (https://arxiv.org/html/2605.20577#bib.bib6), 27 (https://arxiv.org/html/2605.20577#bib.bib7), 12 (https://arxiv.org/html/2605.20577#bib.bib8), 16 (https://arxiv.org/html/2605.20577#bib.bib9), 6 (https://arxiv.org/html/2605.20577#bib.bib10)]。 该领域的一个重要里程碑是 Suphx [11 (https://arxiv.org/html/2605.20577#bib.bib6)],这是第一个在麻将中达到顶级人类水平的 AI。虽然后续工作也展现了强劲的结果,但它们大多依赖于从人类对局日志进行的监督学习(SL)[11 (https://arxiv.org/html/2605.20577#bib.bib6)] 或离线 RL [10 (https://arxiv.org/html/2605.20577#bib.bib5)] 进行预训练。相比之下,AlphaZero 系列算法 [20 (https://arxiv.org/html/2605.20577#bib.bib22), 22 (https://arxiv.org/html/2605.20577#bib.bib23), 21 (https://arxiv.org/html/2605.20577#bib.bib21)] 证明了复杂游戏可以通过从头开始的自我对弈来掌握,无需人类先验知识。这种方法最近已扩展到解决基础算法问题 [4 (https://arxiv.org/html/2605.20577#bib.bib20)]。受这些成就的启发,通过纯 RL 从头解决麻将仍然是一个有前景但尚未充分探索的前沿领域。 然而,在复杂环境中的自我对弈需要大量的试错经验。例如,AlphaHoldem [26 (https://arxiv.org/html/2605.20577#bib.bib30)] 需要 65 亿步训练才能掌握单手无限注扑克。考虑到麻将涉及四名玩家且对局长度比扑克更长,现有的基于中央处理器(CPU)的模拟器在实际训练中造成了计算瓶颈 [8 (https://arxiv.org/html/2605.20577#bib.bib11)]。为了解决数据吞吐量的挑战,RL 社区已转向硬件加速环境 [9 (https://arxiv.org/html/2605.20577#bib.bib12), 1 (https://arxiv.org/html/2605.20577#bib.bib13), 5 (https://arxiv.org/html/2605.20577#bib.bib14), 18 (https://arxiv.org/html/2605.20577#bib.bib16), 14 (https://arxiv.org/html/2605.20577#bib.bib15), 17 (https://arxiv.org/html/2605.20577#bib.bib17)]。这些向量化环境使智能体能够直接在图形处理器(GPU)上以大规模批次收集经验,通常比 CPU 基线快 **100 倍以上** [9 (https://arxiv.org/html/2605.20577#bib.bib12), 14 (https://arxiv.org/html/2605.20577#bib.bib15)]。此外,它们还促进了利用大规模并行交互的新型算法 [7 (https://arxiv.org/html/2605.20577#bib.bib24), 13 (https://arxiv.org/html/2605.20577#bib.bib25)]。 在现有框架中,Pgx [9 (https://arxiv.org/html/2605.20577#bib.bib12)] 提供了一套基于 JAX 的棋盘游戏,但目前缺少对像立直麻将这样复杂信息不完全游戏的全面实现。在这项工作中,我们引入了 **Mahjax**,一个用 JAX [2 (https://arxiv.org/html/2605.20577#bib.bib4)] 编写的完全可向量化的立直麻将环境,旨在支持大规模纯 RL 研究。 我们的贡献总结如下:1)**向量化环境**:我们提供了一个高性能的麻将环境,采用 Pgx 应用程序接口(API),确保与基于 JAX 的现代 RL 流程兼容。2)**性能**:Mahjax 在多个 GPU 上高效扩展,在八块 NVIDIA A100 GPU 上,无赤牌和有赤牌规则下分别实现了高达每秒 **200 万步** 和 **100 万步** 的吞吐量。3)**易用性**:我们提供了可视化工具,以方便调试和分析。4)**验证**:我们通过成功的 RL 训练验证了该环境,证明了其可用于研究。 ## II. 相关工作 我们回顾了麻将 AI 和 GPU 加速 RL 环境的相关工作。 **RL 中的麻将**。麻将已在 RL 文献中得到广泛研究。在专注于智能体的工作中,最重要的里程碑是 Suphx [11 (https://arxiv.org/html/2605.20577#bib.bib6)],这是第一个在日本最受欢迎的麻将平台天凤 [23 (https://arxiv.org/html/2605.20577#bib.bib26)] 上达到顶级人类水平的 AI。此后,商业和开源社区开发了多个智能体。例如,由 Dwango Media Village 开发的 NAGA111https://dmv.nico/en/articles/mahjong_ai_naga/ 达到了天凤的最高等级。Mortal [3 (https://arxiv.org/html/2605.20577#bib.bib29)] 是一个用于训练麻将智能体的开源框架。这些工作的一个共同特点是,它们使用 SL 或离线 RL 在从天凤收集的人类数据上预训练策略,然后通过深度 RL 进行微调。一些研究还探索了麻将的变体。例如,Zhao 和 Holden [27 (https://arxiv.org/html/2605.20577#bib.bib7)] 开发了一个三人麻将(三麻)智能体。此外,Ogami 等人 [16 (https://arxiv.org/html/2605.20577#bib.bib9)] 提出了一种改进玩家评估的方法。 在模拟基础设施方面,Mjx [8 (https://arxiv.org/html/2605.20577#bib.bib11)] 提供了一个快速的 C++ 模拟器,吞吐量约为每小时 4 万局。类似地,Mortal 提供了一个名为 Libriichi [3 (https://arxiv.org/html/2605.20577#bib.bib29)] 的基于 Rust 的快速模拟器,其速度相当。然而,这些基于 CPU 的模拟器在尝试利用大规模自对弈训练所需的大规模并行化时面临可扩展性限制。 **GPU 加速环境**。最近,用 JAX [2 (https://arxiv.org/html/2605.20577#bib.bib4)] 原生编写的环境得到了积极开发 [9 (https://arxiv.org/html/2605.20577#bib.bib12), 1 (https://arxiv.org/html/2605.20577#bib.bib13), 5 (https://arxiv.org/html/2605.20577#bib.bib14), 18 (https://arxiv.org/html/2605.20577#bib.bib16), 14 (https://arxiv.org/html/2605.20577#bib.bib15), 17 (https://arxiv.org/html/2605.20577#bib.bib17)]。Pgx [9 (https://arxiv.org/html/2605.20577#bib.bib12)] 提供了如围棋和将棋等经典棋盘游戏,速度比基于 CPU 的对应物快 10–100 倍。其他领域涵盖 Jumanji [1 (https://arxiv.org/html/2605.20577#bib.bib13)] 中的组合优化、Brax [5 (https://arxiv.org/html/2605.20577#bib.bib14)] 中的可微物理以及 JaxMARL [18 (https://arxiv.org/html/2605.20577#bib.bib16)] 中的多智能体任务。最近,还引入了诸如用于开放式学习的 Craftax [14 (https://arxiv.org/html/2605.20577#bib.bib15)]、用于网格世界导航的 Navix [17 (https://arxiv.org/html/2605.20577#bib.bib17)] 以及用于网格世界中元学习的 XLand-Minigrid [15 (https://arxiv.org/html/2605.20577#bib.bib18)] 等环境。这些向量化环境不仅加速了模拟,还促进了利用大规模并行交互的新型 RL 算法,例如并行 Q 学习(PQN)[7 (https://arxiv.org/html/2605.20577#bib.bib24)]。 ## III. Mahjax 概述 参照标题图 1:展示 Mahjax API 的示例代码片段。在本节中,我们描述 Mahjax 的设计选择与实现细节。 ### III-A API 设计与实现 Mahjax 采用 Pgx [9 (https://arxiv.org/html/2605.20577#bib.bib12)] 的 API 设计,以确保与完全可向量化环境的兼容性。图 1 (https://arxiv.org/html/2605.20577#S3.F1) 展示了一个典型的使用示例。为了与 JAX 框架 [2 (https://arxiv.org/html/2605.20577#bib.bib4)] 保持一致,我们严格遵循函数式编程范式:`State` 数据类将包括手牌、分数、风牌、副露和掩码在内的所有游戏信息存储为不可变的 JAX 数组。这种设计与此前通常采用有状态、面向对象架构的麻将模拟器 [8 (https://arxiv.org/html/2605.20577#bib.bib11)] 形成对比,后者妨碍了在 JAX 中的实现。 至关重要的是,将游戏逻辑实现为纯函数对于 JAX 即时编译(JIT)是必不可少的。然而,麻将逻辑涉及复杂的条件分支,这可能阻碍 GPU 上的并行性能。为了缓解这一问题,我们采用了两种主要的优化技术:1)**向量化逻辑**:我们尽可能用矩阵运算替换控制流发散(例如,if-else 语句)。2)**缓存**:我们对计算密集型的评估(如役种计算)实现了缓存。具体来说,我们预先计算了所有可能花色组合的相关统计信息,并将其编码为位掩码。 参照标题图 2:基于 SVG 的 Mahjax 游戏状态可视化。 ### III-B RL 环境设计 在这里,我们描述 Mahjax 作为 RL 环境的具体配置。 **规则**。我们遵循四玩家东南立直麻将的标准规则222http://mahjong-europe.org/portal/images/docs/Riichi-rules-2025-EN.pdf。我们支持两种主要变体: - **天凤(有赤)规则**:如天凤平台 [23 (https://arxiv.org/html/2605.20577#bib.bib26)] 所使用的标准四玩家东南立直麻将规则,包括赤五牌。以往研究主要聚焦于此变体 [11 (https://arxiv.org/html/2605.20577#bib.bib6), 8 (https://arxiv.org/html/2605.20577#bib.bib11), 3 (https://arxiv.org/html/2605.20577#bib.bib29)]。我们按照 Koyamada 等人 [8 (https://arxiv.org/html/2605.20577#bib.bib11)] 的方法,使用下载的对局日志验证了实现的正确性。 - **无赤规则**:不使用赤牌的游戏变体。为简单起见并获得更高吞吐量,我们移除了若干复杂规则,如流局满贯。 **游戏模式**。为提供不同难度级别,我们提供三种模式:`single`、`east` 和 `half`。在 `single` 模式中,剧情在一局后终止,强调即时的做牌效率。相反,`east` 模式持续最多 4 轮(仅东场)。在 `half` 模式中,剧情持续最多 8 轮(东场和南场),需要长期战略规划,如顺位防守和临时合作 [11 (https://arxiv.org/html/2605.20577#bib.bib6)]。 **动作空间**。动作空间包含离散标识符,涵盖弃牌、杠以及特殊操作(如立直、荣和、碰和过)。提供 `legal_action_mask` 以过滤无效 logits。为严格遵循规则,执行非法动作会立即终止并施加惩罚(默认 −1.0)。 **观测空间**。Mahjax 为基于 Transformer 的智能体 [12 (https://arxiv.org/html/2605.20577#bib.bib8)] 提供结构化的字典观测。它包含诸如手牌索引、动作历史和标量属性(如向听数和分数)等标记化输入。所有观测均以当前玩家为中心。 ### III-C 可视化与用户界面 Mahjax 包含一个基于可缩放矢量图形(SVG)的可视化工具(图 2 (https://arxiv.org/html/2605.20577#S3.F2))和一个基于 Web 的用户界面。这些工具使用户能够定性分析智能体行为、调试环境,并交互式地对抗训练后的智能体。为促进国际研究,该可视化支持英文本地化,以适应不熟悉传统牌图的用户,如图 2 所示。 ## IV. 实验 在本节中,我们评估 Mahjax 的计算效率,并验证其作为 RL 研究平台的有效性。 ### IV-A 速度基准测试 参照标题(a)A100 x 1 参照标题(b)A100 x 8 图 3:Mahjax(有赤和无赤规则)、Pgx 将棋和 Libriichi 在不同批量大小下吞吐量(步/秒)的比较。在单 GPU 设置下,Mahjax 在批量大小约为 2^10 时达到吞吐量平台期。相反,在八 GPU 上,两种规则集下的吞吐量均随更大批量大小继续扩展。Mahjax 在 8 个 GPU 上分别针对无赤和有赤规则达到 200 万步/秒和 100 万步/秒的峰值吞吐量,比 Libriichi 快 10 倍以上,并超过了 Pgx 将棋。**设置**。我们将 Mahjax 与两个基线进行比较:1)Libriichi [3 (https://arxiv.org/html/2605.20577#bib.bib29)],一个用于 Mortal 项目 [3 (https://arxiv.org/html/2605.20577#bib.bib29)] 中有赤规则变体的基于 Rust 的 CPU 模拟器;以及 2)Pgx(将棋)[9 (https://arxiv.org/html/2605.20577#bib.bib12)],Pgx 中的将棋环境。在没有其他 GPU 加速麻将模拟器的情况下,我们引入 Pgx 将棋作为参考点,以评估 Mahjax 的可扩展性。基准测试在使用两个 Intel Xeon Platinum 8360Y CPU 和八块 NVIDIA A100 GPU 的计算节点上进行。对于 JAX 环境,我们使用 `jax.pmap` 在设备间并行化来测量吞吐量,同时利用 Rayon333https://github.com/rayon-rs/rayon 在 Libriichi 中实现多线程执行。对于 Mahjax,我们报告了在单个 GPU 和八个 GPU 上无赤和有赤规则变体的结果,以评估可扩展性。模拟使用随机策略运行 100 个批次步,批量大小从 2 到 16384(8 GPU 设置下从 8 开始)。 **结果**。图 3 (https://arxiv.org/html/2605.20577#S4.F3) 显示了结果。在单个 GPU 上,Mahjax 的吞吐量随批量大小增加到约 2^10 个环境,之后大致饱和,而 Libriichi 由于 CPU 计算限制在约 2^3 处达到平台期。
相似文章
@ekzhang1:本周末jax-js取得了一些进展! - 新的matmul基准测试 - 实时TTS演示(http://jax-js.com/tts) 快得多…
jax-js的进展包括新的matmul基准测试、更快的实时TTS演示、改进的代码生成,以及运行Gemma 3 270B的LLM演示。jax-js是一个开源Web ML框架。
我用自对弈强化学习制作了一个超人类水平的 Generals.io 智能体 [P]
使用基于 JAX 的流水线和 Vision Transformer,通过自对弈强化学习训练了一个超人类水平的 Generals.io 智能体。在人类 1v1 排行榜上排名第一;所有代码和一个快速的 JAX 模拟器均已开源。
@maxencefaldor: 对涌现、自组织或形态发生感兴趣?介绍CAX:在JAX中加速的细胞自动机,一…
CAX是一个高性能的开源库,用于细胞自动机研究,基于JAX构建。它可以将模拟加速高达2000倍,并支持任意维度的离散和连续自动机。
MuJoCo-Drones-Gym:用于控制与强化学习的GPU加速多无人机仿真器
本文介绍MuJoCo-Drones-Gym,一个基于MuJoCo的GPU加速多无人机仿真器,支持灵活的物理模型、动作接口和观测空间,适用于强化学习与控制研究。
@loganthorneloe: This is a excellent explanation of JAX. Understanding how ML frameworks work internally gives you a massive advantage w…
本文详细解释了JAX的核心思想,包括函数纯度、不可变性、显式状态管理和JIT编译,帮助读者从面向对象思维转向函数式编程以优化机器学习性能。