Skip to content

Commit

Permalink
Modified Dtype of Labels to support cross-entropy on cpu
Browse files Browse the repository at this point in the history
  • Loading branch information
gcervantes8 committed Dec 20, 2023
1 parent 9a51b12 commit 76ade85
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions src/models/gan_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,21 +75,21 @@ def create_real_labels(self, b_size, labels):
# Real_label is of size [b_size, num_classes]
real_label = torch.nn.functional.one_hot(labels, num_classes=self.num_classes)
# Adds 2 columns, One with all values 1, the second with all values 0. Size [b_size, num_classes + 2]
real_label = torch.concat((real_label, torch.ones(b_size, 1, device=self.device, dtype=labels.dtype),
torch.zeros(b_size, 1, device=self.device, dtype=labels.dtype)), dim=1)
real_label = torch.concat((real_label, torch.ones(b_size, 1, device=self.device, dtype=self.torch_dtype),
torch.zeros(b_size, 1, device=self.device, dtype=self.torch_dtype)), dim=1)
# Replace the 0s and 1s with value of the fake_label, or true_label
real_label[real_label == 0] = self.fake_label
real_label[real_label == 1] = self.real_label
else:
real_label = torch.full((b_size,), self.real_label, dtype=labels.dtype, device=self.device)
real_label = torch.full((b_size,), self.real_label, dtype=self.torch_dtype, device=self.device)
return real_label

def create_fake_labels(self, b_size, labels):
def create_fake_labels(self, b_size):
if self.is_omni_loss:
fake_label = torch.full((b_size, self.num_classes + 1), self.fake_label, dtype=labels.dtype, device=self.device)
fake_label = torch.concat((fake_label, torch.ones(b_size, 1, device=self.device, dtype=labels.dtype)), dim=1)
fake_label = torch.full((b_size, self.num_classes + 1), self.fake_label, dtype=self.torch_dtype, device=self.device)
fake_label = torch.concat((fake_label, torch.ones(b_size, 1, device=self.device, dtype=self.torch_dtypee)), dim=1)
else:
fake_label = torch.full((b_size,), self.fake_label, dtype=labels.dtype, device=self.device)
fake_label = torch.full((b_size,), self.fake_label, dtype=self.torch_dtype, device=self.device)
return fake_label

def train_step(self, batches_accumulated):
Expand Down Expand Up @@ -133,7 +133,7 @@ def discriminator_step(self, batches_accumulated):
fake_output = self.netD(fake.detach(), labels)

# Calculate D's loss on the all-fake batch
fake_label = self.create_fake_labels(b_size, labels)
fake_label = self.create_fake_labels(b_size)
discrim_on_fake_error = self.criterion(fake_output, fake_label.reshape_as(fake_output))
discrim_on_fake_error = discrim_on_fake_error / self.accumulation_iterations
self.accelerator.backward(discrim_on_fake_error)
Expand Down

0 comments on commit 76ade85

Please sign in to comment.