Openrlhf的DPO源码
之前只看过trl的dpo,今天看了下openrlhf的dpo的实现逻辑:跟想象中差不多,毕竟dpo公式很简单,所以实现都大差不多。\n基本就是要实现 Policy模型的好答案生成概率,坏答案的生成答案,同理生成Reference模型的这两个,计算为1234,然后log(1/3)就是好答案的隐式奖励,log(2/4)就是坏答案上的隐式奖励。\n具体实现\n overview:将log(1/3) – log(2/4) -> log(1/2) – log(3/4) = pi_logratios – ref_logratios,第一个用策略模型forward算出 pi_logratios,第二个用参考模型 foward 算出ref_logratios\n而pi_logratios = policy_chosen_logps – policy_rejected_logps,ef_logratios = reference_chosen_logps – reference_rejected_logps,关键就是要算出这个ps,而这个的话,openrlhf是通过 将 (Q, Agood) 和 (Q, Abad)给放在一个batch里,也就是batchsize变成了两倍。进而调用LLM即可。\n而dpo的亲戚们,cdpo标签平滑版,ipo等都是简单用if加上一两个因子即可\n其他:dpo的亲戚们,cdpo标签平滑版,ipo等都是简单用if加上一两个因子即可