From 6799b2c97a76079d78a4f1fadd53cd7fb5c1bede Mon Sep 17 00:00:00 2001 From: zhaoqi10 Date: Thu, 23 Mar 2023 18:10:42 +0800 Subject: [PATCH 1/9] add mocov3 --- passl/modeling/backbones/moco_vit.py | 125 ++++++++++++++++++ passl/modeling/init.py | 188 +++++++++++++++++++++++++++ 2 files changed, 313 insertions(+) create mode 100644 passl/modeling/backbones/moco_vit.py create mode 100644 passl/modeling/init.py diff --git a/passl/modeling/backbones/moco_vit.py b/passl/modeling/backbones/moco_vit.py new file mode 100644 index 00000000..f48d82e2 --- /dev/null +++ b/passl/modeling/backbones/moco_vit.py @@ -0,0 +1,125 @@ +import math +import paddle +import paddle.nn as nn +from functools import partial, reduce +from operator import mul +from .builder import BACKBONES + +from .. import init +from vision_transformer import VisionTransformer, PatchEmbed, to_2tuple + + +@BACKBONES.register() +class VisionTransformerMoCo(VisionTransformer): + def __init__(self, stop_grad_conv1=False, **kwargs): + super().__init__(**kwargs) + # Use fixed 2D sin-cos position embedding + self.build_2d_sincos_position_embedding() + + # weight initialization + for name, m in self.named_sublayers(): + if isinstance(m, nn.Linear): + if 'qkv' in name: + # treat the weights of Q, K, V separately + val = math.sqrt( + 6. / float(m.weight.shape[1] // 3 + m.weight.shape[0])) + init.uniform_(m.weight, -val, val) + else: + init.xavier_uniform_(m.weight) + init.zeros_(m.bias) + init.normal_(self.cls_token, std=1e-6) + + if isinstance(self.patch_embed, PatchEmbed): + # xavier_uniform initialization + val = math.sqrt(6. / float(3 * reduce( + mul, self.patch_embed.patch_size, 1) + self.embed_dim)) + init.uniform_(self.patch_embed.proj.weight, -val, val) + init.zeros_(self.patch_embed.proj.bias) + + if stop_grad_conv1: + self.patch_embed.proj.weight.stop_gradient = True + self.patch_embed.proj.bias.stop_gradient = True + + def build_2d_sincos_position_embedding(self, temperature=10000.): + h = self.patch_embed.img_size[0] // self.patch_embed.patch_size[0] + w = self.patch_embed.img_size[1] // self.patch_embed.patch_size[1] + grid_w = paddle.arange(w, dtype=paddle.float32) + grid_h = paddle.arange(h, dtype=paddle.float32) + grid_w, grid_h = paddle.meshgrid(grid_w, grid_h) + assert self.embed_dim % 4 == 0, 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding' + pos_dim = self.embed_dim // 4 + omega = paddle.arange(pos_dim, dtype=paddle.float32) / pos_dim + omega = 1. / (temperature**omega) + + out_w = grid_w.flatten()[..., None] @omega[None] + out_h = grid_h.flatten()[..., None] @omega[None] + pos_emb = paddle.concat( + [ + paddle.sin(out_w), paddle.cos(out_w), paddle.sin(out_h), + paddle.cos(out_h) + ], + axis=1)[None, :, :] + pe_token = paddle.zeros([1, 1, self.embed_dim], dtype=paddle.float32) + + pos_embed = paddle.concat([pe_token, pos_emb], axis=1) + self.pos_embed = self.create_parameter(shape=pos_embed.shape) + self.pos_embed.set_value(pos_embed) + self.pos_embed.stop_gradient = True + + +class ConvStem(nn.Layer): + """ + ConvStem, from Early Convolutions Help Transformers See Better, Tete et al. https://arxiv.org/abs/2106.14881 + """ + + def __init__(self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + norm_layer=None, + flatten=True): + super().__init__() + + assert patch_size == 16, 'ConvStem only supports patch size of 16' + assert embed_dim % 8 == 0, 'Embed dimension must be divisible by 8 for ConvStem' + + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.grid_size = (img_size[0] // patch_size[0], + img_size[1] // patch_size[1]) + self.num_patches = self.grid_size[0] * self.grid_size[1] + self.flatten = flatten + + # build stem, similar to the design in https://arxiv.org/abs/2106.14881 + stem = [] + input_dim, output_dim = 3, embed_dim // 8 + for l in range(4): + stem.append( + nn.Conv2D( + input_dim, + output_dim, + kernel_size=3, + stride=2, + padding=1, + bias_attr=False)) + stem.append(nn.BatchNorm2D(output_dim)) + stem.append(nn.ReLU()) + input_dim = output_dim + output_dim *= 2 + stem.append(nn.Conv2D(input_dim, embed_dim, kernel_size=1)) + self.proj = nn.Sequential(*stem) + + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x): + B, C, H, W = x.shape + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x) + if self.flatten: + x = x.flatten(2).transpose((0, 2, 1)) # BCHW -> BNC + x = self.norm(x) + return x \ No newline at end of file diff --git a/passl/modeling/init.py b/passl/modeling/init.py new file mode 100644 index 00000000..92aa94fe --- /dev/null +++ b/passl/modeling/init.py @@ -0,0 +1,188 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import warnings +import paddle + + +@paddle.no_grad() +def zeros_(x): + return x.zero_() + + +@paddle.no_grad() +def ones_(x): + return x.fill_(1.0) + + +@paddle.no_grad() +def constant_(x, value): + return x.fill_(value) + + +@paddle.no_grad() +def normal_(x, mean=0., std=1.): + temp_value = paddle.tensor.random.gaussian( + shape=x.shape, mean=mean, std=std, dtype=x.dtype) + x.copy_(temp_value, False) + return x + + +@paddle.no_grad() +def uniform_(x, a=0., b=1.): + temp_value = paddle.tensor.random.uniform( + shape=x.shape, min=a, max=b, dtype=x.dtype) + x.copy_(temp_value, False) + return x + + +def _calculate_fan_in_and_fan_out(tensor): + dimensions = tensor.dim() + if dimensions < 2: + raise ValueError( + "Fan in and fan out can not be computed for tensor with fewer than 2 dimensions" + ) + + num_input_fmaps = tensor.shape[1] + num_output_fmaps = tensor.shape[0] + receptive_field_size = 1 + if tensor.dim() > 2: + # math.prod is not always available, accumulate the product manually + # we could use functools.reduce but that is not supported by TorchScript + for s in tensor.shape[2:]: + receptive_field_size *= s + fan_in = num_input_fmaps * receptive_field_size + fan_out = num_output_fmaps * receptive_field_size + + return fan_in, fan_out + + +def _calculate_correct_fan(tensor, mode): + mode = mode.lower() + valid_modes = ['fan_in', 'fan_out'] + if mode not in valid_modes: + raise ValueError("Mode {} not supported, please use one of {}".format( + mode, valid_modes)) + + fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) + return fan_in if mode == 'fan_in' else fan_out + + +def calculate_gain(nonlinearity, param=None): + linear_fns = [ + 'linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', + 'conv_transpose2d', 'conv_transpose3d' + ] + if nonlinearity in linear_fns or nonlinearity == 'sigmoid': + return 1 + elif nonlinearity == 'tanh': + return 5.0 / 3 + elif nonlinearity == 'relu': + return math.sqrt(2.0) + elif nonlinearity == 'leaky_relu': + if param is None: + negative_slope = 0.01 + elif not isinstance(param, bool) and isinstance( + param, int) or isinstance(param, float): + # True/False are instances of int, hence check above + negative_slope = param + else: + raise ValueError("negative_slope {} not a valid number".format( + param)) + return math.sqrt(2.0 / (1 + negative_slope**2)) + elif nonlinearity == 'selu': + return 3.0 / 4 # Value found empirically (https://github.com/pytorch/pytorch/pull/50664) + else: + raise ValueError("Unsupported nonlinearity {}".format(nonlinearity)) + + +@paddle.no_grad() +def kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu'): + if 0 in tensor.shape: + warnings.warn("Initializing zero-element tensors is a no-op") + return tensor + fan = _calculate_correct_fan(tensor, mode) + gain = calculate_gain(nonlinearity, a) + std = gain / math.sqrt(fan) + bound = math.sqrt( + 3.0) * std # Calculate uniform bounds from standard deviation + return uniform_(tensor, -bound, bound) + + +@paddle.no_grad() +def kaiming_normal_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu'): + if 0 in tensor.shape: + warnings.warn("Initializing zero-element tensors is a no-op") + return tensor + fan = _calculate_correct_fan(tensor, mode) + gain = calculate_gain(nonlinearity, a) + std = gain / math.sqrt(fan) + normal_(tensor, 0, std) + + +@paddle.no_grad() +def xavier_uniform_(tensor, gain=1.): + fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) + std = gain * math.sqrt(2.0 / float(fan_in + fan_out)) + a = math.sqrt( + 3.0) * std # Calculate uniform bounds from standard deviation + return uniform_(tensor, -a, a) + + +@paddle.no_grad() +def xavier_normal_(tensor, gain=1.): + fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) + std = gain * math.sqrt(2.0 / float(fan_in + fan_out)) + return normal_(tensor, 0., std) + + +@paddle.no_grad() +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + # Modified based on PyTorch nn.init.trunc_normal_ + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn( + "mean is more than 2 std from [a, b] in trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2) + + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tmp = paddle.zeros_like(tensor, dtype='float32') + tmp.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tmp.erfinv_() + + # Transform to proper mean, std + tmp.scale_(std * math.sqrt(2.)) + tmp.add_(paddle.to_tensor(mean, dtype='float32')) + + # Clip to ensure it's in the proper range + tmp.clip_(min=a, max=b) + tmp = tmp.astype(tensor.dtype) + tensor.copy_(tmp, False) + return tensor From 00123e03aabea411f49f0363feab3427079fcaf1 Mon Sep 17 00:00:00 2001 From: zhaoqi10 Date: Fri, 24 Mar 2023 22:06:17 +0800 Subject: [PATCH 2/9] fix config and fluid api. --- configs/simclr/simclr_r50_IM.yaml | 9 +- passl/modeling/architectures/BYOL.py | 137 ++++++++++-------- passl/modeling/architectures/simclr.py | 26 +--- .../modeling/heads/simclr_contrastive_head.py | 30 ++-- passl/modeling/necks/base_neck.py | 108 ++++++++------ 5 files changed, 169 insertions(+), 141 deletions(-) diff --git a/configs/simclr/simclr_r50_IM.yaml b/configs/simclr/simclr_r50_IM.yaml index d555021c..efac4949 100755 --- a/configs/simclr/simclr_r50_IM.yaml +++ b/configs/simclr/simclr_r50_IM.yaml @@ -2,6 +2,7 @@ epochs: 100 use_simclr_iters: True global_batch_size: 4096 output_dir: output_dir +device: gpu model: name: SimCLR @@ -21,7 +22,9 @@ model: dataloader: train: - num_workers: 6 + loader: + num_workers: 6 + use_shared_memory: True sampler: batch_size: 32 shuffle: true @@ -83,7 +86,9 @@ dataloader: std: [0.229, 0.224, 0.225] val: - num_workers: 4 + loader: + num_workers: 4 + use_shared_memory: True sampler: batch_size: 512 shuffle: false diff --git a/passl/modeling/architectures/BYOL.py b/passl/modeling/architectures/BYOL.py index 01959526..37cf06d3 100644 --- a/passl/modeling/architectures/BYOL.py +++ b/passl/modeling/architectures/BYOL.py @@ -33,6 +33,7 @@ import paddle import paddle.fluid.layers as layers + def single_random_gaussian_blur(image, height, width, p=1.0): """Randomly blur an image. Args: @@ -53,22 +54,23 @@ def single_random_gaussian_blur(image, height, width, p=1.0): x = paddle.arange(-radius, radius + 1, 1, "float32") blur_filter = paddle.exp(-paddle.pow(x, 2.0) / (2.0 * paddle.pow(sigma, 2.0))) - blur_filter /= layers.reduce_sum(blur_filter) - blur_v = layers.reshape(blur_filter, [1, 1, kernel_size, 1]) - blur_h = layers.reshape(blur_filter, [1, 1, 1, kernel_size]) + blur_filter /= layers.nn.reduce_sum(blur_filter) + blur_v = paddle.reshape(blur_filter, [1, 1, kernel_size, 1]) + blur_h = paddle.reshape(blur_filter, [1, 1, 1, kernel_size]) num_channels = 3 blur_h = paddle.tile(blur_h, [num_channels, 1, 1, 1]) blur_v = paddle.tile(blur_v, [num_channels, 1, 1, 1]) - + expand_batch_dim = len(image.shape) == 3 if expand_batch_dim: - image = paddle.unsqueeze(image.transpose((2,0,1)), axis=0) + image = paddle.unsqueeze(image.transpose((2, 0, 1)), axis=0) blurred = paddle.nn.functional.conv2d( - image, blur_h, stride=1, padding=padding,groups=3) + image, blur_h, stride=1, padding=padding, groups=3) blurred = paddle.nn.functional.conv2d( - blurred, blur_v, stride=1, padding=padding,groups=3) - return blurred.transpose((0,2,3,1)) + blurred, blur_v, stride=1, padding=padding, groups=3) + return blurred.transpose((0, 2, 3, 1)) + def random_gaussian_blur(image, height, width, p=1.0): """Randomly blur an image. @@ -82,26 +84,29 @@ def random_gaussian_blur(image, height, width, p=1.0): """ res = [] for i in range(image.shape[0]): - res.append(single_random_gaussian_blur(image[i],height,width,p)) - return paddle.concat(res,axis=0) + res.append(single_random_gaussian_blur(image[i], height, width, p)) + return paddle.concat(res, axis=0) -def random_solarization(img,threshold=0.5): - img = paddle.where(img < threshold, img, 1 -img) + +def random_solarization(img, threshold=0.5): + img = paddle.where(img < threshold, img, 1 - img) return img -def img_normalize(img,mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]): + +def img_normalize(img, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]): mean = paddle.to_tensor(mean, dtype='float32').reshape([1, 1, 1, 3]) std = paddle.to_tensor(std, dtype='float32').reshape([1, 1, 1, 3]) return (img - mean) / std + def to_chw(img): - return img.transpose((0,3,1,2)) + return img.transpose((0, 3, 1, 2)) -def batch_random_blur_solariza_normalize_chw( - view1, - view2, - blur_probability=(1.0,0.1), - solariza_probability=(0.0,0.2) ): + +def batch_random_blur_solariza_normalize_chw(view1, + view2, + blur_probability=(1.0, 0.1), + solariza_probability=(0.0, 0.2)): """Apply efficient batch data transformations. Args: images_list: a list of image tensors. @@ -114,23 +119,23 @@ def batch_random_blur_solariza_normalize_chw( def generate_selector(p, bsz): shape = [bsz, 1, 1, 1] - p_tensor = layers.fill_constant( + p_tensor = paddle.tensor.fill_constant( shape=shape, dtype="float32", value=p) - selector = layers.cast( - layers.less_than( - layers.uniform_random( + selector = paddle.cast( + paddle.less_than( + paddle.uniform( shape=shape, min=0, max=1, dtype="float32"), p_tensor), "float32") return selector - - B,H,W,C = view1.shape + + B, H, W, C = view1.shape img1 = view1 img1_new = random_gaussian_blur(img1, H, W, p=1.0) - selector = generate_selector(blur_probability[0],B) + selector = generate_selector(blur_probability[0], B) img1_blur_res = img1_new * selector + img1 * (1 - selector) - - selector = generate_selector(solariza_probability[0],B) + + selector = generate_selector(solariza_probability[0], B) img1_sola_res = random_solarization(img1_blur_res) img1_sola_res = img1_sola_res * selector + img1_blur_res * (1 - selector) img1_sola_res = paddle.clip(img1_sola_res, min=0., max=1.) @@ -140,24 +145,26 @@ def generate_selector(p, bsz): img2 = view2 img2_new = random_gaussian_blur(img2, H, W, p=1.0) - selector = generate_selector(blur_probability[1],B) + selector = generate_selector(blur_probability[1], B) img2_blur_res = img2_new * selector + img2 * (1 - selector) - - selector = generate_selector(solariza_probability[1],B) + + selector = generate_selector(solariza_probability[1], B) img2_sola_res = random_solarization(img2_blur_res) img2_sola_res = img2_sola_res * selector + img2_blur_res * (1 - selector) img2_sola_res = paddle.clip(img2_sola_res, min=0., max=1.) - img2_sola_res.stop_gradient = True + img2_sola_res.stop_gradient = True img2_tran_res = to_chw(img_normalize(img2_sola_res)) return img1_tran_res, img2_tran_res + @MODELS.register() class BYOL(nn.Layer): """ Build a MoCo model with: a query encoder, a key encoder, and a queue https://arxiv.org/abs/1911.05722 """ + def __init__(self, backbone, neck=None, @@ -169,8 +176,7 @@ def __init__(self, target_decay_method='fixed', target_decay_rate=0.996, align_init_network=True, - use_synch_bn=False - ): + use_synch_bn=False): """ Args: backbone (dict): config of backbone. @@ -184,77 +190,82 @@ def __init__(self, self.towers = nn.LayerList() self.base_m = target_decay_rate self.target_decay_method = target_decay_method - + neck1 = build_neck(neck) neck2 = build_neck(neck) - + self.towers.append(nn.Sequential(build_backbone(backbone), neck1)) self.towers.append(nn.Sequential(build_backbone(backbone), neck2)) self.net_init(self.towers) self.predictor = build_neck(predictor) self.net_init(self.predictor) - self.classifier = nn.Linear(embedding_dim,num_classes) + self.classifier = nn.Linear(embedding_dim, num_classes) self.net_init(self.classifier) self.backbone = self.towers[0][0] # self.neck1 = self.towers[0][1] # TODO IMPORTANT! Explore if the initialization requires to be synchronized - for param_q, param_k in zip(self.towers[0].parameters(),self.towers[1].parameters()): + for param_q, param_k in zip(self.towers[0].parameters(), + self.towers[1].parameters()): param_k.stop_gradient = True if align_init_network: - for param_q, param_k in zip(self.towers[0].parameters(),self.towers[1].parameters()): + for param_q, param_k in zip(self.towers[0].parameters(), + self.towers[1].parameters()): param_k.set_value(param_q) # initialize - + # Convert BatchNorm*d to SyncBatchNorm*d if use_synch_bn: - self.towers[0] = nn.SyncBatchNorm.convert_sync_batchnorm(self.towers[0]) - self.towers[1] = nn.SyncBatchNorm.convert_sync_batchnorm(self.towers[1]) + self.towers[0] = nn.SyncBatchNorm.convert_sync_batchnorm( + self.towers[0]) + self.towers[1] = nn.SyncBatchNorm.convert_sync_batchnorm( + self.towers[1]) #self.predictor = nn.SyncBatchNorm.convert_sync_batchnorm(self.predictor) self.head = build_head(head) - - def net_init(self,network): + + def net_init(self, network): for m in network.sublayers(): if isinstance(m, nn.Conv2D): - init.kaiming_init(m,mode="fan_in",nonlinearity="conv2d") + init.kaiming_init(m, mode="fan_in", nonlinearity="conv2d") if isinstance(m, nn.Conv2D): - init.kaiming_init(m,mode="fan_in",nonlinearity="linear") + init.kaiming_init(m, mode="fan_in", nonlinearity="linear") def train_iter(self, *inputs, **kwargs): - + current_iter = kwargs['current_iter'] - total_iters = kwargs['total_iters'] - + total_iters = kwargs['total_iters'] + if self.target_decay_method == 'cosine': - self.m = 1 - (1-self.base_m) * (1 + math.cos(math.pi*(current_iter-0)/total_iters))/2.0 # 47.0 + self.m = 1 - (1 - self.base_m) * (1 + math.cos(math.pi * ( + current_iter - 0) / total_iters)) / 2.0 # 47.0 elif self.target_decay_method == 'fixed': - self.m = self.base_m # 55.7 + self.m = self.base_m # 55.7 else: raise NotImplementedError # self.update_target_network() img_a, img_b, label = inputs - img_a, img_b = batch_random_blur_solariza_normalize_chw(img_a,img_b) + img_a, img_b = batch_random_blur_solariza_normalize_chw(img_a, img_b) embedding = self.towers[0][0](img_a) online_project_view1 = self.towers[0][1](embedding) online_predict_view1 = self.predictor(online_project_view1) online_project_view2 = self.towers[0](img_b) online_predict_view2 = self.predictor(online_project_view2) - + clone_x = embedding.clone() - clone_x.stop_gradient = True + clone_x.stop_gradient = True classif_out = self.classifier(clone_x.squeeze()) - + with paddle.no_grad(): target_project_view1 = self.towers[1](img_a).clone().detach() target_project_view2 = self.towers[1](img_b).clone().detach() a1 = nn.functional.normalize(online_predict_view1, axis=1) b1 = nn.functional.normalize(target_project_view2, axis=1) - b1.stop_gradient = True + b1.stop_gradient = True a2 = nn.functional.normalize(online_predict_view2, axis=1) b2 = nn.functional.normalize(target_project_view1, axis=1) @@ -286,7 +297,9 @@ def update_target_network(self): def update_target_network_L1(self): for param_q, param_k in zip(self.towers[0].parameters(), self.towers[1].parameters()): - paddle.assign(param_k - (1-self.m)*paddle.sign(param_k-param_q), param_k) + paddle.assign(param_k - + (1 - self.m) * paddle.sign(param_k - param_q), + param_k) param_k.stop_gradient = True # L2 + L1 @@ -294,7 +307,10 @@ def update_target_network_L1(self): def update_target_network_clip(self): for param_q, param_k in zip(self.towers[0].parameters(), self.towers[1].parameters()): - paddle.assign(param_k - (1-self.m) * paddle.clip((param_k - param_q), min=-1.0, max=1.0) , param_k) + paddle.assign( + param_k - (1 - self.m) * paddle.clip( + (param_k - param_q), min=-1.0, max=1.0), + param_k) param_k.stop_gradient = True @paddle.no_grad() @@ -302,5 +318,8 @@ def update_target_network_LN_clip(self): for param_q, param_k in zip(self.towers[0].parameters(), self.towers[1].parameters()): paddle.assign((param_k * self.m + param_q * (1. - self.m)), param_k) - paddle.assign(param_k - (1-self.m) * paddle.clip((param_k - param_q), min=-1.0, max=1.0) , param_k) + paddle.assign( + param_k - (1 - self.m) * paddle.clip( + (param_k - param_q), min=-1.0, max=1.0), + param_k) param_k.stop_gradient = True diff --git a/passl/modeling/architectures/simclr.py b/passl/modeling/architectures/simclr.py index 98d78b3b..f94623b8 100755 --- a/passl/modeling/architectures/simclr.py +++ b/passl/modeling/architectures/simclr.py @@ -23,42 +23,35 @@ import paddle.nn.functional as F import paddle.fluid.layers as layers - LARGE_NUM = 1e9 + @MODELS.register() class SimCLR(nn.Layer): """ Simple image SimCLR. """ - def __init__(self, - backbone, - neck=None, - head=None, - dim=128, - T=0.5): + def __init__(self, backbone, neck=None, head=None, dim=128, T=0.5): super(SimCLR, self).__init__() self.T = T - self.encoder = nn.Sequential(build_backbone(backbone), - build_neck(neck)) - + self.encoder = nn.Sequential(build_backbone(backbone), build_neck(neck)) + self.backbone = self.encoder[0] self.head = build_head(head) - - def train_iter(self, *inputs, **kwargs): img_q, img_k = inputs img_con = [img_q, img_k] img_con = paddle.concat(img_con) con = self.encoder(img_con) - con = layers.l2_normalize(con, -1) - q, k = layers.split(con, num_or_sections=2, dim=0) + con = paddle.nn.functional.normalize(con, axis=-1) + q, k = paddle.split(con, num_or_sections=2, axis=0) outputs = self.head(q, k) - + return outputs + def test_iter(self, *inputs, **kwargs): with paddle.no_grad(): img, label = inputs @@ -76,6 +69,3 @@ def forward(self, *inputs, mode='train', **kwargs): return self.backbone(*inputs) else: raise Exception("No such mode: {}".format(mode)) - - - diff --git a/passl/modeling/heads/simclr_contrastive_head.py b/passl/modeling/heads/simclr_contrastive_head.py index 3906d4ac..96fc50f6 100755 --- a/passl/modeling/heads/simclr_contrastive_head.py +++ b/passl/modeling/heads/simclr_contrastive_head.py @@ -55,22 +55,24 @@ def forward(self, pos, neg): hidden1_large = hidden1 hidden2_large = hidden2 labels = F.one_hot( - paddle.reshape(paddle.arange(0, batch_size, 1, "int32"), - [batch_size]), batch_size * 2) + paddle.reshape( + paddle.arange(0, batch_size, 1, "int32"), [batch_size]), + batch_size * 2) masks = F.one_hot( - paddle.reshape(paddle.arange(0, batch_size, 1, "int32"), - [batch_size]), batch_size) + paddle.reshape( + paddle.arange(0, batch_size, 1, "int32"), [batch_size]), + batch_size) - logits_aa = paddle.matmul(hidden1, hidden1_large, - transpose_y=True) / self.temperature + logits_aa = paddle.matmul( + hidden1, hidden1_large, transpose_y=True) / self.temperature logits_aa = logits_aa - masks * LARGE_NUM - logits_bb = paddle.matmul(hidden2, hidden2_large, - transpose_y=True) / self.temperature + logits_bb = paddle.matmul( + hidden2, hidden2_large, transpose_y=True) / self.temperature logits_bb = logits_bb - masks * LARGE_NUM - logits_ab = paddle.matmul(hidden1, hidden2_large, - transpose_y=True) / self.temperature - logits_ba = paddle.matmul(hidden2, hidden1_large, - transpose_y=True) / self.temperature + logits_ab = paddle.matmul( + hidden1, hidden2_large, transpose_y=True) / self.temperature + logits_ba = paddle.matmul( + hidden2, hidden1_large, transpose_y=True) / self.temperature loss_a = paddle.nn.functional.softmax_with_cross_entropy( paddle.concat([logits_ab, logits_aa], 1), labels, soft_label=True) @@ -91,10 +93,10 @@ def forward(self, pos, neg): co2_loss = 1 * (kl_1 + kl_2) total_contrast_loss = contrast_loss + 3 * co2_loss - loss = layers.reduce_mean(total_contrast_loss) + loss = paddle.mean(total_contrast_loss) contrastive_label = paddle.unsqueeze(paddle.argmax(labels, axis=1), 1) - acc1 = layers.accuracy(input=logits_ab, label=contrastive_label) + acc1 = paddle.metric.accuracy(input=logits_ab, label=contrastive_label) outputs = dict() outputs['loss'] = loss outputs['acc1'] = acc1 diff --git a/passl/modeling/necks/base_neck.py b/passl/modeling/necks/base_neck.py index 21f16255..021874a3 100644 --- a/passl/modeling/necks/base_neck.py +++ b/passl/modeling/necks/base_neck.py @@ -80,9 +80,9 @@ def __init__(self, if with_avg_pool: self.avgpool = nn.AdaptiveAvgPool2D((1, 1)) - self.mlp = nn.Sequential(nn.Linear(in_channels, - hid_channels), nn.ReLU(), - nn.Linear(hid_channels, out_channels)) + self.mlp = nn.Sequential( + nn.Linear(in_channels, hid_channels), + nn.ReLU(), nn.Linear(hid_channels, out_channels)) # init_backbone_weight(self.mlp) self.init_parameters() @@ -113,9 +113,12 @@ def __init__(self, if with_avg_pool: self.avgpool = nn.AdaptiveAvgPool2D((1, 1)) - self.mlp = nn.Sequential(nn.Linear(in_channels, hid_channels, bias_attr=with_bias), - nn.BatchNorm1D(hid_channels), nn.ReLU(), - nn.Linear(hid_channels, out_channels)) + self.mlp = nn.Sequential( + nn.Linear( + in_channels, hid_channels, bias_attr=with_bias), + nn.BatchNorm1D(hid_channels), + nn.ReLU(), + nn.Linear(hid_channels, out_channels)) # init_backbone_weight(self.mlp) # self.init_parameters() @@ -190,9 +193,9 @@ def __init__(self, self.conv = BottleneckBlock(in_channels, in_channels // 4) - self.mlp = nn.Sequential(nn.Linear(in_channels, - hid_channels), nn.ReLU(), - nn.Linear(hid_channels, out_channels)) + self.mlp = nn.Sequential( + nn.Linear(in_channels, hid_channels), + nn.ReLU(), nn.Linear(hid_channels, out_channels)) init_backbone_weight(self.mlp) @@ -220,12 +223,14 @@ def __init__(self, self.with_avg_pool = with_avg_pool if with_avg_pool: self.avgpool = nn.AdaptiveAvgPool2D((1, 1)) - self.mlp = nn.Sequential(nn.Linear(in_channels, hid_channels), - nn.BatchNorm1D(hid_channels), nn.ReLU(), - nn.Linear(hid_channels, hid_channels), - nn.BatchNorm1D(hid_channels), nn.ReLU(), - nn.Linear(hid_channels, out_channels), - nn.BatchNorm1D(out_channels)) + self.mlp = nn.Sequential( + nn.Linear(in_channels, hid_channels), + nn.BatchNorm1D(hid_channels), + nn.ReLU(), + nn.Linear(hid_channels, hid_channels), + nn.BatchNorm1D(hid_channels), + nn.ReLU(), + nn.Linear(hid_channels, out_channels), nn.BatchNorm1D(out_channels)) init_backbone_weight_simclr(self.mlp) @@ -233,9 +238,9 @@ def init_parameters(self, init_linear='normal'): _init_parameters(self, init_linear) def forward(self, x): - x = layers.squeeze(x, axes=[]) + x = paddle.squeeze(x) hidden = self.mlp(x) - hidden = layers.l2_normalize(hidden, -1) + hidden = paddle.nn.functional.normalize(hidden, axis=-1) return hidden @@ -255,13 +260,21 @@ def __init__(self, self.with_avg_pool = with_avg_pool if with_avg_pool: self.avgpool = nn.AdaptiveAvgPool2D((1, 1)) - self.mlp = nn.Sequential(nn.Linear(in_channels, hid_channels, bias_attr=with_bias), - nn.BatchNorm1D(hid_channels), nn.ReLU(), - nn.Linear(hid_channels, hid_channels, bias_attr=with_bias), - nn.BatchNorm1D(hid_channels), nn.ReLU(), - nn.Linear(hid_channels, out_channels, bias_attr=with_bias), - nn.BatchNorm1D(out_channels, - weight_attr=with_last_bn_affine, bias_attr=with_last_bn_affine)) + self.mlp = nn.Sequential( + nn.Linear( + in_channels, hid_channels, bias_attr=with_bias), + nn.BatchNorm1D(hid_channels), + nn.ReLU(), + nn.Linear( + hid_channels, hid_channels, bias_attr=with_bias), + nn.BatchNorm1D(hid_channels), + nn.ReLU(), + nn.Linear( + hid_channels, out_channels, bias_attr=with_bias), + nn.BatchNorm1D( + out_channels, + weight_attr=with_last_bn_affine, + bias_attr=with_last_bn_affine)) init_backbone_weight_simclr(self.mlp) @@ -278,6 +291,7 @@ def forward(self, x): class SwAVNeck(nn.Layer): """The non-linear neck in SwAV: fc-bn-relu-fc-normalization. """ + def __init__(self, in_channels, hid_channels, @@ -297,9 +311,8 @@ def __init__(self, else: self.projection_neck = nn.Sequential( nn.Linear(in_channels, hid_channels), - nn.BatchNorm1D(hid_channels), nn.ReLU(), - nn.Linear(hid_channels, out_channels) - ) + nn.BatchNorm1D(hid_channels), + nn.ReLU(), nn.Linear(hid_channels, out_channels)) def forward_projection(self, x): if self.projection_neck is not None: @@ -330,20 +343,22 @@ class MLP2d(nn.Layer): def __init__(self, in_channels, hid_channels=4096, out_channels=256): super(MLP2d, self).__init__() - self.linear1 = nn.Conv2D(in_channels, - hid_channels, - kernel_size=1, - stride=1, - padding=0, - bias_attr=True) + self.linear1 = nn.Conv2D( + in_channels, + hid_channels, + kernel_size=1, + stride=1, + padding=0, + bias_attr=True) self.bn1 = nn.BatchNorm2D(hid_channels) self.relu1 = nn.ReLU() - self.linear2 = nn.Conv2D(hid_channels, - out_channels, - kernel_size=1, - stride=1, - padding=0, - bias_attr=True) + self.linear2 = nn.Conv2D( + hid_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + bias_attr=True) self.init_parameters() def init_parameters(self, init_linear='kaiming'): @@ -363,23 +378,20 @@ def forward(self, x): class DenseCLNeck(nn.Layer): """The non-linear neck in DenseCL: fc-relu-fc, conv-relu-conv. """ - def __init__(self, - in_channels, - hid_channels, - out_channels, - num_grid=None): + + def __init__(self, in_channels, hid_channels, out_channels, num_grid=None): super(DenseCLNeck, self).__init__() self.avgpool = nn.AdaptiveAvgPool2D((1, 1)) self.mlp = nn.Sequential( - nn.Linear(in_channels,hid_channels), nn.ReLU(), - nn.Linear(hid_channels, out_channels)) + nn.Linear(in_channels, hid_channels), + nn.ReLU(), nn.Linear(hid_channels, out_channels)) self.with_pool = num_grid != None if self.with_pool: self.pool = nn.AdaptiveAvgPool2D((num_grid, num_grid)) self.mlp2 = nn.Sequential( - nn.Conv2D(in_channels, hid_channels, 1), nn.ReLU(), - nn.Conv2D(hid_channels, out_channels, 1)) + nn.Conv2D(in_channels, hid_channels, 1), + nn.ReLU(), nn.Conv2D(hid_channels, out_channels, 1)) self.avgpool2 = nn.AdaptiveAvgPool2D((1, 1)) # init_backbone_weight(self.mlp and self.mlp2) From c8d3e8ce0d28e20692707249d107a915e96db1ac Mon Sep 17 00:00:00 2001 From: zhaoqi10 Date: Sun, 26 Mar 2023 22:03:44 +0800 Subject: [PATCH 3/9] fix simclr config. --- configs/simclr/simclr_r50_IM.yaml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/configs/simclr/simclr_r50_IM.yaml b/configs/simclr/simclr_r50_IM.yaml index efac4949..54e0d339 100755 --- a/configs/simclr/simclr_r50_IM.yaml +++ b/configs/simclr/simclr_r50_IM.yaml @@ -23,7 +23,7 @@ model: dataloader: train: loader: - num_workers: 6 + num_workers: 4 use_shared_memory: True sampler: batch_size: 32 @@ -90,7 +90,7 @@ dataloader: num_workers: 4 use_shared_memory: True sampler: - batch_size: 512 + batch_size: 256 shuffle: false drop_last: false dataset: @@ -110,18 +110,18 @@ dataloader: lr_scheduler: name: simclrCosineWarmup - learning_rate_scaling: sqrt + learning_rate_scaling: linear total_images: 1281167 warmup_epochs: 10 start_lr: 0 - end_lr: 1.0 + end_lr: 0.3 T_max: 200 optimizer: name: LarsMomentumOptimizer momentum: 0.9 - lars_weight_decay: 0.0001 + lars_weight_decay: 1e-6 exclude_from_weight_decay: ["scale","offset",".bias"] log_config: From 8f2304e64793cd9a56faff7fc24ed63e4893edb7 Mon Sep 17 00:00:00 2001 From: zhaoqi10 Date: Tue, 28 Mar 2023 16:13:22 +0800 Subject: [PATCH 4/9] fix byol lr scheduler with warmup and global_batch_size setting in trainer. --- configs/byol/byol_r50_IM.yaml | 15 ++++++++------- passl/engine/trainer.py | 12 ++++-------- passl/modeling/architectures/BYOL.py | 4 ++-- 3 files changed, 14 insertions(+), 17 deletions(-) diff --git a/configs/byol/byol_r50_IM.yaml b/configs/byol/byol_r50_IM.yaml index 56ae2970..a8a59fb8 100644 --- a/configs/byol/byol_r50_IM.yaml +++ b/configs/byol/byol_r50_IM.yaml @@ -1,5 +1,5 @@ epochs: 300 -use_byol_iters: True +use_simclr_iters: True total_images: 1281167 global_batch_size: 4096 output_dir: output_dir @@ -84,12 +84,13 @@ dataloader: lr_scheduler: - name: CosineWarmup - learning_rate: 4.8 - T_max: 93835 - warmup_steps: 3127 - start_lr: 0.0001 - end_lr: 4.8 + name: simclrCosineWarmup + learning_rate_scaling: linear + total_images: 1281167 + warmup_epochs: 10 + start_lr: 0 + end_lr: 0.3 + T_max: 300 optimizer: diff --git a/passl/engine/trainer.py b/passl/engine/trainer.py index 649459a6..296ea737 100644 --- a/passl/engine/trainer.py +++ b/passl/engine/trainer.py @@ -145,21 +145,17 @@ def __init__(self, cfg): self.train_dataloader, self.mixup_fn = build_dataloader( cfg.dataloader.train, self.device) self.iters_per_epoch = len(self.train_dataloader) - + self.batch_size = cfg.dataloader.train.sampler.batch_size + self.global_batch_size = self.batch_size * dist.get_world_size() # use byol iters if self.use_byol_iters: - self.global_batch_size = cfg.global_batch_size self.byol_total_iters = self.epochs * cfg.total_images // self.global_batch_size - - if self.use_byol_iters: self.lr_scheduler = build_lr_scheduler(cfg.lr_scheduler, self.byol_total_iters) elif self.use_simclr_iters: - self.batch_size = cfg.dataloader.train.sampler.batch_size - self.global_batch_size = cfg.global_batch_size self.epochs = cfg.epochs self.lr_scheduler = build_lr_scheduler_simclr( - cfg.lr_scheduler, self.iters_per_epoch, self.batch_size * 8, + cfg.lr_scheduler, self.iters_per_epoch, self.global_batch_size, cfg.epochs, self.current_iter) else: self.lr_scheduler = build_lr_scheduler(cfg.lr_scheduler, @@ -224,7 +220,7 @@ def __init__(self, cfg): self.add_train_hooks() self.add_custom_hooks() self.hooks = sorted(self.hooks, key=lambda x: x.priority) - + print("hooks: ", self.hooks) if self.epochs: self.total_iters = self.epochs * self.iters_per_epoch self.by_epoch = True diff --git a/passl/modeling/architectures/BYOL.py b/passl/modeling/architectures/BYOL.py index 37cf06d3..236dd976 100644 --- a/passl/modeling/architectures/BYOL.py +++ b/passl/modeling/architectures/BYOL.py @@ -161,8 +161,8 @@ def generate_selector(p, bsz): @MODELS.register() class BYOL(nn.Layer): """ - Build a MoCo model with: a query encoder, a key encoder, and a queue - https://arxiv.org/abs/1911.05722 + Build a BYOL model referenced from paper + https://arxiv.org/abs/2006.07733 """ def __init__(self, From 994b8a5ff751384d3c3853903a242d4b9478b785 Mon Sep 17 00:00:00 2001 From: zhaoqi10 Date: Fri, 7 Apr 2023 19:30:26 +0800 Subject: [PATCH 5/9] add CI tests for ssl models. --- tests/CI/cash.sh | 102 ++++++++++++++++++ tests/CI/ssl/byol/byol_r50_IM_linear.sh | 9 ++ tests/CI/ssl/byol/byol_r50_IM_pretrain.sh | 9 ++ tests/CI/ssl/moco/moco_v1_r50_linear.sh | 9 ++ tests/CI/ssl/moco/moco_v1_r50_pretrain.sh | 8 ++ tests/CI/ssl/moco/moco_v2_r50_linear.sh | 8 ++ tests/CI/ssl/moco/moco_v2_r50_pretrain.sh | 0 tests/CI/ssl/simclr/simclr_r50_IM_linear.sh | 9 ++ tests/CI/ssl/simclr/simclr_r50_IM_pretrain.sh | 6 ++ tests/CI/ssl/simsiam/simsiam_r50_IM_linear.sh | 9 ++ .../CI/ssl/simsiam/simsiam_r50_IM_pretrain.sh | 8 ++ tests/CI/ssl/swav/swav_r50_IM_linear.sh | 9 ++ tests/CI/ssl/swav/swav_r50_IM_pretrain.sh | 6 ++ 13 files changed, 192 insertions(+) create mode 100644 tests/CI/cash.sh create mode 100644 tests/CI/ssl/byol/byol_r50_IM_linear.sh create mode 100644 tests/CI/ssl/byol/byol_r50_IM_pretrain.sh create mode 100644 tests/CI/ssl/moco/moco_v1_r50_linear.sh create mode 100644 tests/CI/ssl/moco/moco_v1_r50_pretrain.sh create mode 100644 tests/CI/ssl/moco/moco_v2_r50_linear.sh create mode 100644 tests/CI/ssl/moco/moco_v2_r50_pretrain.sh create mode 100644 tests/CI/ssl/simclr/simclr_r50_IM_linear.sh create mode 100644 tests/CI/ssl/simclr/simclr_r50_IM_pretrain.sh create mode 100644 tests/CI/ssl/simsiam/simsiam_r50_IM_linear.sh create mode 100644 tests/CI/ssl/simsiam/simsiam_r50_IM_pretrain.sh create mode 100644 tests/CI/ssl/swav/swav_r50_IM_linear.sh create mode 100644 tests/CI/ssl/swav/swav_r50_IM_pretrain.sh diff --git a/tests/CI/cash.sh b/tests/CI/cash.sh new file mode 100644 index 00000000..b8a6a85a --- /dev/null +++ b/tests/CI/cash.sh @@ -0,0 +1,102 @@ +set -e + +export passl_path=/paddle/PASSL/tests/CI +export log_path=/paddle/log_passl +passl_gpu_model_list=( \ + moco_v1_r50_pretrain \ + moco_v1_r50_linear \ + moco_v2_r50_pretrain \ + moco_v2_r50_linear \ + simclr_r50_IM_pretrain \ + simclr_r50_IM_linear \ + byol_r50_IM_pretrain \ + byol_r50_IM_linear \ + simsiam_r50_IM_pretrain \ + simsiam_r50_IM_linear \ + swav_r50_IM_pretrain \ + swav_r50_IM_linear \ +) + +function moco_v1_r50_pretrain(){ + cd ${passl_path} + rm -rf log + bash ./ssl/moco/moco_v1_r50_pretrain.sh + loss=`tail log/workerlog.0 | grep "50/200" | cut -d " " -f17 ` + check_result 1.3840e+00 ${loss%?} $FUNCNAME} + +function moco_v1_r50_linear(){ + cd ${passl_path} + rm -rf log + bash ./ssl/moco/moco_v1_r50_linear.sh + loss=`tail log/workerlog.0 | grep "50/200" | cut -d " " -f17 ` + check_result 1.3840e+00 ${loss%?} $FUNCNAME} + +function moco_v2_r50_pretrain(){ + cd ${passl_path} + rm -rf log + bash ./ssl/moco/moco_v2_r50_pretrain.sh + loss=`tail log/workerlog.0 | grep "50/200" | cut -d " " -f17 ` + check_result 7.0774e+00 ${loss%?} $FUNCNAME} + +function moco_v2_r50_linear(){ + cd ${passl_path} + rm -rf log + bash ./ssl/moco/moco_v2_r50_linear.sh + loss=`tail log/workerlog.0 | grep "50/200" | cut -d " " -f17 ` + check_result 7.0774e+00 ${loss%?} $FUNCNAME} + +function simclr_r50_IM_pretrain(){ + cd ${passl_path} + rm -rf log + bash ./ssl/simclr/simclr_r50_IM_pretrain.sh + loss=`tail log/workerlog.0 | grep "50/200" | cut -d " " -f17 ` + check_result 1.3840e+00 ${loss%?} $FUNCNAME} + +function simclr_r50_IM_linear(){ + cd ${passl_path} + rm -rf log + bash ./ssl/simclr/simclr_r50_IM_linear.sh + loss=`tail log/workerlog.0 | grep "50/200" | cut -d " " -f17 ` + check_result 1.3840e+00 ${loss%?} $FUNCNAME} + +function byol_r50_IM_pretrain(){ + cd ${passl_path} + rm -rf log + bash ./ssl/byol/byol_r50_IM_pretrain.sh + loss=`tail log/workerlog.0 | grep "50/200" | cut -d " " -f17 ` + check_result 1.3840e+00 ${loss%?} $FUNCNAME} + +function byol_r50_IM_linear(){ + cd ${passl_path} + rm -rf log + bash ./ssl/byol/byol_r50_IM_linear.sh + loss=`tail log/workerlog.0 | grep "50/200" | cut -d " " -f17 ` + check_result 1.3840e+00 ${loss%?} $FUNCNAME} + +function simsiam_r50_IM_pretrain(){ + cd ${passl_path} + rm -rf log + bash ./ssl/simsiam/simsiam_r50_IM_pretrain.sh + loss=`tail log/workerlog.0 | grep "50/100" | cut -d " " -f17 ` + check_result 1.3840e+00 ${loss%?} $FUNCNAME} + +function simsiam_r50_IM_linear(){ + cd ${passl_path} + rm -rf log + bash ./ssl/simsiam/simsiam_r50_IM_linear.sh + loss=`tail log/workerlog.0 | grep "50/100" | cut -d " " -f17 ` + check_result 1.3840e+00 ${loss%?} $FUNCNAME} + +function swav_r50_IM_pretrain(){ + cd ${passl_path} + rm -rf log + bash ./ssl/swav/swav_r50_IM_pretrain.sh + loss=`tail log/workerlog.0 | grep "50/100" | cut -d " " -f17 ` + check_result 1.3840e+00 ${loss%?} $FUNCNAME} + +function swav_r50_IM_linear(){ + cd ${passl_path} + rm -rf log + bash ./ssl/swav/swav_r50_IM_linear.sh + loss=`tail log/workerlog.0 | grep "50/100" | cut -d " " -f17 ` + check_result 1.3840e+00 ${loss%?} $FUNCNAME} \ No newline at end of file diff --git a/tests/CI/ssl/byol/byol_r50_IM_linear.sh b/tests/CI/ssl/byol/byol_r50_IM_linear.sh new file mode 100644 index 00000000..4df653a9 --- /dev/null +++ b/tests/CI/ssl/byol/byol_r50_IM_linear.sh @@ -0,0 +1,9 @@ +FLAGS_cudnn_exhaustive_search=0 +export PADDLE_NNODES=1 +export PADDLE_MASTER="127.0.0.1:12538" +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +python -m paddle.distributed.launch --devices=$CUDA_VISIBLE_DEVICES ../../../../tools/train.py \ + -c ../../../../configs/byol/byol_clas_r50.yaml \ + -o epochs=50 \ + --pretrain ./pretrained/ssl/pretrain/byol_r50_backbone.pd + diff --git a/tests/CI/ssl/byol/byol_r50_IM_pretrain.sh b/tests/CI/ssl/byol/byol_r50_IM_pretrain.sh new file mode 100644 index 00000000..5e590e64 --- /dev/null +++ b/tests/CI/ssl/byol/byol_r50_IM_pretrain.sh @@ -0,0 +1,9 @@ +FLAGS_cudnn_exhaustive_search=0 +export PADDLE_NNODES=1 +export PADDLE_MASTER="127.0.0.1:12538" +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +python -m paddle.distributed.launch --devices=$CUDA_VISIBLE_DEVICES ../../../../tools/train.py \ + -c ../../../../configs/byol/byol_r50_IM.yaml \ + -o epochs=50 \ + --pretrain ./pretrained/ssl/byol_r50_backbone.pd + diff --git a/tests/CI/ssl/moco/moco_v1_r50_linear.sh b/tests/CI/ssl/moco/moco_v1_r50_linear.sh new file mode 100644 index 00000000..9a843149 --- /dev/null +++ b/tests/CI/ssl/moco/moco_v1_r50_linear.sh @@ -0,0 +1,9 @@ +FLAGS_cudnn_exhaustive_search=0 +export PADDLE_NNODES=1 +export PADDLE_MASTER="127.0.0.1:12538" +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +python -m paddle.distributed.launch --devices=$CUDA_VISIBLE_DEVICES ../../../../tools/train.py \ + -c ../../../../configs/moco/moco_clas_r50.yaml \ + -o epochs=50 \ + --pretrain ./pretrained/ssl/moco_v1_r50_backbone.pd + diff --git a/tests/CI/ssl/moco/moco_v1_r50_pretrain.sh b/tests/CI/ssl/moco/moco_v1_r50_pretrain.sh new file mode 100644 index 00000000..addcc226 --- /dev/null +++ b/tests/CI/ssl/moco/moco_v1_r50_pretrain.sh @@ -0,0 +1,8 @@ +FLAGS_cudnn_exhaustive_search=0 +export PADDLE_NNODES=1 +export PADDLE_MASTER="127.0.0.1:12538" +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +python -m paddle.distributed.launch --devices=$CUDA_VISIBLE_DEVICES ../../../../tools/train.py \ + -c ../../../../configs/moco/moco_v1_r50.yaml \ + -o epochs=50 + diff --git a/tests/CI/ssl/moco/moco_v2_r50_linear.sh b/tests/CI/ssl/moco/moco_v2_r50_linear.sh new file mode 100644 index 00000000..ae915664 --- /dev/null +++ b/tests/CI/ssl/moco/moco_v2_r50_linear.sh @@ -0,0 +1,8 @@ +FLAGS_cudnn_exhaustive_search=0 +export PADDLE_NNODES=1 +export PADDLE_MASTER="127.0.0.1:12538" +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +python -m paddle.distributed.launch --devices=$CUDA_VISIBLE_DEVICES ../../../../tools/train.py \ + -c ../../../../configs/moco/moco_clas_r50.yaml \ + -o epochs=50 \ + --pretrain ./pretrained/ssl/moco_v2_r50_backbone.pd \ No newline at end of file diff --git a/tests/CI/ssl/moco/moco_v2_r50_pretrain.sh b/tests/CI/ssl/moco/moco_v2_r50_pretrain.sh new file mode 100644 index 00000000..e69de29b diff --git a/tests/CI/ssl/simclr/simclr_r50_IM_linear.sh b/tests/CI/ssl/simclr/simclr_r50_IM_linear.sh new file mode 100644 index 00000000..c710f9de --- /dev/null +++ b/tests/CI/ssl/simclr/simclr_r50_IM_linear.sh @@ -0,0 +1,9 @@ +FLAGS_cudnn_exhaustive_search=0 +export PADDLE_NNODES=1 +export PADDLE_MASTER="127.0.0.1:12538" +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +python -m paddle.distributed.launch --devices=$CUDA_VISIBLE_DEVICES ../../../../tools/train.py \ + -c ../../../../configs/simclr/simclr_clas_r50.yaml \ + -o epochs=50 \ + --pretrain ./pretrained/ssl/simclr_r50_backbone.pd + diff --git a/tests/CI/ssl/simclr/simclr_r50_IM_pretrain.sh b/tests/CI/ssl/simclr/simclr_r50_IM_pretrain.sh new file mode 100644 index 00000000..a0ea8c11 --- /dev/null +++ b/tests/CI/ssl/simclr/simclr_r50_IM_pretrain.sh @@ -0,0 +1,6 @@ +FLAGS_cudnn_exhaustive_search=0 +export PADDLE_NNODES=1 +export PADDLE_MASTER="127.0.0.1:12538" +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +python -m paddle.distributed.launch --devices=$CUDA_VISIBLE_DEVICES ../../../../tools/train.py -c ../../../../configs/simclr/simclr_r50_IM.yaml + diff --git a/tests/CI/ssl/simsiam/simsiam_r50_IM_linear.sh b/tests/CI/ssl/simsiam/simsiam_r50_IM_linear.sh new file mode 100644 index 00000000..1181f01a --- /dev/null +++ b/tests/CI/ssl/simsiam/simsiam_r50_IM_linear.sh @@ -0,0 +1,9 @@ +FLAGS_cudnn_exhaustive_search=0 +export PADDLE_NNODES=1 +export PADDLE_MASTER="127.0.0.1:12538" +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +python -m paddle.distributed.launch --devices=$CUDA_VISIBLE_DEVICES ../../../../tools/train.py \ + -c ../../../../configs/simsiam/simsiam_clas_r50.yaml \ + -o epochs=50 \ + --pretrain ./pretrained/ssl/simsiam_r50_backbone.pd + diff --git a/tests/CI/ssl/simsiam/simsiam_r50_IM_pretrain.sh b/tests/CI/ssl/simsiam/simsiam_r50_IM_pretrain.sh new file mode 100644 index 00000000..84062bb6 --- /dev/null +++ b/tests/CI/ssl/simsiam/simsiam_r50_IM_pretrain.sh @@ -0,0 +1,8 @@ +FLAGS_cudnn_exhaustive_search=0 +export PADDLE_NNODES=1 +export PADDLE_MASTER="127.0.0.1:12538" +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +python -m paddle.distributed.launch --devices=$CUDA_VISIBLE_DEVICES ../../../../tools/train.py \ + -c ../../../../configs/simsiam/simsiam_r50.yaml \ + -o epochs=50 \ + diff --git a/tests/CI/ssl/swav/swav_r50_IM_linear.sh b/tests/CI/ssl/swav/swav_r50_IM_linear.sh new file mode 100644 index 00000000..ac39c8ed --- /dev/null +++ b/tests/CI/ssl/swav/swav_r50_IM_linear.sh @@ -0,0 +1,9 @@ +FLAGS_cudnn_exhaustive_search=0 +export PADDLE_NNODES=1 +export PADDLE_MASTER="127.0.0.1:12538" +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +python -m paddle.distributed.launch --devices=$CUDA_VISIBLE_DEVICES ../../../../tools/train.py \ + -c ../../../../configs/swav/swav_clas_r50.yaml \ + -o epochs=50 \ + --pretrain ./pretrained/ssl/swav_r50_backbone.pd + diff --git a/tests/CI/ssl/swav/swav_r50_IM_pretrain.sh b/tests/CI/ssl/swav/swav_r50_IM_pretrain.sh new file mode 100644 index 00000000..a4e1f201 --- /dev/null +++ b/tests/CI/ssl/swav/swav_r50_IM_pretrain.sh @@ -0,0 +1,6 @@ +FLAGS_cudnn_exhaustive_search=0 +export PADDLE_NNODES=1 +export PADDLE_MASTER="127.0.0.1:12538" +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +python -m paddle.distributed.launch --devices=$CUDA_VISIBLE_DEVICES ../../../../tools/train.py -c ../../../../configs/simclr/swav_r50_100ep.yaml + From d7ffef88d6d1c193b706e185a459f3e72ba2708f Mon Sep 17 00:00:00 2001 From: zhaoqi10 Date: Mon, 10 Apr 2023 20:37:10 +0800 Subject: [PATCH 6/9] rm modeling in new passl. --- passl/modeling/backbones/moco_vit.py | 125 ------------------ passl/modeling/init.py | 188 --------------------------- 2 files changed, 313 deletions(-) delete mode 100644 passl/modeling/backbones/moco_vit.py delete mode 100644 passl/modeling/init.py diff --git a/passl/modeling/backbones/moco_vit.py b/passl/modeling/backbones/moco_vit.py deleted file mode 100644 index f48d82e2..00000000 --- a/passl/modeling/backbones/moco_vit.py +++ /dev/null @@ -1,125 +0,0 @@ -import math -import paddle -import paddle.nn as nn -from functools import partial, reduce -from operator import mul -from .builder import BACKBONES - -from .. import init -from vision_transformer import VisionTransformer, PatchEmbed, to_2tuple - - -@BACKBONES.register() -class VisionTransformerMoCo(VisionTransformer): - def __init__(self, stop_grad_conv1=False, **kwargs): - super().__init__(**kwargs) - # Use fixed 2D sin-cos position embedding - self.build_2d_sincos_position_embedding() - - # weight initialization - for name, m in self.named_sublayers(): - if isinstance(m, nn.Linear): - if 'qkv' in name: - # treat the weights of Q, K, V separately - val = math.sqrt( - 6. / float(m.weight.shape[1] // 3 + m.weight.shape[0])) - init.uniform_(m.weight, -val, val) - else: - init.xavier_uniform_(m.weight) - init.zeros_(m.bias) - init.normal_(self.cls_token, std=1e-6) - - if isinstance(self.patch_embed, PatchEmbed): - # xavier_uniform initialization - val = math.sqrt(6. / float(3 * reduce( - mul, self.patch_embed.patch_size, 1) + self.embed_dim)) - init.uniform_(self.patch_embed.proj.weight, -val, val) - init.zeros_(self.patch_embed.proj.bias) - - if stop_grad_conv1: - self.patch_embed.proj.weight.stop_gradient = True - self.patch_embed.proj.bias.stop_gradient = True - - def build_2d_sincos_position_embedding(self, temperature=10000.): - h = self.patch_embed.img_size[0] // self.patch_embed.patch_size[0] - w = self.patch_embed.img_size[1] // self.patch_embed.patch_size[1] - grid_w = paddle.arange(w, dtype=paddle.float32) - grid_h = paddle.arange(h, dtype=paddle.float32) - grid_w, grid_h = paddle.meshgrid(grid_w, grid_h) - assert self.embed_dim % 4 == 0, 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding' - pos_dim = self.embed_dim // 4 - omega = paddle.arange(pos_dim, dtype=paddle.float32) / pos_dim - omega = 1. / (temperature**omega) - - out_w = grid_w.flatten()[..., None] @omega[None] - out_h = grid_h.flatten()[..., None] @omega[None] - pos_emb = paddle.concat( - [ - paddle.sin(out_w), paddle.cos(out_w), paddle.sin(out_h), - paddle.cos(out_h) - ], - axis=1)[None, :, :] - pe_token = paddle.zeros([1, 1, self.embed_dim], dtype=paddle.float32) - - pos_embed = paddle.concat([pe_token, pos_emb], axis=1) - self.pos_embed = self.create_parameter(shape=pos_embed.shape) - self.pos_embed.set_value(pos_embed) - self.pos_embed.stop_gradient = True - - -class ConvStem(nn.Layer): - """ - ConvStem, from Early Convolutions Help Transformers See Better, Tete et al. https://arxiv.org/abs/2106.14881 - """ - - def __init__(self, - img_size=224, - patch_size=16, - in_chans=3, - embed_dim=768, - norm_layer=None, - flatten=True): - super().__init__() - - assert patch_size == 16, 'ConvStem only supports patch size of 16' - assert embed_dim % 8 == 0, 'Embed dimension must be divisible by 8 for ConvStem' - - img_size = to_2tuple(img_size) - patch_size = to_2tuple(patch_size) - self.img_size = img_size - self.patch_size = patch_size - self.grid_size = (img_size[0] // patch_size[0], - img_size[1] // patch_size[1]) - self.num_patches = self.grid_size[0] * self.grid_size[1] - self.flatten = flatten - - # build stem, similar to the design in https://arxiv.org/abs/2106.14881 - stem = [] - input_dim, output_dim = 3, embed_dim // 8 - for l in range(4): - stem.append( - nn.Conv2D( - input_dim, - output_dim, - kernel_size=3, - stride=2, - padding=1, - bias_attr=False)) - stem.append(nn.BatchNorm2D(output_dim)) - stem.append(nn.ReLU()) - input_dim = output_dim - output_dim *= 2 - stem.append(nn.Conv2D(input_dim, embed_dim, kernel_size=1)) - self.proj = nn.Sequential(*stem) - - self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() - - def forward(self, x): - B, C, H, W = x.shape - assert H == self.img_size[0] and W == self.img_size[1], \ - f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." - x = self.proj(x) - if self.flatten: - x = x.flatten(2).transpose((0, 2, 1)) # BCHW -> BNC - x = self.norm(x) - return x \ No newline at end of file diff --git a/passl/modeling/init.py b/passl/modeling/init.py deleted file mode 100644 index 92aa94fe..00000000 --- a/passl/modeling/init.py +++ /dev/null @@ -1,188 +0,0 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math -import warnings -import paddle - - -@paddle.no_grad() -def zeros_(x): - return x.zero_() - - -@paddle.no_grad() -def ones_(x): - return x.fill_(1.0) - - -@paddle.no_grad() -def constant_(x, value): - return x.fill_(value) - - -@paddle.no_grad() -def normal_(x, mean=0., std=1.): - temp_value = paddle.tensor.random.gaussian( - shape=x.shape, mean=mean, std=std, dtype=x.dtype) - x.copy_(temp_value, False) - return x - - -@paddle.no_grad() -def uniform_(x, a=0., b=1.): - temp_value = paddle.tensor.random.uniform( - shape=x.shape, min=a, max=b, dtype=x.dtype) - x.copy_(temp_value, False) - return x - - -def _calculate_fan_in_and_fan_out(tensor): - dimensions = tensor.dim() - if dimensions < 2: - raise ValueError( - "Fan in and fan out can not be computed for tensor with fewer than 2 dimensions" - ) - - num_input_fmaps = tensor.shape[1] - num_output_fmaps = tensor.shape[0] - receptive_field_size = 1 - if tensor.dim() > 2: - # math.prod is not always available, accumulate the product manually - # we could use functools.reduce but that is not supported by TorchScript - for s in tensor.shape[2:]: - receptive_field_size *= s - fan_in = num_input_fmaps * receptive_field_size - fan_out = num_output_fmaps * receptive_field_size - - return fan_in, fan_out - - -def _calculate_correct_fan(tensor, mode): - mode = mode.lower() - valid_modes = ['fan_in', 'fan_out'] - if mode not in valid_modes: - raise ValueError("Mode {} not supported, please use one of {}".format( - mode, valid_modes)) - - fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) - return fan_in if mode == 'fan_in' else fan_out - - -def calculate_gain(nonlinearity, param=None): - linear_fns = [ - 'linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', - 'conv_transpose2d', 'conv_transpose3d' - ] - if nonlinearity in linear_fns or nonlinearity == 'sigmoid': - return 1 - elif nonlinearity == 'tanh': - return 5.0 / 3 - elif nonlinearity == 'relu': - return math.sqrt(2.0) - elif nonlinearity == 'leaky_relu': - if param is None: - negative_slope = 0.01 - elif not isinstance(param, bool) and isinstance( - param, int) or isinstance(param, float): - # True/False are instances of int, hence check above - negative_slope = param - else: - raise ValueError("negative_slope {} not a valid number".format( - param)) - return math.sqrt(2.0 / (1 + negative_slope**2)) - elif nonlinearity == 'selu': - return 3.0 / 4 # Value found empirically (https://github.com/pytorch/pytorch/pull/50664) - else: - raise ValueError("Unsupported nonlinearity {}".format(nonlinearity)) - - -@paddle.no_grad() -def kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu'): - if 0 in tensor.shape: - warnings.warn("Initializing zero-element tensors is a no-op") - return tensor - fan = _calculate_correct_fan(tensor, mode) - gain = calculate_gain(nonlinearity, a) - std = gain / math.sqrt(fan) - bound = math.sqrt( - 3.0) * std # Calculate uniform bounds from standard deviation - return uniform_(tensor, -bound, bound) - - -@paddle.no_grad() -def kaiming_normal_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu'): - if 0 in tensor.shape: - warnings.warn("Initializing zero-element tensors is a no-op") - return tensor - fan = _calculate_correct_fan(tensor, mode) - gain = calculate_gain(nonlinearity, a) - std = gain / math.sqrt(fan) - normal_(tensor, 0, std) - - -@paddle.no_grad() -def xavier_uniform_(tensor, gain=1.): - fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) - std = gain * math.sqrt(2.0 / float(fan_in + fan_out)) - a = math.sqrt( - 3.0) * std # Calculate uniform bounds from standard deviation - return uniform_(tensor, -a, a) - - -@paddle.no_grad() -def xavier_normal_(tensor, gain=1.): - fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) - std = gain * math.sqrt(2.0 / float(fan_in + fan_out)) - return normal_(tensor, 0., std) - - -@paddle.no_grad() -def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): - # Modified based on PyTorch nn.init.trunc_normal_ - # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf - def norm_cdf(x): - # Computes standard normal cumulative distribution function - return (1. + math.erf(x / math.sqrt(2.))) / 2. - - if (mean < a - 2 * std) or (mean > b + 2 * std): - warnings.warn( - "mean is more than 2 std from [a, b] in trunc_normal_. " - "The distribution of values may be incorrect.", - stacklevel=2) - - # Values are generated by using a truncated uniform distribution and - # then using the inverse CDF for the normal distribution. - # Get upper and lower cdf values - l = norm_cdf((a - mean) / std) - u = norm_cdf((b - mean) / std) - - # Uniformly fill tensor with values from [l, u], then translate to - # [2l-1, 2u-1]. - tmp = paddle.zeros_like(tensor, dtype='float32') - tmp.uniform_(2 * l - 1, 2 * u - 1) - - # Use inverse cdf transform for normal distribution to get truncated - # standard normal - tmp.erfinv_() - - # Transform to proper mean, std - tmp.scale_(std * math.sqrt(2.)) - tmp.add_(paddle.to_tensor(mean, dtype='float32')) - - # Clip to ensure it's in the proper range - tmp.clip_(min=a, max=b) - tmp = tmp.astype(tensor.dtype) - tensor.copy_(tmp, False) - return tensor From 1269b6d515448575823f3de0342bed774dff79ba Mon Sep 17 00:00:00 2001 From: zhaoqi10 Date: Tue, 11 Apr 2023 09:59:02 +0800 Subject: [PATCH 7/9] fix syntax error in case. --- tests/CI/case.sh | 90 ++++++++++++++++++++++++++---------------------- 1 file changed, 49 insertions(+), 41 deletions(-) diff --git a/tests/CI/case.sh b/tests/CI/case.sh index aa79d4e2..a2e6621a 100644 --- a/tests/CI/case.sh +++ b/tests/CI/case.sh @@ -210,65 +210,73 @@ function cae_base_patch16_224_lp_in1k_1n8c_dp_fp16o1() { ips=`cat log/workerlog.0 |grep time: |awk -F: '{print $10}' |cut -d " " -f2|awk 'NR>20 {print}' | awk '{a+=$1}END{print a/NR}'` check_result 6.7196 ${loss} 1.07848 ${ips} $FUNCNAME } -###### SimCLR ###### +####### SimCLR ###### function simclr_r50_IM_pretrain(){ - cd ${passl_path} - rm -rf log - bash ./ssl/simclr/simclr_r50_IM_pretrain.sh - loss=`tail log/workerlog.0 | grep "50/5004" | cut -d " " -f17 ` - check_result 2.8107e+01 ${loss%?} $FUNCNAME} + cd ${passl_path} + rm -rf log + bash ./ssl/simclr/simclr_r50_IM_pretrain.sh + loss=`tail log/workerlog.0 | grep "50/5004" | cut -d " " -f17 ` + check_result 2.8107e+01 ${loss%?} $FUNCNAME +} function simclr_r50_IM_linear(){ - cd ${passl_path} - rm -rf log - bash ./ssl/simclr/simclr_r50_IM_linear.sh - loss=`tail log/workerlog.0 | grep "50/312" | cut -d " " -f17 ` - check_result 6.8498e+00 ${loss%?} $FUNCNAME} + cd ${passl_path} + rm -rf log + bash ./ssl/simclr/simclr_r50_IM_linear.sh + loss=`tail log/workerlog.0 | grep "50/312" | cut -d " " -f17 ` + check_result 6.8498e+00 ${loss%?} $FUNCNAME +} ###### BYOL ###### function byol_r50_IM_pretrain(){ - cd ${passl_path} - rm -rf log - bash ./ssl/byol/byol_r50_IM_pretrain.sh - loss=`tail log/workerlog.0 | grep "50/1251" | cut -d " " -f33 ` - check_result 8.9050e+00 ${loss%?} $FUNCNAME} + cd ${passl_path} + rm -rf log + bash ./ssl/byol/byol_r50_IM_pretrain.sh + loss=`tail log/workerlog.0 | grep "50/1251" | cut -d " " -f33 ` + check_result 8.9050e+00 ${loss%?} $FUNCNAME +} function byol_r50_IM_linear(){ - cd ${passl_path} - rm -rf log - bash ./ssl/byol/byol_r50_IM_linear.sh - loss=`tail log/workerlog.0 | grep "50/1252" | cut -d " " -f15 ` - check_result 1.0264e+08 ${loss%?} $FUNCNAME} + cd ${passl_path} + rm -rf log + bash ./ssl/byol/byol_r50_IM_linear.sh + loss=`tail log/workerlog.0 | grep "50/1252" | cut -d " " -f15 ` + check_result 1.0264e+08 ${loss%?} $FUNCNAME +} ###### SimSiam ###### function simsiam_r50_IM_pretrain(){ - cd ${passl_path} - rm -rf log - bash ./ssl/simsiam/simsiam_r50_IM_pretrain.sh - loss=`tail log/workerlog.0 | grep "50/2502" | cut -d " " -f17 ` - check_result -3.9680e-01 ${loss%?} $FUNCNAME} + cd ${passl_path} + rm -rf log + bash ./ssl/simsiam/simsiam_r50_IM_pretrain.sh + loss=`tail log/workerlog.0 | grep "50/2502" | cut -d " " -f17 ` + check_result -3.9680e-01 ${loss%?} $FUNCNAME +} function simsiam_r50_IM_linear(){ - cd ${passl_path} - rm -rf log - bash ./ssl/simsiam/simsiam_r50_IM_linear.sh - loss=`tail log/workerlog.0 | grep "50/312" | cut -d " " -f17 ` - check_result 6.8936e+00 ${loss%?} $FUNCNAME} + cd ${passl_path} + rm -rf log + bash ./ssl/simsiam/simsiam_r50_IM_linear.sh + loss=`tail log/workerlog.0 | grep "50/312" | cut -d " " -f17 ` + check_result 6.8936e+00 ${loss%?} $FUNCNAME +} ###### SWAV ###### function swav_r50_IM_pretrain(){ - cd ${passl_path} - rm -rf log - bash ./ssl/swav/swav_r50_IM_pretrain.sh - loss=`tail log/workerlog.0 | grep "50/5004" | cut -d " " -f17 ` - check_result 8.3500e+00 ${loss%?} $FUNCNAME} + cd ${passl_path} + rm -rf log + bash ./ssl/swav/swav_r50_IM_pretrain.sh + loss=`tail log/workerlog.0 | grep "50/5004" | cut -d " " -f17 ` + check_result 8.3500e+00 ${loss%?} $FUNCNAME +} function swav_r50_IM_linear(){ - cd ${passl_path} - rm -rf log - bash ./ssl/swav/swav_r50_IM_linear.sh - loss=`tail log/workerlog.0 | grep "50/5004" | cut -d " " -f17 ` - check_result 4.7808e+00 ${loss%?} $FUNCNAME} + cd ${passl_path} + rm -rf log + bash ./ssl/swav/swav_r50_IM_linear.sh + loss=`tail log/workerlog.0 | grep "50/5004" | cut -d " " -f17 ` + check_result 4.7808e+00 ${loss%?} $FUNCNAME +} function check_result() { From 5cca2e71bbdd2c944ea2ea4825021c2948727d23 Mon Sep 17 00:00:00 2001 From: zhaoqi10 Date: Tue, 11 Apr 2023 16:02:21 +0800 Subject: [PATCH 8/9] fix case names and update requirements. --- requirements.txt | 1 + tests/CI/case.sh | 48 +++++++++---------- ...ar.sh => byol_r50_lp_in1k_1n8c_dp_fp32.sh} | 0 ...in.sh => byol_r50_pt_in1k_1n8c_dp_fp32.sh} | 0 ....sh => simclr_r50_lp_in1k_1n8c_dp_fp32.sh} | 0 ....sh => simclr_r50_pt_in1k_1n8c_dp_fp32.sh} | 0 ...sh => simsiam_r50_lp_in1k_1n8c_dp_fp32.sh} | 0 ...sh => simsiam_r50_pt_in1k_1n8c_dp_fp32.sh} | 0 ...ar.sh => swav_r50_lp_in1k_1n8c_dp_fp32.sh} | 0 ...in.sh => swav_r50_pt_in1k_1n8c_dp_fp32.sh} | 0 10 files changed, 25 insertions(+), 24 deletions(-) rename tests/CI/ssl/byol/{byol_r50_IM_linear.sh => byol_r50_lp_in1k_1n8c_dp_fp32.sh} (100%) rename tests/CI/ssl/byol/{byol_r50_IM_pretrain.sh => byol_r50_pt_in1k_1n8c_dp_fp32.sh} (100%) rename tests/CI/ssl/simclr/{simclr_r50_IM_linear.sh => simclr_r50_lp_in1k_1n8c_dp_fp32.sh} (100%) rename tests/CI/ssl/simclr/{simclr_r50_IM_pretrain.sh => simclr_r50_pt_in1k_1n8c_dp_fp32.sh} (100%) rename tests/CI/ssl/simsiam/{simsiam_r50_IM_linear.sh => simsiam_r50_lp_in1k_1n8c_dp_fp32.sh} (100%) rename tests/CI/ssl/simsiam/{simsiam_r50_IM_pretrain.sh => simsiam_r50_pt_in1k_1n8c_dp_fp32.sh} (100%) rename tests/CI/ssl/swav/{swav_r50_IM_linear.sh => swav_r50_lp_in1k_1n8c_dp_fp32.sh} (100%) rename tests/CI/ssl/swav/{swav_r50_IM_pretrain.sh => swav_r50_pt_in1k_1n8c_dp_fp32.sh} (100%) diff --git a/requirements.txt b/requirements.txt index b7317462..b047c537 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ pillow numpy easydict +einops scikit-image scipy requests diff --git a/tests/CI/case.sh b/tests/CI/case.sh index a2e6621a..50d79cc1 100644 --- a/tests/CI/case.sh +++ b/tests/CI/case.sh @@ -18,6 +18,14 @@ set -e export passl_path=/paddle/PASSL/tests/CI export log_path=/paddle/log_passl passl_gpu_model_list=( \ + simclr_r50_pt_in1k_1n8c_dp_fp32 \ + simclr_r50_lp_in1k_1n8c_dp_fp32 \ + byol_r50_pt_in1k_1n8c_dp_fp32 \ + byol_r50_lp_in1k_1n8c_dp_fp32 \ + simsiam_r50_pt_in1k_1n8c_dp_fp32 \ + simsiam_r50_lp_in1k_1n8c_dp_fp32 \ + swav_r50_pt_in1k_1n8c_dp_fp32 \ + swav_r50_lp_in1k_1n8c_dp_fp32 \ ViT_base_patch16_224_in1k_1n8c_dp_fp16o2 \ ViT_base_patch16_384_ft_in1k_1n8c_dp_fp16o2 \ DeiT_base_patch16_224_in1k_1n8c_dp_fp32 \ @@ -34,14 +42,6 @@ passl_gpu_model_list=( \ cae_base_patch16_224_pt_in1k_1n8c_dp_fp16o1 \ cae_base_patch16_224_ft_in1k_1n8c_dp_fp16o1 \ cae_base_patch16_224_lp_in1k_1n8c_dp_fp16o1 \ - simclr_r50_IM_pretrain \ - simclr_r50_IM_linear \ - byol_r50_IM_pretrain \ - byol_r50_IM_linear \ - simsiam_r50_IM_pretrain \ - simsiam_r50_IM_linear \ - swav_r50_IM_pretrain \ - swav_r50_IM_linear \ ) @@ -211,69 +211,69 @@ function cae_base_patch16_224_lp_in1k_1n8c_dp_fp16o1() { check_result 6.7196 ${loss} 1.07848 ${ips} $FUNCNAME } ####### SimCLR ###### -function simclr_r50_IM_pretrain(){ +function simclr_r50_pt_in1k_1n8c_dp_fp32(){ cd ${passl_path} rm -rf log - bash ./ssl/simclr/simclr_r50_IM_pretrain.sh + bash ./ssl/simclr/simclr_r50_pt_in1k_1n8c_dp_fp32.sh loss=`tail log/workerlog.0 | grep "50/5004" | cut -d " " -f17 ` check_result 2.8107e+01 ${loss%?} $FUNCNAME } -function simclr_r50_IM_linear(){ +function simclr_r50_lp_in1k_1n8c_dp_fp32(){ cd ${passl_path} rm -rf log - bash ./ssl/simclr/simclr_r50_IM_linear.sh + bash ./ssl/simclr/simclr_r50_lp_in1k_1n8c_dp_fp32.sh loss=`tail log/workerlog.0 | grep "50/312" | cut -d " " -f17 ` check_result 6.8498e+00 ${loss%?} $FUNCNAME } ###### BYOL ###### -function byol_r50_IM_pretrain(){ +function byol_r50_pt_in1k_1n8c_dp_fp32(){ cd ${passl_path} rm -rf log - bash ./ssl/byol/byol_r50_IM_pretrain.sh + bash ./ssl/byol/byol_r50_pt_in1k_1n8c_dp_fp32.sh loss=`tail log/workerlog.0 | grep "50/1251" | cut -d " " -f33 ` check_result 8.9050e+00 ${loss%?} $FUNCNAME } -function byol_r50_IM_linear(){ +function byol_r50_lp_in1k_1n8c_dp_fp32(){ cd ${passl_path} rm -rf log - bash ./ssl/byol/byol_r50_IM_linear.sh + bash ./ssl/byol/byol_r50_lp_in1k_1n8c_dp_fp32.sh loss=`tail log/workerlog.0 | grep "50/1252" | cut -d " " -f15 ` check_result 1.0264e+08 ${loss%?} $FUNCNAME } ###### SimSiam ###### -function simsiam_r50_IM_pretrain(){ +function simsiam_r50_pt_in1k_1n8c_dp_fp32(){ cd ${passl_path} rm -rf log - bash ./ssl/simsiam/simsiam_r50_IM_pretrain.sh + bash ./ssl/simsiam/simsiam_r50_pt_in1k_1n8c_dp_fp32.sh loss=`tail log/workerlog.0 | grep "50/2502" | cut -d " " -f17 ` check_result -3.9680e-01 ${loss%?} $FUNCNAME } -function simsiam_r50_IM_linear(){ +function simsiam_r50_lp_in1k_1n8c_dp_fp32(){ cd ${passl_path} rm -rf log - bash ./ssl/simsiam/simsiam_r50_IM_linear.sh + bash ./ssl/simsiam/simsiam_r50_lp_in1k_1n8c_dp_fp32.sh loss=`tail log/workerlog.0 | grep "50/312" | cut -d " " -f17 ` check_result 6.8936e+00 ${loss%?} $FUNCNAME } ###### SWAV ###### -function swav_r50_IM_pretrain(){ +function swav_r50_pt_in1k_1n8c_dp_fp32(){ cd ${passl_path} rm -rf log - bash ./ssl/swav/swav_r50_IM_pretrain.sh + bash ./ssl/swav/swav_r50_pt_in1k_1n8c_dp_fp32.sh loss=`tail log/workerlog.0 | grep "50/5004" | cut -d " " -f17 ` check_result 8.3500e+00 ${loss%?} $FUNCNAME } -function swav_r50_IM_linear(){ +function swav_r50_lp_in1k_1n8c_dp_fp32(){ cd ${passl_path} rm -rf log - bash ./ssl/swav/swav_r50_IM_linear.sh + bash ./ssl/swav/swav_r50_lp_in1k_1n8c_dp_fp32.sh loss=`tail log/workerlog.0 | grep "50/5004" | cut -d " " -f17 ` check_result 4.7808e+00 ${loss%?} $FUNCNAME } diff --git a/tests/CI/ssl/byol/byol_r50_IM_linear.sh b/tests/CI/ssl/byol/byol_r50_lp_in1k_1n8c_dp_fp32.sh similarity index 100% rename from tests/CI/ssl/byol/byol_r50_IM_linear.sh rename to tests/CI/ssl/byol/byol_r50_lp_in1k_1n8c_dp_fp32.sh diff --git a/tests/CI/ssl/byol/byol_r50_IM_pretrain.sh b/tests/CI/ssl/byol/byol_r50_pt_in1k_1n8c_dp_fp32.sh similarity index 100% rename from tests/CI/ssl/byol/byol_r50_IM_pretrain.sh rename to tests/CI/ssl/byol/byol_r50_pt_in1k_1n8c_dp_fp32.sh diff --git a/tests/CI/ssl/simclr/simclr_r50_IM_linear.sh b/tests/CI/ssl/simclr/simclr_r50_lp_in1k_1n8c_dp_fp32.sh similarity index 100% rename from tests/CI/ssl/simclr/simclr_r50_IM_linear.sh rename to tests/CI/ssl/simclr/simclr_r50_lp_in1k_1n8c_dp_fp32.sh diff --git a/tests/CI/ssl/simclr/simclr_r50_IM_pretrain.sh b/tests/CI/ssl/simclr/simclr_r50_pt_in1k_1n8c_dp_fp32.sh similarity index 100% rename from tests/CI/ssl/simclr/simclr_r50_IM_pretrain.sh rename to tests/CI/ssl/simclr/simclr_r50_pt_in1k_1n8c_dp_fp32.sh diff --git a/tests/CI/ssl/simsiam/simsiam_r50_IM_linear.sh b/tests/CI/ssl/simsiam/simsiam_r50_lp_in1k_1n8c_dp_fp32.sh similarity index 100% rename from tests/CI/ssl/simsiam/simsiam_r50_IM_linear.sh rename to tests/CI/ssl/simsiam/simsiam_r50_lp_in1k_1n8c_dp_fp32.sh diff --git a/tests/CI/ssl/simsiam/simsiam_r50_IM_pretrain.sh b/tests/CI/ssl/simsiam/simsiam_r50_pt_in1k_1n8c_dp_fp32.sh similarity index 100% rename from tests/CI/ssl/simsiam/simsiam_r50_IM_pretrain.sh rename to tests/CI/ssl/simsiam/simsiam_r50_pt_in1k_1n8c_dp_fp32.sh diff --git a/tests/CI/ssl/swav/swav_r50_IM_linear.sh b/tests/CI/ssl/swav/swav_r50_lp_in1k_1n8c_dp_fp32.sh similarity index 100% rename from tests/CI/ssl/swav/swav_r50_IM_linear.sh rename to tests/CI/ssl/swav/swav_r50_lp_in1k_1n8c_dp_fp32.sh diff --git a/tests/CI/ssl/swav/swav_r50_IM_pretrain.sh b/tests/CI/ssl/swav/swav_r50_pt_in1k_1n8c_dp_fp32.sh similarity index 100% rename from tests/CI/ssl/swav/swav_r50_IM_pretrain.sh rename to tests/CI/ssl/swav/swav_r50_pt_in1k_1n8c_dp_fp32.sh From 89deae208865ba388f37c3bb6907616d12ac2cdd Mon Sep 17 00:00:00 2001 From: zhaoqi10 Date: Wed, 12 Apr 2023 14:05:19 +0800 Subject: [PATCH 9/9] fix tests and trainer. --- passl_v110/datasets/builder.py | 7 ++++ passl_v110/engine/trainer.py | 5 +-- passl_v110/hooks/log_hook.py | 3 +- passl_v110/modeling/architectures/BYOL.py | 2 +- tests/CI/case.sh | 32 +++++++++---------- .../ssl/byol/byol_r50_lp_in1k_1n8c_dp_fp32.sh | 3 ++ .../ssl/byol/byol_r50_pt_in1k_1n8c_dp_fp32.sh | 5 ++- .../simclr/simclr_r50_lp_in1k_1n8c_dp_fp32.sh | 2 ++ .../simclr/simclr_r50_pt_in1k_1n8c_dp_fp32.sh | 5 ++- .../simsiam_r50_lp_in1k_1n8c_dp_fp32.sh | 2 ++ .../simsiam_r50_pt_in1k_1n8c_dp_fp32.sh | 1 + .../ssl/swav/swav_r50_lp_in1k_1n8c_dp_fp32.sh | 2 ++ .../ssl/swav/swav_r50_pt_in1k_1n8c_dp_fp32.sh | 4 ++- 13 files changed, 50 insertions(+), 23 deletions(-) diff --git a/passl_v110/datasets/builder.py b/passl_v110/datasets/builder.py index b4cb0687..04ea688c 100644 --- a/passl_v110/datasets/builder.py +++ b/passl_v110/datasets/builder.py @@ -16,6 +16,7 @@ import numpy as np import math import paddle +import random from paddle.io import DistributedBatchSampler from ..utils.registry import Registry, build_from_config @@ -92,11 +93,17 @@ def build_dataloader(cfg, device): sampler_name = sampler_cfg.pop('name', 'DistributedBatchSampler') + def worker_init_fn(worker_id): + """ set seed in subproces for dataloader when num_workers > 0""" + np.random.seed(cfg.seed + worker_id) + random.seed(cfg.seed + worker_id) + sampler = eval("{}".format(sampler_name))(dataset, **sampler_cfg) dataloader = paddle.io.DataLoader(dataset, batch_sampler=sampler, places=device, + worker_init_fn=worker_init_fn, **loader_cfg) #setup mixup / cutmix diff --git a/passl_v110/engine/trainer.py b/passl_v110/engine/trainer.py index 225889ff..aa5f5415 100644 --- a/passl_v110/engine/trainer.py +++ b/passl_v110/engine/trainer.py @@ -116,7 +116,7 @@ def __init__(self, cfg): self.log_interval = cfg.log_config.interval # set seed - seed = cfg.get('seed', False) + seed = cfg.get('seed', 2023) if seed: seed += dp_rank paddle.seed(seed) @@ -309,7 +309,8 @@ def train(self): self.inner_iter = self.current_iter % self.iters_per_epoch self.current_iter += 1 self.current_epoch = iter_loader.epoch - + if hasattr(self.train_dataloader.batch_sampler, "set_epoch"): + self.train_dataloader.batch_sampler.set_epoch(self.current_epoch) data = next(iter_loader) self.call_hook('train_iter_begin') diff --git a/passl_v110/hooks/log_hook.py b/passl_v110/hooks/log_hook.py index 51346e4b..dccfae06 100644 --- a/passl_v110/hooks/log_hook.py +++ b/passl_v110/hooks/log_hook.py @@ -92,7 +92,7 @@ def _log_info(self, log_dict, trainer): log_items.append(val) log_str += ', '.join(log_items) - + print(log_str) trainer.logger.info(log_str) def _round_float(self, items): @@ -151,6 +151,7 @@ def train_iter_end(self, trainer): trainer.logs[k].update(float(v)) if self.by_epoch and self.every_n_inner_iters(trainer, self.interval): + print('train_iter_end >>>>>>>>>>>>>.') self.print_log(trainer) def train_epoch_end(self, trainer): diff --git a/passl_v110/modeling/architectures/BYOL.py b/passl_v110/modeling/architectures/BYOL.py index d01abbb1..ffe13da5 100644 --- a/passl_v110/modeling/architectures/BYOL.py +++ b/passl_v110/modeling/architectures/BYOL.py @@ -54,7 +54,7 @@ def single_random_gaussian_blur(image, height, width, p=1.0): x = paddle.arange(-radius, radius + 1, 1, "float32") blur_filter = paddle.exp(-paddle.pow(x, 2.0) / (2.0 * paddle.pow(sigma, 2.0))) - blur_filter /= layers.nn.reduce_sum(blur_filter) + blur_filter /= paddle.sum(blur_filter) blur_v = paddle.reshape(blur_filter, [1, 1, kernel_size, 1]) blur_h = paddle.reshape(blur_filter, [1, 1, 1, kernel_size]) num_channels = 3 diff --git a/tests/CI/case.sh b/tests/CI/case.sh index 50d79cc1..bd81e69d 100644 --- a/tests/CI/case.sh +++ b/tests/CI/case.sh @@ -215,16 +215,16 @@ function simclr_r50_pt_in1k_1n8c_dp_fp32(){ cd ${passl_path} rm -rf log bash ./ssl/simclr/simclr_r50_pt_in1k_1n8c_dp_fp32.sh - loss=`tail log/workerlog.0 | grep "50/5004" | cut -d " " -f17 ` - check_result 2.8107e+01 ${loss%?} $FUNCNAME + loss=`tail log/workerlog.0 | grep "20/5004" | cut -d " " -f11 ` + check_result 2.8602e+01 ${loss} 0 0 $FUNCNAME } function simclr_r50_lp_in1k_1n8c_dp_fp32(){ cd ${passl_path} rm -rf log bash ./ssl/simclr/simclr_r50_lp_in1k_1n8c_dp_fp32.sh - loss=`tail log/workerlog.0 | grep "50/312" | cut -d " " -f17 ` - check_result 6.8498e+00 ${loss%?} $FUNCNAME + loss=`tail log/workerlog.0 | grep "50/312" | cut -d " " -f13 ` + check_result 6.8497e+00 ${loss} 0 0 $FUNCNAME } ###### BYOL ###### @@ -232,16 +232,16 @@ function byol_r50_pt_in1k_1n8c_dp_fp32(){ cd ${passl_path} rm -rf log bash ./ssl/byol/byol_r50_pt_in1k_1n8c_dp_fp32.sh - loss=`tail log/workerlog.0 | grep "50/1251" | cut -d " " -f33 ` - check_result 8.9050e+00 ${loss%?} $FUNCNAME + loss=`tail log/workerlog.0 | grep "20/1251" | cut -d " " -f27 ` + check_result 1.0883e+01 ${loss} 0 0 $FUNCNAME } function byol_r50_lp_in1k_1n8c_dp_fp32(){ cd ${passl_path} rm -rf log bash ./ssl/byol/byol_r50_lp_in1k_1n8c_dp_fp32.sh - loss=`tail log/workerlog.0 | grep "50/1252" | cut -d " " -f15 ` - check_result 1.0264e+08 ${loss%?} $FUNCNAME + loss=`tail log/workerlog.0 | grep "50/1252" | cut -d " " -f12 ` + check_result 8.6801e+07 ${loss} 0 0 $FUNCNAME } ###### SimSiam ###### @@ -249,16 +249,16 @@ function simsiam_r50_pt_in1k_1n8c_dp_fp32(){ cd ${passl_path} rm -rf log bash ./ssl/simsiam/simsiam_r50_pt_in1k_1n8c_dp_fp32.sh - loss=`tail log/workerlog.0 | grep "50/2502" | cut -d " " -f17 ` - check_result -3.9680e-01 ${loss%?} $FUNCNAME + loss=`tail log/workerlog.0 | grep "20/2502" | cut -d " " -f12 ` + check_result -3.9680e-01 ${loss} 0 0 $FUNCNAME } function simsiam_r50_lp_in1k_1n8c_dp_fp32(){ cd ${passl_path} rm -rf log bash ./ssl/simsiam/simsiam_r50_lp_in1k_1n8c_dp_fp32.sh - loss=`tail log/workerlog.0 | grep "50/312" | cut -d " " -f17 ` - check_result 6.8936e+00 ${loss%?} $FUNCNAME + loss=`tail log/workerlog.0 | grep "50/312" | cut -d " " -f14 ` + check_result 6.8936e+00 ${loss} 0 0 $FUNCNAME } ###### SWAV ###### @@ -266,16 +266,16 @@ function swav_r50_pt_in1k_1n8c_dp_fp32(){ cd ${passl_path} rm -rf log bash ./ssl/swav/swav_r50_pt_in1k_1n8c_dp_fp32.sh - loss=`tail log/workerlog.0 | grep "50/5004" | cut -d " " -f17 ` - check_result 8.3500e+00 ${loss%?} $FUNCNAME + loss=`tail log/workerlog.0 | grep "50/5004" | cut -d " " -f11 ` + check_result 8.3689e+00 ${loss} 0 0 $FUNCNAME } function swav_r50_lp_in1k_1n8c_dp_fp32(){ cd ${passl_path} rm -rf log bash ./ssl/swav/swav_r50_lp_in1k_1n8c_dp_fp32.sh - loss=`tail log/workerlog.0 | grep "50/5004" | cut -d " " -f17 ` - check_result 4.7808e+00 ${loss%?} $FUNCNAME + loss=`tail log/workerlog.0 | grep "50/5004" | cut -d " " -f11 ` + check_result 4.6752e+00 ${loss} 0 0 $FUNCNAME } diff --git a/tests/CI/ssl/byol/byol_r50_lp_in1k_1n8c_dp_fp32.sh b/tests/CI/ssl/byol/byol_r50_lp_in1k_1n8c_dp_fp32.sh index dcc03e7c..be74d61d 100644 --- a/tests/CI/ssl/byol/byol_r50_lp_in1k_1n8c_dp_fp32.sh +++ b/tests/CI/ssl/byol/byol_r50_lp_in1k_1n8c_dp_fp32.sh @@ -13,12 +13,15 @@ # limitations under the License. FLAGS_cudnn_exhaustive_search=0 +FLAGS_cudnn_deterministic=1 export PADDLE_NNODES=1 export PADDLE_MASTER="127.0.0.1:12538" export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m paddle.distributed.launch --devices=$CUDA_VISIBLE_DEVICES ../../tools_v110/train.py \ -c ../../configs/byol/byol_clas_r50.yaml \ -o total_iters=51 \ + -o seed=2023 \ --pretrain ./pretrained/byol/byol_r50_ext_backbone.pd \ + diff --git a/tests/CI/ssl/byol/byol_r50_pt_in1k_1n8c_dp_fp32.sh b/tests/CI/ssl/byol/byol_r50_pt_in1k_1n8c_dp_fp32.sh index 72aed5fc..8c9f2d0b 100644 --- a/tests/CI/ssl/byol/byol_r50_pt_in1k_1n8c_dp_fp32.sh +++ b/tests/CI/ssl/byol/byol_r50_pt_in1k_1n8c_dp_fp32.sh @@ -13,9 +13,12 @@ # limitations under the License. FLAGS_cudnn_exhaustive_search=0 +FLAGS_cudnn_deterministic=1 export PADDLE_NNODES=1 export PADDLE_MASTER="127.0.0.1:12538" export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m paddle.distributed.launch --devices=$CUDA_VISIBLE_DEVICES ../../tools_v110/train.py \ -c ../../configs/byol/byol_r50_IM.yaml \ - -o total_iters=51 + -o total_iters=51 \ + -o seed=2023 \ + -o dataloader.train.loader.num_workers=0 diff --git a/tests/CI/ssl/simclr/simclr_r50_lp_in1k_1n8c_dp_fp32.sh b/tests/CI/ssl/simclr/simclr_r50_lp_in1k_1n8c_dp_fp32.sh index a4c917e3..d01e5fe8 100644 --- a/tests/CI/ssl/simclr/simclr_r50_lp_in1k_1n8c_dp_fp32.sh +++ b/tests/CI/ssl/simclr/simclr_r50_lp_in1k_1n8c_dp_fp32.sh @@ -13,10 +13,12 @@ # limitations under the License. FLAGS_cudnn_exhaustive_search=0 +FLAGS_cudnn_deterministic=1 export PADDLE_NNODES=1 export PADDLE_MASTER="127.0.0.1:12538" export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m paddle.distributed.launch --devices=$CUDA_VISIBLE_DEVICES ../../tools_v110/train.py \ -c ../../configs/simclr/simclr_clas_r50.yaml \ -o total_iters=51 \ + -o seed=2023 \ --pretrain ./pretrained/simclr/simclr_r50_backbone.pd diff --git a/tests/CI/ssl/simclr/simclr_r50_pt_in1k_1n8c_dp_fp32.sh b/tests/CI/ssl/simclr/simclr_r50_pt_in1k_1n8c_dp_fp32.sh index e779f689..db4e452c 100644 --- a/tests/CI/ssl/simclr/simclr_r50_pt_in1k_1n8c_dp_fp32.sh +++ b/tests/CI/ssl/simclr/simclr_r50_pt_in1k_1n8c_dp_fp32.sh @@ -13,9 +13,12 @@ # limitations under the License. FLAGS_cudnn_exhaustive_search=0 +FLAGS_cudnn_deterministic=1 export PADDLE_NNODES=1 export PADDLE_MASTER="127.0.0.1:12538" export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m paddle.distributed.launch --devices=$CUDA_VISIBLE_DEVICES ../../tools_v110/train.py \ -c ../../configs/simclr/simclr_r50_IM.yaml \ - -o total_iters=51 + -o total_iters=51 \ + -o dataloader.train.dataset.dataroot=../../data/ILSVRC2012/val \ + -o dataloader.train.loader.num_workers=2 \ diff --git a/tests/CI/ssl/simsiam/simsiam_r50_lp_in1k_1n8c_dp_fp32.sh b/tests/CI/ssl/simsiam/simsiam_r50_lp_in1k_1n8c_dp_fp32.sh index f1677852..53860cac 100644 --- a/tests/CI/ssl/simsiam/simsiam_r50_lp_in1k_1n8c_dp_fp32.sh +++ b/tests/CI/ssl/simsiam/simsiam_r50_lp_in1k_1n8c_dp_fp32.sh @@ -13,10 +13,12 @@ # limitations under the License. FLAGS_cudnn_exhaustive_search=0 +FLAGS_cudnn_deterministic=1 export PADDLE_NNODES=1 export PADDLE_MASTER="127.0.0.1:12538" export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m paddle.distributed.launch --devices=$CUDA_VISIBLE_DEVICES ../../tools_v110/train.py \ -c ../../configs/simsiam/simsiam_clas_r50.yaml \ -o total_iters=51 \ + -o seed=2023 \ --pretrain ./pretrained/simsiam/simsiam_r50_ext_backbone.pd diff --git a/tests/CI/ssl/simsiam/simsiam_r50_pt_in1k_1n8c_dp_fp32.sh b/tests/CI/ssl/simsiam/simsiam_r50_pt_in1k_1n8c_dp_fp32.sh index 307c8a34..b67f5b2a 100644 --- a/tests/CI/ssl/simsiam/simsiam_r50_pt_in1k_1n8c_dp_fp32.sh +++ b/tests/CI/ssl/simsiam/simsiam_r50_pt_in1k_1n8c_dp_fp32.sh @@ -13,6 +13,7 @@ # limitations under the License. FLAGS_cudnn_exhaustive_search=0 +FLAGS_cudnn_deterministic=1 export PADDLE_NNODES=1 export PADDLE_MASTER="127.0.0.1:12538" export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 diff --git a/tests/CI/ssl/swav/swav_r50_lp_in1k_1n8c_dp_fp32.sh b/tests/CI/ssl/swav/swav_r50_lp_in1k_1n8c_dp_fp32.sh index 92ffb818..2d62b6c3 100644 --- a/tests/CI/ssl/swav/swav_r50_lp_in1k_1n8c_dp_fp32.sh +++ b/tests/CI/ssl/swav/swav_r50_lp_in1k_1n8c_dp_fp32.sh @@ -13,10 +13,12 @@ # limitations under the License. FLAGS_cudnn_exhaustive_search=0 +FLAGS_cudnn_deterministic=1 export PADDLE_NNODES=1 export PADDLE_MASTER="127.0.0.1:12538" export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m paddle.distributed.launch --devices=$CUDA_VISIBLE_DEVICES ../../tools_v110/train.py \ -c ../../configs/swav/swav_clas_r50.yaml \ -o total_iters=51 \ + -o seed=2023 \ --pretrain ./pretrained/swav/swav_r50_ext_backbone.pd diff --git a/tests/CI/ssl/swav/swav_r50_pt_in1k_1n8c_dp_fp32.sh b/tests/CI/ssl/swav/swav_r50_pt_in1k_1n8c_dp_fp32.sh index 094174c7..49d85b43 100644 --- a/tests/CI/ssl/swav/swav_r50_pt_in1k_1n8c_dp_fp32.sh +++ b/tests/CI/ssl/swav/swav_r50_pt_in1k_1n8c_dp_fp32.sh @@ -13,9 +13,11 @@ # limitations under the License. FLAGS_cudnn_exhaustive_search=0 +FLAGS_cudnn_deterministic=1 export PADDLE_NNODES=1 export PADDLE_MASTER="127.0.0.1:12538" export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m paddle.distributed.launch --devices=$CUDA_VISIBLE_DEVICES ../../tools_v110/train.py \ -c ../../configs/swav/swav_r50_100ep.yaml \ - -o total_iters=51 + -o total_iters=51 \ + -o seed=2023