Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ObjectFormer #41

Open
shunlibiyela opened this issue Oct 8, 2024 · 2 comments
Open

ObjectFormer #41

shunlibiyela opened this issue Oct 8, 2024 · 2 comments
Assignees

Comments

@shunlibiyela
Copy link

您好,请问ObjectFormer的--init_weight_path object_former/processed_model_weights.pth 这个weight在哪里
谢谢回复

@SunnyHaze
Copy link
Contributor

@Inkyl Xuekang is responsible for this part.

@Inkyl
Copy link
Contributor

Inkyl commented Oct 8, 2024

您好,请问ObjectFormer的--init_weight_path object_former/processed_model_weights.pth 这个weight在哪里 谢谢回复

Sorry, the specific weights are not currently available, but I can provide you with a script to process and extract the relevant weights.

import math
from typing import List, Optional
import torch
import timm
import torch.nn.functional as F

# Load a pre-trained Vision Transformer (ViT) model
model = timm.create_model('vit_base_patch16_224', pretrained=True)

def resample_abs_pos_embed(
        posemb,
        new_size: List[int],
        old_size: Optional[List[int]] = None,
        num_prefix_tokens: int = 1,
        interpolation: str = 'bicubic',
        antialias: bool = True,
        verbose: bool = False,
):
    # Determine the old and new sizes, assuming a square shape if old_size is not provided
    num_pos_tokens = posemb.shape[1]
    num_new_tokens = new_size[0] * new_size[1] + num_prefix_tokens
    if num_new_tokens == num_pos_tokens and new_size[0] == new_size[1]:
        return posemb

    if old_size is None:
        hw = int(math.sqrt(num_pos_tokens - num_prefix_tokens))
        old_size = hw, hw

    # Separate the prefix tokens if any exist
    if num_prefix_tokens:
        posemb_prefix, posemb = posemb[:, :num_prefix_tokens], posemb[:, num_prefix_tokens:]
    else:
        posemb_prefix, posemb = None, posemb

    # Perform interpolation
    embed_dim = posemb.shape[-1]
    orig_dtype = posemb.dtype
    posemb = posemb.float()  # Convert to float32 for interpolation
    posemb = posemb.reshape(1, old_size[0], old_size[1], -1).permute(0, 3, 1, 2)
    posemb = F.interpolate(posemb, size=new_size, mode=interpolation, antialias=antialias)
    posemb = posemb.permute(0, 2, 3, 1).reshape(1, -1, embed_dim)
    posemb = posemb.to(orig_dtype)

    # Concatenate back the prefix tokens if they were separated earlier
    if posemb_prefix is not None:
        posemb = torch.cat([posemb_prefix, posemb], dim=1)

    return posemb


# Initialize a dictionary to store the processed weights
processed_state_dict = {}

# Extract and resample the positional embedding
pos_embed = model.state_dict()['pos_embed'][0][1::].unsqueeze(0)
pos_embed = resample_abs_pos_embed(pos_embed, [14, 28], num_prefix_tokens=0)
processed_state_dict['pos_embed'] = pos_embed

# Copy the patch embedding projection weights
processed_state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
processed_state_dict['patch_embed.proj.bias'] = model.state_dict()['patch_embed.proj.bias']

# Process and extract weights from the first 8 transformer blocks
for i in range(8):  # Only process the first 8 blocks
    block_prefix = f'blocks.{i}.'

    # Extract norm1 weights and biases
    processed_state_dict[f'{block_prefix}norm1.weight'] = model.state_dict()[f'{block_prefix}norm1.weight']
    processed_state_dict[f'{block_prefix}norm1.bias'] = model.state_dict()[f'{block_prefix}norm1.bias']
    
    # Split and extract q, k, v weights and biases from qkv
    qkv_weight = model.state_dict()[f'{block_prefix}attn.qkv.weight']
    qkv_bias = model.state_dict()[f'{block_prefix}attn.qkv.bias']
    dim = qkv_weight.shape[0] // 3
    processed_state_dict[f'{block_prefix}attn.q.weight'] = qkv_weight[:dim]
    processed_state_dict[f'{block_prefix}attn.k.weight'] = qkv_weight[dim:2*dim]
    processed_state_dict[f'{block_prefix}attn.v.weight'] = qkv_weight[2*dim:]
    processed_state_dict[f'{block_prefix}attn.q.bias'] = qkv_bias[:dim]
    processed_state_dict[f'{block_prefix}attn.k.bias'] = qkv_bias[dim:2*dim]
    processed_state_dict[f'{block_prefix}attn.v.bias'] = qkv_bias[2*dim:]

    # Extract the attention projection weights and biases
    processed_state_dict[f'{block_prefix}attn.proj.weight'] = model.state_dict()[f'{block_prefix}attn.proj.weight']
    processed_state_dict[f'{block_prefix}attn.proj.bias'] = model.state_dict()[f'{block_prefix}attn.proj.bias']
    
    # Extract norm2 and MLP weights and biases
    processed_state_dict[f'{block_prefix}norm2.weight'] = model.state_dict()[f'{block_prefix}norm2.weight']
    processed_state_dict[f'{block_prefix}norm2.bias'] = model.state_dict()[f'{block_prefix}norm2.bias']
    processed_state_dict[f'{block_prefix}mlp.fc1.weight'] = model.state_dict()[f'{block_prefix}mlp.fc1.weight']
    processed_state_dict[f'{block_prefix}mlp.fc1.bias'] = model.state_dict()[f'{block_prefix}mlp.fc1.bias']
    processed_state_dict[f'{block_prefix}mlp.fc2.weight'] = model.state_dict()[f'{block_prefix}mlp.fc2.weight']
    processed_state_dict[f'{block_prefix}mlp.fc2.bias'] = model.state_dict()[f'{block_prefix}mlp.fc2.bias']

# Save the processed weights to a .pth file
torch.save(processed_state_dict, 'processed_model_weights.pth')

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants