在上一篇文章中:44.人工智能——深度学习飞桨框架自定义数据集,定义好了数据集,本文就根据数据集定义模型,实现时间和天气多分类预测。

#定义模型 #这里以resnet50为分类模型 from paddle.vision.models import resnet50 import numpy as np class PWModel(paddle.nn.Layer): def __init__(self): super(PWModel, self).__init__() #定义骨干网resnet50,预训练权重为True backbone = resnet50(pretrained=True) backbone.fc=paddle.nn.Identity() #骨干网的全连接层保持一致性 self.backbone = backbone #有两个全连接层 #时间分类 self.fc1 = paddle.nn.Linear(in_features=2048, out_features=4) #天气分类 self.fc2 = paddle.nn.Linear(in_features=2048, out_features=3) #前向计算 def forward(self, x): x = self.backbone(x) #同时完成时间和天气分类 period = self.fc1(x) weather = self.fc2(x) return period, weather

#查看模型结构 model=paddle.Model(WeatherModel()) model.summary((1,3,256,256))

显示结果:

------------------------------------------------------------------------------- Layer (type) Input Shape Output Shape Param # =============================================================================== Conv2D-180 [[1, 3, 256, 256]] [1, 64, 128, 128] 9,408 BatchNorm2D-180 [[1, 64, 128, 128]] [1, 64, 128, 128] 256 ReLU-61 [[1, 64, 128, 128]] [1, 64, 128, 128] 0 MaxPool2D-5 [[1, 64, 128, 128]] [1, 64, 64, 64] 0 Conv2D-182 [[1, 64, 64, 64]] [1, 64, 64, 64] 4,096 BatchNorm2D-182 [[1, 64, 64, 64]] [1, 64, 64, 64] 256 ReLU-62 [[1, 256, 64, 64]] [1, 256, 64, 64] 0 Conv2D-183 [[1, 64, 64, 64]] [1, 64, 64, 64] 36,864 BatchNorm2D-183 [[1, 64, 64, 64]] [1, 64, 64, 64] 256 Conv2D-184 [[1, 64, 64, 64]] [1, 256, 64, 64] 16,384 BatchNorm2D-184 [[1, 256, 64, 64]] [1, 256, 64, 64] 1,024 Conv2D-181 [[1, 64, 64, 64]] [1, 256, 64, 64] 16,384 BatchNorm2D-181 [[1, 256, 64, 64]] [1, 256, 64, 64] 1,024 BottleneckBlock-49 [[1, 64, 64, 64]] [1, 256, 64, 64] 0 Conv2D-185 [[1, 256, 64, 64]] [1, 64, 64, 64] 16,384 BatchNorm2D-185 [[1, 64, 64, 64]] [1, 64, 64, 64] 256 ReLU-63 [[1, 256, 64, 64]] [1, 256, 64, 64] 0 Conv2D-186 [[1, 64, 64, 64]] [1, 64, 64, 64] 36,864 BatchNorm2D-186 [[1, 64, 64, 64]] [1, 64, 64, 64] 256 Conv2D-187 [[1, 64, 64, 64]] [1, 256, 64, 64] 16,384 BatchNorm2D-187 [[1, 256, 64, 64]] [1, 256, 64, 64] 1,024 BottleneckBlock-50 [[1, 256, 64, 64]] [1, 256, 64, 64] 0 Conv2D-188 [[1, 256, 64, 64]] [1, 64, 64, 64] 16,384 BatchNorm2D-188 [[1, 64, 64, 64]] [1, 64, 64, 64] 256 ReLU-64 [[1, 256, 64, 64]] [1, 256, 64, 64] 0 Conv2D-189 [[1, 64, 64, 64]] [1, 64, 64, 64] 36,864 BatchNorm2D-189 [[1, 64, 64, 64]] [1, 64, 64, 64] 256 Conv2D-190 [[1, 64, 64, 64]] [1, 256, 64, 64] 16,384 BatchNorm2D-190 [[1, 256, 64, 64]] [1, 256, 64, 64] 1,024 BottleneckBlock-51 [[1, 256, 64, 64]] [1, 256, 64, 64] 0 Conv2D-192 [[1, 256, 64, 64]] [1, 128, 64, 64] 32,768 BatchNorm2D-192 [[1, 128, 64, 64]] [1, 128, 64, 64] 512 ReLU-65 [[1, 512, 32, 32]] [1, 512, 32, 32] 0 Conv2D-193 [[1, 128, 64, 64]] [1, 128, 32, 32] 147,456 BatchNorm2D-193 [[1, 128, 32, 32]] [1, 128, 32, 32] 512 Conv2D-194 [[1, 128, 32, 32]] [1, 512, 32, 32] 65,536 BatchNorm2D-194 [[1, 512, 32, 32]] [1, 512, 32, 32] 2,048 Conv2D-191 [[1, 256, 64, 64]] [1, 512, 32, 32] 131,072 BatchNorm2D-191 [[1, 512, 32, 32]] [1, 512, 32, 32] 2,048 BottleneckBlock-52 [[1, 256, 64, 64]] [1, 512, 32, 32] 0 Conv2D-195 [[1, 512, 32, 32]] [1, 128, 32, 32] 65,536 BatchNorm2D-195 [[1, 128, 32, 32]] [1, 128, 32, 32] 512 ReLU-66 [[1, 512, 32, 32]] [1, 512, 32, 32] 0 Conv2D-196 [[1, 128, 32, 32]] [1, 128, 32, 32] 147,456 BatchNorm2D-196 [[1, 128, 32, 32]] [1, 128, 32, 32] 512 Conv2D-197 [[1, 128, 32, 32]] [1, 512, 32, 32] 65,536 BatchNorm2D-197 [[1, 512, 32, 32]] [1, 512, 32, 32] 2,048 BottleneckBlock-53 [[1, 512, 32, 32]] [1, 512, 32, 32] 0 Conv2D-198 [[1, 512, 32, 32]] [1, 128, 32, 32] 65,536 BatchNorm2D-198 [[1, 128, 32, 32]] [1, 128, 32, 32] 512 ReLU-67 [[1, 512, 32, 32]] [1, 512, 32, 32] 0 Conv2D-199 [[1, 128, 32, 32]] [1, 128, 32, 32] 147,456 BatchNorm2D-199 [[1, 128, 32, 32]] [1, 128, 32, 32] 512 Conv2D-200 [[1, 128, 32, 32]] [1, 512, 32, 32] 65,536 BatchNorm2D-200 [[1, 512, 32, 32]] [1, 512, 32, 32] 2,048 BottleneckBlock-54 [[1, 512, 32, 32]] [1, 512, 32, 32] 0 Conv2D-201 [[1, 512, 32, 32]] [1, 128, 32, 32] 65,536 BatchNorm2D-201 [[1, 128, 32, 32]] [1, 128, 32, 32] 512 ReLU-68 [[1, 512, 32, 32]] [1, 512, 32, 32] 0 Conv2D-202 [[1, 128, 32, 32]] [1, 128, 32, 32] 147,456 BatchNorm2D-202 [[1, 128, 32, 32]] [1, 128, 32, 32] 512 Conv2D-203 [[1, 128, 32, 32]] [1, 512, 32, 32] 65,536 BatchNorm2D-203 [[1, 512, 32, 32]] [1, 512, 32, 32] 2,048 BottleneckBlock-55 [[1, 512, 32, 32]] [1, 512, 32, 32] 0 Conv2D-205 [[1, 512, 32, 32]] [1, 256, 32, 32] 131,072 BatchNorm2D-205 [[1, 256, 32, 32]] [1, 256, 32, 32] 1,024 ReLU-69 [[1, 1024, 16, 16]] [1, 1024, 16, 16] 0 Conv2D-206 [[1, 256, 32, 32]] [1, 256, 16, 16] 589,824 BatchNorm2D-206 [[1, 256, 16, 16]] [1, 256, 16, 16] 1,024 Conv2D-207 [[1, 256, 16, 16]] [1, 1024, 16, 16] 262,144 BatchNorm2D-207 [[1, 1024, 16, 16]] [1, 1024, 16, 16] 4,096 Conv2D-204 [[1, 512, 32, 32]] [1, 1024, 16, 16] 524,288 BatchNorm2D-204 [[1, 1024, 16, 16]] [1, 1024, 16, 16] 4,096 BottleneckBlock-56 [[1, 512, 32, 32]] [1, 1024, 16, 16] 0 Conv2D-208 [[1, 1024, 16, 16]] [1, 256, 16, 16] 262,144 BatchNorm2D-208 [[1, 256, 16, 16]] [1, 256, 16, 16] 1,024 ReLU-70 [[1, 1024, 16, 16]] [1, 1024, 16, 16] 0 Conv2D-209 [[1, 256, 16, 16]] [1, 256, 16, 16] 589,824 BatchNorm2D-209 [[1, 256, 16, 16]] [1, 256, 16, 16] 1,024 Conv2D-210 [[1, 256, 16, 16]] [1, 1024, 16, 16] 262,144 BatchNorm2D-210 [[1, 1024, 16, 16]] [1, 1024, 16, 16] 4,096 BottleneckBlock-57 [[1, 1024, 16, 16]] [1, 1024, 16, 16] 0 Conv2D-211 [[1, 1024, 16, 16]] [1, 256, 16, 16] 262,144 BatchNorm2D-211 [[1, 256, 16, 16]] [1, 256, 16, 16] 1,024 ReLU-71 [[1, 1024, 16, 16]] [1, 1024, 16, 16] 0 Conv2D-212 [[1, 256, 16, 16]] [1, 256, 16, 16] 589,824 BatchNorm2D-212 [[1, 256, 16, 16]] [1, 256, 16, 16] 1,024 Conv2D-213 [[1, 256, 16, 16]] [1, 1024, 16, 16] 262,144 BatchNorm2D-213 [[1, 1024, 16, 16]] [1, 1024, 16, 16] 4,096 BottleneckBlock-58 [[1, 1024, 16, 16]] [1, 1024, 16, 16] 0 Conv2D-214 [[1, 1024, 16, 16]] [1, 256, 16, 16] 262,144 BatchNorm2D-214 [[1, 256, 16, 16]] [1, 256, 16, 16] 1,024 ReLU-72 [[1, 1024, 16, 16]] [1, 1024, 16, 16] 0 Conv2D-215 [[1, 256, 16, 16]] [1, 256, 16, 16] 589,824 BatchNorm2D-215 [[1, 256, 16, 16]] [1, 256, 16, 16] 1,024 Conv2D-216 [[1, 256, 16, 16]] [1, 1024, 16, 16] 262,144 BatchNorm2D-216 [[1, 1024, 16, 16]] [1, 1024, 16, 16] 4,096 BottleneckBlock-59 [[1, 1024, 16, 16]] [1, 1024, 16, 16] 0 Conv2D-217 [[1, 1024, 16, 16]] [1, 256, 16, 16] 262,144 BatchNorm2D-217 [[1, 256, 16, 16]] [1, 256, 16, 16] 1,024 ReLU-73 [[1, 1024, 16, 16]] [1, 1024, 16, 16] 0 Conv2D-218 [[1, 256, 16, 16]] [1, 256, 16, 16] 589,824 BatchNorm2D-218 [[1, 256, 16, 16]] [1, 256, 16, 16] 1,024 Conv2D-219 [[1, 256, 16, 16]] [1, 1024, 16, 16] 262,144 BatchNorm2D-219 [[1, 1024, 16, 16]] [1, 1024, 16, 16] 4,096 BottleneckBlock-60 [[1, 1024, 16, 16]] [1, 1024, 16, 16] 0 Conv2D-220 [[1, 1024, 16, 16]] [1, 256, 16, 16] 262,144 BatchNorm2D-220 [[1, 256, 16, 16]] [1, 256, 16, 16] 1,024 ReLU-74 [[1, 1024, 16, 16]] [1, 1024, 16, 16] 0 Conv2D-221 [[1, 256, 16, 16]] [1, 256, 16, 16] 589,824 BatchNorm2D-221 [[1, 256, 16, 16]] [1, 256, 16, 16] 1,024 Conv2D-222 [[1, 256, 16, 16]] [1, 1024, 16, 16] 262,144 BatchNorm2D-222 [[1, 1024, 16, 16]] [1, 1024, 16, 16] 4,096 BottleneckBlock-61 [[1, 1024, 16, 16]] [1, 1024, 16, 16] 0 Conv2D-224 [[1, 1024, 16, 16]] [1, 512, 16, 16] 524,288 BatchNorm2D-224 [[1, 512, 16, 16]] [1, 512, 16, 16] 2,048 ReLU-75 [[1, 2048, 8, 8]] [1, 2048, 8, 8] 0 Conv2D-225 [[1, 512, 16, 16]] [1, 512, 8, 8] 2,359,296 BatchNorm2D-225 [[1, 512, 8, 8]] [1, 512, 8, 8] 2,048 Conv2D-226 [[1, 512, 8, 8]] [1, 2048, 8, 8] 1,048,576 BatchNorm2D-226 [[1, 2048, 8, 8]] [1, 2048, 8, 8] 8,192 Conv2D-223 [[1, 1024, 16, 16]] [1, 2048, 8, 8] 2,097,152 BatchNorm2D-223 [[1, 2048, 8, 8]] [1, 2048, 8, 8] 8,192 BottleneckBlock-62 [[1, 1024, 16, 16]] [1, 2048, 8, 8] 0 Conv2D-227 [[1, 2048, 8, 8]] [1, 512, 8, 8] 1,048,576 BatchNorm2D-227 [[1, 512, 8, 8]] [1, 512, 8, 8] 2,048 ReLU-76 [[1, 2048, 8, 8]] [1, 2048, 8, 8] 0 Conv2D-228 [[1, 512, 8, 8]] [1, 512, 8, 8] 2,359,296 BatchNorm2D-228 [[1, 512, 8, 8]] [1, 512, 8, 8] 2,048 Conv2D-229 [[1, 512, 8, 8]] [1, 2048, 8, 8] 1,048,576 BatchNorm2D-229 [[1, 2048, 8, 8]] [1, 2048, 8, 8] 8,192 BottleneckBlock-63 [[1, 2048, 8, 8]] [1, 2048, 8, 8] 0 Conv2D-230 [[1, 2048, 8, 8]] [1, 512, 8, 8] 1,048,576 BatchNorm2D-230 [[1, 512, 8, 8]] [1, 512, 8, 8] 2,048 ReLU-77 [[1, 2048, 8, 8]] [1, 2048, 8, 8] 0 Conv2D-231 [[1, 512, 8, 8]] [1, 512, 8, 8] 2,359,296 BatchNorm2D-231 [[1, 512, 8, 8]] [1, 512, 8, 8] 2,048 Conv2D-232 [[1, 512, 8, 8]] [1, 2048, 8, 8] 1,048,576 BatchNorm2D-232 [[1, 2048, 8, 8]] [1, 2048, 8, 8] 8,192 BottleneckBlock-64 [[1, 2048, 8, 8]] [1, 2048, 8, 8] 0 AdaptiveAvgPool2D-5 [[1, 2048, 8, 8]] [1, 2048, 1, 1] 0 Identity-5 [[1, 2048]] [1, 2048] 0 ResNet-5 [[1, 3, 256, 256]] [1, 2048] 0 Linear-14 [[1, 2048]] [1, 4] 8,196 Linear-15 [[1, 2048]] [1, 3] 6,147 =============================================================================== Total params: 23,575,495 Trainable params: 23,469,255 Non-trainable params: 106,240 ------------------------------------------------------------------------------- Input size (MB): 0.75 Forward/backward pass size (MB): 341.55 Params size (MB): 89.93 Estimated Total Size (MB): 432.23 ------------------------------------------------------------------------------- {'total_params': 23575495, 'trainable_params': 23469255}

可以看到在Linear-14和Linear-15两个全连接层输出两个分类。

resnet模型原理(45.人工智能以ResNet为backbone的多标签分类模型搭建)(1)

resnet网络结构

,