Skip to content

Commit

Permalink
add gpu initialization of fc layers
Browse files Browse the repository at this point in the history
  • Loading branch information
siddharth9820 committed Nov 3, 2023
1 parent f8773ee commit ae40954
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions axonn/intra_layer/fully_connected.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@ def divide(a, b):
assert a % b == 0
return a // b


@torch.no_grad()
def extract_local_params_from_full_params(
full_params, out_features_group, in_features_group, depth_group
params, out_features_group, in_features_group, depth_group
):
params = Drop.apply(torch.t(full_params).contiguous(), out_features_group)
params = torch.t(params).contiguous()

params = Drop.apply(params, in_features_group)
params = Drop.apply(torch.t(params).contiguous(), out_features_group)
params = torch.t(params).contiguous()
params = Drop.apply(params.reshape(-1), depth_group) # create 1D view
return params

Expand All @@ -30,12 +31,13 @@ def initialize_params(
in_features_group,
depth_group,
init_method,
init_device='cuda',
):
params = torch.empty((out_features, in_features))
params = torch.empty((out_features, in_features), device=init_device)
init_method(params)
params = extract_local_params_from_full_params(
params, out_features_group, in_features_group, depth_group
)
).cpu()
return params


Expand Down

0 comments on commit ae40954

Please sign in to comment.