Skip to content

Commit

Permalink
update script
Browse files Browse the repository at this point in the history
  • Loading branch information
guochengqian committed Dec 1, 2023
1 parent f2f7532 commit d567fc6
Show file tree
Hide file tree
Showing 7 changed files with 789 additions and 27 deletions.
6 changes: 3 additions & 3 deletions cfgs/s3dis_pix4point/pix4point.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,19 @@ model:
add_pos_each_block: True
qkv_bias: True
act_args:
act: 'gelu' # better than relu
act: 'gelu'
norm_args:
norm: 'ln'
eps: 1.0e-6
embed_args:
NAME: P3Embed
feature_type: 'dp_df' # show an abaltion study of this.
feature_type: 'dp_df'
reduction: 'max'
sample_ratio: 0.0625
normalize_dp: False
group_size: 32
subsample: 'fps' # random, FPS
group: 'knn' # change it to group args.
group: 'knn'
conv_args:
order: conv-norm-act
layers: 4
Expand Down
6 changes: 3 additions & 3 deletions cfgs/s3dis_pix4point/pix4point_bert.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,19 @@ model:
add_pos_each_block: True
qkv_bias: True
act_args:
act: 'gelu' # better than relu
act: 'gelu'
norm_args:
norm: 'ln'
eps: 1.0e-6
embed_args:
NAME: P3Embed
feature_type: 'dp_df' # show an abaltion study of this.
feature_type: 'dp_df'
reduction: 'max'
sample_ratio: 0.0625
normalize_dp: False
group_size: 32
subsample: 'fps' # random, FPS
group: 'knn' # change it to group args.
group: 'knn'
conv_args:
order: conv-norm-act
layers: 4
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
model:
NAME: BaseSeg
encoder_args:
NAME: InvPointViT
NAME: PointViT
in_channels: 7
embed_dim: 384
depth: 12
Expand All @@ -13,19 +13,19 @@ model:
add_pos_each_block: True
qkv_bias: True
act_args:
act: 'gelu' # better than relu
act: 'gelu'
norm_args:
norm: 'ln'
eps: 1.0e-6
embed_args:
NAME: P3Embed
feature_type: 'dp_df' # show an abaltion study of this.
feature_type: 'dp_df'
reduction: 'max'
sample_ratio: 0.0625
normalize_dp: False
group_size: 32
subsample: 'fps' # random, FPS
group: 'knn' # change it to group args.
group: 'knn'
conv_args:
order: conv-norm-act
layers: 4
Expand All @@ -44,5 +44,5 @@ model:
norm_args:
norm: 'ln1d'

mode: finetune_encoder_inv
mode: finetune_encoder_freeze_blocks
pretrained_path: pretrained/imagenet/mae_s.pth
24 changes: 15 additions & 9 deletions examples/segmentation/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,17 +238,18 @@ def main(gpu, cfg):

val_miou, val_macc, val_oa, val_ious, val_accs = 0., 0., 0., [], []
best_val, macc_when_best, oa_when_best, ious_when_best, best_epoch = 0., 0., 0., [], 0
total_iter = 0
for epoch in range(cfg.start_epoch, cfg.epochs + 1):
if cfg.distributed:
train_loader.sampler.set_epoch(epoch)
if hasattr(train_loader.dataset, 'epoch'): # some dataset sets the dataset length as a fixed steps.
train_loader.dataset.epoch = epoch - 1
train_loss, train_miou, train_macc, train_oa, _, _ = \
train_one_epoch(model, train_loader, criterion, optimizer, scheduler, scaler, epoch, cfg)
train_loss, train_miou, train_macc, train_oa, _, _, total_iter = \
train_one_epoch(model, train_loader, criterion, optimizer, scheduler, scaler, epoch, total_iter, cfg)

is_best = False
if epoch % cfg.val_freq == 0:
val_miou, val_macc, val_oa, val_ious, val_accs = validate_fn(model, val_loader, cfg, epoch=epoch)
val_miou, val_macc, val_oa, val_ious, val_accs = validate_fn(model, val_loader, cfg, epoch=epoch, total_iter=total_iter)
if val_miou > best_val:
is_best = True
best_val = val_miou
Expand Down Expand Up @@ -337,7 +338,7 @@ def main(gpu, cfg):
wandb.finish(exit_code=True)


