本文提出了一种多阶段训练方案,包括大规模蒸馏、滚动偏好优化和可验证奖励的强化学习,显著提升了小型语言模型在数学推理任务中的性能,使3.8B参数的Phi-4-Mini-Reasoning模型超过了近两倍参数的开源基线模型。
Small Language Model, Chain-of-Thought, Mathematical Reasoning, Distillation, Reinforcement Learning, Preference Optimization
Haoran Xu, Baolin Peng, Hany Awadalla, Dongdong Chen, Yen-Chun Chen, Mei Gao, Young Jin Kim, Yunsheng Li, Liliang Ren, Yelong Shen, Shuohang Wang, Weijian Xu, Jianfeng Gao, Weizhu Chen
Microsoft
Generated by grok-3-mini-latest
Background Problem
大型语言模型(LLMs)通过链式思考(Chain-of-Thought, CoT)技术显著提升了推理能力,但小型语言模型(SLMs)由于模型容量有限,改进推理能力仍面临挑战。最近的研究表明,从LLM生成的合成数据中进行蒸馏可以显著提升SLM的推理性能,例如Deepseek-R1的工作将Llama-8B在MATH-500上的准确率从44.4%提高到89.1%。然而,现有的方法缺乏详细的训练方案,且简单应用某些技术(如直接蒸馏小数据集)可能导致性能下降。本文的目标是解决这一问题,提出一个系统的、多阶段的训练配方,针对SLM在数学推理任务中的能力提升,提供一个全面且有效的解决方案。
Method
- 核心思想: 本文提出一种多阶段连续训练范式,旨在通过逐步构建和优化SLM的推理能力,解决SLM容量受限的挑战。核心在于将蒸馏、偏好学习和强化学习结合,逐步从基础推理能力到高级优化。
- 如何实现: 训练过程分为四个主要步骤:
- 蒸馏作为中期训练(Distillation as Mid-Training):使用大规模的合成CoT数据(从Deepseek-R1采样),通过标准因果语言建模目标进行下一token预测训练,采用打包模式提高效率,目标是嵌入通用的推理能力。
- 蒸馏作为监督微调(Distillation as Supervised Fine-Tuning):在中期训练基础上,使用高质量的CoT数据子集进行微调,不采用打包模式,教导模型决定生成停止点,焦点在提高泛化能力。
- 滚动偏好学习(Rollout Preference Learning):利用被拒绝的LLM生成样本构建偏好对,应用直接偏好优化(DPO),公式为:,其中和分别为优选和劣选回滚。
- 使用可验证奖励的强化学习(RL with Verifiable Reward):基于PPO或GRPO算法,引入改进如提示优化、奖励再平衡(通过过采样和过滤)和温度退火(从1.0线性衰减到0.6),奖励函数基于最终答案正确性:。
- 主要步骤: 整个过程从大规模数据蒸馏开始,逐步过渡到偏好优化和在线RL,确保训练稳定性和性能提升,而不直接修改基模型。
Experiment
- 实验设置: 本文在数学推理任务上评估模型,包括AIME 2024、MATH-500和GPQA Diamond数据集。每个任务进行3次运行,报告平均性能,生成参数设置为温度0.6、top_p 0.95、最大序列长度32K。基线模型包括o1-mini、DeepSeek-R1-Distill-Qwen-7B等。训练设置:蒸馏阶段使用批量大小128、学习率1e-5、5个周期;DPO阶段学习率5e-7、1个周期;RL阶段学习率5e-7、序列长度25K。
- 数据集和原因: 使用了多样化的数据集,如AIME 2024测试高难度问题、MATH-500评估一般数学推理、GPQA Diamond检查复杂查询。选择这些数据集是因为它们可验证且覆盖不同难度水平,实验设计旨在验证每个训练阶段的贡献,通过消融研究(如表3所示)分析性能提升。
- 结果和预期匹配: Phi-4-Mini-Reasoning(3.8B参数)在MATH-500上达到94.6%、AIME上57.5%,超过了更大模型(如DeepSeek-R1-Distill-Qwen-7B高3.2点)。消融实验显示,每个阶段逐步提升性能(例如,添加RL后平均提升7点),与预期一致,证明了多阶段训练的有效性。实验全面合理,稳定性通过与DAPO的比较得到验证,pass@k和cons@16指标显示改进。
Further Thoughts
这项工作突出了数据质量和训练策略在SLM性能提升中的关键作用,可能启发其他领域如代码生成或常识推理的应用,因为类似的多阶段方法可以帮助模型在资源受限环境下实现高效学习;此外,它强调了RL稳定性的重要性,未来可以探索与其他优化技术(如元学习)的结合,以提高模型的泛化能力和部署效率,特别是在边缘计算场景中。