Type something to search...
自制医疗迷你深度探测器 R1:用强化学习进行微调推理,成本低于 $3!

自制医疗迷你深度探测器 R1:用强化学习进行微调推理,成本低于 $3!

使用强化学习构建自己的医疗迷你 DeepSeek R1

使用 Unsloth 和 TRL 在 T4 GPU 上微调您自己的多领域推理模型,费用低于 3 美元。

介绍

大型语言模型(LLMs)与医疗保健的交集带来了令人兴奋的机遇,但也面临独特的挑战。在本教程中,我们将探讨如何使用**组相对策略优化(GRPO)**来调整阿里巴巴的Qwen-3B模型以进行医学推理——这是DeepSeek团队最近提出的一种新兴强化学习技术,提供了稳定的训练并且具有较低的内存需求[1]

为什么这很重要

  • 🏥 患者安全第一:医疗AI中的幻觉可能是危险的
  • 💡 领域专业化:通用LLM在临床推理方面表现不佳
  • 效率:我们的3B参数模型可在消费级GPU上运行

如O3和DeepSeek R1等推理模型在许多具有挑战性的基准测试中显示出了前所未有的改进。它们改变了监督微调的趋势,转向了一种实用的强化学习(RL)。我们在深度学习领域所取得的许多突破主要来自RL,例如AlphaGo,因为该模型能够通过与不同的现实场景互动进行学习,这些场景在监督微调中往往很难提供示例。

如果您有兴趣了解更多关于推理模型或历史的详细信息,我强烈推荐Maarten的文章[2]。DeepSeek工作的美在于他们使用GRPO实现了LLM微调的实用框架。根据Maarten的文章:

该算法背后的直觉是,它使导致正确或错误答案的所有选择变得或多或少可能。这些选择可以是令牌的集合,也可以是推理步骤。

正如您在下面的图像中看到的:目标是激励模型在正确的**块中生成响应,以及一个我们可以轻松验证的最终正确答案(例如数学)。

好的,背景介绍到此为止,让我们动手实践。本文中使用的代码可以作为colab notebook轻松运行,您可以使用T4免费层。

安装 Unsloth 和 TRL

开源技术已经取得了长足的进步- 在本教程中,我们将使用两个令人惊叹的开源库:

Unsloth : 一个帮助我们从 GPU 中挤出尽可能多内存并提高训练性能的库。

TRL: 来自 huggingface 的开源库,帮助我们实现 GRPO。

我们还将使用 Qlora,这是一种能帮助我们以更高效的内存方式微调模型的技术。如果您想了解更多关于 Qlora 的信息,我强烈推荐 Sebastian 的 文章

!pip install unsloth vllm  # Memory-efficient training & inference
!pip install trl@git+https://github.com/huggingface/trl  # GRPO implementation
from unsloth import FastLanguageModel, PatchFastRL
PatchFastRL("GRPO", FastLanguageModel)

下载并初始化模型

我们将首先下载模型,并利用50%的GPU容量以及vLLM推理来加速使用Qlora的GRPO训练。

from unsloth import is_bfloat16_supported
import torch
max_seq_length = 2048 # 可以增加以获得更长的推理轨迹
lora_rank = 64 # 较大的秩 = 更智能,但更慢

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "Qwen/Qwen2.5-3B-Instruct",
    max_seq_length = max_seq_length,
    load_in_4bit = True, # LoRA 16bit时为False
    fast_inference = True, # 启用vLLM快速推理
    max_lora_rank = lora_rank,
    gpu_memory_utilization = 0.5, # 如果内存不足则减少
)

model = FastLanguageModel.get_peft_model(
    model,
    r = lora_rank, # 选择任何大于0的数字!建议8, 16, 32, 64, 128
    target_modules = [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ], # 如果内存不足则移除QKVO
    lora_alpha = lora_rank,
    use_gradient_checkpointing = "unsloth", # 启用长上下文微调
    random_state = 3407,
)

关键选择

  • 量化:启用16/24GB GPU训练(兼容T4/A10)
  • LoRA秩64:平衡性能与内存
  • vLLM集成:在RL期间生成速度提升50%

数据策略:医疗推理鸡尾酒

我们使用 Hugging Face 的 interleave_datasets 混合三组关键数据集:

