当前位置: 首页 > 手游 > 原神

多任务学习梯度冲突在电商场景的改进

来源:网络 时间:2023-07-29 06:40:22
导读电商算法团队针对多任务学习的梯度冲突问题,改进CAGrad与GradNorm,提出CAGN算法,上线后在电商信息流推荐中带来GMV +8%的提升。

本文作者:韩肖,赵呈路

编辑整理&彩蛋:CastellanZhang

电商算法团队针对多任务学习的梯度冲突问题,改进CAGrad与GradNorm,提出CAGN(Conflict-Averse Gradient Normalization for Multi-task Learning)算法,上线后在电商信息流推荐中带来GMV +8%的提升。

1. 背景

在多任务学习中,如何实现知识共享通常比较受关注,如谷歌、腾讯为实现表征层面知识共享,设计出MMOE和PLE等网络结构,再比如我们在label层面引入蒸馏学习解决CTR、CVR目标不一致问题,又或者利用因果推断技术衡量多任务之间的相互影响。但与表征层面相比,多任务优化问题却鲜有受到关注,一般而言,梯度下降及其变种是数值计算常使用的优化方法,如:GradNorm[1]提出拉齐多任务梯度量级以及平衡loss变化速率,进而提高多任务的建模效果;CAGrad[2]提出解决多任务中梯度冲突来实现多任务的最优化。然而现有算法只关注了多任务优化时的某一个方面,如GradNorm依靠拉齐量级来实现任务同等重要,但忽略了不同任务梯度之间方向存在冲突,而CAGrad解决多任务之间的梯度冲突,却忽视量级大小以及学习速率快慢,导致共享层极易偏向于某种任务。因此我们综合GradNorm与CAGrad的优点并进一步改进,提出CAGN(Conflict-Averse Gradient Normalization for Multi-task Learning)算法,同时解决多任务优化时的三大问题:

不同任务梯度量级的差异

数量级较大的任务往往会主导共享层模型参数,导致梯度量级小的任务无法得到有效的共享表征;

不同任务学习速率差异

不同任务的学习难度不同,数据量少的任务如CVR模型收敛快、学习速率快,数据量大的任务如CTR模型反而不容易收敛,在联合训练多个epoch情况下,CTR还没收敛,CVR甚至已经过拟合;

任务梯度也会相互冲突

在共享参数进行梯度更新时,不同任务产生梯度可能会存在方向上背离,导致优化抵消,进而都无法优化。

2. CAGN算法

算法整体上的思路是分两阶段,第一阶段给每个任务的损失函数配上合适的权重,目的是将不同任务的梯度量级拉齐,第二阶段设计合适的参数更新方法,解决常规优化方法的梯度冲突问题。

多任务总的损失函数  通常可以设计为如下线性形式, 为任务数, 为training step, 为  的权重,输入为batch样本  , 为  上的loss:

这里权重  为  的函数,通过我们的算法在每个时间步会自适应调整。

2.1 梯度量级

为保证各任务梯度的量级一致,参考GradNorm算法,选择各任务在共享网络最后一层产生的梯度作为目标变量,则目标变量的范数为

其中参数  为共享网络最后一层;并假设为影响量级唯一变量。

为保证任务优化的效率,区别于GradNorm采用的梯度范数的期望值作为对齐目标,而采用最大梯度范数作为各任务对齐的目标,其中:

变量  的目标函数为:

这里看作仅关于权重向量的损失函数,需计算其梯度来更新每个,用以保证多任务在共享层产生梯度量级近似于相同。

2.2 收敛速率

对齐量级后假设多任务需优化多轮,任务会因收敛速率不一致,出现任务A过拟合,任务B仍处于欠拟合现象,因此在量级对齐过程中需参数调控,控制多任务的收敛快慢。对于多任务损失函数

参考GradNorm定义以下变量,表征的学习速率:

其中  为  在训练时间步的loss比率; 在学习过程中,损失下降越快表明其收敛越快,定义变量  度量任务的逆训练速率(inverse training rate),具体来说, 的值越高,学习速率越慢,任务产生梯度量级就应该越大,以鼓励任务更快地训练,反之亦然。 是每个时间步  各任务想要达到目标,通过该目标值拉齐梯度量级大小。因此速率和最大量级两者进行结合,具体设定  为每个任务  的新目标值,即:

