From 347f644a5800bb293d9d09013c9720e53a26c65d Mon Sep 17 00:00:00 2001 From: Fabian Schaipp Date: Fri, 15 Mar 2024 12:40:02 +0100 Subject: [PATCH] adapt training script --- data/checkpoints/max2d.pt | Bin 2368 -> 2368 bytes scripts/train_max_fun.py | 38 +++++++++++++----------------- src/ncopt/functions/max_linear.py | 15 ++++++++---- 3 files changed, 28 insertions(+), 25 deletions(-) diff --git a/data/checkpoints/max2d.pt b/data/checkpoints/max2d.pt index 8a0b45670fed4a386338f4375783b0c1fd9f4b5c..f6f37cc891bd4a1c4371dd8ff564e23ea7fdde64 100644 GIT binary patch delta 328 zcmX>gbU>S&Uy<8&zG+JP?Kb!32Di$|^-v9OcTYwlOtM*&CV)6x+0$Hh? zxNV!u_iq>Nc(z?aL(CK;$+&g$L?o@OZUWXU^UXir+F%ZnJ+#j~Wb*}97DgEZQ%hr0 zGjk&&V+%_|Q$s@wLt_Jw3RM^Jg2^A)Bsk%|m@LTd%NRR3o?VFr}D)rC+r6@ bK-NwE31omwp1hx3e6klu6bsmp6&y1Fc=&6u delta 321 zcmX>gbU;6*x13W$r_CLY)la5LdG|Y0+aig4s%_;y*`PN znSp_kd2#@oj5t4o6GLudg;9!rQe{bMeo;zlk(-kp!{mv~`jZzh^G`m&+#vgyWvl&_ zPyA-CIxkJ1{Qqws;LXmF`*caU0MKlK$^LAzldD+V1SbBk-`@{K zp}C={xrvFXxv`0*iHW&|5lF?v#)EE?Kd?z~!o4zCklmLtaB@7m5(~&(lUv!%Sinx$ e4`hI>oBR{V0GT{_KfCy3FODb{upuirW&i;8@NY)| diff --git a/scripts/train_max_fun.py b/scripts/train_max_fun.py index cab709c..1286434 100755 --- a/scripts/train_max_fun.py +++ b/scripts/train_max_fun.py @@ -11,6 +11,10 @@ import torch from torch.optim.lr_scheduler import StepLR +from ncopt.functions.max_linear import MaxOfLinear + +# %% Generate data + c1 = np.sqrt(2) c2 = 2.0 @@ -44,28 +48,20 @@ def generate_data(grid_points): num_samples = len(tX) # number of training points # %% - - -class Max2D(torch.nn.Module): - def __init__(self): - super().__init__() - self.l1 = torch.nn.Linear(2, 2) - - def forward(self, x): - x = self.l1(x) - x, _ = torch.max(x, dim=-1) - return x - - loss_fn = torch.nn.MSELoss(reduction="mean") -model = Max2D() +# model = MaxOfLinear(params=(torch.tensor([[c1, 0], [0, c2]]), +# torch.tensor([-1., -1.]) +# ) +# ) + +model = MaxOfLinear(input_dim=2, output_dim=2) -print(model.l1.weight) -print(model.l1.bias) +print(model.linear.weight.data) +print(model.linear.bias.data) # testing -x = torch.tensor([1.0, 4.0]) -print("True value: ", g(x[0], x[1]), ". Predicted value: ", model(x).item()) +x = torch.tensor([[1.0, 4.0]]) +print("True value: ", g(x[0, 0], x[0, 1]), ". Predicted value: ", model(x)[0].item()) # %% Training @@ -88,7 +84,7 @@ def sample_batch(num_samples, b): for t in range(num_samples // batch_size): S = sample_batch(num_samples, batch_size) x_batch = tX[S] - z_batch = tZ[S] + z_batch = tZ[S][:, None] # dummy dimension to match model output optimizer.zero_grad() @@ -101,8 +97,8 @@ def sample_batch(num_samples, b): scheduler.step() print("Learned parameters:") -print(model.l1.weight) -print(model.l1.bias) +print(model.linear.weight.data) +print(model.linear.bias.data) # %% Save checkpoint diff --git a/src/ncopt/functions/max_linear.py b/src/ncopt/functions/max_linear.py index 60e25a4..c2e985f 100644 --- a/src/ncopt/functions/max_linear.py +++ b/src/ncopt/functions/max_linear.py @@ -9,7 +9,13 @@ class MaxOfLinear(torch.nn.Module): where the maximum is taken over the components of Ax + b. """ - def __init__(self, input_dim: int = None, output_dim: int = None, params: tuple = None): + def __init__( + self, + input_dim: int = None, + output_dim: int = None, + params: tuple = None, + dtype=torch.float32, + ): super().__init__() assert params is not None or ( @@ -26,14 +32,15 @@ def __init__(self, input_dim: int = None, output_dim: int = None, params: tuple self.linear = torch.nn.Linear(input_dim, output_dim) # Set the weights if the mapping is given + # Default type of torch.nn.Linear is float32 if params is not None: - self.linear.weight.data = params[0] - self.linear.bias.data = params[1] + self.linear.weight.data = params[0].type(dtype) + self.linear.bias.data = params[1].type(dtype) return def forward(self, x): x = self.linear(x) - # make sure to have output shape [batch_siez, 1] by keepdim=True + # make sure to have output shape [batch_size, 1] by keepdim=True x, _ = torch.max(x, dim=-1, keepdim=True) return x