diff --git a/dgl_ptm/dgl_ptm/model/initialize_model.py b/dgl_ptm/dgl_ptm/model/initialize_model.py index f037cfb..db9a959 100644 --- a/dgl_ptm/dgl_ptm/model/initialize_model.py +++ b/dgl_ptm/dgl_ptm/model/initialize_model.py @@ -66,7 +66,7 @@ def sample_distribution_tensor(type, dist_parameters, n_samples, round=False, de cdf_min = (1 + torch.erf(trunc_val_min / torch.sqrt(torch.tensor(2.0))))/2 cdf_max = (1 + torch.erf(trunc_val_max / torch.sqrt(torch.tensor(2.0))))/2 - uniform_samples = torch.rand(size) + uniform_samples = torch.rand(n_samples) inverse_transform = torch.erfinv( 2 *(cdf_min + (cdf_max - cdf_min) * uniform_samples) - 1 )