Skip to content

Commit

Permalink
e3movie improvements. A few tweaks to libraries and parameters for e3…
Browse files Browse the repository at this point in the history
…make3d_gauss
  • Loading branch information
sludtke42 committed Jan 14, 2025
1 parent fe9b885 commit 7ed21be
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 56 deletions.
37 changes: 23 additions & 14 deletions libpyEM/EMAN3.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,7 +613,7 @@ def compress_hdf(fsp,bits,nooutliers=False,level=1):
if fsp[-4:].lower()!=".hdf" : return
nm=get_temp_name()
os.rename(fsp,nm)
n=EMUtil.get_image_count(nm)
n=file_image_count(nm)
for i in range(n): EMData(nm,i).write_compressed(fsp,i,bits,nooutliers=nooutliers,level=level)
os.unlink(nm)

Expand Down Expand Up @@ -1097,7 +1097,7 @@ def parse_infile_arg(arg):
slices_inc = parse_string_to_slices(seq_inc)
slices_exc = parse_string_to_slices(seq_exc) if seq_exc else []

nimg = EMUtil.get_image_count(fname)
nimg = file_image_count(fname)
if ":" not in arg: return arg,range(nimg) #quick stopgap for performance problem

idxs = OrderedDict()
Expand Down Expand Up @@ -1840,7 +1840,7 @@ def is_2d_image_mx(filename):
a.read_image(filename,0,True)
if a.get_ndim() != 2:
return False, "Image is not 2D :", filename
elif EMUtil.get_image_count(filename) < 1:
elif file_image_count(filename) < 1:
return False, "Image has not particles in it :", filename
else:
return True, "Image is a 2D stack"
Expand All @@ -1854,7 +1854,7 @@ def check_files_are_2d_images(filenames):
return fine, message
else:
for name in filenames:
if EMUtil.get_image_count(name) > 1:
if file_image_count(name) > 1:
return False, "Image contains more than one image :", name

else:
Expand Down Expand Up @@ -1954,7 +1954,7 @@ def file_exists( file_name ):
return True
else:
try:
if db_check_dict(file_name) and EMUtil.get_image_count(file_name) != 0: # a database can exist but have no images in it, in which case we consider it to not exist
if db_check_dict(file_name) and file_image_count(file_name) != 0: # a database can exist but have no images in it, in which case we consider it to not exist
return True
else: return False
except: return False
Expand Down Expand Up @@ -2830,7 +2830,7 @@ def image_eosplit(filename):
oute=None
outo=None
else : # This means we have a regular image file as an input
n=EMUtil.get_image_count(filename)
n=file_image_count(filename)
eset=filename.rsplit(".",1)[0]+"_even.lst"
oset=filename.rsplit(".",1)[0]+"_odd.lst"

Expand Down Expand Up @@ -3061,12 +3061,20 @@ def db_write_image(self, fsp, *parms):

def db_write_images(fsp,
imgs,
idxs=0,
idx0=0,
imgtype=IMAGE_UNKNOWN,
header_only=False,
reqion=None,
region=None,
filestoragetype=EM_FLOAT,
use_host_endian=True):
"""fsp - output filename, accepts ":" syntax for compression
imgs - list of EMData objects
idx0 - index in the file of the first image, eg - 5 would write imgs[0] to position 5 in the file and imgs[1] to position 6
imgtype - use IMAGE_UNKNOWN for extension-based type
header_only - if set, only updates header and does not write image data
region - write only a portion of the image data to disk
filestorage - data type for writing. In most cases using ":" notation with filename is preferred
use_host_endian - rarely should be altered"""