其中  为超参数,主要用来调节学习速率的快慢,在我们任务中这个值在0.15附近。有了目标值,只需要定义两者差的L1范数或者L2范数作为损失来更新唯一变量 :

在对损失函数求导时,需要将目标值看作常数,代码中通过stop-gradient来避免目标值发生更新。因为新模型开始训练时候,受初始化影响loss是极不稳定的,为此我们使用最近的10步均值来代替 :

2.3 梯度冲突

上述梯度设计,虽然能够保证每个任务在底层共享网络中施加相同的影响力,但他们之间仍会出现优化方向的冲突,因此借鉴CAGrad思路从本质出发,寻找一个优化向量既能减小多任务总的加权损失,又能减小每个任务的单独损失。

令  加上权重后的损失为 ,假设模型通过来更新参数,其中是步长,是更新向量,为共享层的参数,为了达到这个目的,考虑最小化多任务中表现最差的任务损失:

如果始终成立,意味着表现最差的任务其loss一直在下降,这也意味着所有损失都会下降。因此我们目标即找到这个最佳向量让最差的loss下降最多,数学表达为:

对目标函数做变换:

泰勒展开忽略高阶项

这里 ,因此最优化目标变为:

这里给  加了一个约束,其中 , 是一个预先指定的超参数,用来控制搜索半径。由于是离散值,不能使用导数求解,这里做个变化,将离散优化问题转换连续优化问题:

假设,则;

首先,对任意 ,,有

其次,对于满足,有

因此

这样就把一个离散问题转化为带约束的连续优化问题。

记,,原始问题可改写为

对应的拉格朗日函数

在互补松弛条件下,原始问题最优化等价于拉格朗日函数的最优化,又因为下式:

因此最优化变形为:

=0,w\in \mathcal{W}}\quad \hat g_w^\top d - \lambda/2(\Vert\hat g_0-d\Vert^2-\phi) " data-formula-type="block-equation" style=" display: block; text-align: center; overflow: auto; display: block; -webkit-overflow-scrolling: touch; " data-tool="mdnice编辑器">

因为变量在外侧不能直接求导,需要先对 求导,直接对 求导又回到了原来的问题,并没什么意义,但目标满足强对偶性成立的条件:目标函数为凸的且约束条件为仿射函数,求解原始问题等价于求解其对偶问题(交换min/max):

=0,w\in \mathcal{W}} \max _{d\in R^m}\quad \hat g_w^\top d - \lambda/2(\Vert\hat g_0-d\Vert^2-\phi) " data-formula-type="block-equation" style=" display: block; text-align: center; overflow: auto; display: block; -webkit-overflow-scrolling: touch; " data-tool="mdnice编辑器">

固定,关于求导求解最大值

代入原式

=0}\hat g_{w}^\top (\hat g_{0}+\hat g_{w}/\lambda) - \frac{\lambda}{2}\Vert\hat g_w/\lambda\Vert^2+\frac{\lambda\phi}{2} " data-formula-type="block-equation" style=" display: block; text-align: center; overflow: auto; display: block; -webkit-overflow-scrolling: touch; " data-tool="mdnice编辑器">

等价于

=0}\hat g_{w}^\top \hat g_{0}+\frac{1}{2\lambda}||\hat g_w||^2+\frac{\lambda\phi}{2} " data-formula-type="block-equation" style=" display: block; text-align: center; overflow: auto; display: block; -webkit-overflow-scrolling: touch; " data-tool="mdnice编辑器">

固定,的最优解为

代入后只剩下  的最优化:

我们可以设 ,变成一个  的无约束优化问题,最简单的SGD即可求解。

求得最优解  后,可根据

来确定最佳梯度方向,从而使表现最差任务的loss达到最小值。

2.4 伪代码

整个算法分为两步,首先通过优化任务权重让每个任务在共享层产生梯度量级近似,进而实现每个任务同等重要。第二步在共享层的平均梯度附近,寻找一个最优的使得表现最差的任务能优化到其能达到的最小值,从而尽量避免梯度冲突。

3. 实验

首先我们利用仿真实验直观说明CAgrad不考虑任务量级所引发的问题,然后在实际场景对比一下各个算法的效果。

3.1 CAGrad仿真实验

考虑两个量级差异较大且方向冲突的梯度向量和,根据不同的确定最佳,只有  足够大时,最佳梯度  才能和 、 都不产生冲突。但  如果太大又可能影响算法效果和收敛性。

