Skip to content

Commit

Permalink
adapt training script
Browse files Browse the repository at this point in the history
  • Loading branch information
fabian-sp committed Mar 15, 2024
1 parent 5dcc855 commit 347f644
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 25 deletions.
Binary file modified data/checkpoints/max2d.pt
Binary file not shown.
38 changes: 17 additions & 21 deletions scripts/train_max_fun.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
import torch
from torch.optim.lr_scheduler import StepLR

from ncopt.functions.max_linear import MaxOfLinear

# %% Generate data

c1 = np.sqrt(2)
c2 = 2.0

Expand Down Expand Up @@ -44,28 +48,20 @@ def generate_data(grid_points):
num_samples = len(tX) # number of training points

# %%


class Max2D(torch.nn.Module):
def __init__(self):
super().__init__()
self.l1 = torch.nn.Linear(2, 2)

def forward(self, x):
x = self.l1(x)
x, _ = torch.max(x, dim=-1)
return x


loss_fn = torch.nn.MSELoss(reduction="mean")
model = Max2D()
# model = MaxOfLinear(params=(torch.tensor([[c1, 0], [0, c2]]),
# torch.tensor([-1., -1.])
# )
# )

model = MaxOfLinear(input_dim=2, output_dim=2)

print(model.l1.weight)
print(model.l1.bias)
print(model.linear.weight.data)
print(model.linear.bias.data)

# testing
x = torch.tensor([1.0, 4.0])
print("True value: ", g(x[0], x[1]), ". Predicted value: ", model(x).item())
x = torch.tensor([[1.0, 4.0]])
print("True value: ", g(x[0, 0], x[0, 1]), ". Predicted value: ", model(x)[0].item())


# %% Training
Expand All @@ -88,7 +84,7 @@ def sample_batch(num_samples, b):
for t in range(num_samples // batch_size):
S = sample_batch(num_samples, batch_size)
x_batch = tX[S]
z_batch = tZ[S]
z_batch = tZ[S][:, None] # dummy dimension to match model output

optimizer.zero_grad()

Expand All @@ -101,8 +97,8 @@ def sample_batch(num_samples, b):
scheduler.step()

print("Learned parameters:")
print(model.l1.weight)
print(model.l1.bias)
print(model.linear.weight.data)
print(model.linear.bias.data)


# %% Save checkpoint
Expand Down
15 changes: 11 additions & 4 deletions src/ncopt/functions/max_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,13 @@ class MaxOfLinear(torch.nn.Module):
where the maximum is taken over the components of Ax + b.
"""

def __init__(self, input_dim: int = None, output_dim: int = None, params: tuple = None):
def __init__(
self,
input_dim: int = None,
output_dim: int = None,
params: tuple = None,
dtype=torch.float32,
):
super().__init__()

assert params is not None or (
Expand All @@ -26,14 +32,15 @@ def __init__(self, input_dim: int = None, output_dim: int = None, params: tuple
self.linear = torch.nn.Linear(input_dim, output_dim)

# Set the weights if the mapping is given
# Default type of torch.nn.Linear is float32
if params is not None:
self.linear.weight.data = params[0]
self.linear.bias.data = params[1]
self.linear.weight.data = params[0].type(dtype)
self.linear.bias.data = params[1].type(dtype)

return

def forward(self, x):
x = self.linear(x)
# make sure to have output shape [batch_siez, 1] by keepdim=True
# make sure to have output shape [batch_size, 1] by keepdim=True
x, _ = torch.max(x, dim=-1, keepdim=True)
return x

0 comments on commit 347f644

Please sign in to comment.