我构建了一个Mamba1变体,称为SM1,d_state=1,在Blackwell上纯PyTorch运行[P]

Reddit r/MachineLearning 模型

摘要

作者提出了SM1,一个Mamba1的变体,d_state=1,使用两个原生PyTorch操作替代选择性扫描,与d_state=16相比内存减少16倍。闭式解消除了状态维度,实现了每个token恒定内存的高效推理。

在Windows上,mamba-ssm不容易获取,也无法在sm\_120上编译。SM1(标量Mamba1)用两个原生PyTorch操作替代了整个选择性扫描:`L = torch.cumprod(dA, dim=1)` `h = L * (h0.unsqueeze(1) + torch.cumsum(dBx / L.clamp(min=1e-6), dim=1))` `y = h * C` 这是通过参数变分法得到的d\_state=1递推的精确闭式解。不是近似,它与顺序计算的浮点精度完全相同。d\_state=2会破坏它。d\_state=1是闭式解存在的边界。Mamba1扫描中间变量是(B, T, F, S)。SM1完全消除了S,与d\_state=16的Mamba1相比,扫描内存减少16倍。一个130M参数模型的推理状态大约有14,080个浮点数,56 KB,没有KV缓存,每个token永久O(1)。我目前正在使用163K个MIDI文件训练它,大约相当于我自定义格式的25亿个token。130M参数在不到我16 GB显存一半的空间中运行,显卡是RTX 5060 Ti。d\_state仅在表示尚未编码结构时才扩展表达能力。因此,如果将结构编码到token中,就不需要d\_state超过标量。
查看原文

相似文章

δ-mem:大型语言模型的高效在线记忆机制

Hugging Face Daily Papers

本文介绍了 δ-mem,这是一种轻量级的记忆机制,通过为冻结的注意力骨干网络增加一个紧凑的关联记忆状态来增强大型语言模型。实验表明,该机制在计算开销极小的情况下,在记忆密集型基准测试中实现了性能提升。