Skip to content

Direct Preference Optimization (DPO)

什么是DPO

DPO是一种用于大语言模型(LLM)对齐的新方法,它的主要目的是让AI模型的输出更好地符合人类偏好。 它是RLHF(基于人类反馈的强化学习)的一种替代方案,该方法于2023年提出, 显著简化了语言模型与人类偏好的对齐过程

在DPO之前提出之前

在 DPO 提出之前,RLHF 的训练过程通常分为两个阶段:首先是监督微调(SFT),然后是强化学习(RL)。

第一步是训练 reward model。

训练数据是同一个 prompt 的 2 个回答,让人或 GPT4 标注哪个回答更好,reward model 会去优化如下的 loss

其中 就是 reward model 用来给回答打分。 是训练数据集, 是 prompt, 和 分别是好的回答和不好的回答。也就是说,要尽可能让好的回答的得分比不好的回答高,拉大他们之间的差别。

第二步是用 RL 算法来提升模型的得分。 。

loss 如下:

其中是我们在训练的 LLM,是训练的初始值。这个 loss 意思是希望 LLM 输出的回答的评分能尽可能高,同时不要偏离太多,保证它还能正常做回答,不要训成一个评分很高但是回答乱码的东西。

DPO 的贡献

DPO 的公式推导

如果我们归一化一下分母,即取

TIP

注意这里是关于的函数与无关,这里很重要因为我们的loss是不需要计算关于输入的损失的 所以下面变换时可以忽略

也就可以构造出一个新的概率分布:

我们将归一化后的新概率分布替换原来的分子那么我们可以将优化目标重新表示为:

TIP

这里有一个非常重要的point我们将 表达式中的 巧妙的转换成其他新loss的KL散度的形式,

The key point is 我们不需要再面对公式中 sample操作,这样我们可以进行梯度下降backprop 同时我们将loss转换成了的KL散度。这非常的amazing啊! 我们目前只需要最小化新的概率分布 就完成了之前的复杂的RL算法

现在的问题是如何get的概率分布,我们可以通过的定义式 get到的关系,那么我们做一些简单的变换可以得到

这时候我们将的表达式直接带入到上面的reward model的loss function

就会得到下面的表达式

OK 现在我们只差最后一步,在上面推导出的表达式 中我们可以发现我们实现需要的是最小化的散度, 那么当我们令散度等于0时实际我么就是在优化

所以我们不如直接将loss直接写成这样

这非常的amazing啊!我们竟然可以通过这个trick将两阶段的train简化成一阶段,同时还不用优化了sampling的操作让损失可以直接backprop

实际应用

  • 可以用于大语言模型的微调和对齐

  • 在实践中比传统RLHF方法更容易实现和部署

  • 已经在多个开源项目中得到应用和验证

结论

通过以上的推导,我们可以看到DPO方法通过将RLHF的两阶段训练简化为单阶段训练,大大简化了训练过程。 同时使用KL 散度的形式来优化训练过程使得模型可以直接backprop而不需要用trick去解决sample的问题并使得训练内存需求得以缓解。