Skip to content

Commit

Permalink
Using Adam optimizer from optax for optimization. Results are signifi…
Browse files Browse the repository at this point in the history
…cantly improved over previous simplistic gradient descent.
  • Loading branch information
sludtke42 committed Dec 12, 2024
1 parent 0f742d1 commit 3e9124d
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 26 deletions.
10 changes: 5 additions & 5 deletions libpyEM/EMAN3jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,18 +323,18 @@ def __len__(self): return len(self._data)
def shape(self):
# note that the returned shape is N,Z,Y,X regardless of representation
if isinstance(self._data,list): return(np.array((len(self._data),self._data[0]["nz"],self._data[0]["ny"],self._data[0]["nx"])))
return(self._data.shape)
return(np.array(self._data.shape))

def center_clip(self,size):
size=int(size)
if size<1: raise Exception("center_clip(size) must be called with a positive integer")
shp=(self.shape-size)//2
if isinstance(self._data,list):
newlst=[im.get_clip(Region(int(shp[1]),int(shp[2]),int(shp[2]),size,size,size)) for im in self._data]
return EMStack2D(newlst)
elif isinstance(self._data,np.ndarray) or isinstance(self._data,tf.Tensor):
return EMStack3D(newlst)
elif isinstance(self._data,np.ndarray) or isinstance(self._data,jax.Array):
newary=self._data[:,shp[1]:shp[1]+size,shp[2]:shp[2]+size,shp[3]:shp[3]+size]
return EMStack2D(newary)
return EMStack3D(newary)

def do_fft(self,keep_type=False):
"""Computes the FFT of each image and returns a new EMStack3D. If keep_type is not set, will convert to Tensor before computing FFT."""
Expand Down Expand Up @@ -426,7 +426,7 @@ def __len__(self): return len(self._data)
def shape(self):
# note that the returned shape is N,Y,X regardless of representation
if isinstance(self._data,list): return(np.array((len(self._data),self._data[0]["ny"],self._data[0]["nx"])))
return(self._data.shape)
return(np.array(self._data.shape))

@property
def orientations(self):
Expand Down
90 changes: 69 additions & 21 deletions programs/e3make3d_gauss.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@

from EMAN3 import *
from EMAN3jax import *
import jax
import optax
import numpy as np
import sys
import time
Expand Down Expand Up @@ -59,6 +61,7 @@ def main():
parser.add_argument("--dfmin", type=float, help="The minimum defocus appearing in the project, for use with --ctf",default=0.5)
parser.add_argument("--dfmax", type=float, help="The maximum defocus appearing in the project, for use with --ctf",default=2.0)
parser.add_argument("--sym", type=str,help="symmetry. currently only support c and d", default="c1")
parser.add_argument("--fscdebug", type=str,help="Compute the FSC of the final map with a reference volume for debugging",default=None)
parser.add_argument("--gpudev",type=int,help="GPU Device, default 0", default=0)
parser.add_argument("--gpuram",type=int,help="Maximum GPU ram to allocate in MB, default=4096", default=4096)
# parser.add_argument("--precache",type=str,help="Rather than perform a reconstruction, only perform caching on the input file for later use. String is the folder to put the cache files in.")
Expand Down Expand Up @@ -127,12 +130,12 @@ def main():
else:
stages=[
[512, 16,32,1.8,-3 ,1,.03, 2.0],
[512, 16,32,1.8, 0 ,4,.03, 1.0],
[512, 16,64,1.8, 0 ,4,.03, 1.0],
[1024, 32,32,1.5, 0 ,4,.02,1.5],
[1024, 32,24,1.5,-1 ,3,.02,1.0],
[1024, 32,32,1.5,-1 ,3,.02,1.0],
[4096, 64,24,1.2,-1.5,3,.01,1.0],
[8192, 256,16,1.0,-2 ,3,.005,1.0],
[32768,512,16,0.8,-2 ,1,.002,0.75]
[8192, 256,24,1.0,-2 ,3,.005,1.0],
[32768,512,24,0.8,-2 ,1,.002,0.75]
]

