-
Notifications
You must be signed in to change notification settings - Fork 71
Description
Hi @ZENGXH ,
I was trying to fine tune the LION from the weights of unconditional/all55/checkpoints/epoch_10999_iters_2100999.pt by the config file unconditional/all55/cfg.yml you provide.
My basic idea is to freeze the weights of VAE encoder and decoder and only fine tune the two priors by imitating the behavior in train_2prior.py. I did the necessary preprocessing of the data points that I have used the pre-trained VAE to make sure that the input point clouds can be reconstructed.
However, the training does not goes well. and the final generated results by demo.py is like:
I attach the key components of the code I write here:
timestep # 1->1000
def gain_x_t(timesteps, noise, x0):
t_p, var_t_p, m_t_p = self.iw_quantities(timestep) # as in utils/diffusion_continuous.py
x_t = m_t_p * x0 + torch.sqrt(var_t_p) * noise
return t_p, x_t
x_start_obj_g, x_start_obj_l = LION.VAE.encode_obj(obj_points) # VAE is the pre-trained LION VAE
x_start_obj_g, x_start_obj_l = x_start_obj_g.detach(), x_start_obj_l.detach()
noise['obj_g'] = torch.rand_like(x_start_obj_g)
noise['obj_l'] = torch.rand_like(x_start_obj_l)
t_p, x_t_obj_g = gain_x_t(timestep, noise['obj_g'], x_start_obj_g)
t_p, x_t_obj_l = gain_x_t(timestep, noise['obj_l'], x_start_obj_l)
global_cond = LION.VAE.global2style(x_start_obj_g).detach()
pred_noise_g = LION.priors[0](x_t_obj_g, t_p, x0=None, clip_feat=None)
pred_noise_l = LION.priors[1](x_t_obj_l, t_p, x0=None, condition_input=global_cond, clip_feat=None)
loss_g = F.mse_loss(pred_noise_g.view(B,-1), noise['obj_g'].view(B,-1), reduction='mean')
loss_l = F.mse_loss(pred_noise_l.view(B,-1), noise['obj_l'].view(B,-1), reduction='mean')
I can't find an obvious error in on my side and the training losses seem good to me. However, as shown above, the fine-tuned model can't generate valid point clouds....
Also, the dataset I am using contains different object categories and I use no clip feature as the condition. I assumed this should be fine. But can you also confirm this? It would be great if you can share any idea! Thanks
