
自制医疗迷你深度探测器 R1:用强化学习进行微调推理,成本低于 $3!
- Rifx.Online
- Machine Learning , Health , Natural Language Processing
- 05 Mar, 2025
使用强化学习构建自己的医疗迷你 DeepSeek R1
使用 Unsloth 和 TRL 在 T4 GPU 上微调您自己的多领域推理模型,费用低于 3 美元。
介绍
大型语言模型(LLMs)与医疗保健的交集带来了令人兴奋的机遇,但也面临独特的挑战。在本教程中,我们将探讨如何使用**组相对策略优化(GRPO)**来调整阿里巴巴的Qwen-3B模型以进行医学推理——这是DeepSeek团队最近提出的一种新兴强化学习技术,提供了稳定的训练并且具有较低的内存需求[1]。
为什么这很重要
- 🏥 患者安全第一:医疗AI中的幻觉可能是危险的
- 💡 领域专业化:通用LLM在临床推理方面表现不佳
- ⚡ 效率:我们的3B参数模型可在消费级GPU上运行
如O3和DeepSeek R1等推理模型在许多具有挑战性的基准测试中显示出了前所未有的改进。它们改变了监督微调的趋势,转向了一种实用的强化学习(RL)。我们在深度学习领域所取得的许多突破主要来自RL,例如AlphaGo,因为该模型能够通过与不同的现实场景互动进行学习,这些场景在监督微调中往往很难提供示例。
如果您有兴趣了解更多关于推理模型或历史的详细信息,我强烈推荐Maarten的文章[2]。DeepSeek工作的美在于他们使用GRPO实现了LLM微调的实用框架。根据Maarten的文章:
该算法背后的直觉是,它使导致正确或错误答案的所有选择变得或多或少可能。这些选择可以是令牌的集合,也可以是推理步骤。
正如您在下面的图像中看到的:目标是激励模型在正确的*
好的,背景介绍到此为止,让我们动手实践。本文中使用的代码可以作为colab notebook轻松运行,您可以使用T4免费层。
安装 Unsloth 和 TRL
开源技术已经取得了长足的进步- 在本教程中,我们将使用两个令人惊叹的开源库:
Unsloth : 一个帮助我们从 GPU 中挤出尽可能多内存并提高训练性能的库。
TRL: 来自 huggingface 的开源库,帮助我们实现 GRPO。
我们还将使用 Qlora,这是一种能帮助我们以更高效的内存方式微调模型的技术。如果您想了解更多关于 Qlora 的信息,我强烈推荐 Sebastian 的 文章。
!pip install unsloth vllm # Memory-efficient training & inference
!pip install trl@git+https://github.com/huggingface/trl # GRPO implementation
from unsloth import FastLanguageModel, PatchFastRL
PatchFastRL("GRPO", FastLanguageModel)
下载并初始化模型
我们将首先下载模型,并利用50%的GPU容量以及vLLM推理来加速使用Qlora的GRPO训练。
from unsloth import is_bfloat16_supported
import torch
max_seq_length = 2048 # 可以增加以获得更长的推理轨迹
lora_rank = 64 # 较大的秩 = 更智能,但更慢
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = "Qwen/Qwen2.5-3B-Instruct",
max_seq_length = max_seq_length,
load_in_4bit = True, # LoRA 16bit时为False
fast_inference = True, # 启用vLLM快速推理
max_lora_rank = lora_rank,
gpu_memory_utilization = 0.5, # 如果内存不足则减少
)
model = FastLanguageModel.get_peft_model(
model,
r = lora_rank, # 选择任何大于0的数字!建议8, 16, 32, 64, 128
target_modules = [
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",
], # 如果内存不足则移除QKVO
lora_alpha = lora_rank,
use_gradient_checkpointing = "unsloth", # 启用长上下文微调
random_state = 3407,
)
关键选择:
- 量化:启用16/24GB GPU训练(兼容T4/A10)
- LoRA秩64:平衡性能与内存
- vLLM集成:在RL期间生成速度提升50%
数据策略:医疗推理鸡尾酒
我们使用 Hugging Face 的 interleave_datasets
混合三组关键数据集:
PubMedQA(占总数据的 70%):
- 临床问题回答,答案为 是/否/可能
- 过滤至 <1024 个标记以提高内存效率
GSM8K:
- 数学应用题以维持数字推理能力
健康基准:
- 50 多个医学专业的多项选择题
- 类别涵盖心脏病学到疫苗接种
专业提示:权重应反映数据集的复杂性——PubMedQA 的曝光率是其他数据集的 3 倍,以处理其细微差别。我们在这里没有使用任何权重,但我们对数据集进行了洗牌,且由于我们有三倍的 PubMedQA 样本,因此我们有三倍的机会将这些示例展示给模型。
import re
from datasets import load_dataset, Dataset, interleave_datasets, concatenate_datasets
## Load and prep dataset
SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""
XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""
def extract_xml_answer(text: str) -> str:
answer = text.split("<answer>")[-1]
answer = answer.split("</answer>")[0]
return answer.strip()
def extract_hash_answer(text: str) -> str | None:
if "####" not in text:
return None
return text.split("####")[1].strip()
## uncomment middle messages for 1-shot prompting
def get_datasets(split = "train") -> Dataset:
data = load_dataset('openai/gsm8k', 'main')[split] # type: ignore
data = data.map(lambda x: { # type: ignore
'prompt': [
{'role': 'system', 'content': SYSTEM_PROMPT},
{'role': 'user', 'content': x['question']}
],
'answer': extract_hash_answer(x['answer']),
'db_set':'gsm8k'
}) # type: ignore
data = data.remove_columns(['question'])
data_qa = load_dataset("qiaojin/PubMedQA", "pqa_artificial")[split] # two times more than other datasets
data_qa = data_qa.filter(lambda x: len("\n".join(x['context']['contexts'])) < 1024) # avoid long traces
data_qa = data_qa.map(lambda x: { # type: ignore
'prompt': [
{'role': 'system', 'content': SYSTEM_PROMPT},
{
"role": "user",
"content": "Given the scientific context below:\n" +
"\n".join(x['context']['contexts']) +
"\n\nAnswer the following question:\n" +
x['question'] +
" with 'yes', 'no' or 'maybe'. You need to carefully review the context and reason before answering."
},
],
'answer': x['final_decision'],
'db_set': 'pubmedqa'
}) # type: ignore
data_qa = data_qa.remove_columns(['pubid', 'question', 'context', 'long_answer', 'final_decision'])
categories =['Lab_Medicine', 'Wearables', 'Dermatology', 'Gastroenterology', 'Internal_Medicine', 'Oncology', 'Orthopedics', 'General_Surgery', 'Ophthalmology', 'Audiology', 'Head_Neck_Surgery', 'Elderly_Care', 'Pediatrics', 'Allergy_Immunology', 'Rheumatology', 'Pharmacy', 'Obstetrics_Gynecology', 'Microbiology', 'Dentistry', 'Physical_Medicine_and_Rehabilitation', 'Neurology', 'Psychiatry', 'Pathology', 'Genetics', 'Rare_Diseases', 'Hematology', 'Emergency', 'Endocrinology', 'Radiology', 'Cardiology', 'Pulmonology', 'Infectious_Diseases', 'Critical_Care', 'Pediatric_Surgery', 'Neuroscience', 'Epidemiology', 'Fitness_Sports', 'Health_Education', 'Health_Economics', 'Health_Entrepreneurship', 'Hospital_Management', 'Mental_Health', 'Nutrition', 'Palliative_Care', 'Preventive_Medicine', 'Public_Health', 'Social_Media_Addiction', 'Sleep', 'Supplements', 'Vaccination', 'Work_Health', 'Wellbeing']
data_mc = concatenate_datasets([load_dataset("yesilhealth/Health_Benchmarks",i)[i] for i in categories])
data_mc = data_mc.map(lambda x: { # type: ignore
'prompt': [
{'role': 'system', 'content': SYSTEM_PROMPT},
{
"role": "user",
"content": "\n\nAnswer the following question:\n" +
x['Questions'] +
"\n With 'A', 'B', 'C' or 'D'. You need to carefully review the context and reason before answering."
},
],
'answer': x['Answers'],
'db_set': 'med_mc'
}) # type: ignore
data_mc = data_mc.remove_columns(['Answers', 'Questions'])
dataset = concatenate_datasets([data, data_qa, data_mc])
return dataset
秘密配方:奖励工程
我们的多重奖励系统同时教授 推理结构 和 医学准确性(详见 notebook 的奖励函数):
def correctness_reward(responses, answers):
# Gives 2.0 for exact matches, 1.0 for partial
return [2.0 if match else (1.0 if partial else 0.0)...]
def format_reward(completions):
# Enforces <reasoning>...</answer> structure
return [0.5 if re.match(XML_PATTERN) else 0.0...]
奖励层级:
- 正确性(50% 权重):与真实情况的对齐
- 格式化(30%):XML 风格的推理痕迹
- 中间检查(20%):有效答案类型
类比:就像教导医学住院医生——既要赞扬诊断准确性,也要重视适当的文档记录。
GRPO训练配置
这些参数大多是猜测,并未经过优化,但在我最初的实验中效果不错。欢迎根据您的使用案例进行调整和实验。
from trl import GRPOConfig, GRPOTrainer
training_args = GRPOConfig(
use_vllm = True, # use vLLM for fast inference!
learning_rate = 5e-6,
adam_beta1 = 0.9,
adam_beta2 = 0.99,
weight_decay = 0.1,
warmup_ratio = 0.1,
lr_scheduler_type = "cosine",
optim = "adamw_8bit",
logging_steps = 1,
bf16 = is_bfloat16_supported(),
fp16 = not is_bfloat16_supported(),
per_device_train_batch_size = 1,
gradient_accumulation_steps = 1, # Increase to 4 for smoother training
num_generations = 6, # Decrease if out of memory
max_prompt_length = 1024,
max_completion_length = 1024,
#num_train_epochs = 1, # Set to 1 for a full training run
max_steps = 750,
save_steps = 100,
max_grad_norm = 0.1,
report_to = "none", # Can use Weights & Biases
output_dir = "outputs",
)
trainer = GRPOTrainer(
model = model,
processing_class = tokenizer,
reward_funcs = [
xmlcount_reward_func,
soft_format_reward_func,
strict_format_reward_func,
int_reward_func,
correctness_reward_func,
],
args = training_args,
train_dataset = train_dataset,
eval_dataset=test_dataset,
)
trainer.train()
就这样——由于我们的奖励函数质量很高——您应该很快看到奖励函数快速上升(RL微调的前12个步骤)
模型开始发展推理能力
让我们看看模型在训练过程中生成的一些示例答案:
Prompt:
Given the scientific context below:
We have previously shown the detrimental effects of 19 sub-erythemal exposures to daily ultraviolet radiation (DUVR, which mimics non-extreme exposure conditions), delivered over 4 weeks to volunteers. This source had UVA (320-400 nm) to UVB (290-320 nm) irradiance ratio of 25, instead of that close to 10 that is typically the case with solar-simulated radiation (SSR) that represents summer global sunlight with a clear sky and quasi-zenith solar irradiance.
Here, we report on an extension of this previous study, in which we evaluated the photoprotection afforded by a broad-spectrum daily-care product with a low-sun protection factor (SPF 8, UVA-PF 7 and 3* rated UVA protection). We assessed cellular and molecular markers of photodamage that are relevant to skin cancer and photoageing.
This study shows that biological effects of repeated exposure to DUVR can be prevented by a broad-spectrum daily-care product and that the level of protection afforded varies with the studied endpoint.
Answer the following question:
Does a broad-spectrum sunscreen prevent cumulative damage from repeated exposure to sub-erythemal solar ultraviolet radiation representative of temperate latitudes? with 'yes', 'no' or 'maybe'. You need to carefully review the context and reason before answering.
Model Response:
<reasoning>
The context provided discusses a study where repeated exposures to sub-erythemal UV radiation (DUVR) can cause detrimental effects over a period of 4 weeks. It also mentions that a broad-spectrum daily-care product with an SPF 8 and UVA protection factor with a 7 and 3* rating was used. The study aims to evaluate whether this broad-spectrum sunscreen can prevent cumulative damage from such repeated exposures. The mention of repeated exposure and the presence of a broad-spectrum product strongly suggests that the sunscreen can prevent cumulative damage. However, the actual effectiveness of the sunscreen in preventing cumulative damage is not explicitly stated in the provided context. Given the experimental setup and the corrective action (broad-spectrum product) taken, it can be inferred that the answer is likely 'yes'.
The sunscreen contains SPF 8, which means it provides protection to around 92% of the UVB radiation, and it has a UVA protection factor (UVA-PF) of 7, which likely provides varying degrees of protection to different wavelengths, including those of UVA. Such broad-spectrum protection would effectively prevent the cumulative damage associated with prolonged exposure to sub-erythemal UV radiation.
</reasoning>
<answer>
yes
</answer>
正如您所看到的,模型开始将其思维组织成良好的推理块——现在您可以想象,我们可以通过使用这样的技术大大提高模型在专业领域进行检索增强生成(RAG)的能力。我们甚至可以教模型编写 SQL 代码、医学术语标准化等,潜力无穷!