但是当拉齐量级后,解决冲突问题就相对变得很容易, 就算很小,比如0.05,依旧能够找到不产生冲突的最佳梯度 。

3.2 离线实验

我们采用小米商城线上真实数据,其中训练集2023-04-27~2023-05-03,样本量在亿级别。预测集采用线上2023-05-04日志,在我们base模型上采用cold-start方式进行训练。综合来看,CAGN相比base提升最为明显。

算法CTR-AUCCVR-AUCCTCVR-AUCBase0.7260.8260.8559GradNorm:L00.6730.8370.858GradNorm:(t-10:t-1)0.7150.8510.873CAGrad0.7280.8420.8738CAGrad+速率控制0.72770.8320.862CAGrad+量级控制0.72770.84790.876CAGN0.7370.84220.8765

3.3 线上效果

将CAGN运用在重排LTR模型,调整包括点击、成交、GMV、不同类目曝光占比等多个目标,在其他目标不降的情况下,信息流实验组相比对照组GMV提升8%。

4. 参考文献

[1] GradNorm: Gradient Normalization for Adaptive Loss Balancing in Deep Multitask Networks, 2017

[2] Conflict-Averse Gradient Descent for Multi-task Learning, 2021

彩蛋部分

“大哥,我看这小子腹部乌黑,想必中了铁砂掌,当以真气贯注他的足阳明胃经医治。”绿袍老者说着便抓起马侬右脚,一股热气从冲阳穴透了进去。

“错了错了,二弟你只见其表,未见其里,铁砂掌中含有剧毒,须得护住心脉,以防毒气攻心。”红袍老者一指点在马侬胸口天池穴,注入一道真气。

“大哥二哥,打架小弟不及你们厉害,这疗伤的本领却是你们不及我了。我看这小子四肢僵直,神志不清,当以真气游走于督脉诸穴,方可起死回生。”蓝袍老者伸掌按在马侬头顶,催动真气注入百会穴。

马侬昏昏沉沉中只觉得体内气血翻滚,三股热气盘旋冲撞,经脉中诸处穴道犹如刀戳一般,剧痛难忍。马侬猛然间清醒过来,但四肢百骸动弹不得,张口大叫却发不出半点声音,耳听得三位老者仍在争辨,方知有人给自己疗伤,只是意见不一,各凭己法施以真气。马侬心中叫苦:“这三人好心医我,但这三道真气强弱不同,所走经脉各异,胡乱冲撞下去我命休矣。”

三位老者运功多时,见马侬气息依旧微弱,各自焦躁起来,口中互相埋怨,愈加催动内力。马侬只觉得身体越来越热,如煎如沸,一阵眩晕,恍惚中忆起刚入师门之时有次练功过勤,岔了气息,也是这般痛楚煎熬,幸亏师父在旁守护方得无虞,事后师父传了一篇禹王诀,嘱咐自己若再有走火入魔之时,可按此心法行功,顺气归源。想到此处,马侬一阵欣喜,今日之境不正如走火入魔一般么?忙运起禹王诀,封闭全身大部分穴道,阻滞三股真气乱窜,经脉中只留下真气前方主干通道,因势利导,归拢约束,逐渐汇集为一股,沛然涌入督脉之中,一路上行,过百会,竟而直冲任脉。任督二脉既通,顷刻间内息运转了十二次周天,丹田中真气充盈欲裂,哇地一口淤血吐出,登时神清气爽,坐了起来。

END

往期回顾:

【大模型慢学】GPT起源以及GPT系列采用Decoder-only架构的原因探讨

GCN、GraphSAGE、MMD、对比学习、图增强,一锅乱炖

小米电商算法秘籍:移花接木-新模型如何热启动快速超越线上旧模型

小米电商推荐算法CVR模型实践

【论文学习】面向CTR预估的自适应参数生成网络

【论文学习】通过孪⽣⾃适应掩码层高效学习特征表示

【KDD2022论文详解】阿里妈妈基于对抗梯度的探索模型

一篇旧文:E&E论文学习笔记

深度学习参数初始化详细推导:Xavier方法和kaiming方法【二】

深度学习参数初始化详细推导:Xavier方法和kaiming方法【一】

Transformer做的千层饼

都看到这儿了,何不关注一下

声明:本网页内容旨在传播知识,若有侵权等问题请及时与本网联系,我们将在第一时间删除处理。E-MAIL:704559159@qq.com

Top
加盟网