PubMedQA(占总数据的 70%):

  • 临床问题回答,答案为 是/否/可能
  • 过滤至 <1024 个标记以提高内存效率

GSM8K

  • 数学应用题以维持数字推理能力

健康基准

  • 50 多个医学专业的多项选择题
  • 类别涵盖心脏病学到疫苗接种

专业提示:权重应反映数据集的复杂性——PubMedQA 的曝光率是其他数据集的 3 倍,以处理其细微差别。我们在这里没有使用任何权重,但我们对数据集进行了洗牌,且由于我们有三倍的 PubMedQA 样本,因此我们有三倍的机会将这些示例展示给模型。

import re
from datasets import load_dataset, Dataset, interleave_datasets, concatenate_datasets

## Load and prep dataset
SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""

XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""

def extract_xml_answer(text: str) -> str:
    answer = text.split("<answer>")[-1]
    answer = answer.split("</answer>")[0]
    return answer.strip()

def extract_hash_answer(text: str) -> str | None:
    if "####" not in text:
        return None
    return text.split("####")[1].strip()

## uncomment middle messages for 1-shot prompting
def get_datasets(split = "train") -> Dataset:
    data = load_dataset('openai/gsm8k', 'main')[split] # type: ignore
    data = data.map(lambda x: { # type: ignore
        'prompt': [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {'role': 'user', 'content': x['question']}
        ],
        'answer': extract_hash_answer(x['answer']),
        'db_set':'gsm8k'
    }) # type: ignore
    data = data.remove_columns(['question'])
    
    data_qa = load_dataset("qiaojin/PubMedQA", "pqa_artificial")[split] # two times more than other datasets
    data_qa = data_qa.filter(lambda x: len("\n".join(x['context']['contexts'])) < 1024) # avoid long traces
    data_qa = data_qa.map(lambda x: { # type: ignore
        'prompt': [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {
                "role": "user",
                "content": "Given the scientific context below:\n" + 
                          "\n".join(x['context']['contexts']) + 
                          "\n\nAnswer the following question:\n" +
                          x['question'] + 
                          " with 'yes', 'no' or 'maybe'. You need to carefully review the context and reason before answering."
            },
        ],
        'answer': x['final_decision'],
        'db_set': 'pubmedqa'
    }) # type: ignore
    data_qa = data_qa.remove_columns(['pubid', 'question', 'context', 'long_answer', 'final_decision'])
    
    
    categories =['Lab_Medicine', 'Wearables', 'Dermatology', 'Gastroenterology', 'Internal_Medicine', 'Oncology', 'Orthopedics', 'General_Surgery', 'Ophthalmology', 'Audiology', 'Head_Neck_Surgery', 'Elderly_Care', 'Pediatrics', 'Allergy_Immunology', 'Rheumatology', 'Pharmacy', 'Obstetrics_Gynecology', 'Microbiology', 'Dentistry', 'Physical_Medicine_and_Rehabilitation', 'Neurology', 'Psychiatry', 'Pathology', 'Genetics', 'Rare_Diseases', 'Hematology', 'Emergency', 'Endocrinology', 'Radiology', 'Cardiology', 'Pulmonology', 'Infectious_Diseases', 'Critical_Care', 'Pediatric_Surgery', 'Neuroscience', 'Epidemiology', 'Fitness_Sports', 'Health_Education', 'Health_Economics', 'Health_Entrepreneurship', 'Hospital_Management', 'Mental_Health', 'Nutrition', 'Palliative_Care', 'Preventive_Medicine', 'Public_Health', 'Social_Media_Addiction', 'Sleep', 'Supplements', 'Vaccination', 'Work_Health', 'Wellbeing']
    data_mc = concatenate_datasets([load_dataset("yesilhealth/Health_Benchmarks",i)[i] for i in categories])
    data_mc = data_mc.map(lambda x: { # type: ignore
        'prompt': [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {
                "role": "user",
                "content": "\n\nAnswer the following question:\n" +
                          x['Questions'] + 
                          "\n With 'A', 'B', 'C' or 'D'. You need to carefully review the context and reason before answering."
            },
        ],
        'answer': x['Answers'],
        'db_set': 'med_mc'
    }) # type: ignore
    data_mc = data_mc.remove_columns(['Answers', 'Questions'])
    
    dataset = concatenate_datasets([data, data_qa, data_mc])
    return dataset

