引导你自己潜在的自我监督学习的新方法
论文链接:https://arxiv.org/abs/2006.07733
此库为BYOL自监督学习的原理性复现代码,使用最简单易读的方式,编写,没有使用复杂的函数调用。总计两百余行代码。完全按照算法顺序编写。并给出了,网络训练好以后的冻结网络参数,续接网络层,继续训练几轮的测试代码。
该库仅仅是对其方法的介绍性复现,可能无法到达论文介绍精度。如果需要进一步使用,需要在读懂原理基础上,更进一步优化代码,训练、测试和使用。
网络更加庞大,难以训练,需要大量的标记数据来监督训练,成本过高。所以需要一种自监督学习,来训练网络,使网络更加泛化。
输入一张图片(input image),记作
再将图片
相同的图片
将参数为
我们想要网络projection空间中的,$z_{\theta}$ 和
由此优化在线网络( online network )的参数
而目标网络(target network)的参数
系数
算法流程:
在线网络,和目标网络,使用resnet18
online_net = ResNet18() # 实例化online(在线网络)
target_net = ResNet18() # 实例化target(目标网络)
投射网络
from torch import nn
class MLP(nn.Module):
"""
预测网络, 将在在线网络的输出投射至另一空间来预测目标网络的输出
"""
def __init__(self, in_features, hidden_features, projection_features):
"""
预测网络
:param in_features: 输入特征数
:param hidden_features: 隐藏特征数
:param projection_features: 投影特征数
"""
super(MLP, self).__init__()
self.layer = nn.Sequential(
nn.Linear(in_features, hidden_features),
nn.BatchNorm1d(hidden_features),
nn.ReLU(inplace=True),
nn.Linear(hidden_features, projection_features),
)
def forward(self, x):
return self.layer(x)
数据集使用CIFAR10
对图像的数据增强方式
class TransformsSimCLR:
"""
一种随机数据扩充模块,它对任意给定的数据实例进行随机转换,
得到同一实例的两个相关视图,
记为x̃i和x̃j,我们认为这是一个正对。
"""
def __init__(self, size, train=True):
"""
:param size:图片尺寸
"""
s = 1
color_jitter = torchvision.transforms.ColorJitter(
0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s
)
self.train_transform = torchvision.transforms.Compose(
[
torchvision.transforms.RandomResizedCrop(size=size),
torchvision.transforms.RandomHorizontalFlip(), # with 0.5 probability
torchvision.transforms.RandomApply([color_jitter], p=0.8),
torchvision.transforms.RandomGrayscale(p=0.2),
torchvision.transforms.ToTensor(),
]
)
self.test_transform = torchvision.transforms.Compose(
[
torchvision.transforms.Resize(size=size),
torchvision.transforms.ToTensor(),
]
)
self.train = train
def __call__(self, x):
"""
:param x: 图片
:return: x̃i和x̃j,即 v、v'
"""
if self.train:
return self.train_transform(x), self.train_transform(x)
else:
return self.test_transform(x)
训练过程中,将同一张图片,经过随机数据增强得到两张不同的图片,分别输入在线网络和目标网络,得到各自得projection,同时,对于目标网络,不追踪梯度。
online_projection_one = online_net(x_i)
with torch.no_grad():
target_projection_one = target_net(x_j)
然后将在线网络输出得projection,经过prediton的变换,与目标网络的projection做损失:
prediction = MLP(in_features=1000, hidden_features=2048, projection_features=1000)
loss_one = loss_function(prediction(online_projection_one), target_projection_one.detach())
损失函数为:
def loss_function(predict, target):
"""
损失函数,比较余弦相似度。归一化的欧氏距离等价于余弦相似度
:param predict: online net输出的prediction
:param target: target网络输出的projection
:return: loss(损失)
"""
return 2-2*torch.cosine_similarity(predict, target, dim=-1)
由此,优化参数
而目标网络参数的更新,根据在线网络更新:
for target_parameter, online_parameter in zip(target_net.parameters(), online_net.parameters()):
old_weight = target_parameter.data
update = online_parameter.data
target_parameter.data = old_weight * tau + (1 - tau) * update
系数
将训练得到的在线网络的权重保存
使用过程,就是冻结网络参数,后续接上一层网络,训练微调一下,即可使用。
Python >= 3.6.0 要求安装 requirements.txt 所有依赖项:
$ pip install -r requirements.txt
运行代码,将自动下载CIFAR10数据集,经行训练
$ python train.py
运行代码,将使用训练好的权重,并不优化网络权重,额外训练一层线性分类层,测试该无监督学习在CIFAR10数据集中测试集中的表现。
$ python test.py