diff --git a/src/models/gan_model.py b/src/models/gan_model.py index 14c1f9d..0631f89 100644 --- a/src/models/gan_model.py +++ b/src/models/gan_model.py @@ -75,21 +75,21 @@ def create_real_labels(self, b_size, labels): # Real_label is of size [b_size, num_classes] real_label = torch.nn.functional.one_hot(labels, num_classes=self.num_classes) # Adds 2 columns, One with all values 1, the second with all values 0. Size [b_size, num_classes + 2] - real_label = torch.concat((real_label, torch.ones(b_size, 1, device=self.device, dtype=labels.dtype), - torch.zeros(b_size, 1, device=self.device, dtype=labels.dtype)), dim=1) + real_label = torch.concat((real_label, torch.ones(b_size, 1, device=self.device, dtype=self.torch_dtype), + torch.zeros(b_size, 1, device=self.device, dtype=self.torch_dtype)), dim=1) # Replace the 0s and 1s with value of the fake_label, or true_label real_label[real_label == 0] = self.fake_label real_label[real_label == 1] = self.real_label else: - real_label = torch.full((b_size,), self.real_label, dtype=labels.dtype, device=self.device) + real_label = torch.full((b_size,), self.real_label, dtype=self.torch_dtype, device=self.device) return real_label - def create_fake_labels(self, b_size, labels): + def create_fake_labels(self, b_size): if self.is_omni_loss: - fake_label = torch.full((b_size, self.num_classes + 1), self.fake_label, dtype=labels.dtype, device=self.device) - fake_label = torch.concat((fake_label, torch.ones(b_size, 1, device=self.device, dtype=labels.dtype)), dim=1) + fake_label = torch.full((b_size, self.num_classes + 1), self.fake_label, dtype=self.torch_dtype, device=self.device) + fake_label = torch.concat((fake_label, torch.ones(b_size, 1, device=self.device, dtype=self.torch_dtypee)), dim=1) else: - fake_label = torch.full((b_size,), self.fake_label, dtype=labels.dtype, device=self.device) + fake_label = torch.full((b_size,), self.fake_label, dtype=self.torch_dtype, device=self.device) return fake_label def train_step(self, batches_accumulated): @@ -133,7 +133,7 @@ def discriminator_step(self, batches_accumulated): fake_output = self.netD(fake.detach(), labels) # Calculate D's loss on the all-fake batch - fake_label = self.create_fake_labels(b_size, labels) + fake_label = self.create_fake_labels(b_size) discrim_on_fake_error = self.criterion(fake_output, fake_label.reshape_as(fake_output)) discrim_on_fake_error = discrim_on_fake_error / self.accumulation_iterations self.accelerator.backward(discrim_on_fake_error)