秘密配方:奖励工程

我们的多重奖励系统同时教授 推理结构医学准确性(详见 notebook 的奖励函数):

def correctness_reward(responses, answers):
    # Gives 2.0 for exact matches, 1.0 for partial
    return [2.0 if match else (1.0 if partial else 0.0)...]

def format_reward(completions):
    # Enforces <reasoning>...</answer> structure
    return [0.5 if re.match(XML_PATTERN) else 0.0...]

奖励层级

  1. 正确性(50% 权重):与真实情况的对齐
  2. 格式化(30%):XML 风格的推理痕迹
  3. 中间检查(20%):有效答案类型

类比:就像教导医学住院医生——既要赞扬诊断准确性,也要重视适当的文档记录。

GRPO训练配置

这些参数大多是猜测,并未经过优化,但在我最初的实验中效果不错。欢迎根据您的使用案例进行调整和实验。

from trl import GRPOConfig, GRPOTrainer
training_args = GRPOConfig(
    use_vllm = True, # use vLLM for fast inference!
    learning_rate = 5e-6,
    adam_beta1 = 0.9,
    adam_beta2 = 0.99,
    weight_decay = 0.1,
    warmup_ratio = 0.1,
    lr_scheduler_type = "cosine",
    optim = "adamw_8bit",
    logging_steps = 1,
    bf16 = is_bfloat16_supported(),
    fp16 = not is_bfloat16_supported(),
    per_device_train_batch_size = 1,
    gradient_accumulation_steps = 1, # Increase to 4 for smoother training
    num_generations = 6, # Decrease if out of memory
    max_prompt_length = 1024,
    max_completion_length = 1024,
    #num_train_epochs = 1, # Set to 1 for a full training run
    max_steps = 750,
    save_steps = 100,
    max_grad_norm = 0.1,
    report_to = "none", # Can use Weights & Biases
    output_dir = "outputs",
)

trainer = GRPOTrainer(
    model = model,
    processing_class = tokenizer,
    reward_funcs = [
        xmlcount_reward_func,
        soft_format_reward_func,
        strict_format_reward_func,
        int_reward_func,
        correctness_reward_func,
    ],
    args = training_args,
    train_dataset = train_dataset,
    eval_dataset=test_dataset,
)
trainer.train()

就这样——由于我们的奖励函数质量很高——您应该很快看到奖励函数快速上升(RL微调的前12个步骤)

模型开始发展推理能力

让我们看看模型在训练过程中生成的一些示例答案:

Prompt:

Given the scientific context below:

We have previously shown the detrimental effects of 19 sub-erythemal exposures to daily ultraviolet radiation (DUVR, which mimics non-extreme exposure conditions), delivered over 4 weeks to volunteers. This source had UVA (320-400 nm) to UVB (290-320 nm) irradiance ratio of 25, instead of that close to 10 that is typically the case with solar-simulated radiation (SSR) that represents summer global sunlight with a clear sky and quasi-zenith solar irradiance.
Here, we report on an extension of this previous study, in which we evaluated the photoprotection afforded by a broad-spectrum daily-care product with a low-sun protection factor (SPF 8, UVA-PF 7 and 3* rated UVA protection). We assessed cellular and molecular markers of photodamage that are relevant to skin cancer and photoageing.
This study shows that biological effects of repeated exposure to DUVR can be prevented by a broad-spectrum daily-care product and that the level of protection afforded varies with the studied endpoint.

Answer the following question:
Does a broad-spectrum sunscreen prevent cumulative damage from repeated exposure to sub-erythemal solar ultraviolet radiation representative of temperate latitudes? with 'yes', 'no' or 'maybe'. You need to carefully review the context and reason before answering. 

Model Response:

<reasoning>
The context provided discusses a study where repeated exposures to sub-erythemal UV radiation (DUVR) can cause detrimental effects over a period of 4 weeks. It also mentions that a broad-spectrum daily-care product with an SPF 8 and UVA protection factor with a 7 and 3* rating was used. The study aims to evaluate whether this broad-spectrum sunscreen can prevent cumulative damage from such repeated exposures. The mention of repeated exposure and the presence of a broad-spectrum product strongly suggests that the sunscreen can prevent cumulative damage. However, the actual effectiveness of the sunscreen in preventing cumulative damage is not explicitly stated in the provided context. Given the experimental setup and the corrective action (broad-spectrum product) taken, it can be inferred that the answer is likely 'yes'.

