Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sign descent seems to do better than adamw? #488

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,32 @@
# sign nanoGPT

only train.py is modified, ctrl+f "ayy"


adamw:
```
step 5000: train loss 0.6171, val loss 1.6973
iter 5000: loss 0.8138, time 31669.98ms, mfu 4.02%
```
sign descent:
```
step 5000: train loss 1.0883, val loss 1.4821
iter 5000: loss 1.1968, time 34425.73ms, mfu 3.27%
```

seed += 420:

adamw:
```
step 5000: train loss 0.6116, val loss 1.6995
iter 5000: loss 0.8267, time 33304.84ms, mfu 4.56%
```
sign descent:
```
step 5000: train loss 1.0963, val loss 1.4689
iter 5000: loss 1.2122, time 30964.94ms, mfu 4.05%
```
![sign_nanoGPT](signs.png)

# nanoGPT

Expand Down
Binary file added signs.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
19 changes: 17 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import torch
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
import torch.optim as optim

from model import GPTConfig, GPT

Expand Down Expand Up @@ -71,7 +72,7 @@
# system
device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks
dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler
compile = True # use PyTorch 2.0 to compile the model to be faster
compile = False # use PyTorch 2.0 to compile the model to be faster
# -----------------------------------------------------------------------------
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
exec(open('configurator.py').read()) # overrides from command line or config file
Expand Down Expand Up @@ -196,7 +197,8 @@ def get_batch(split):
scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))

# optimizer
optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2), device_type)
# optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2), device_type)
optimizer = optim.SGD(model.parameters(), lr=batch_size/10745088)
if init_from == 'resume':
optimizer.load_state_dict(checkpoint['optimizer'])
checkpoint = None # free up memory
Expand All @@ -211,6 +213,16 @@ def get_batch(split):
if ddp:
model = DDP(model, device_ids=[ddp_local_rank])

# ayy
total_count = 0
for parameter in model.parameters():
count = 1
for dimension in parameter.shape:
count *= dimension
total_count += count
print(parameter.shape, count)
print('param count', total_count) # 10745088

# helps estimate an arbitrarily accurate loss over either split using many batches
@torch.no_grad()
def estimate_loss():
Expand Down Expand Up @@ -307,6 +319,9 @@ def get_lr(it):
if grad_clip != 0.0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
# ayy
for p in model.parameters():
p.grad = torch.sign(p.grad) # or whatever other operation
# step the optimizer and scaler if training in fp16
scaler.step(optimizer)
scaler.update()
Expand Down