import torch |
import numpy as np |
from functools import reduce |
from copy import deepcopy |
from torch.optim import Optimizer |
def is_legal(v): |
""" |
Checks that tensor is not NaN or Inf. |
Inputs: |
v (tensor): tensor to be checked |
""" |
legal = not torch.isnan(v).any() and not torch.isinf(v) |
return legal |
def polyinterp(points, x_min_bound=None, x_max_bound=None, plot=False): |
""" |
Gives the minimizer and minimum of the interpolating polynomial over given points |
based on function and derivative information. Defaults to bisection if no critical |
points are valid. |
Based on polyinterp.m Matlab function in minFunc by Mark Schmidt with some slight |
modifications. |
Implemented by: Hao-Jun Michael Shi and Dheevatsa Mudigere |
Last edited 12/6/18. |
Inputs: |
points (nparray): two-dimensional array with each point of form [x f g] |
x_min_bound (float): minimum value that brackets minimum (default: minimum of points) |
x_max_bound (float): maximum value that brackets minimum (default: maximum of points) |
plot (bool): plot interpolating polynomial |
Outputs: |
x_sol (float): minimizer of interpolating polynomial |
F_min (float): minimum of interpolating polynomial |
Note: |
. Set f or g to np.nan if they are unknown |
""" |
no_points = points.shape[0] |
order = np.sum(1 - np.isnan(points[:,1:3]).astype('int')) - 1 |
x_min = np.min(points[:, 0]) |
x_max = np.max(points[:, 0]) |
if(x_min_bound is None): |
x_min_bound = x_min |
if(x_max_bound is None): |
x_max_bound = x_max |
if no_points == 2 and order == 2 and plot is False: |
if(points[0, 0] == 0): |
x_sol = -points[0, 2]*points[1, 0]**2/(2*(points[1, 1] - points[0, 1] - points[0, 2]*points[1, 0])) |
else: |
a = -(points[0, 1] - points[1, 1] - points[0, 2]*(points[0, 0] - points[1, 0]))/(points[0, 0] - points[1, 0])**2 |
x_sol = points[0, 0] - points[0, 2]/(2*a) |
x_sol = np.minimum(np.maximum(x_min_bound, x_sol), x_max_bound) |
elif no_points == 2 and order == 3 and plot is False: |
d1 = points[0, 2] + points[1, 2] - 3*((points[0, 1] - points[1, 1])/(points[0, 0] - points[1, 0])) |
d2 = np.sqrt(d1**2 - points[0, 2]*points[1, 2]) |
if np.isreal(d2): |
x_sol = points[1, 0] - (points[1, 0] - points[0, 0])*((points[1, 2] + d2 - d1)/(points[1, 2] - points[0, 2] + 2*d2)) |
x_sol = np.minimum(np.maximum(x_min_bound, x_sol), x_max_bound) |
else: |
x_sol = (x_max_bound + x_min_bound)/2 |
else: |
A = np.zeros((0, order+1)) |
b = np.zeros((0, 1)) |
for i in range(no_points): |
if not np.isnan(points[i, 1]): |
constraint = np.zeros((1, order+1)) |
for j in range(order, -1, -1): |
constraint[0, order - j] = points[i, 0]**j |
A = np.append(A, constraint, 0) |
b = np.append(b, points[i, 1]) |
for i in range(no_points): |
if not np.isnan(points[i, 2]): |
constraint = np.zeros((1, order+1)) |
for j in range(order): |
constraint[0, j] = (order-j)*points[i,0]**(order-j-1) |
A = np.append(A, constraint, 0) |
b = np.append(b, points[i, 2]) |
if(A.shape[0] != A.shape[1] or np.linalg.matrix_rank(A) != A.shape[0]): |
x_sol = (x_min_bound + x_max_bound)/2 |
f_min = np.Inf |
else: |
coeff = np.linalg.solve(A, b) |
dcoeff = np.zeros(order) |
for i in range(len(coeff) - 1): |
dcoeff[i] = coeff[i]*(order-i) |
crit_pts = np.array([x_min_bound, x_max_bound]) |
crit_pts = np.append(crit_pts, points[:, 0]) |
if not np.isinf(dcoeff).any(): |
roots = np.roots(dcoeff) |
crit_pts = np.append(crit_pts, roots) |
f_min = np.Inf |
x_sol = (x_min_bound + x_max_bound)/2 |
for crit_pt in crit_pts: |
if np.isreal(crit_pt) and crit_pt >= x_min_bound and crit_pt <= x_max_bound: |
F_cp = np.polyval(coeff, crit_pt) |
if np.isreal(F_cp) and F_cp < f_min: |
x_sol = np.real(crit_pt) |
f_min = np.real(F_cp) |
if (plot): |
import matplotlib.pyplot as plt |
plt.figure() |
x = np.arange(x_min_bound, x_max_bound, (x_max_bound - x_min_bound)/10000) |
f = np.polyval(coeff, x) |
plt.plot(x, f) |
plt.plot(x_sol, f_min, 'x') |
return x_sol |
class LBFGS(Optimizer): |
""" |
Implements the L-BFGS algorithm. Compatible with multi-batch and full-overlap |
L-BFGS implementations and (stochastic) Powell damping. Partly based on the |
original L-BFGS implementation in PyTorch, Mark Schmidt's minFunc MATLAB code, |
and Michael Overton's weak Wolfe line search MATLAB code. |
Implemented by: Hao-Jun Michael Shi and Dheevatsa Mudigere |
Last edited 12/6/18. |
Warnings: |
. Does not support per-parameter options and parameter groups. |
. All parameters have to be on a single device. |
Inputs: |
lr (float): steplength or learning rate (default: 1) |
history_size (int): update history size (default: 10) |
line_search (str): designates line search to use (default: 'Wolfe') |
Options: |
'None': uses steplength designated in algorithm |
'Armijo': uses Armijo backtracking line search |
'Wolfe': uses Armijo-Wolfe bracketing line search |
dtype: data type (default: torch.float) |
debug (bool): debugging mode |
References: |
[1] Berahas, Albert S., Jorge Nocedal, and Martin Takác. "A Multi-Batch L-BFGS |
Method for Machine Learning." Advances in Neural Information Processing |
Systems. 2016. |
[2] Bollapragada, Raghu, et al. "A Progressive Batching L-BFGS Method for Machine |
Learning." International Conference on Machine Learning. 2018. |
[3] Lewis, Adrian S., and Michael L. Overton. "Nonsmooth Optimization via Quasi-Newton |
Methods." Mathematical Programming 141.1-2 (2013): 135-163. |
[4] Liu, Dong C., and Jorge Nocedal. "On the Limited Memory BFGS Method for |
Large Scale Optimization." Mathematical Programming 45.1-3 (1989): 503-528. |
[5] Nocedal, Jorge. "Updating Quasi-Newton Matrices With Limited Storage." |
Mathematics of Computation 35.151 (1980): 773-782. |
[6] Nocedal, Jorge, and Stephen J. Wright. "Numerical Optimization." Springer New York, |
2006. |
[7] Schmidt, Mark. "minFunc: Unconstrained Differentiable Multivariate Optimization |
in Matlab." Software available at http://www.cs.ubc.ca/~schmidtm/Software/minFunc.html |
(2005). |
[8] Schraudolph, Nicol N., Jin Yu, and Simon Günter. "A Stochastic Quasi-Newton |
Method for Online Convex Optimization." Artificial Intelligence and Statistics. |
2007. |
[9] Wang, Xiao, et al. "Stochastic Quasi-Newton Methods for Nonconvex Stochastic |
Optimization." SIAM Journal on Optimization 27.2 (2017): 927-956. |
""" |
def __init__(self, params, lr=1, history_size=10, line_search='Wolfe', |
dtype=torch.float, debug=False): |
if not 0.0 <= lr: |
raise ValueError("Invalid learning rate: {}".format(lr)) |
if not 0 <= history_size: |
raise ValueError("Invalid history size: {}".format(history_size)) |
if line_search not in ['Armijo', 'Wolfe', 'None']: |
raise ValueError("Invalid line search: {}".format(line_search)) |
defaults = dict(lr=lr, history_size=history_size, line_search=line_search, |
dtype=dtype, debug=debug) |
super(LBFGS, self).__init__(params, defaults) |
if len(self.param_groups) != 1: |
raise ValueError("L-BFGS doesn't support per-parameter options " |
"(parameter groups)") |
self._params = self.param_groups[0]['params'] |
self._numel_cache = None |
state = self.state['global_state'] |
state.setdefault('n_iter', 0) |
state.setdefault('curv_skips', 0) |
state.setdefault('fail_skips', 0) |
state.setdefault('H_diag',1) |
state.setdefault('fail', True) |
state['old_dirs'] = [] |
state['old_stps'] = [] |
def _numel(self): |
if self._numel_cache is None: |
self._numel_cache = reduce(lambda total, p: total + p.numel(), self._params, 0) |
return self._numel_cache |
def _gather_flat_grad(self): |
views = [] |
for p in self._params: |
if p.grad is None: |
view = p.data.new(p.data.numel()).zero_() |
elif p.grad.data.is_sparse: |
view = p.grad.data.to_dense().view(-1) |
else: |
view = p.grad.data.view(-1) |
views.append(view) |
return torch.cat(views, 0) |
def _add_update(self, step_size, update): |
offset = 0 |
for p in self._params: |
numel = p.numel() |
p.data.add_(step_size, update[offset:offset + numel].view_as(p.data)) |
offset += numel |
assert offset == self._numel() |
def _copy_params(self): |
current_params = [] |
for param in self._params: |
current_params.append(deepcopy(param.data)) |
return current_params |
def _load_params(self, current_params): |
i = 0 |
for param in self._params: |
param.data[:] = current_params[i] |
i += 1 |
def line_search(self, line_search): |
""" |
Switches line search option. |
Inputs: |
line_search (str): designates line search to use |
Options: |
'None': uses steplength designated in algorithm |
'Armijo': uses Armijo backtracking line search |
'Wolfe': uses Armijo-Wolfe bracketing line search |
""" |
group = self.param_groups[0] |
group['line_search'] = line_search |
return |
def two_loop_recursion(self, vec): |
""" |
Performs two-loop recursion on given vector to obtain Hv. |
Inputs: |
vec (tensor): 1-D tensor to apply two-loop recursion to |
Output: |
r (tensor): matrix-vector product Hv |
""" |
group = self.param_groups[0] |
history_size = group['history_size'] |
state = self.state['global_state'] |
old_dirs = state.get('old_dirs') |
old_stps = state.get('old_stps') |
H_diag = state.get('H_diag') |
num_old = len(old_dirs) |
if 'rho' not in state: |
state['rho'] = [None] * history_size |
state['alpha'] = [None] * history_size |
rho = state['rho'] |
alpha = state['alpha'] |
for i in range(num_old): |
rho[i] = 1. / old_stps[i].dot(old_dirs[i]) |
q = vec |
for i in range(num_old - 1, -1, -1): |
alpha[i] = old_dirs[i].dot(q) * rho[i] |
q.add_(-alpha[i], old_stps[i]) |
r = torch.mul(q, H_diag) |
for i in range(num_old): |
beta = old_stps[i].dot(r) * rho[i] |
r.add_(alpha[i] - beta, old_dirs[i]) |
return r |
def curvature_update(self, flat_grad, eps=1e-2, damping=False): |
""" |
Performs curvature update. |
Inputs: |
flat_grad (tensor): 1-D tensor of flattened gradient for computing |
gradient difference with previously stored gradient |
eps (float): constant for curvature pair rejection or damping (default: 1e-2) |
damping (bool): flag for using Powell damping (default: False) |
""" |
assert len(self.param_groups) == 1 |
if(eps <= 0): |
raise(ValueError('Invalid eps; must be positive.')) |
group = self.param_groups[0] |
history_size = group['history_size'] |
debug = group['debug'] |
state = self.state['global_state'] |
fail = state.get('fail') |
if not fail: |
d = state.get('d') |
t = state.get('t') |
old_dirs = state.get('old_dirs') |
old_stps = state.get('old_stps') |
H_diag = state.get('H_diag') |
prev_flat_grad = state.get('prev_flat_grad') |
Bs = state.get('Bs') |
y = flat_grad.sub(prev_flat_grad) |
s = d.mul(t) |
sBs = s.dot(Bs) |
ys = y.dot(s) |
if ys > eps*sBs or damping == True: |
if damping == True and ys < eps*sBs: |
if debug: |
print('Applying Powell damping...') |
theta = ((1-eps)*sBs)/(sBs - ys) |
y = theta*y + (1-theta)*Bs |
if len(old_dirs) == history_size: |
old_dirs.pop(0) |
old_stps.pop(0) |
old_dirs.append(s) |
old_stps.append(y) |
H_diag = ys / y.dot(y) |
state['old_dirs'] = old_dirs |
state['old_stps'] = old_stps |
state['H_diag'] = H_diag |
else: |
state['curv_skips'] += 1 |
if debug: |
print('Curvature pair skipped due to failed criterion') |
else: |
state['fail_skips'] += 1 |
if debug: |
print('Line search failed; curvature pair update skipped') |
return |
def _step(self, p_k, g_Ok, g_Sk=None, options={}): |
""" |
Performs a single optimization step. |
Inputs: |
p_k (tensor): 1-D tensor specifying search direction |
g_Ok (tensor): 1-D tensor of flattened gradient over overlap O_k used |
for gradient differencing in curvature pair update |
g_Sk (tensor): 1-D tensor of flattened gradient over full sample S_k |
used for curvature pair damping or rejection criterion, |
if None, will use g_Ok (default: None) |
options (dict): contains options for performing line search |
Options for Armijo backtracking line search: |
'closure' (callable): reevaluates model and returns function value |
'current_loss' (tensor): objective value at current iterate (default: F(x_k)) |
'gtd' (tensor): inner product g_Ok'd in line search (default: g_Ok'd) |
'eta' (tensor): factor for decreasing steplength > 0 (default: 2) |
'c1' (tensor): sufficient decrease constant in (0, 1) (default: 1e-4) |
'max_ls' (int): maximum number of line search steps permitted (default: 10) |
'interpolate' (bool): flag for using interpolation (default: True) |
'inplace' (bool): flag for inplace operations (default: True) |
'ls_debug' (bool): debugging mode for line search |
Options for Wolfe line search: |
'closure' (callable): reevaluates model and returns function value |
'current_loss' (tensor): objective value at current iterate (default: F(x_k)) |
'gtd' (tensor): inner product g_Ok'd in line search (default: g_Ok'd) |
'eta' (float): factor for extrapolation (default: 2) |
'c1' (float): sufficient decrease constant in (0, 1) (default: 1e-4) |
'c2' (float): curvature condition constant in (0, 1) (default: 0.9) |
'max_ls' (int): maximum number of line search steps permitted (default: 10) |
'interpolate' (bool): flag for using interpolation (default: True) |
'inplace' (bool): flag for inplace operations (default: True) |
'ls_debug' (bool): debugging mode for line search |
Outputs (depends on line search): |
. No line search: |
t (float): steplength |
. Armijo backtracking line search: |
F_new (tensor): loss function at new iterate |
t (tensor): final steplength |
ls_step (int): number of backtracks |
closure_eval (int): number of closure evaluations |
desc_dir (bool): descent direction flag |
True: p_k is descent direction with respect to the line search |
function |
False: p_k is not a descent direction with respect to the line |
search function |
fail (bool): failure flag |
True: line search reached maximum number of iterations, failed |
False: line search succeeded |
. Wolfe line search: |
F_new (tensor): loss function at new iterate |
g_new (tensor): gradient at new iterate |
t (float): final steplength |
ls_step (int): number of backtracks |
closure_eval (int): number of closure evaluations |
grad_eval (int): number of gradient evaluations |
desc_dir (bool): descent direction flag |
True: p_k is descent direction with respect to the line search |
function |
False: p_k is not a descent direction with respect to the line |
search function |
fail (bool): failure flag |
True: line search reached maximum number of iterations, failed |
False: line search succeeded |
Notes: |
. If encountering line search failure in the deterministic setting, one |
should try increasing the maximum number of line search steps max_ls. |
""" |
assert len(self.param_groups) == 1 |
group = self.param_groups[0] |
lr = group['lr'] |
line_search = group['line_search'] |
dtype = group['dtype'] |
debug = group['debug'] |
state = self.state['global_state'] |
d = state.get('d') |
t = state.get('t') |
prev_flat_grad = state.get('prev_flat_grad') |
Bs = state.get('Bs') |
state['n_iter'] += 1 |
d = p_k |
if prev_flat_grad is None: |
prev_flat_grad = g_Ok.clone() |
else: |
prev_flat_grad.copy_(g_Ok) |
t = lr |
closure_eval = 0 |
if g_Sk is None: |
g_Sk = g_Ok.clone() |
if(line_search == 'Armijo'): |
if(options): |
if('closure' not in options.keys()): |
raise(ValueError('closure option not specified.')) |
else: |
closure = options['closure'] |
if('gtd' not in options.keys()): |
gtd = g_Ok.dot(d) |
else: |
gtd = options['gtd'] |
if('current_loss' not in options.keys()): |
F_k = closure() |
closure_eval += 1 |
else: |
F_k = options['current_loss'] |
if('eta' not in options.keys()): |
eta = 2 |
elif(options['eta'] <= 0): |
raise(ValueError('Invalid eta; must be positive.')) |
else: |
eta = options['eta'] |
if('c1' not in options.keys()): |
c1 = 1e-4 |
elif(options['c1'] >= 1 or options['c1'] <= 0): |
raise(ValueError('Invalid c1; must be strictly between 0 and 1.')) |
else: |
c1 = options['c1'] |
if('max_ls' not in options.keys()): |
max_ls = 10 |
elif(options['max_ls'] <= 0): |
raise(ValueError('Invalid max_ls; must be positive.')) |
else: |
max_ls = options['max_ls'] |
if('interpolate' not in options.keys()): |
interpolate = True |
else: |
interpolate = options['interpolate'] |
if('inplace' not in options.keys()): |
inplace = True |
else: |
inplace = options['inplace'] |
if('ls_debug' not in options.keys()): |
ls_debug = False |
else: |
ls_debug = options['ls_debug'] |
else: |
raise(ValueError('Options are not specified; need closure evaluating function.')) |
if(interpolate): |
if(torch.cuda.is_available()): |
F_prev = torch.tensor(np.nan, dtype=dtype).cuda() |
else: |
F_prev = torch.tensor(np.nan, dtype=dtype) |
ls_step = 0 |
t_prev = 0 |
fail = False |
if ls_debug: |
print('==================================== Begin Armijo line search ===================================') |
print('F(x): %.8e g*d: %.8e' %(F_k, gtd)) |
if gtd >= 0: |
desc_dir = False |
if debug: |
print('Not a descent direction!') |
else: |
desc_dir = True |
if not inplace: |
current_params = self._copy_params() |
self._add_update(t, d) |
F_new = closure() |
closure_eval += 1 |
if(ls_debug): |
print('LS Step: %d t: %.8e F(x+td): %.8e F-c1*t*g*d: %.8e F(x): %.8e' |
%(ls_step, t, F_new, F_k + c1*t*gtd, F_k)) |
while F_new > F_k + c1*t*gtd or not is_legal(F_new): |
if(ls_step >= max_ls): |
if inplace: |
self._add_update(-t, d) |
else: |
self._load_params(current_params) |
t = 0 |
F_new = closure() |
closure_eval += 1 |
fail = True |
break |
else: |
t_new = t |
if(ls_step == 0 or not interpolate or not is_legal(F_new)): |
t = t/eta |
elif(ls_step == 1 or not is_legal(F_prev)): |
t = polyinterp(np.array([[0, F_k.item(), gtd.item()], [t_new, F_new.item(), np.nan]])) |
else: |
t = polyinterp(np.array([[0, F_k.item(), gtd.item()], [t_new, F_new.item(), np.nan], |
[t_prev, F_prev.item(), np.nan]])) |
if(interpolate): |
if(t < 1e-3*t_new): |
t = 1e-3*t_new |
elif(t > 0.6*t_new): |
t = 0.6*t_new |
F_prev = F_new |
t_prev = t_new |
if inplace: |
self._add_update(t-t_new, d) |
else: |
self._load_params(current_params) |
self._add_update(t, d) |
F_new = closure() |
closure_eval += 1 |
ls_step += 1 |
if(ls_debug): |
print('LS Step: %d t: %.8e F(x+td): %.8e F-c1*t*g*d: %.8e F(x): %.8e' |
%(ls_step, t, F_new, F_k + c1*t*gtd, F_k)) |
if Bs is None: |
Bs = (g_Sk.mul(-t)).clone() |
else: |
Bs.copy_(g_Sk.mul(-t)) |
if ls_debug: |
print('Final Steplength:', t) |
print('===================================== End Armijo line search ====================================') |
state['d'] = d |
state['prev_flat_grad'] = prev_flat_grad |
state['t'] = t |
state['Bs'] = Bs |
state['fail'] = fail |
return F_new, t, ls_step, closure_eval, desc_dir, fail |
elif(line_search == 'Wolfe'): |
if(options): |
if('closure' not in options.keys()): |
raise(ValueError('closure option not specified.')) |
else: |
closure = options['closure'] |
if('current_loss' not in options.keys()): |
F_k = closure() |
closure_eval += 1 |
else: |
F_k = options['current_loss'] |
if('gtd' not in options.keys()): |
gtd = g_Ok.dot(d) |
else: |
gtd = options['gtd'] |
if('eta' not in options.keys()): |
eta = 2 |
elif(options['eta'] <= 1): |
raise(ValueError('Invalid eta; must be greater than 1.')) |
else: |
eta = options['eta'] |
if('c1' not in options.keys()): |
c1 = 1e-4 |
elif(options['c1'] >= 1 or options['c1'] <= 0): |
raise(ValueError('Invalid c1; must be strictly between 0 and 1.')) |
else: |
c1 = options['c1'] |
if('c2' not in options.keys()): |
c2 = 0.9 |
elif(options['c2'] >= 1 or options['c2'] <= 0): |
raise(ValueError('Invalid c2; must be strictly between 0 and 1.')) |
elif(options['c2'] <= c1): |
raise(ValueError('Invalid c2; must be strictly larger than c1.')) |
else: |
c2 = options['c2'] |
if('max_ls' not in options.keys()): |
max_ls = 10 |
elif(options['max_ls'] <= 0): |
raise(ValueError('Invalid max_ls; must be positive.')) |
else: |
max_ls = options['max_ls'] |
if('interpolate' not in options.keys()): |
interpolate = True |
else: |
interpolate = options['interpolate'] |
if('inplace' not in options.keys()): |
inplace = True |
else: |
inplace = options['inplace'] |
if('ls_debug' not in options.keys()): |
ls_debug = False |
else: |
ls_debug = options['ls_debug'] |
else: |
raise(ValueError('Options are not specified; need closure evaluating function.')) |
ls_step = 0 |
grad_eval = 0 |
t_prev = 0 |
alpha = 0 |
beta = float('Inf') |
fail = False |
if(interpolate): |
F_a = F_k |
g_a = gtd |
if(torch.cuda.is_available()): |
F_b = torch.tensor(np.nan, dtype=dtype).cuda() |
g_b = torch.tensor(np.nan, dtype=dtype).cuda() |
else: |
F_b = torch.tensor(np.nan, dtype=dtype) |
g_b = torch.tensor(np.nan, dtype=dtype) |
if ls_debug: |
print('==================================== Begin Wolfe line search ====================================') |
print('F(x): %.8e g*d: %.8e' %(F_k, gtd)) |
if gtd >= 0: |
desc_dir = False |
if debug: |
print('Not a descent direction!') |
else: |
desc_dir = True |
if not inplace: |
current_params = self._copy_params() |
self._add_update(t, d) |
F_new = closure() |
closure_eval += 1 |
while True: |
if(ls_step >= max_ls): |
if inplace: |
self._add_update(-t, d) |
else: |
self._load_params(current_params) |
t = 0 |
F_new = closure() |
F_new.backward() |
g_new = self._gather_flat_grad() |
closure_eval += 1 |
grad_eval += 1 |
fail = True |
break |
if(ls_debug): |
print('LS Step: %d t: %.8e alpha: %.8e beta: %.8e' |
%(ls_step, t, alpha, beta)) |
print('Armijo: F(x+td): %.8e F-c1*t*g*d: %.8e F(x): %.8e' |
%(F_new, F_k + c1*t*gtd, F_k)) |
if(F_new > F_k + c1*t*gtd): |
beta = t |
t_prev = t |
if(interpolate): |
F_b = F_new |
if(torch.cuda.is_available()): |
g_b = torch.tensor(np.nan, dtype=dtype).cuda() |
else: |
g_b = torch.tensor(np.nan, dtype=dtype) |
else: |
F_new.backward() |
g_new = self._gather_flat_grad() |
grad_eval += 1 |
gtd_new = g_new.dot(d) |
if(ls_debug): |
print('Wolfe: g(x+td)*d: %.8e c2*g*d: %.8e gtd: %.8e' |
%(gtd_new, c2*gtd, gtd)) |
if(gtd_new < c2*gtd): |
alpha = t |
t_prev = t |
if(interpolate): |
F_a = F_new |
g_a = gtd_new |
else: |
break |
if(not interpolate or not is_legal(F_b)): |
if(beta == float('Inf')): |
t = eta*t |
else: |
t = (alpha + beta)/2.0 |
else: |
t = polyinterp(np.array([[alpha, F_a.item(), g_a.item()],[beta, F_b.item(), g_b.item()]])) |
if(beta == float('Inf')): |
if(t > 2*eta*t_prev): |
t = 2*eta*t_prev |
elif(t < eta*t_prev): |
t = eta*t_prev |
else: |
if(t < alpha + 0.2*(beta - alpha)): |
t = alpha + 0.2*(beta - alpha) |
elif(t > (beta - alpha)/2.0): |
t = (beta - alpha)/2.0 |
if(t <= 0): |
t = (beta - alpha)/2.0 |
if inplace: |
self._add_update(t - t_prev, d) |
else: |
self._load_params(current_params) |
self._add_update(t, d) |
F_new = closure() |
closure_eval += 1 |
ls_step += 1 |
if Bs is None: |
Bs = (g_Sk.mul(-t)).clone() |
else: |
Bs.copy_(g_Sk.mul(-t)) |
if ls_debug: |
print('Final Steplength:', t) |
print('===================================== End Wolfe line search =====================================') |
state['d'] = d |
state['prev_flat_grad'] = prev_flat_grad |
state['t'] = t |
state['Bs'] = Bs |
state['fail'] = fail |
return F_new, g_new, t, ls_step, closure_eval, grad_eval, desc_dir, fail |
else: |
self._add_update(t, d) |
if Bs is None: |
Bs = (g_Sk.mul(-t)).clone() |
else: |
Bs.copy_(g_Sk.mul(-t)) |
state['d'] = d |
state['prev_flat_grad'] = prev_flat_grad |
state['t'] = t |
state['Bs'] = Bs |
state['fail'] = False |
return t |
def step(self, p_k, g_Ok, g_Sk=None, options={}): |
return self._step(p_k, g_Ok, g_Sk, options) |
class FullBatchLBFGS(LBFGS): |
""" |
Implements full-batch or deterministic L-BFGS algorithm. Compatible with |
Powell damping. Can be used when evaluating a deterministic function and |
gradient. Wraps the LBFGS optimizer. Performs the two-loop recursion, |
updating, and curvature updating in a single step. |
Implemented by: Hao-Jun Michael Shi and Dheevatsa Mudigere |
Last edited 11/15/18. |
Warnings: |
. Does not support per-parameter options and parameter groups. |
. All parameters have to be on a single device. |
Inputs: |
lr (float): steplength or learning rate (default: 1) |
history_size (int): update history size (default: 10) |
line_search (str): designates line search to use (default: 'Wolfe') |
Options: |
'None': uses steplength designated in algorithm |
'Armijo': uses Armijo backtracking line search |
'Wolfe': uses Armijo-Wolfe bracketing line search |
dtype: data type (default: torch.float) |
debug (bool): debugging mode |
""" |
def __init__(self, params, lr=1, history_size=10, line_search='Wolfe', |
dtype=torch.float, debug=False): |
super(FullBatchLBFGS, self).__init__(params, lr, history_size, line_search, |
dtype, debug) |
def step(self, options={}): |
""" |
Performs a single optimization step. |
Inputs: |
options (dict): contains options for performing line search |
General Options: |
'eps' (float): constant for curvature pair rejection or damping (default: 1e-2) |
'damping' (bool): flag for using Powell damping (default: False) |
Options for Armijo backtracking line search: |
'closure' (callable): reevaluates model and returns function value |
'current_loss' (tensor): objective value at current iterate (default: F(x_k)) |
'gtd' (tensor): inner product g_Ok'd in line search (default: g_Ok'd) |
'eta' (tensor): factor for decreasing steplength > 0 (default: 2) |
'c1' (tensor): sufficient decrease constant in (0, 1) (default: 1e-4) |
'max_ls' (int): maximum number of line search steps permitted (default: 10) |
'interpolate' (bool): flag for using interpolation (default: True) |
'inplace' (bool): flag for inplace operations (default: True) |
'ls_debug' (bool): debugging mode for line search |
Options for Wolfe line search: |
'closure' (callable): reevaluates model and returns function value |
'current_loss' (tensor): objective value at current iterate (default: F(x_k)) |
'gtd' (tensor): inner product g_Ok'd in line search (default: g_Ok'd) |
'eta' (float): factor for extrapolation (default: 2) |
'c1' (float): sufficient decrease constant in (0, 1) (default: 1e-4) |
'c2' (float): curvature condition constant in (0, 1) (default: 0.9) |
'max_ls' (int): maximum number of line search steps permitted (default: 10) |
'interpolate' (bool): flag for using interpolation (default: True) |
'inplace' (bool): flag for inplace operations (default: True) |
'ls_debug' (bool): debugging mode for line search |
Outputs (depends on line search): |
. No line search: |
t (float): steplength |
. Armijo backtracking line search: |
F_new (tensor): loss function at new iterate |
t (tensor): final steplength |
ls_step (int): number of backtracks |
closure_eval (int): number of closure evaluations |
desc_dir (bool): descent direction flag |
True: p_k is descent direction with respect to the line search |
function |
False: p_k is not a descent direction with respect to the line |
search function |
fail (bool): failure flag |
True: line search reached maximum number of iterations, failed |
False: line search succeeded |
. Wolfe line search: |
F_new (tensor): loss function at new iterate |
g_new (tensor): gradient at new iterate |
t (float): final steplength |
ls_step (int): number of backtracks |
closure_eval (int): number of closure evaluations |
grad_eval (int): number of gradient evaluations |
desc_dir (bool): descent direction flag |
True: p_k is descent direction with respect to the line search |
function |
False: p_k is not a descent direction with respect to the line |
search function |
fail (bool): failure flag |
True: line search reached maximum number of iterations, failed |
False: line search succeeded |
Notes: |
. If encountering line search failure in the deterministic setting, one |
should try increasing the maximum number of line search steps max_ls. |
""" |
if('damping' not in options.keys()): |
damping = False |
else: |
damping = options['damping'] |
if('eps' not in options.keys()): |
eps = 1e-2 |
else: |
eps = options['eps'] |
grad = self._gather_flat_grad() |
state = self.state['global_state'] |
if(state['n_iter'] > 0): |
self.curvature_update(grad, eps, damping) |
p = self.two_loop_recursion(-grad) |
return self._step(p, grad, options=options) |