点击上方关注,All in AI中国

在上一篇文章中,我在Keras(这里)做了一些多任务学习,在完成那个之后,我想做一个关于在Pytorch中进行多任务学习的后续帖子。这主要是因为我认为在另一个框架中构建它对我来说是一个很好的练习,但是在这篇文章中,我将介绍在构建模型之后我做了一些额外的优化构建模型。

我还使用这个模型作为我在另一系列帖子(第1部分和第2部分)中构建的面部相似性管道的一部分。

这里提供一个快速的回顾,多任务模型是指单个模型优化以解决一系列通常相关的问题。在机械上,这是通过将模型管道的某些核心部分的输出馈送到一系列输出“磁头”来完成的,这些输出“磁头”的损失可以通过加法进行评分和组合,然后网络可以根据总计的总损耗调整其权重。

我注意到训练其中的一些问题的是,由于用于优化的损失函数是通过将各个损失函数相加而创建的,因此有点难以专注于改进单个任务头。然而,通过对Pytorch网络的一些早期实验,我发现传统的调整策略是相当有效的。

我使用的基本调整策略是吴恩达在他的在线讲座中概述的策略,其中大多数机器学习算法在方差和偏差之间存在折衷,但是对于神经网络,情况并非如此。通过神经网络,这种权衡并不是一种担忧,因为我们可以使用不同的机制来解决这两种问题。在网络不合适的情况下,您可以为其添加额外的计算能力,并且在过度拟合的情况下,您可以应用正规化,如丢失或批量标准化。通过平衡这两种机制的应用程序,您可以优化您的网络。

性别:女,地区:欧洲,格斗风格:近战,阵营:LG,主要颜色:['Silver','Gold','Bl

数据集和管道

fate 工作建议(在FateGrandOrder上调整多任务Pytorch网络)(1)

性别:女,地区:亚洲,格斗风格:近战,阵营:CG,主要颜色:['红色','黑色','白色']

该项目的数据集与我之前基于Keras的多任务学习帖相同,它包含来自手机游戏Fate Grand Order(FGO)的大约400个角色图像。数据集由大约40个不同的字符和26个不同的标签组成,以创建多标签样式数据集。

这些类别涵盖了角色的性别、所属地区、战斗风格、图像的主要颜色以及对应的角色。

我必须做的唯一其他真正的修改是定制一个自定义Pytorch数据集类,它接受一系列列表并输出一个图像,5个目标引导模型。 Pytorch可以轻松获取数据集类并根据需要对其进行修改。通常只编写自己的init,getitem和len函数。我的大多数修改都出现在getitem部分,我在其中指定如何读取图像并从目标列表列表中获取相应的目标,我将其称为“king_of_lists”。

fate 工作建议(在FateGrandOrder上调整多任务Pytorch网络)(2)

我喜欢彩色代码与您可以添加到博客文章中的灰色代码片段

fate 工作建议(在FateGrandOrder上调整多任务Pytorch网络)(3)

性别:女,地区:欧洲,格斗风格:近战,阵营:NE

构建基本Pytorch模型

对于这个,我的第一步是从一个非常基本的模型开始。我开始制作一个模型,将Resnet50作为主干,将其输出提供给5个输出头。

通过初始化需要在init部分中优化的图层和事物,在Pytorch中构建自定义网络非常简单。然后,您可以定义数据在前向部分中如何流经模型。对于这个用例,我真正做的是初始化该核心模型(resnet50),然后将输出输入我创建的5个头(y1o、y2o、y3o、y4o、y5o)。这些是模型的输出,而不是您通常会看到的标准单输出。

fate 工作建议(在FateGrandOrder上调整多任务Pytorch网络)(4)

第一款只有Resnet50和5个输出头的型号

要查看训练循环,请随时查看笔记本(此处)。主要是我修改了标准的Pytorch网络。然而,有趣的部分是如何计算损失函数。我认为这实际上会更复杂,但它实际上非常简单,基本上我只计算每个的损失(loss0,loss1,loss2,loss3,loss4),将它们加在一起然后用它来进行反向传播。

fate 工作建议(在FateGrandOrder上调整多任务Pytorch网络)(5)

我使用Adam优化器训练了50个epoch的基本模型,学习率为.0001,随着时间的推移衰减并保存了具有最佳验证损失的模型。

fate 工作建议(在FateGrandOrder上调整多任务Pytorch网络)(6)

总的来说并不可怕,但也不是很好。但总体而言,它的表现优于我之前发布的基于Keras VGG的网络。

fate 工作建议(在FateGrandOrder上调整多任务Pytorch网络)(7)

基本Pytorch模型通常优于Keras模型。此轮中的Resnet正在数据集上进行微调,这是有道理的。

fate 工作建议(在FateGrandOrder上调整多任务Pytorch网络)(8)

性别:男,地区:中东,战斗风格:魔术,阵营:LG,主要颜色:['白色','蓝色','紫色']

这是一个不错的开始,但我认为我可以改善模型的全面性能。虽然我认为Pytorch模型仍然不合适。为此,我添加了两个更为密集的256层,并将其输入模型中。以下代码段有修改,基本上只需添加两个大小为256的x1和x2层。

fate 工作建议(在FateGrandOrder上调整多任务Pytorch网络)(9)

增加了两层x1和x2以获得额外的火力

在训练这个新模型后,与我建立的基础Keras和Pytorch模型相比,训练和验证准确性有所提高,整体性能更好。

fate 工作建议(在FateGrandOrder上调整多任务Pytorch网络)(10)

一方面,我们可以就此打住,但我在这里注意到的另一件事是,虽然模型总体上表现更好,但现在已经过度拟合了训练集。请参阅下面的第二个模型的分数。

fate 工作建议(在FateGrandOrder上调整多任务Pytorch网络)(11)

既然它是过度拟合的,我想我可以添加一些正则化来试图抵消它。这需要一些修补,但我发现在这种情况下,相对较高水平的批量标准化是有用的。在这次运行中,我最终使用了2e-1。

fate 工作建议(在FateGrandOrder上调整多任务Pytorch网络)(12)

添加bn1和bn2作为正则化

在第三轮之后,增加批量标准化的模型显示出比先前最佳对准精度大约10个百分点的增加,我认为这是最难的类别,并且战斗风格增加了5个百分点。然而,性别和颜色则成下降趋势。它与特定角色的原始区域相关联。总的来说,我会说这部分的成功是喜忧参半的,但在难度较大的领域确实有所帮助。

fate 工作建议(在FateGrandOrder上调整多任务Pytorch网络)(13)

fate 工作建议(在FateGrandOrder上调整多任务Pytorch网络)(14)

值得注意的是,虽然我添加了批量规范来尝试减少过度拟合,但是训练和验证之间的差距与以前类似......模型的预测似乎比以前更好地推广,这也是添加正则化的目标结果之一。

fate 工作建议(在FateGrandOrder上调整多任务Pytorch网络)(15)

性别:女,地区:中东,战斗风格:远程,阵营:NE,主要颜色:['蓝色','金色','黑色']

结论和结束

我认为这个简单的调优过程是一个很好的指示器,说明您在普通网络上使用的策略仍然可以应用于多任务模型。这方面的一个难点是难以针对多任务模型中的特定缺陷,现在我只针对更大的总体问题(所有节点的过拟合和欠拟合)。在这些头上添加额外的图层和节点是一种选择,但随后成为您需要调整的其他超参数。

因此,最后两个模型的表现相当类似,基础更深的网络在颜色上表现更好,在性别上略有差异,在区域上表现得更好,而批处理规范化模型在战斗风格和对齐方面表现得更好。

fate 工作建议(在FateGrandOrder上调整多任务Pytorch网络)(16)

这引发了选择什么模型的问题。您可以计算像Fbeta这样的东西来处理它,以尝试在所有不同的任务中制定组合指标。如果您的目标是拥有单一的最佳模型,那么这个单一的度量标准就有意义了。

如果您愿意使用多个模型,则另一个选择是采用性能良好的模型,并将它们集成起来进行预测。这将允许您利用每个不同模型中性能更好的区域。我认为在这种情况下可行是可行的,一种模型在对齐类别上做得更好,而第二种模型在很多不同类别中做得更好。

在另一种情况下,您的任务执行得较差,在这种情况下,color的效果不是很好,您可以在集成中添加一些专门的模型,以尝试提高这些领域的性能。

fate 工作建议(在FateGrandOrder上调整多任务Pytorch网络)(17)

性别:女,地区:欧洲,格斗风格:魔术,阵营:LG,主要颜色:['蓝色','白色','银色']

fate 工作建议(在FateGrandOrder上调整多任务Pytorch网络)(18)

编译出品

,