if fsp[:4].lower() == "bdb:":
print("ERROR: BDB is not supported in this version of EMAN2. You must use EMAN2.91 or earlier to access legacy data.")
Expand Down Expand Up @@ -3145,11 +3153,11 @@ def db_write_images(fsp,

#print(f"PY {im['render_min']} - {im['render_max']} {im['minimum']} - {im['maximum']} {im['render_bits']}")
if bits < 0:
EMData.write_images_c(fsp, imgs, idxs) # bits<0 implies no compression
EMData.write_images_c(fsp, imgs, idx0) # bits<0 implies no compression
else:
EMData.write_images_c(fsp, imgs, idxs, EMUtil.ImageType.IMAGE_UNKNOWN, 0, None, EMUtil.EMDataType.EM_COMPRESSED)
EMData.write_images_c(fsp, imgs, idx0, EMUtil.ImageType.IMAGE_UNKNOWN, 0, None, EMUtil.EMDataType.EM_COMPRESSED)
else:
EMData.write_images_c(fsp, imgs, idxs, imgtype, header_only, reqion, filestoragetype, use_host_endian)
EMData.write_images_c(fsp, imgs, idx0, imgtype, header_only, region, filestoragetype, use_host_endian)

EMData.write_images_c = staticmethod(EMData.write_images)
EMData.write_images = staticmethod(db_write_images)
Expand Down Expand Up @@ -3211,7 +3219,7 @@ def im_write_compressed(self,fsp,n,bits=8,minval=0,maxval=0,nooutliers=True,leve
except: pass

if n==-1:
try: n=EMUtil.get_image_count(fsp)
try: n=file_image_count(fsp)
except: n=0

for i,im in enumerate(self):
Expand Down Expand Up @@ -3278,15 +3286,16 @@ def im_write_compressed(self,fsp,n,bits=8,minval=0,maxval=0,nooutliers=True,leve
EMData.write_compressed=im_write_compressed


def db_get_image_count(fsp):
# This way "file_image_count(fsp)" is available as a shortcut
def file_image_count(fsp):
if ":" in fsp:
fsp, idxs = parse_infile_arg(fsp)
return len(idxs)
else:
return EMUtil.get_image_count_c(fsp)

EMUtil.get_image_count_c = staticmethod(EMUtil.get_image_count)
EMUtil.get_image_count = db_get_image_count
EMUtil.get_image_count = file_image_count


__doc__ = \
Expand Down
4 changes: 2 additions & 2 deletions libpyEM/EMAN3jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,12 @@ def write(self,stack,n0,ortss=None,tytxs=None):
stack.coerce_numpy()
if ortss is not None:
try: self.orts[n0:n0+len(stack)]=ortss
except: self.orts[n0:n0+len(stack)]=ortss.numpy()
except: self.orts[n0:n0+len(stack)]=np.array(ortss)
if tytxs is not None:
try: self.tytx[n0:n0+len(stack)]=tytxs
except:
# print(tytxs,tytxs.shape)
self.tytx[n0:n0+len(stack)]=tytxs.numpy()
self.tytx[n0:n0+len(stack)]=np.array(tytxs)

# we go through the images one at a time, serialze, and write to a file with a directory
self.fp.seek(self.cloc)
Expand Down
4 changes: 2 additions & 2 deletions libpyEM/EMAN3tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,9 +257,9 @@ def align_translate(ref,maxshift=-1):
on each axis"""
pass

def write_images(self,fsp=None,bits=12):
def write_images(self,fsp=None,bits=12,n_start=0):
self.coerce_emdata()
im_write_compressed(self._data,fsp,0,bits)
im_write_compressed(self._data,fsp,n_start,bits)

def downsample(self,newsize):
"""Downsamples each image/volume in Fourier space such that its real-space dimensions after downsampling
Expand Down
69 changes: 47 additions & 22 deletions programs/e3make3d_gauss.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,19 +45,21 @@ def main():
"""
parser = EMArgumentParser(usage=usage,version=EMANVERSION)
parser.add_argument("--volout", type=str,help="Volume output file", default="threed.hdf")
parser.add_argument("--volout", type=str,help="Volume output file. Note that volumes will be appended to an existing file", default="threed.hdf")
parser.add_argument("--gaussout", type=str,help="Gaussian list output file",default=None)
parser.add_argument("--volfiltlp", type=float, help="Lowpass filter to apply to output volume in A, 0 disables, default=40", default=40)
parser.add_argument("--volfilthp", type=float, help="Highpass filter to apply to output volume in A, 0 disables, default=2500", default=2500)
parser.add_argument("--frc_z", type=float, help="FRC Z threshold (mean-sigma*Z)", default=3.0)
parser.add_argument("--apix", type=float, help="A/pix override for raw data", default=-1)
parser.add_argument("--thickness", type=float, help="For tomographic data specify the Z thickness in A to limit the reconstruction domain", default=-1)
parser.add_argument("--preclip",type=int,help="Trim the input images to the specified (square) box size in pixels", default=-1)
parser.add_argument("--postclip",type=int,help="Trim the output volumes to the specified (square) box size in pixels", default=-1)
parser.add_argument("--initgauss",type=int,help="Gaussians in the first pass, scaled with stage, default=500", default=500)
parser.add_argument("--savesteps", action="store_true",help="Save the gaussian parameters for each refinement step, for debugging and demos")
parser.add_argument("--tomo", action="store_true",help="tomogram mode, changes optimization steps")
parser.add_argument("--spt", action="store_true",help="subtomogram averaging mode, changes optimization steps")
parser.add_argument("--ctf", type=int,help="0=no ctf, 1=single ctf, 2=layered ctf",default=0)
parser.add_argument("--ptcl3d_id", type=int, help="only use 2-D particles with matching ptcl3d_id parameter",default=-1)
parser.add_argument("--dfmin", type=float, help="The minimum defocus appearing in the project, for use with --ctf",default=0.5)
parser.add_argument("--dfmax", type=float, help="The maximum defocus appearing in the project, for use with --ctf",default=2.0)
parser.add_argument("--sym", type=str,help="symmetry. currently only support c and d", default="c1")
Expand All @@ -69,10 +71,14 @@ def main():

(options, args) = parser.parse_args()
jax_set_device(dev=0,maxmem=options.gpuram)

llo=E3init(sys.argv)

nptcl=EMUtil.get_image_count(args[0])
if options.ptcl3d_id>=0:
if args[0][-4:]!=".lst" : error_exit("--ptcl3d_id only works with .lst input files")
lsx=LSXFile(args[0])
selimg=[i for i in range(len(lsx)) if lsx[i][2]["ptcl3d_id"]==options.ptcl3d_id]
nptcl=len(selimg)
nxraw=EMData(args[0],0,True)["nx"]
if options.preclip>0: nxraw=options.preclip
nxrawm2=good_size_small(nxraw-2)
Expand Down Expand Up @@ -130,34 +136,49 @@ def main():
]
else:
stages=[
[512, 16,32,1.8,-3 ,1,.03, 2.0],
[512, 16,48,1.8, 0 ,4,.03, 1.0],
[1024, 32,32,1.5, 0 ,4,.02,1.5],
[1024, 32,32,1.5,-1 ,3,.02,1.0],
[4096, 64,32,1.2,-1.5,3,.01,1.0],
[8192, 256,32,1.0,-2 ,3,.005,1.0],
[32768,512,32,0.8,-2 ,1,.002,0.75]
[512, 16,32,1.8,-3 ,1,.01, 2.0],
[512, 16,32,1.8, 0 ,4,.01, 1.0],
[1024, 32,32,1.5, 0 ,4,.005,1.5],
[1024, 32,32,1.5,-1 ,3,.005,1.0],
[4096, 64,32,1.2,-1.5,3,.003,1.0],
[8192, 256,32,1.0,-2 ,3,.003,1.0],
[32768,512,32,0.8,-2 ,1,.001,0.75]
]

batchsize=256

times=[time.time()]

# Cache initialization
if options.verbose: print("Caching particle data")
downs=sorted(set([s[1] for s in stages]))
caches={down:StackCache(f"tmp_{os.getpid()}_{down}.cache",nptcl) for down in downs} # dictionary keyed by box size
for i in range(0,nptcl,1000):
if options.ptcl3d_id>=0 :
if options.verbose>1:
print(f" Caching {i}/{nptcl}",end="\r",flush=True)
sys.stdout.flush()
stk=EMStack2D(EMData.read_images(args[0],range(i,min(i+1000,nptcl))))
print(f" Caching {nptcl}")
stk=EMStack2D(EMData.read_images(args[0],selimg))
if options.preclip>0 : stk=stk.center_clip(options.preclip)
orts,tytx=stk.orientations
tytx/=jnp.array((nxraw,nxraw,1)) # Don't divide the defocus
for im in stk.emdata: im.process_inplace("normalize.edgemean")
stkf=stk.do_fft()
for down in downs:
stkfds=stkf.downsample(min(down,nxrawm2))
caches[down].write(stkfds,i,orts,tytx)
caches[down].write(stkfds,0,orts,tytx)
else:
for i in range(0,nptcl,1000):
if options.verbose>1:
print(f" Caching {i}/{nptcl}",end="\r",flush=True)
sys.stdout.flush()
stk=EMStack2D(EMData.read_images(args[0],range(i,min(i+1000,nptcl))))
if options.preclip>0 : stk=stk.center_clip(options.preclip)
orts,tytx=stk.orientations
tytx/=nxraw
for im in stk.emdata: im.process_inplace("normalize.edgemean")
stkf=stk.do_fft()
for down in downs:
stkfds=stkf.downsample(min(down,nxrawm2))
caches[down].write(stkfds,i,orts,tytx)

# Forces all of the caches to share the same orientation information so we can update them simultaneously below (FRCs not jointly cached!)
for down in downs[1:]:
Expand Down Expand Up @@ -188,15 +209,18 @@ def main():
# TODO: Ok, this should really use one of the proper optimization algorithms available from the deep learning toolkits
# this basic conjugate gradient gets the job done, but not very efficiently I suspect...
optim = optax.adam(.005) # parm is learning rate
# optim = optax.lion(.003) # tried, seems not quite as good as Adam in test, but maybe worth another try
# optim = optax.lamb(.005) # tried, slightly better than adam, worse than lion
# optim = optax.fromage(.01) # tried, not as good
optim_state=optim.init(gaus._data) # initialize with data
for i in range(stage[2]): # training epochs
if rstep<.01: break # don't continue if we've optimized well at this level
if nptcl>stage[0]: idx0=sn+i
else: idx0=0
nliststg=range(idx0,nptcl,max(1,nptcl//stage[0])) # all of the particles to use in the current epoch in the current stage, sn+i provides stochasticity
imshift=0.0
for j in range(0,len(nliststg),512): # compute the gradient step piecewise due to memory limitations, 512 particles at a time
ptclsfds,orts,tytx=caches[stage[1]].read(nliststg[j:j+512])
for j in range(0,len(nliststg),batchsize): # compute the gradient step piecewise due to memory limitations, 512 particles at a time
ptclsfds,orts,tytx=caches[stage[1]].read(nliststg[j:j+batchsize])
# standard mode, optimize gaussian parms only
# if not options.tomo or sn<2:
if options.ctf==0:
Expand Down Expand Up @@ -241,15 +265,15 @@ def main():
step0=jnp.nan_to_num(step0)
if j==0:
step,stept,qual,shift,sca,imshift=step0,stept0,qual0,shift0,sca0,imshift0
caches[stage[1]].add_orts(nliststg[j:j+512],None,stept0*rstep) # we can immediately add the current 500 since it is per-particle
caches[stage[1]].add_orts(nliststg[j:j+batchsize],None,stept0*rstep) # we can immediately add the current 500 since it is per-particle
else:
step+=step0
caches[stage[1]].add_orts(nliststg[j:j+512],None,stept0*rstep) # we can immediately add the current 500 since it is per-particle
caches[stage[1]].add_orts(nliststg[j:j+batchsize],None,stept0*rstep) # we can immediately add the current 500 since it is per-particle
qual+=qual0
shift+=shift0
sca+=sca0
imshift+=imshift0
norm=len(nliststg)//512+1
norm=len(nliststg)//batchsize+1
qual/=norm
# # if the quality got worse, we take smaller steps, starting by stepping back almost to the last good step
# if qual<lqual:
Expand All @@ -268,7 +292,7 @@ def main():
sca/=norm
imshift/=norm

update, optim_state = optim.update(step, optim_state)
update, optim_state = optim.update(step, optim_state, gaus._data)
gaus._data = optax.apply_updates(gaus._data, update)

if options.savesteps: from_numpy(gaus.numpy).write_image("steps.hdf",-1)
Expand Down Expand Up @@ -349,7 +373,8 @@ def main():

outsz=min(1024,nxraw)
times.append(time.time())
vol=gaus.volume(outsz,zmax).center_clip(outsz)
if options.postclip>0 : vol=gaus.volume(outsz,zmax).center_clip(options.postclip)
else : vol=gaus.volume(outsz,zmax).center_clip(outsz)
vol=vol.emdata[0]
times.append(time.time())
vol["apix_x"]=apix*nxraw/outsz
Expand All @@ -359,7 +384,7 @@ def main():
if options.volfilthp>0: vol.process_inplace("filter.highpass.gauss",{"cutoff_freq":1.0/options.volfilthp})
if options.volfiltlp>0: vol.process_inplace("filter.lowpass.gauss",{"cutoff_freq":1.0/options.volfiltlp})
times.append(time.time())
vol.write_image(options.volout,0)
vol.write_image(options.volout,-1)

# this is just to save some extra processing steps
if options.fscdebug is not None:
Expand Down
52 changes: 36 additions & 16 deletions programs/e3movie.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,14 @@ def main():
parser.add_argument("--frames",type=str,default=None,help="<first>,<last+1> movie frames to use, first frame is 0, '0,3' will use frames 0,1,2")
parser.add_argument("--acftest",action="store_true",default=False,help="compute ACF images for input stack")
parser.add_argument("--ccftest",action="store_true",default=False,help="compute CCF between each image and the middle image in the movie")
parser.add_argument("--ccfdtest",action="store_true",default=False,help="compute the CCF between each image and the next image in the movie, length n-1")
parser.add_argument("--ccfdtest",type=str,default=None,help="compute the CCF between each image and the next image in the movie, length n-1, provide the filename of the gain correction image")
parser.add_argument("--ccftiletest",action="store_true",default=False,help="test on tiled average of CCF")
parser.add_argument("--gpudev",type=int,help="GPU Device, default 0", default=0)
parser.add_argument("--gpuram",type=int,help="Maximum GPU ram to allocate in MB, default=4096", default=4096)
parser.add_argument("--verbose", "-v", dest="verbose", action="store", metavar="n", type=int, default=0, help="verbose level [0-9], higher number means higher level of verboseness")

(options, args) = parser.parse_args()
tf_set_device(dev=0,maxmem=options.gpuram)

pid=E3init(argv)
nmov=len(args)
Expand All @@ -70,7 +73,7 @@ def main():
if (i%10==0): E3progress(pid,i/nmov)
nimg=EMUtil.get_image_count(args[i])
print(f"{args[i]}: {nimg}")
if (nimg>500) : raise Exception("Can't deal with movies with >500 frames at present")
#if (nimg>500) : raise Exception("Can't deal with movies with >500 frames at present")
for j in range(0,nimg,50):
imgs=EMData.read_images(f"{args[i]}:{j}:{j+50}")
for img in imgs: avgr.add_image(img)
Expand Down Expand Up @@ -288,22 +291,39 @@ def main():
for im in cens.emdata: im.process_inplace("normalize.edgemean")
cens.write_images("ccfs.hdf")

if options.ccfdtest:
avg=EMData("average.hdf",0)
if options.ccfdtest is not None:
try: os.unlink("ccfs.hdf")
except: pass
try: os.unlink("ccfs1k.hdf")
except: pass
avg=EMData(options.ccfdtest,0)
avg.div(avg["mean"])
#avg.add(-avg["mean"])
nimg=EMUtil.get_image_count(args[0])
imgs=EMStack2D(EMData.read_images(f"{args[0]}:0:{min(50,nimg)}"))
for im in imgs:
im.div(avg)
ffts=imgs.do_fft()
ccfs=ffts.calc_ccf(ffts,offset=1)
ccfsr=ccfs.do_ift()
_,nx,ny=ccfsr.shape

cens=EMStack2D(ccfsr.tensor[:,nx//2-64:nx//2+64,ny//2-64:ny//2+64])
for im in cens.emdata: im.process_inplace("normalize.edgemean")
cens.write_images("ccfs.hdf")
nimg=file_image_count(args[0])
for i in range(1,nimg,25):
print(f"{i-1}:{min(i+25,nimg)}")
imgs=EMStack2D(EMData.read_images(f"{args[0]}:{i-1}:{min(i+25,nimg)}"))
for im in imgs:
im.process_inplace("math.fixgain.counting",{"gain":avg,"gainmin":3,"gainmax":3})
imgs.coerce_tensor()
# imgs=imgs.downsample(4096)
ffts=imgs.do_fft()
ccfs=ffts.calc_ccf(ffts,offset=1)
ccfsr=ccfs.do_ift()
_,nx,ny=ccfsr.shape
cens=ccfsr.center_clip(64)
#cens=EMStack2D(ccfsr.tensor[:,nx//2-64:nx//2+64,ny//2-64:ny//2+64])
for im in cens.emdata: im.process_inplace("normalize.edgemean")
cens.write_images("ccfs.hdf",bits=0,n_start=i-1)

imgs=imgs.center_clip(1024)
ffts=imgs.do_fft()
ccfs=ffts.calc_ccf(ffts,offset=1)
ccfsr=ccfs.do_ift()
_,nx,ny=ccfsr.shape
cens=ccfsr.center_clip(64)
for im in cens.emdata: im.process_inplace("normalize.edgemean")
cens.write_images("ccfs1k.hdf",bits=0,n_start=i-1)

if options.ccftiletest:
avg=EMData("average.hdf",0)
Expand Down

0 comments on commit 7ed21be

Please sign in to comment.