From f2816dd0035860120777c186667b02b8f3de1811 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Mon, 6 Nov 2023 20:18:58 -0500 Subject: [PATCH] Initialize layers on the GPU (#51) --- axonn/intra_layer/fully_connected.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/axonn/intra_layer/fully_connected.py b/axonn/intra_layer/fully_connected.py index 5ca5002..040800f 100644 --- a/axonn/intra_layer/fully_connected.py +++ b/axonn/intra_layer/fully_connected.py @@ -12,12 +12,13 @@ def divide(a, b): return a // b +@torch.no_grad() def extract_local_params_from_full_params( - full_params, out_features_group, in_features_group, depth_group + params, out_features_group, in_features_group, depth_group ): - params = Drop.apply(torch.t(full_params).contiguous(), out_features_group) - params = torch.t(params).contiguous() params = Drop.apply(params, in_features_group) + params = Drop.apply(torch.t(params).contiguous(), out_features_group) + params = torch.t(params).contiguous() params = Drop.apply(params.reshape(-1), depth_group) # create 1D view return params @@ -30,12 +31,13 @@ def initialize_params( in_features_group, depth_group, init_method, + init_device="cuda", ): - params = torch.empty((out_features, in_features)) + params = torch.empty((out_features, in_features), device=init_device) init_method(params) params = extract_local_params_from_full_params( params, out_features_group, in_features_group, depth_group - ) + ).cpu() return params