Skip to content

Commit

Permalink
Initialize layers on the GPU (#51)
Browse files Browse the repository at this point in the history
  • Loading branch information
siddharth9820 authored Nov 7, 2023
1 parent f8773ee commit f2816dd
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions axonn/intra_layer/fully_connected.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@ def divide(a, b):
return a // b


@torch.no_grad()
def extract_local_params_from_full_params(
full_params, out_features_group, in_features_group, depth_group
params, out_features_group, in_features_group, depth_group
):
params = Drop.apply(torch.t(full_params).contiguous(), out_features_group)
params = torch.t(params).contiguous()
params = Drop.apply(params, in_features_group)
params = Drop.apply(torch.t(params).contiguous(), out_features_group)
params = torch.t(params).contiguous()
params = Drop.apply(params.reshape(-1), depth_group) # create 1D view
return params

Expand All @@ -30,12 +31,13 @@ def initialize_params(
in_features_group,
depth_group,
init_method,
init_device="cuda",
):
params = torch.empty((out_features, in_features))
params = torch.empty((out_features, in_features), device=init_device)
init_method(params)
params = extract_local_params_from_full_params(
params, out_features_group, in_features_group, depth_group
)
).cpu()
return params


Expand Down

0 comments on commit f2816dd

Please sign in to comment.