times=[time.time()]
Expand Down Expand Up @@ -183,6 +186,8 @@ def main():
rstep=1.0
# TODO: Ok, this should really use one of the proper optimization algorithms available from the deep learning toolkits
# this basic conjugate gradient gets the job done, but not very efficiently I suspect...
optim = optax.adam(.005) # parm is learning rate
optim_state=optim.init(gaus._data) # initialize with data
for i in range(stage[2]): # training epochs
if rstep<.01: break # don't continue if we've optimized well at this level
if nptcl>stage[0]: idx0=sn+i
Expand All @@ -194,10 +199,10 @@ def main():
# standard mode, optimize gaussian parms only
# if not options.tomo or sn<2:
if True:
step0,qual0,shift0,sca0=gradient_step(gaus,ptclsfds,orts,tytx,stage[3],stage[7],frc_Z)
step0,qual0,shift0,sca0=gradient_step_optax(gaus,ptclsfds,orts,tytx,stage[3],stage[7],frc_Z)
step0=jnp.nan_to_num(step0)
if j==0:
step,qual,shift,sca=step0,qual0,shift0,sca0
step,qual,shift,sca=step0,-qual0,shift0,sca0
else:
step+=step0
qual+=qual0
Expand Down Expand Up @@ -235,25 +240,30 @@ def main():
imshift+=imshift0
norm=len(nliststg)//512+1
qual/=norm
# if the quality got worse, we take smaller steps, starting by stepping back almost to the last good step
if qual<lqual:
rstep/=2.0 # if we start falling or oscillating we reduce the step within the epoch
step=-lstep*.95 # new gradient doesn't matter, first we want to mostly undo the previous step
lstep*=.05
gaus.add_array(step)
if options.savesteps: from_numpy(gaus.numpy).write_image("steps.hdf",-1)
print(f"{i}: {qual:1.5f}\t \t\t \t \t{rstep:1.5f} reverse")
continue
step*=rstep/norm
lstep=step
# # if the quality got worse, we take smaller steps, starting by stepping back almost to the last good step
# if qual<lqual:
# rstep/=2.0 # if we start falling or oscillating we reduce the step within the epoch
# step=-lstep*.95 # new gradient doesn't matter, first we want to mostly undo the previous step
# lstep*=.05
# gaus.add_array(step)
# if options.savesteps: from_numpy(gaus.numpy).write_image("steps.hdf",-1)
# print(f"{i}: {qual:1.5f}\t \t\t \t \t{rstep:1.5f} reverse")
# continue
# step*=rstep/norm
# lstep=step
# gaus.add_array(step)
# lqual=qual
shift/=norm
sca/=norm
imshift/=norm
gaus.add_array(step)
lqual=qual

update, optim_state = optim.update(step, optim_state)
gaus._data = optax.apply_updates(gaus._data, update)


if options.savesteps: from_numpy(gaus.numpy).write_image("steps.hdf",-1)

print(f"{i}: {qual:1.5f}\t{shift:1.5f}\t\t{sca:1.5f}\t{imshift:1.5f}\t{rstep:1.5f}")
print(f"{i}: {qual:1.5f}\t{shift:1.5f}\t\t{sca:1.5f}\t{imshift:1.5f}")
if qual>0.99: break

# end of epoch, save images and projections for comparison
Expand Down Expand Up @@ -293,7 +303,8 @@ def main():

outsz=min(1024,nxraw)
times.append(time.time())
vol=gaus.volume(outsz,zmax).emdata[0]
vol=gaus.volume(outsz,zmax).center_clip(outsz)
vol=vol.emdata[0]
times.append(time.time())
vol["apix_x"]=apix*nxraw/outsz
vol["apix_y"]=apix*nxraw/outsz
Expand All @@ -304,6 +315,9 @@ def main():
times.append(time.time())
vol.write_image(options.volout,0)

# this is just to save some extra processing steps
if options.fscdebug is not None:
os.system(f'e2proc3d.py {options.volout.split(":")[0]} {options.volout.rsplit(".",1)[0]}_fsc.txt --calcfsc {options.fscdebug}')

times=np.array(times)
#times-=times[0]
Expand Down Expand Up @@ -342,6 +356,40 @@ def gradient_step(gaus,ptclsfds,orts,tytx,weight=1.0,relstep=1.0,frc_Z=3.0):
return (step,float(qual),float(shift),float(sca))
# print(f"{i}) {float(qual)}\t{float(shift)}\t{float(sca)}")

def gradient_step_optax(gaus,ptclsfds,orts,tytx,weight=1.0,relstep=1.0,frc_Z=3.0):
"""Computes one gradient step on the Gaussian coordinates given a set of particle FFTs at the appropriate scale,
computing FRC to axial Nyquist, with specified linear weighting factor (def 1.0). Linear weight goes from
0-2. 1 is unweighted, >1 upweights low resolution, <1 upweights high resolution.
returns step, qual, shift, scale
step - one gradient step to be applied with (gaus.add_tensor)
qual - mean frc
shift - std of xyz shift gradient
scale - std of amplitude gradient"""
ny=ptclsfds.shape[1]
mx=orts.to_mx2d(swapxy=True)
gausary=gaus.jax
ptcls=ptclsfds.jax

frcs,grad=gradvalfnl(gausary,mx,tytx,ptcls,weight,frc_Z)

qual=frcs # functions used in jax gradient can't return a list, so frcs is a single value now
shift=grad[:,:3].std() # translational std
sca=grad[:,3].std() # amplitude std

return (grad,float(qual),float(shift),float(sca))

def prj_frc_loss(gausary,mx2d,tytx,ptcls,weight,frc_Z):
"""Aggregates the functions we need to calculate the gradient through. Computes the frc array resulting from the
comparison of the Gaussians in gaus to particles in known orientations. Returns -frc since optax wants to minimize, not maximize"""

ny=ptcls.shape[1]
#pfn=jax.jit(gauss_project_simple_fn,static_argnames=["boxsize"])
#prj=pfn(gausary,mx2d,ny,tytx)
prj=gauss_project_simple_fn(gausary,mx2d,ny,tytx)
return -jax_frc_jit(jax_fft2d(prj),ptcls,weight,2,frc_Z)

gradvalfnl=jax.value_and_grad(prj_frc_loss)

def prj_frc(gausary,mx2d,tytx,ptcls,weight,frc_Z):
"""Aggregates the functions we need to calculate the gradient through. Computes the frc array resulting from the
comparison of the Gaussians in gaus to particles in known orientations."""
Expand Down

0 comments on commit 3e9124d

Please sign in to comment.