The sunscreen contains SPF 8, which means it provides protection to around 92% of the UVB radiation, and it has a UVA protection factor (UVA-PF) of 7, which likely provides varying degrees of protection to different wavelengths, including those of UVA. Such broad-spectrum protection would effectively prevent the cumulative damage associated with prolonged exposure to sub-erythemal UV radiation.

</reasoning>

<answer>
yes
</answer>

正如您所看到的,模型开始将其思维组织成良好的推理块——现在您可以想象,我们可以通过使用这样的技术大大提高模型在专业领域进行检索增强生成(RAG)的能力。我们甚至可以教模型编写 SQL 代码、医学术语标准化等,潜力无穷!

完整笔记本在这里!

Related Posts

结合chatgpt-o3-mini与perplexity Deep Research的3步提示:提升论文写作质量的终极指南

结合chatgpt-o3-mini与perplexity Deep Research的3步提示:提升论文写作质量的终极指南

AI 研究报告和论文写作 合并两个系统指令以获得两个模型的最佳效果 Perplexity AI 的 Deep Research 工具提供专家级的研究报告,而 OpenAI 的 ChatGPT-o3-mini-high 擅长推理。我发现你可以将它们结合起来生成令人难以置信的论文,这些论文比任何一个模型单独撰写的都要好。你只需要将这个一次性提示复制到 **

阅读更多
让 Excel 过时的 10 种 Ai 工具:实现数据分析自动化,节省手工作业时间

让 Excel 过时的 10 种 Ai 工具:实现数据分析自动化,节省手工作业时间

Non members click here作为一名软件开发人员,多年来的一个发现总是让我感到惊讶,那就是人们还在 Excel

阅读更多
使用 ChatGPT 搜索网络功能的 10 种创意方法

使用 ChatGPT 搜索网络功能的 10 种创意方法

例如,提示和输出 你知道可以使用 ChatGPT 的“搜索网络”功能来完成许多任务,而不仅仅是基本的网络搜索吗? 对于那些不知道的人,ChatGPT 新的“搜索网络”功能提供实时信息。 截至撰写此帖时,该功能仅对使用 ChatGPT 4o 和 4o-mini 的付费会员开放。 ![](https://images.weserv.nl/?url=https://cdn-im

阅读更多
掌握Ai代理:解密Google革命性白皮书的10个关键问题解答

掌握Ai代理:解密Google革命性白皮书的10个关键问题解答

10 个常见问题解答 本文是我推出的一个名为“10 个常见问题解答”的新系列的一部分。在本系列中,我旨在通过回答关于该主题的十个最常见问题来分解复杂的概念。我的目标是使用简单的语言和相关的类比,使这些想法易于理解。 图片来自 [Solen Feyissa](https://unsplash.com/@solenfeyissa?utm_source=medium&utm_medi

阅读更多
在人工智能和技术领域保持领先地位的 10 项必学技能 📚

在人工智能和技术领域保持领先地位的 10 项必学技能 📚

在人工智能和科技这样一个动态的行业中,保持领先意味着不断提升你的技能。无论你是希望深入了解人工智能模型性能、掌握数据分析,还是希望通过人工智能转变传统领域如法律,这些课程都是你成功的捷径。以下是一个精心策划的高价值课程列表,可以助力你的职业发展,并让你始终处于创新的前沿。 1. 生成性人工智能简介课程: [生成性人工智能简介](https://genai.works

阅读更多
揭开真相!深度探悉DeepSeek AI的十大误区,您被误导了吗?

揭开真相!深度探悉DeepSeek AI的十大误区,您被误导了吗?

在AI军备竞赛中分辨事实与虚构 DeepSeek AI真的是它所宣传的游戏规则改变者,还是仅仅聪明的营销和战略炒作?👀 虽然一些人将其视为AI效率的革命性飞跃,但另一些人则认为它的成功建立在借用(甚至窃取的)创新和可疑的做法之上。传言称,DeepSeek的首席执行官在疫情期间像囤积卫生纸一样囤积Nvidia芯片——这只是冰山一角。 从其声称的550万美元培训预算到使用Open

阅读更多
Type something to search...