diff --git a/robin/train/train_mem.py b/robin/train/train_mem.py index dd1598c..b2d25c2 100644 --- a/robin/train/train_mem.py +++ b/robin/train/train_mem.py @@ -9,9 +9,12 @@ USE_FLASH_ATTN_2 = False if __name__ == "__main__": - hostname = os.environ.get('HOSTNAME') if os.environ.get('HOSTNAME') != None else os.uname()[1] - if hostname == None: hostname = os.environ.get('HOST_NAME') - + hostname = os.environ.get('HOST_NAME') + if hostname == None: + hostname = os.environ.get('HOSTNAME') + if hostname == None: + hostname = os.uname()[1] + print('Running on cluster:', end=' ') match hostname.lower(): case x if 'frontier' in x: