From 612f059dd4f77489a23abb3054de8877b628d260 Mon Sep 17 00:00:00 2001 From: "vincent.rocher" Date: Fri, 3 Oct 2025 17:03:28 +0200 Subject: [PATCH] patching supervised loss weight because it is not passed to NeuralAdmixture class --- neural_admixture/model/train.py | 4 ++-- neural_admixture/src/main.py | 5 +++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/neural_admixture/model/train.py b/neural_admixture/model/train.py index a17ee77..673394d 100644 --- a/neural_admixture/model/train.py +++ b/neural_admixture/model/train.py @@ -18,7 +18,7 @@ def train(epochs: int, batch_size: int, learning_rate: float, K: int, seed: int, data: torch.Tensor, device: torch.device, num_gpus: int, hidden_size: int, - master: bool, V: np.ndarray, pops : np.ndarray, min_k: int=None, max_k: int=None, n_components: int=None) -> Tuple[torch.Tensor, torch.Tensor, torch.nn.Module]: + master: bool, V: np.ndarray, pops : np.ndarray, min_k: int=None, max_k: int=None, n_components: int=None, supervised_loss_weight: int=100) -> Tuple[torch.Tensor, torch.Tensor, torch.nn.Module]: """ Initializes P and Q matrices and trains a neural admixture model using GMM. @@ -128,7 +128,7 @@ def train(epochs: int, batch_size: int, learning_rate: float, K: int, seed: int, pack2bit = None packed_data = data - model = NeuralAdmixture(K, epochs, batch_size, learning_rate, device, seed, num_gpus, master, pack2bit, min_k, max_k) + model = NeuralAdmixture(K, epochs, batch_size, learning_rate, device, seed, num_gpus, master, pack2bit, min_k, max_k,supervised_loss_weight) Qs, Ps, model = model.launch_training(P_init, packed_data, hidden_size, V.shape[1], V, M, N, pops) if master: diff --git a/neural_admixture/src/main.py b/neural_admixture/src/main.py index 487e64e..070d52a 100644 --- a/neural_admixture/src/main.py +++ b/neural_admixture/src/main.py @@ -32,8 +32,9 @@ def fit_model(args: argparse.Namespace, data: torch.Tensor, device: torch.device min_k = int(args.min_k) max_k = int(args.max_k) K = None - - Ps, Qs, model = train(epochs, batch_size, learning_rate, K, seed, data, device, num_gpus, hidden_size, master, V, pops, min_k, max_k, n_components) + if args.supervised_loss_weight is None: + args.supervised_loss_weight = 100 + Ps, Qs, model = train(epochs, batch_size, learning_rate, K, seed, data, device, num_gpus, hidden_size, master, V, pops, min_k, max_k, n_components, args.supervised_loss_weight) if master: Path(save_dir).mkdir(parents=True, exist_ok=True)