Skip to content

Fine-tune with a new object point cloud dataset #68

@noahcao

Description

@noahcao

Hi @ZENGXH ,

I was trying to fine tune the LION from the weights of unconditional/all55/checkpoints/epoch_10999_iters_2100999.pt by the config file unconditional/all55/cfg.yml you provide.

My basic idea is to freeze the weights of VAE encoder and decoder and only fine tune the two priors by imitating the behavior in train_2prior.py. I did the necessary preprocessing of the data points that I have used the pre-trained VAE to make sure that the input point clouds can be reconstructed.

However, the training does not goes well. and the final generated results by demo.py is like:

Screenshot 2024-03-25 at 19 05 40

I attach the key components of the code I write here:

timestep # 1->1000

def gain_x_t(timesteps, noise, x0):
  t_p, var_t_p, m_t_p = self.iw_quantities(timestep)  # as in utils/diffusion_continuous.py
  x_t = m_t_p * x0 + torch.sqrt(var_t_p) * noise
  return t_p, x_t

x_start_obj_g, x_start_obj_l = LION.VAE.encode_obj(obj_points) # VAE is the pre-trained LION VAE
x_start_obj_g, x_start_obj_l = x_start_obj_g.detach(), x_start_obj_l.detach()
noise['obj_g'] = torch.rand_like(x_start_obj_g)
noise['obj_l'] = torch.rand_like(x_start_obj_l)

t_p, x_t_obj_g = gain_x_t(timestep, noise['obj_g'], x_start_obj_g)
t_p, x_t_obj_l = gain_x_t(timestep, noise['obj_l'], x_start_obj_l)

global_cond = LION.VAE.global2style(x_start_obj_g).detach()
pred_noise_g = LION.priors[0](x_t_obj_g, t_p, x0=None, clip_feat=None)
pred_noise_l = LION.priors[1](x_t_obj_l, t_p, x0=None, condition_input=global_cond, clip_feat=None)

loss_g = F.mse_loss(pred_noise_g.view(B,-1), noise['obj_g'].view(B,-1), reduction='mean')
loss_l = F.mse_loss(pred_noise_l.view(B,-1), noise['obj_l'].view(B,-1), reduction='mean')

I can't find an obvious error in on my side and the training losses seem good to me. However, as shown above, the fine-tuned model can't generate valid point clouds....

Screenshot 2024-03-25 at 19 17 33

Also, the dataset I am using contains different object categories and I use no clip feature as the condition. I assumed this should be fine. But can you also confirm this? It would be great if you can share any idea! Thanks

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions