33
33
UserWarning ,
34
34
)
35
35
36
+ from instructlab .training .hpu_utils import is_torch_hpu_available
37
+
38
+ if is_torch_hpu_available ():
39
+ import habana_frameworks .torch .core as htcore
40
+ import habana_frameworks .torch .distributed .hccl
41
+ from optimum .habana .transformers .modeling_utils import adapt_transformers_to_gaudi
42
+ adapt_transformers_to_gaudi ()
43
+
36
44
# Third Party
37
45
from tqdm import tqdm
38
46
from transformers import AutoConfig
@@ -139,10 +147,19 @@ def train(
139
147
total_length = float (torch .tensor ([batch .pop ("total_length" )]))
140
148
if not args .use_dolomite :
141
149
for k in batch :
142
- batch [k ] = batch [k ].to (local_rank )
150
+ batch [k ] = batch [k ].to ('hpu' if is_torch_hpu_available () else local_rank )
151
+
152
+ hpu_args = {}
153
+ if is_torch_hpu_available ():
154
+ hpu_args = {
155
+ "use_flash_attention" :True ,
156
+ "lazy_mode" :False ,
157
+ }
158
+
143
159
output = model (
144
160
** batch ,
145
161
use_cache = False ,
162
+ ** hpu_args ,
146
163
)
147
164
loss = output .loss
148
165
log_loss = loss .detach ().item ()
@@ -179,8 +196,14 @@ def train(
179
196
elapsed_time = time .time () - start
180
197
overall_throughput = args .samples_per_gpu * world_size / elapsed_time
181
198
current_lr = accelerator .lr_scheduler .get_last_lr ()[0 ]
182
- cuda_mem_allocated = torch .cuda .memory_allocated () / (1024 ** 3 )
183
- cuda_malloc_retries = torch .cuda .memory_stats ()["num_alloc_retries" ]
199
+
200
+ if is_torch_hpu_available ():
201
+ mem_allocated = torch .hpu .memory_allocated () / (1024 ** 3 )
202
+ malloc_retries = 0
203
+ else :
204
+ mem_allocated = torch .cuda .memory_allocated () / (1024 ** 3 )
205
+ malloc_retries = torch .cuda .memory_stats ()["num_alloc_retries" ]
206
+
184
207
global_grad_norm = (
185
208
model .get_global_grad_norm ()
186
209
if hasattr (model , "get_global_grad_norm" )
@@ -202,8 +225,8 @@ def train(
202
225
"rank" : torch .distributed .get_rank (),
203
226
"overall_throughput" : overall_throughput ,
204
227
"lr" : current_lr ,
205
- "cuda_mem_allocated" : cuda_mem_allocated ,
206
- "cuda_malloc_retries" : cuda_malloc_retries ,
228
+ ( "hpu" if is_torch_hpu_available () else "cuda" ) + "_mem_allocated" : mem_allocated ,
229
+ ( "hpu" if is_torch_hpu_available () else "cuda" ) + "_malloc_retries" : malloc_retries ,
207
230
"num_loss_counted_tokens" : int (num_loss_counted_tokens ),
208
231
"num_tokens_rank0" : int (total_length ),
209
232
"batch_size" : int (micro_batch_size ),
@@ -236,7 +259,10 @@ def train(
236
259
global_step += 1
237
260
if local_rank == 0 :
238
261
inner_pb .update (1 )
239
- torch .cuda .empty_cache ()
262
+
263
+ if not is_torch_hpu_available ():
264
+ torch .cuda .empty_cache ()
265
+
240
266
if args .checkpoint_at_epoch :
241
267
base_logger .debug (f"Saving checkpoint at epoch { epoch } " )
242
268
save_checkpoint (
@@ -314,17 +340,24 @@ def main(args):
314
340
args .model_type = model_conf .model_type
315
341
316
342
#### distributed init #####
317
- torch .cuda .set_device (int (os .environ ["LOCAL_RANK" ]))
343
+ if is_torch_hpu_available ():
344
+ torch .hpu .set_device (int (os .environ ["LOCAL_RANK" ]))
345
+ else :
346
+ torch .cuda .set_device (int (os .environ ["LOCAL_RANK" ]))
347
+
318
348
args .local_rank = int (os .environ ["LOCAL_RANK" ])
319
349
320
350
timeout = _get_collective_timeout ()
321
- if timeout is not None :
322
- torch .distributed .init_process_group (timeout = timeout )
323
- else :
324
- torch .distributed .init_process_group ()
351
+ backend = "hccl" if is_torch_hpu_available () else None
352
+ torch .distributed .init_process_group (backend = backend , timeout = timeout )
325
353
326
354
args .global_rank = torch .distributed .get_rank ()
327
- tensor = torch .ByteTensor ([False ]).cuda ()
355
+
356
+ if is_torch_hpu_available ():
357
+ tensor = torch .ByteTensor ([False ]).to ('hpu' )
358
+ else :
359
+ tensor = torch .ByteTensor ([False ]).cuda ()
360
+
328
361
torch .distributed .all_reduce (tensor )
329
362
torch .distributed .barrier ()
330
363
0 commit comments