Skip to content
This repository has been archived by the owner on Nov 29, 2023. It is now read-only.

Commit

Permalink
Refactor unnecessary else / elif when if block has a return s…
Browse files Browse the repository at this point in the history
…tatement (#114)

Co-authored-by: deepsource-autofix[bot] <62050782+deepsource-autofix[bot]@users.noreply.github.com>
  • Loading branch information
deepsource-autofix[bot] authored Jan 4, 2022
1 parent 943bfa4 commit 4831362
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 25 deletions.
6 changes: 2 additions & 4 deletions satflow/models/cloudgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,8 +245,7 @@ def training_step(self, batch, batch_idx, optimizer_idx):
images, future_images = batch
if self.condition_time:
return self.train_per_timestep(images, future_images, optimizer_idx, batch_idx)
else:
return self.train_all_timestep(images, future_images, optimizer_idx, batch_idx)
return self.train_all_timestep(images, future_images, optimizer_idx, batch_idx)

def val_all_timestep(self, images, future_images, batch_idx):
# generate images
Expand Down Expand Up @@ -328,8 +327,7 @@ def validation_step(self, batch, batch_idx):
images, future_images = batch
if self.condition_time:
return self.val_per_timestep(images, future_images, batch_idx)
else:
return self.val_all_timestep(images, future_images, batch_idx)
return self.val_all_timestep(images, future_images, batch_idx)

def forward(self, x, **kwargs):
return self.generator.forward(x, **kwargs)
Expand Down
3 changes: 1 addition & 2 deletions satflow/models/gan/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,5 +130,4 @@ def cal_gradient_penalty(
((gradients + 1e-16).norm(2, dim=1) - constant) ** 2
).mean() * lambda_gp # added eps
return gradient_penalty, gradients
else:
return 0.0, None
return 0.0, None
3 changes: 1 addition & 2 deletions satflow/models/gan/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,5 +427,4 @@ def __init__(
def forward(self, x):
if self.outermost:
return self.model(x)
else: # add skip connections
return torch.cat([x, self.model(x)], 1)
return torch.cat([x, self.model(x)], 1)
9 changes: 4 additions & 5 deletions satflow/models/layers/TimeDistributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,10 @@ def forward(self, *tensors, **kwargs):
"input x with shape:(bs,seq_len,channels,width,height)"
if self.low_mem or self.tdim != 1:
return self.low_mem_forward(*tensors, **kwargs)
else:
# only support tdim=1
inp_shape = tensors[0].shape
bs, seq_len = inp_shape[0], inp_shape[1]
out = self.module(*[x.view(bs * seq_len, *x.shape[2:]) for x in tensors], **kwargs)
# only support tdim=1
inp_shape = tensors[0].shape
bs, seq_len = inp_shape[0], inp_shape[1]
out = self.module(*[x.view(bs * seq_len, *x.shape[2:]) for x in tensors], **kwargs)
return self.format_output(out, bs, seq_len)

def low_mem_forward(self, *tensors, **kwargs):
Expand Down
22 changes: 10 additions & 12 deletions satflow/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,18 @@ def reverse_space_to_depth(
dh=spatial_block_size,
dw=spatial_block_size,
)
elif len(frames.shape) == 5:
if len(frames.shape) == 5:
return einops.rearrange(
frames,
"b t h w (dt dh dw c) -> b (t dt) (h dh) (w dw) c",
dt=temporal_block_size,
dh=spatial_block_size,
dw=spatial_block_size,
)
else:
raise ValueError(
"Frames should be of rank 4 (batch, height, width, channels)"
" or rank 5 (batch, time, height, width, channels)"
)
raise ValueError(
"Frames should be of rank 4 (batch, height, width, channels)"
" or rank 5 (batch, time, height, width, channels)"
)


def space_to_depth(
Expand All @@ -57,16 +56,15 @@ def space_to_depth(
dh=spatial_block_size,
dw=spatial_block_size,
)
elif len(frames.shape) == 5:
if len(frames.shape) == 5:
return einops.rearrange(
frames,
"b (t dt) (h dh) (w dw) c -> b t h w (dt dh dw c)",
dt=temporal_block_size,
dh=spatial_block_size,
dw=spatial_block_size,
)
else:
raise ValueError(
"Frames should be of rank 4 (batch, height, width, channels)"
" or rank 5 (batch, time, height, width, channels)"
)
raise ValueError(
"Frames should be of rank 4 (batch, height, width, channels)"
" or rank 5 (batch, time, height, width, channels)"
)

0 comments on commit 4831362

Please sign in to comment.