From ae409549118d3bef8a8a2d8a71957067394f7fda Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Fri, 3 Nov 2023 13:54:07 -0400 Subject: [PATCH] add gpu initialization of fc layers --- axonn/intra_layer/fully_connected.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/axonn/intra_layer/fully_connected.py b/axonn/intra_layer/fully_connected.py index 5ca5002..f63f52a 100644 --- a/axonn/intra_layer/fully_connected.py +++ b/axonn/intra_layer/fully_connected.py @@ -11,13 +11,14 @@ def divide(a, b): assert a % b == 0 return a // b - +@torch.no_grad() def extract_local_params_from_full_params( - full_params, out_features_group, in_features_group, depth_group + params, out_features_group, in_features_group, depth_group ): - params = Drop.apply(torch.t(full_params).contiguous(), out_features_group) - params = torch.t(params).contiguous() + params = Drop.apply(params, in_features_group) + params = Drop.apply(torch.t(params).contiguous(), out_features_group) + params = torch.t(params).contiguous() params = Drop.apply(params.reshape(-1), depth_group) # create 1D view return params @@ -30,12 +31,13 @@ def initialize_params( in_features_group, depth_group, init_method, + init_device='cuda', ): - params = torch.empty((out_features, in_features)) + params = torch.empty((out_features, in_features), device=init_device) init_method(params) params = extract_local_params_from_full_params( params, out_features_group, in_features_group, depth_group - ) + ).cpu() return params