MedQA:在AMD ROCm上微调临床AI——无需CUDA
摘要
一个教程和项目,演示在AMD MI300X上使用ROCm对Qwen3-1.7B进行LoRA微调,用于临床问答,为医疗AI开发提供无需CUDA的替代方案。
查看缓存全文
缓存时间: 2026/05/08 12:26
MedQA:在 AMD ROCm 上微调临床 AI —— 无需 CUDA
来源:https://huggingface.co/blog/lablab-ai-amd-developer-hackathon/medqa 返回文章 (https://huggingface.co/blog)
Harikrishna 的头像 (https://huggingface.co/HK2184)
- 创意来源
- 为什么选择 AMD ROCm?
- 数据集:MedMCQA
- 模型:Qwen3-1.7B
- 提示格式
- 使用 LoRA 进行训练
- 推理
- 从 HuggingFace Hub 加载
- 挑战与修复
- 结果
- 自己动手尝试
- 下一步计划
- 结论
在 lablab.ai 的 AMD 开发者黑客马拉松中,使用 AMD MI300X 对 MedMCQA 数据集进行 Qwen3-1.7B LoRA 微调的完整指南。
创意来源
医学问答是一项风险极高的任务。如果模型在临床单选题中自信地选错了答案,这不仅是一个错误——它可能带来危险。与此同时,大多数开源医学 AI 工作都假设你拥有 NVIDIA GPU。CUDA 是默认选项,其他一切都只是事后考虑。
本项目挑战了这一假设。
MedQA 是一个完全基于 AMD 硬件并使用 ROCm 构建的 LoRA 微调临床问答模型。它接收一道多项选择的医学问题,并返回正确答案的字母以及临床推理的解释。整个训练流程——从数据加载到适配器导出——都在 AMD Instinct MI300X 上运行,没有任何 CUDA 依赖。
- 🤗 HuggingFace Hub 上的模型: HK2184/medqa-qwen3-lora (https://huggingface.co/HK2184/medqa-qwen3-lora)
- 🚀 在线演示: HuggingFace Spaces (https://huggingface.co/spaces/lablab-ai-amd-developer-hackathon/MedQA-Medical-AI-on-AMD-ROCm)
- 💻 GitHub: MedQA-Medical-AI-on-AMD-ROCm (https://github.com/HK2184/MedQA-Medical-AI-on-AMD-ROCm)
为什么选择 AMD ROCm?
AMD Instinct MI300X 是一款非凡的硬件:单设备拥有 192 GB HBM3 内存。对于 LLM 微调来说,VRAM 往往是主要的限制因素——它决定了批量大小、序列长度,以及你是否需要量化。有了 192 GB 的内存,我们可以在完全 fp16 精度下使用 LoRA 训练 Qwen3-1.7B,而无需任何 4-bit 或 8-bit 量化技巧。
更重要的是,我们的目标是证明 HuggingFace 生态系统——Transformers、PEFT、TRL、Accelerate——能够完美地运行在 ROCm 上。事实确实如此。同样的训练代码,只需设置三个环境变量,就能在 ROCm 上运行,就像在 CUDA 上一样:
os.environ["ROCR_VISIBLE_DEVICES"] = "0"
os.environ["HIP_VISIBLE_DEVICES"] = "0"
os.environ["HSA_OVERRIDE_GFX_VERSION"] = "9.4.2"
仅此而已。不需要修改代码,不需要自定义内核,不需要 CUDA 兼容层。
数据集:MedMCQA
MedMCQA (https://huggingface.co/datasets/openlifescienceai/medmcqa) 是一个大规模的医学多项选择题数据集,源自印度的医学入学考试(AIIMS、USMLE 风格)。每条数据包含:
- 一道临床问题
- 四个选项(A–D)
- 正确答案的索引
- 可选的自由文本解释(
exp字段)
在本项目中,我们使用了 2,000 个训练样本——故意选取一个较小的切片,以证明可以在短时间内完成有意义的微调。在 MI300X 上训练大约需要 5 分钟。
模型:Qwen3-1.7B
基础模型是 Qwen/Qwen3-1.7B (https://huggingface.co/Qwen/Qwen3-1.7B)——阿里巴巴最新发布的小规模语言模型。拥有 17 亿参数,它足够小巧,可以低成本微调,但又有足够的能力生成连贯的临床推理。它支持 trust_remote_code=True,并能通过 HuggingFace Transformers 干净地加载。
提示格式
提示格式的一致性对于指令微调至关重要。每个训练样本和每次推理调用都使用相同的模板:
### Question:
{question}
### Options:
A) {opa}
B) {opb}
C) {opc}
D) {opd}
### Answer:
{answer_letter}) {answer_text}
### Explanation:
{explanation}
在训练期间,模型会看到包含答案和解释的完整序列。在推理时,我们提供直到 ### Answer:\n 之前的所有内容,然后让模型从那里开始生成。
使用 LoRA 进行训练
我们没有微调全部 15 亿参数,而是通过 PEFT 库使用 LoRA(低秩适应)。LoRA 将可训练的小型秩分解矩阵注入到注意力层中,保持基础权重不变。
LoRA 配置
from peft import LoraConfig, get_peft_model, TaskType
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
r=8,
lora_alpha=16,
lora_dropout=0.05,
target_modules=["q_proj", "v_proj"],
bias="none",
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
# trainable params: 2,228,224 || all params: 1,543,901,184 || trainable%: 0.1443
只有模型 15 亿参数中的 约 220 万 被训练。这保持了较低的内存使用量并加快了训练速度。
训练参数
from transformers import TrainingArguments
args = TrainingArguments(
output_dir="./outputs",
num_train_epochs=2,
per_device_train_batch_size=4,
gradient_accumulation_steps=4, # effective batch size = 16
learning_rate=2e-4,
fp16=True,
bf16=False,
eval_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
gradient_checkpointing=True,
optim="adamw_torch",
warmup_ratio=0.05,
lr_scheduler_type="cosine",
report_to="none",
)
有几点值得注意:
fp16=True, bf16=False—— 我们使用标准的 fp16。在早期试验 bfloat16 时遇到了 NaN 损失;切换到 fp16 后完全解决。gradient_checkpointing=True—— 用计算换内存。在 MI300X 上 192 GB VRAM 并非绝对必要,但为了在更小的 GPU 上复现,这是一个好习惯。gradient_accumulation_steps=4—— 实际物理批量大小为 4,有效批量大小为 16。- 余弦学习率调度 + 预热 —— 对于短训练运行,比固定调度收敛更平滑。
完整训练流程
from transformers import DataCollatorForSeq2Seq, Trainer
collator = DataCollatorForSeq2Seq(
tokenizer,
model=model,
padding=True,
pad_to_multiple_of=8,
)
trainer = Trainer(
model=model,
args=args,
train_dataset=train_ds,
eval_dataset=val_ds,
data_collator=collator,
)
trainer.train()
# Save adapter + tokenizer
model.save_pretrained("./outputs")
tokenizer.save_pretrained("./outputs")
训练结束后,./outputs 目录包含 LoRA 适配器权重——只有几 MB 的文件,而不是完整的数 GB 模型检查点。
推理
在推理时,我们加载基础模型,附加 LoRA 适配器,并可以选择合并权重:
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import torch
tokenizer = AutoTokenizer.from_pretrained("./outputs", trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
base_model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen3-1.7B",
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True,
)
model = PeftModel.from_pretrained(base_model, "./outputs")
model.eval()
生成使用贪婪解码(do_sample=False),并设置了重复惩罚以防止模型循环重复:
def generate(prompt, model, tokenizer):
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
output = model.generate(
**inputs,
max_new_tokens=200,
do_sample=False,
temperature=1.0,
repetition_penalty=1.1,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.eos_token_id,
)
new_tokens = output[0][inputs["input_ids"].shape[-1]:]
return tokenizer.decode(new_tokens, skip_special_tokens=True)
输出示例
问题: 以下哪项是高血压急症的一线治疗?
A) 口服氨氯地平
B) 静脉注射拉贝洛尔或静脉注射硝普钠
C) 舌下含服硝苯地平
D) 肌肉注射肼屈嗪
模型输出:
B) 静脉注射拉贝洛尔或静脉注射硝普钠
解释:
静脉注射拉贝洛尔(β受体阻滞剂)或硝普钠可在紧急情况下快速降低血压。口服药物起效太慢,无法应对需要立即控制血压以避免终末器官损伤的高血压急症。
模型不仅输出一个字母——它还解释了原因,这正是其临床价值所在。
从 HuggingFace Hub 加载
微调后的适配器已公开可用。你可以直接加载,无需克隆仓库:
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import torch
tokenizer = AutoTokenizer.from_pretrained(
"Qwen/Qwen3-1.7B", trust_remote_code=True
)
base = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen3-1.7B",
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
)
model = PeftModel.from_pretrained(base, "HK2184/medqa-qwen3-lora")
model = model.merge_and_unload()
model.eval()
挑战与修复
没有一个 AMD ROCm 项目是没有“战斗故事”章节的。以下是我们遇到的一些问题:
| 挑战 | 根本原因 | 修复方法 |
|---|---|---|
| NaN 损失 | 混合精度不稳定 | 从 bfloat16 切换到 fp16 |
| GPU 未检测到 | 缺少 ROCm 环境变量 | 设置 ROCR_VISIBLE_DEVICES,HIP_VISIBLE_DEVICES,HSA_OVERRIDE_GFX_VERSION |
| bitsandbytes 不支持 | 没有针对 ROCm 的 bitsandbytes 构建 | 完全放弃量化——MI300X 有足够 VRAM |
| 推理输出乱码 | Tokenizer 填充设置错误 | 设置 pad_token = eos_token 并修正 padding_side |
| Trainer 评估错误 | Transformers 版本不匹配 | 固定 transformers>=4.40.0 |
bitsandbytes 的问题值得说明:在 NVIDIA 硬件上,通常需要 4-bit 量化才能将模型装入内存。在拥有 192 GB HBM3 的 MI300X 上,这完全没有必要。这是一个真正的硬件优势——训练更干净,没有量化伪影。
结果
| 指标 | 数值 |
|---|---|
| 可训练参数 | 约 220 万(占总参数的 0.15%) |
| MI300X 训练时间 | 约 5 分钟 |
| 使用的数据集大小 | 2,000 个样本 |
| MedMCQA 基础准确率 | 约 45% |
| 框架 | PyTorch + ROCm 6.1 |
自己动手尝试
没有 GPU?没问题。 在线 Gradio 演示在 HuggingFace Spaces 上运行(CPU 推理):
拥有 AMD 硬件? 克隆仓库并本地运行:
git clone https://github.com/HK2184/MedQA-Medical-AI-on-AMD-ROCm.git
cd MedQA-Medical-AI-on-AMD-ROCm
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.1
pip install transformers datasets peft accelerate trl gradio
python train.py # 约 5 分钟
python infer.py # 运行示例问题
python app.py # 启动 Gradio UI
下一步计划
这个项目证明了整个流程是可行的。下一步是进行扩展和加固:
- 更大数据集 —— 使用完整的 MedMCQA 语料库(约 18 万道题)进行训练,并加入 PubMedQA
- 置信度评分 —— 在答案旁边添加校准的置信度估计
- RAG 集成 —— 将答案基于实时医学文献检索
- 评估框架 —— 在训练分割之外进行适当的保留集准确率基准测试
结论
MedQA 表明,在开源 AMD 硬件上构建一个能力强、可解释的医学 AI 不仅是可能的——而且非常直接。HuggingFace 生态系统对 ROCm 的兼容性确实很好。MI300X 的内存余量消除了整整一类工程问题。而 LoRA 使得微调一个 1.7B 模型成为只需 5 分钟的工作。
如果你正在 AMD ROCm 上构建项目并且遇到了障碍,上面的修复方法应该能为你节省数小时。如果你正在构建医学 AI,那么强调解释而非单纯准确率是值得认真对待的。
为 lablab.ai 上的 AMD 开发者黑客马拉松构建 · 由 AMD ROCm + HuggingFace 生态系统驱动
— Harikrishna Sivanand Iyer 和 Srijan Sivaram A
Screenshot From 2026-05-07 14-26-07 (https://cdn-uploads.huggingface.co/production/uploads/649d94f965079a8cc70d468f/RHAC3pe5_ng5_8RqJ2Sd9.png)
相似文章
TurboQuant+MTP在ROCm(Llama CPP)上的实现
一位开发者成功在llama.cpp中让TurboQuant TBQ4 KV缓存和多Token预测在AMD ROCm上针对RDNA3 GPU运行,实现在24GB显存上支持64k上下文,并具有有竞争力的token速率。
双Radeon R9700——在llama.cpp上运行Qwen 3.6 27B Q8 MTP
关于在使用ROCm的llama.cpp上,于双AMD Radeon R9700配置下运行Qwen 3.6 27B Q8模型的技术报告,包括性能基准测试和配置详情。
@leopardracer: https://x.com/leopardracer/status/2055341758523883631
一位用户分享了他们搭建双GPU本地AI实验室的经验,使用了RTX 4080 Super和5060 Ti,通过llama.cpp和llama-swap运行Qwen 3.6模型,以降低API成本并实现无限制的实验。
hipEngine:面向RDNA3(Strix Halo、7900 XTX)的快速原生Qwen 3.6推理引擎
hipEngine是一个新的开源、ROCm原生LLM推理引擎,专为AMD RDNA3 GPU设计,在Qwen 3.6模型上相比llama.cpp提供有竞争力的预填充和解码性能。
2026年中ROCm状态 [D]
作者询问2026年中AMD的ROCm生态系统在AI训练领域的当前可行性,将其与NVIDIA的CUDA进行比较,并询问它是否已达到PyTorch的“开箱即用”阶段。