def train_one_epoch(model, train_loader, criterion, optimizer, scheduler, scaler, epoch, cfg):
def train_one_epoch(model, train_loader, criterion, optimizer, scheduler, scaler, epoch, total_iter, cfg):
loss_meter = AverageMeter()
cm = ConfusionMatrix(num_classes=cfg.num_classes, ignore_index=cfg.ignore_index)
model.train() # set model to training mode
Expand All @@ -356,6 +357,8 @@ def train_one_epoch(model, train_loader, criterion, optimizer, scheduler, scaler
end of debug """
data['x'] = get_features_by_keys(data, cfg.feature_keys)
data['epoch'] = epoch
total_iter += 1
data['iter'] = total_iter
with torch.cuda.amp.autocast(enabled=cfg.use_amp):
logits = model(data)
loss = criterion(logits, target) if 'mask' not in cfg.criterion_args.NAME.lower() \
Expand All @@ -380,7 +383,9 @@ def train_one_epoch(model, train_loader, criterion, optimizer, scheduler, scaler
optimizer.zero_grad()
if not cfg.sched_on_epoch:
scheduler.step(epoch)

# mem = torch.cuda.max_memory_allocated() / 1024. / 1024.
# print(f"Memory after backward is {mem}")

# update confusion matrix
cm.update(logits.argmax(dim=1), target)
loss_meter.update(loss.item())
Expand All @@ -389,11 +394,11 @@ def train_one_epoch(model, train_loader, criterion, optimizer, scheduler, scaler
pbar.set_description(f"Train Epoch [{epoch}/{cfg.epochs}] "
f"Loss {loss_meter.val:.3f} Acc {cm.overall_accuray:.2f}")
miou, macc, oa, ious, accs = cm.all_metrics()
return loss_meter.avg, miou, macc, oa, ious, accs
return loss_meter.avg, miou, macc, oa, ious, accs, total_iter


@torch.no_grad()
def validate(model, val_loader, cfg, num_votes=1, data_transform=None, epoch=-1):
def validate(model, val_loader, cfg, num_votes=1, data_transform=None, epoch=-1, total_iter=-1):
model.eval() # set model to eval mode
cm = ConfusionMatrix(num_classes=cfg.num_classes, ignore_index=cfg.ignore_index)
pbar = tqdm(enumerate(val_loader), total=val_loader.__len__(), desc='Val')
Expand All @@ -404,6 +409,7 @@ def validate(model, val_loader, cfg, num_votes=1, data_transform=None, epoch=-1)
target = data['y'].squeeze(-1)
data['x'] = get_features_by_keys(data, cfg.feature_keys)
data['epoch'] = epoch
data['iter'] = total_iter
logits = model(data)
if 'mask' not in cfg.criterion_args.NAME or cfg.get('use_maks', False):
cm.update(logits.argmax(dim=1), target)
Expand Down Expand Up @@ -438,7 +444,7 @@ def validate(model, val_loader, cfg, num_votes=1, data_transform=None, epoch=-1)


@torch.no_grad()
def validate_sphere(model, val_loader, cfg, num_votes=1, data_transform=None, epoch=-1):
def validate_sphere(model, val_loader, cfg, num_votes=1, data_transform=None, epoch=-1, total_iter=-1):
"""
validation for sphere sampled input points with mask.
in this case, between different batches, there are overlapped points.
Expand All @@ -460,6 +466,7 @@ def validate_sphere(model, val_loader, cfg, num_votes=1, data_transform=None, ep
data[key] = data[key].cuda(non_blocking=True)
data['x'] = get_features_by_keys(data, cfg.feature_keys)
data['epoch'] = epoch
data['iter'] = total_iter
logits = model(data)
all_logits.append(logits)
idx_points.append(data['input_inds'])
Expand Down Expand Up @@ -703,7 +710,6 @@ def test(model, data_list, cfg, num_votes=1):
cfg.mode,
cfg.cfg_basename, # cfg file name
f'ngpus{cfg.world_size}',
f'seed{cfg.seed}',
]
opt_list = [] # for checking experiment configs from logging file
for i, opt in enumerate(opts):
Expand Down
Loading

0 comments on commit d567fc6

Please sign in to comment.