-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmy_lbfgsb.py
40 lines (34 loc) · 1.29 KB
/
my_lbfgsb.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
'''
import torch
from torch.optim.lbfgsb import LBFGSB
def my_L_BFGS_B(x_init, objective, low, high):
# x = x_init.clone().detach().requires_grad_(True)
optimizer = LBFGSB([x], lower_bound=low, upper_bound=high)
for step in range(20):
def closure():
if torch.is_grad_enabled():
optimizer.zero_grad() # Clear gradients
loss = objective(x) # Compute the loss
if loss.requires_grad:
loss.backward() # Compute gradients
return loss
optimizer.step(closure) # Perform one optimization step
loss = objective(x)
return x.detach(), loss.data.item()
'''
import torch
from torch.optim.lbfgsb import LBFGSB
def my_L_BFGS_B(x_init, objective, low, high):
x = x_init.clone().detach().requires_grad_(True)
optimizer = LBFGSB([x], lower_bound = low, upper_bound = high)
def closure():
if torch.is_grad_enabled():
optimizer.zero_grad() # Clear gradients
loss = objective(x) # Compute the loss
if loss.requires_grad:
loss.backward() # Compute gradients
return loss
for step in range(20):
optimizer.step(closure)
loss = closure().detach()
return x.detach(), loss