-
Notifications
You must be signed in to change notification settings - Fork 11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ObjectFormer #41
Comments
@Inkyl Xuekang is responsible for this part. |
Sorry, the specific weights are not currently available, but I can provide you with a script to process and extract the relevant weights. import math
from typing import List, Optional
import torch
import timm
import torch.nn.functional as F
# Load a pre-trained Vision Transformer (ViT) model
model = timm.create_model('vit_base_patch16_224', pretrained=True)
def resample_abs_pos_embed(
posemb,
new_size: List[int],
old_size: Optional[List[int]] = None,
num_prefix_tokens: int = 1,
interpolation: str = 'bicubic',
antialias: bool = True,
verbose: bool = False,
):
# Determine the old and new sizes, assuming a square shape if old_size is not provided
num_pos_tokens = posemb.shape[1]
num_new_tokens = new_size[0] * new_size[1] + num_prefix_tokens
if num_new_tokens == num_pos_tokens and new_size[0] == new_size[1]:
return posemb
if old_size is None:
hw = int(math.sqrt(num_pos_tokens - num_prefix_tokens))
old_size = hw, hw
# Separate the prefix tokens if any exist
if num_prefix_tokens:
posemb_prefix, posemb = posemb[:, :num_prefix_tokens], posemb[:, num_prefix_tokens:]
else:
posemb_prefix, posemb = None, posemb
# Perform interpolation
embed_dim = posemb.shape[-1]
orig_dtype = posemb.dtype
posemb = posemb.float() # Convert to float32 for interpolation
posemb = posemb.reshape(1, old_size[0], old_size[1], -1).permute(0, 3, 1, 2)
posemb = F.interpolate(posemb, size=new_size, mode=interpolation, antialias=antialias)
posemb = posemb.permute(0, 2, 3, 1).reshape(1, -1, embed_dim)
posemb = posemb.to(orig_dtype)
# Concatenate back the prefix tokens if they were separated earlier
if posemb_prefix is not None:
posemb = torch.cat([posemb_prefix, posemb], dim=1)
return posemb
# Initialize a dictionary to store the processed weights
processed_state_dict = {}
# Extract and resample the positional embedding
pos_embed = model.state_dict()['pos_embed'][0][1::].unsqueeze(0)
pos_embed = resample_abs_pos_embed(pos_embed, [14, 28], num_prefix_tokens=0)
processed_state_dict['pos_embed'] = pos_embed
# Copy the patch embedding projection weights
processed_state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
processed_state_dict['patch_embed.proj.bias'] = model.state_dict()['patch_embed.proj.bias']
# Process and extract weights from the first 8 transformer blocks
for i in range(8): # Only process the first 8 blocks
block_prefix = f'blocks.{i}.'
# Extract norm1 weights and biases
processed_state_dict[f'{block_prefix}norm1.weight'] = model.state_dict()[f'{block_prefix}norm1.weight']
processed_state_dict[f'{block_prefix}norm1.bias'] = model.state_dict()[f'{block_prefix}norm1.bias']
# Split and extract q, k, v weights and biases from qkv
qkv_weight = model.state_dict()[f'{block_prefix}attn.qkv.weight']
qkv_bias = model.state_dict()[f'{block_prefix}attn.qkv.bias']
dim = qkv_weight.shape[0] // 3
processed_state_dict[f'{block_prefix}attn.q.weight'] = qkv_weight[:dim]
processed_state_dict[f'{block_prefix}attn.k.weight'] = qkv_weight[dim:2*dim]
processed_state_dict[f'{block_prefix}attn.v.weight'] = qkv_weight[2*dim:]
processed_state_dict[f'{block_prefix}attn.q.bias'] = qkv_bias[:dim]
processed_state_dict[f'{block_prefix}attn.k.bias'] = qkv_bias[dim:2*dim]
processed_state_dict[f'{block_prefix}attn.v.bias'] = qkv_bias[2*dim:]
# Extract the attention projection weights and biases
processed_state_dict[f'{block_prefix}attn.proj.weight'] = model.state_dict()[f'{block_prefix}attn.proj.weight']
processed_state_dict[f'{block_prefix}attn.proj.bias'] = model.state_dict()[f'{block_prefix}attn.proj.bias']
# Extract norm2 and MLP weights and biases
processed_state_dict[f'{block_prefix}norm2.weight'] = model.state_dict()[f'{block_prefix}norm2.weight']
processed_state_dict[f'{block_prefix}norm2.bias'] = model.state_dict()[f'{block_prefix}norm2.bias']
processed_state_dict[f'{block_prefix}mlp.fc1.weight'] = model.state_dict()[f'{block_prefix}mlp.fc1.weight']
processed_state_dict[f'{block_prefix}mlp.fc1.bias'] = model.state_dict()[f'{block_prefix}mlp.fc1.bias']
processed_state_dict[f'{block_prefix}mlp.fc2.weight'] = model.state_dict()[f'{block_prefix}mlp.fc2.weight']
processed_state_dict[f'{block_prefix}mlp.fc2.bias'] = model.state_dict()[f'{block_prefix}mlp.fc2.bias']
# Save the processed weights to a .pth file
torch.save(processed_state_dict, 'processed_model_weights.pth') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
您好,请问ObjectFormer的--init_weight_path object_former/processed_model_weights.pth 这个weight在哪里
谢谢回复
The text was updated successfully, but these errors were encountered: