diff --git a/galsim_hub/generative_model.py b/galsim_hub/generative_model.py index 6def703..4270149 100644 --- a/galsim_hub/generative_model.py +++ b/galsim_hub/generative_model.py @@ -69,10 +69,11 @@ def __init__(self, file_name=None): self.sample_req_params[k] = float def sample(self, cat, noise=None, rng=None, x_interpolant=None, k_interpolant=None, - pad_factor=4, noise_pad_size=0, gsparams=None, session_config=None): + pad_factor=4, noise_pad_size=0, gsparams=None, session_config=None, flux=None): """ Samples galaxy images from the model """ + fluxes = [None] * len(cat) if flux is None else flux # If we are sampling for the first time if self.module is None: self.module = hub.Module(self.file_name) @@ -121,7 +122,8 @@ def sample(self, cat, noise=None, rng=None, x_interpolant=None, k_interpolant=N noise_pad_size=noise_pad_size, noise_pad=noise, rng=rng, - gsparams=gsparams)) + gsparams=gsparams, + flux=fluxes[i])) if len(ims) == 1: ims = ims[0]