本文提出 CoLM 方法,通过构建小批量核心集匹配大批量梯度,在内存需求减少 2 倍的情况下,使 LLM 微调性能优于 4 倍批大小的常规训练,同时提升收敛速度。
Large Language Model, Fine-tuning, Efficiency, Data Augmentation, Pre-training
Dang Nguyen, Wenhan Yang, Rathul Anand, Yu Yang, Baharan Mirzasoleiman
UCLA
Generated by grok-3
Background Problem
大型语言模型(LLM)的训练和微调需要大量计算资源和 GPU 内存,尤其是在存储参数、计算梯度和优化器状态(如 Adam 的动量和历史梯度)时。以 Phi-2(27 亿参数)为例,使用批大小 128 进行完整微调至少需要 44GB GPU 内存,这使得使用更大批大小(batch size)以提升收敛速度和性能变得不可行。本文从数据视角出发,针对这一内存瓶颈问题,提出通过构建小批量核心集(mini-batch coresets)来模拟大批量训练的效果,解决的关键问题是:在内存受限的情况下,如何通过选择代表性数据子集提升 LLM 训练的收敛速度和性能,尤其是在语言数据高度不平衡的场景下。
Method
本文提出了 Coresets for Training LLMs (CoLM) 方法,核心思想是通过构建小批量核心集来匹配大批量梯度,从而在内存受限的情况下模拟大批量训练的效果。具体步骤如下:
- 处理不平衡数据:针对语言数据中来源(sources)高度不平衡的问题,CoLM 提出将小来源(样本数少于平均值的来源)的所有样本保留到核心集中,以确保其代表性;对于大来源,则通过梯度匹配选择代表性样本(medoids)。此外,所有选中样本被赋予均匀权重,以平衡大小来源的学习速度。
- 适配 Adam 优化器:由于 Adam 是 LLM 训练的标准优化器,CoLM 通过历史指数平均值对梯度进行归一化(仅基于大来源样本计算历史项),以选择更适合 Adam 的核心集。
- 降低梯度维度:针对 LLM 梯度高维问题,CoLM 使用零阶方法(如 SPSA)计算最后一层 V-投影矩阵的平滑梯度,并通过稀疏化(保留归一化梯度幅度最大的维度)降低维度,使用 ℓ1 距离计算稀疏梯度间的差异以选择 medoids。 批判性思考:虽然 CoLM 的方法在理论上针对不平衡数据和 Adam 优化器进行了优化,但保留小来源所有样本的策略可能引入噪声或无关数据,影响模型性能。此外,零阶方法和稀疏化技术的计算开销在更大规模模型上可能显著增加,论文未充分讨论这一潜在问题。
Experiment
实验主要在微调和预训练两个场景下进行评估:
- 微调实验:在 MathInstruct 数据集(包含 260K 数学指令样本,14 个高度不平衡来源)上微调 Phi-2、Phi-3、Zephyr 和 Llama-3 模型,使用 LoRA 技术。结果显示,CoLM 以批大小 64(从 128 中选择)在内存需求减少 2 倍的情况下,性能优于批大小 256 的常规微调,平均准确率提升至 56.6%(对比常规微调的 55.3%)。在 SuperGLUE 基准数据集上,CoLM 通过聚类隐藏状态定义来源,性能平均提升 4.5%,并在某些任务上优于批大小 128 的常规微调。
- 预训练实验:在 Pile 数据集的子集上预训练 Llama-60M 模型,CoLM 在下游任务准确率上平均提升 1.4%,验证困惑度与 2 倍批大小的常规预训练接近。
- 消融研究:验证了保留小来源样本、单独选择大来源 medoids 和 Adam 归一化梯度等设计的有效性,其中保留小来源样本提升约 3% 准确率。 批判性思考:实验设置较为合理,涵盖了微调和预训练场景,并通过消融研究验证了各组件的作用。然而,实验主要集中在数学推理和分类任务,缺乏对生成任务等其他场景的测试,可能限制方法的普适性。此外,虽然内存减少和性能提升显著,但与最新数据选择或内存优化方法的对比不足,实验结果可能存在选择性偏见。
Further Thoughts
CoLM 的方法从数据选择角度解决内存问题,为 LLM 训练提供了一种新思路,特别是在资源受限环境下的应用潜力值得关注。然而,其假设小来源样本必须全部保留的策略可能在实际应用中面临挑战,例如在噪声数据较多的场景下可能导致性能下降。未来可以探索结合影响函数或数据质量评估的方法,动态筛选小来源样本以提升鲁棒性。此外,CoLM 与其他内存优化技术(如 LoRA)的兼容性是一个亮点,但其与最新方法(如基于梯度低秩投影的 GaLore)的联合效果值得进一步研究,尤其是在更大规模模型和多样化任务上的表现。另一个有趣的方向是探索 CoLM 在联邦学习(Federated Learning)中的应用,特别是在客户端资源受限的情况下,如何通过核心集选择提升分布式训练效率。