单张 RTX 5080 从头训练 2.35 亿参数 LLM

Reddit r/LocalLLaMA 模型

摘要

一位爱好者在单张 RTX 5080 上从头训练出 2.35 亿参数的 LLM,公开完整 PyTorch 流程并开源 Plasma 1.0。

大家好,这个项目折腾了很久,今天也发到这里分享一下。我用 PyTorch 从零写了一个小型 Transformer 语言模型:没有预训练权重,也不靠 HuggingFace 下载,所有参数都在一张消费级 GPU 上从原始文本训出来。当前版本 Plasma 1.0:2.35 亿参数、18 层、隐藏维度 1024,LLaMA 风格:16 个查询头 + 4 个 KV 头的 GQA(head_dim 64)、SwiGLU FFN 中间维度 2816、RoPE theta 10000、RMSNorm pre-norm、共享嵌入。词表用 32k SentencePiece BPE,开 bf16 混合精度 + 梯度检查点,刚好塞进 5080,序列长度 1024,训了约 50 亿 token。 整套流程也自己撸了一遍: - 数据来自 FineWeb-Edu、Wikipedia、StackExchange、代码、ArXiv - 质量 + 毒性过滤 - MinHash 去重 - 定制 SentencePiece 分词器 - 按领域加权混合数据 - 预训练 + 指令微调,带 loss mask,只让模型学 assistant token 指令微调后随手跑两句: 你:第一次世界大战什么时候爆发? 1386.ai:第一次世界大战始于 1914 年 6 月 26 日。 你:牛排是用什么做的? 1386.ai:牛排可由多种肉类制成,包括牛肉。 显然打不过 Llama 3,幻觉、奇怪输出、规模天花板都很明显。但这样从头到尾撸一遍,比单纯微调大模型学到的东西多太多。 Plasma 1.1 正在训(5 亿参数),目标提升多轮能力,词表更大并带 byte fallback。 Repo:[github.com/eb1386/1386.ai](http://github.com/eb1386/1386.ai) 欢迎问任何流程或架构细节。
查看原文

相似文章