Skip to content

Bundle Adjustment on VGGT outputs #78

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
dsvilarkovic opened this issue Apr 1, 2025 · 8 comments
Open

Bundle Adjustment on VGGT outputs #78

dsvilarkovic opened this issue Apr 1, 2025 · 8 comments

Comments

@dsvilarkovic
Copy link

Hello,
I am trying to work with your model and was trying to figure out how I can create needed data for the global Bundle Adjustment as you do it in VGGSFM as suggested by this Issue comment. So far I understand I can create global pointcloud with predictions_to_glb function from vggs/visual_util.py, but I don't understand how to create tracks for init_BA or global_BA from VGGSFM's triangulation.py script?

I am ready to work towards building a Pull Request for bundle adjustment on VGGT repo if I get some help on this.

@jytime
Copy link
Contributor

jytime commented Apr 1, 2025

Hi @dsvilarkovic , thanks for your interest! In VGGSfM, we use Aliked/Sift/Superpoint to extract keypoints, and run tracking over those keypoints to get the tracks over other frames. In VGGT, you can conduct tracking as shown in Detailed Usage, e.g.,

    # Predict Tracks
    # choose your own points to track, with shape (N, 2) for one scene
    query_points = torch.FloatTensor([[100.0, 200.0], 
                                        [60.72, 259.94]]).to(device)
    track_list, vis_score, conf_score = model.track_head(aggregated_tokens_list, images, ps_idx, query_points=query_points[None])

(I noticed the released checkpoint has a lower tracking accuracy compared to our original version and working to solve it, but it should still give you some decent results)

@NeutrinoLiu
Copy link

NeutrinoLiu commented Apr 3, 2025

Hi @dsvilarkovic , thanks for your interest! In VGGSfM, we use Aliked/Sift/Superpoint to extract keypoints, and run tracking over those keypoints to get the tracks over other frames. In VGGT, you can conduct tracking as shown in Detailed Usage, e.g.,

    # Predict Tracks
    # choose your own points to track, with shape (N, 2) for one scene
    query_points = torch.FloatTensor([[100.0, 200.0], 
                                        [60.72, 259.94]]).to(device)
    track_list, vis_score, conf_score = model.track_head(aggregated_tokens_list, images, ps_idx, query_points=query_points[None])

(I noticed the released checkpoint has a lower tracking accuracy compared to our original version and working to solve it, but it should still give you some decent results)

hi may i know how keypoint sampling works when you perform BA, did you uniformly samples key points from the first frame and use their tracking info to perform BA or any other specific policy? Further how did you decide the number of keypoints (N)? (obviously its some sort of trade off: more points brings more accuracy yet might be slower.)

@jytime
Copy link
Contributor

jytime commented Apr 3, 2025

Hi @dsvilarkovic ,

We used Aliked to extract keypoints. Usually N=1024 is enough. If aiming for better accuracy, you could try N=2048. Higher number will not bring much more improvement, in our observations.

@dsvilarkovic
Copy link
Author

dsvilarkovic commented Apr 7, 2025

@jytime what do you use for inlier_masks and valid_tracks? Do you put all ones or you use some values from VGGT's output to do this?

I am building BA setup with the usage of globalBA from VGGSFM.

I extract keypoints using Hloc's Superpoint (1024 of them on the first image) and then I track keypoints with VGGT's tracker on S-1 other images. For filtering tracking i do this:

        keypoints_np = self.extract_keypoints(first_frame)
        keypoints = torch.from_numpy(keypoints_np).float()
        
        # 3. Track keypoints across all frames
        aggregated_tokens_list, ps_idx = self.model.aggregator(images)
        with torch.no_grad(), torch.cuda.amp.autocast(dtype=self.dtype):
            track_list, vis_score, conf_score = self.model.track_head(
                aggregated_tokens_list, 
                images, 
                ps_idx, 
                query_points=keypoints[None]
        )
        tracks = track_list[-1]
        
        # 4. Filter tracks based on visibility and confidence
        B, S, K, _ = tracks.shape  # Batch, Sequence, Keypoints
        
        print("Filtering keypoint tracks...")
        self.keypoint_conf_thresh = 0.2
        self.visibility_thresh = 0.2

        valid_track_indices = []
        for i in range(K):
            # Check if this keypoint is valid across all frames
            is_valid = [(conf_score[0, j, i] > self.keypoint_conf_thresh) and 
                        (vis_score[0, j, i] > self.visibility_thresh) 
                        for j in range(S)]
            
            # Only keep tracks that are valid in at least 60% of frames
            if sum(is_valid) >= S * 0.6:
                valid_track_indices.append(i)
        
        filtered_tracks = tracks[0, :, valid_track_indices]

Triangulated points are filtered sub 1024 keypoints that use predicted world pointcloud values from the VGGT's pointcloud predictor.

        # 5. Get 3D points for the tracks using the first frame's depth/point map
        world_points = predictions["world_points"].squeeze(0)
        world_points_conf = predictions["world_points_conf"].squeeze(0)
        
        # Get coordinates for the first frame
        first_frame_coords = filtered_tracks[0].long()  # Shape: [N, 2]
        
        # Extract 3D points at these coordinates
        y_coords = first_frame_coords[:, 1].clamp(0, world_points.shape[0]-1)
        x_coords = first_frame_coords[:, 0].clamp(0, world_points.shape[1]-1)

filtered_tracks is used as an input for pred_tracks and for inlier_mask for globalBA I just do all ones:

        inlier_mask = torch.ones((N, S), dtype=bool)
        valid_tracks = torch.ones(N, dtype=bool)

        points3D_opt, extrinsics_opt, intrinsics_opt, extra_params_opt, recon = global_BA(
                triangulated_points=triangulated_points, 
                valid_tracks=valid_tracks, 
                pred_tracks=filtered_tracks, 
                inlier_mask=inlier_mask,
                extrinsics=extrinsics, 
                intrinsics=intrinsics, 
                extra_params=None, 
                image_size=image_size,
                shared_camera=False, 
                camera_type="SIMPLE_PINHOLE"
            )

TLDR: How do you set values for inlier_mask and pred_tracks?

@jytime
Copy link
Contributor

jytime commented Apr 10, 2025

Hi @dsvilarkovic ,

In this case, the inlier_mask can be computed by the something like (conf_score>thres_1) & (vis_score>thres2). The pred_tracks can be constructed using the tracking head of VGGT, the tracking head of VGGSfM, or the tracks built by any other two-view matching methods.

@supertan0204
Copy link

Hi @dsvilarkovic any progress?

@dsvilarkovic
Copy link
Author

@supertan0204 not yet, I am bit busy with private work.

If you have some BA datasets that you want me to try it on, I would appreciate guidance.

@supertan0204
Copy link

supertan0204 commented Apr 21, 2025

@jytime what do you use for inlier_masks and valid_tracks? Do you put all ones or you use some values from VGGT's output to do this?

I am building BA setup with the usage of globalBA from VGGSFM.

I extract keypoints using Hloc's Superpoint (1024 of them on the first image) and then I track keypoints with VGGT's tracker on S-1 other images. For filtering tracking i do this:

        keypoints_np = self.extract_keypoints(first_frame)
        keypoints = torch.from_numpy(keypoints_np).float()
        
        # 3. Track keypoints across all frames
        aggregated_tokens_list, ps_idx = self.model.aggregator(images)
        with torch.no_grad(), torch.cuda.amp.autocast(dtype=self.dtype):
            track_list, vis_score, conf_score = self.model.track_head(
                aggregated_tokens_list, 
                images, 
                ps_idx, 
                query_points=keypoints[None]
        )
        tracks = track_list[-1]
        
        # 4. Filter tracks based on visibility and confidence
        B, S, K, _ = tracks.shape  # Batch, Sequence, Keypoints
        
        print("Filtering keypoint tracks...")
        self.keypoint_conf_thresh = 0.2
        self.visibility_thresh = 0.2

        valid_track_indices = []
        for i in range(K):
            # Check if this keypoint is valid across all frames
            is_valid = [(conf_score[0, j, i] > self.keypoint_conf_thresh) and 
                        (vis_score[0, j, i] > self.visibility_thresh) 
                        for j in range(S)]
            
            # Only keep tracks that are valid in at least 60% of frames
            if sum(is_valid) >= S * 0.6:
                valid_track_indices.append(i)
        
        filtered_tracks = tracks[0, :, valid_track_indices]

Triangulated points are filtered sub 1024 keypoints that use predicted world pointcloud values from the VGGT's pointcloud predictor.

        # 5. Get 3D points for the tracks using the first frame's depth/point map
        world_points = predictions["world_points"].squeeze(0)
        world_points_conf = predictions["world_points_conf"].squeeze(0)
        
        # Get coordinates for the first frame
        first_frame_coords = filtered_tracks[0].long()  # Shape: [N, 2]
        
        # Extract 3D points at these coordinates
        y_coords = first_frame_coords[:, 1].clamp(0, world_points.shape[0]-1)
        x_coords = first_frame_coords[:, 0].clamp(0, world_points.shape[1]-1)

filtered_tracks is used as an input for pred_tracks and for inlier_mask for globalBA I just do all ones:

        inlier_mask = torch.ones((N, S), dtype=bool)
        valid_tracks = torch.ones(N, dtype=bool)

        points3D_opt, extrinsics_opt, intrinsics_opt, extra_params_opt, recon = global_BA(
                triangulated_points=triangulated_points, 
                valid_tracks=valid_tracks, 
                pred_tracks=filtered_tracks, 
                inlier_mask=inlier_mask,
                extrinsics=extrinsics, 
                intrinsics=intrinsics, 
                extra_params=None, 
                image_size=image_size,
                shared_camera=False, 
                camera_type="SIMPLE_PINHOLE"
            )

TLDR: How do you set values for inlier_mask and pred_tracks?

Hi @dsvilarkovic it would be super helpful if you could kindly share how you performed the first two parts (i.e. # 1 and # 2 based on your comments) of the codes you shared :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants