From abf18f05b754ecbacd62d09c2886e22deb68d25d Mon Sep 17 00:00:00 2001 From: sludtke42 Date: Fri, 31 May 2024 13:01:47 -0500 Subject: [PATCH] More work on gaussian refinement --- libEM/exception.h | 2 +- libpyEM/EMAN3tensor.py | 39 +++++--- libpyEM/qtgui/embrowser.py | 20 ++-- programs/e2iminfo.py | 2 +- programs/e3make3d_gauss.py | 51 ++++------ programs/e3spa_refine_gauss.py | 174 ++++++++++++++++++++++----------- 6 files changed, 179 insertions(+), 109 deletions(-) diff --git a/libEM/exception.h b/libEM/exception.h index 72d36d3df4..9c0f735fd5 100644 --- a/libEM/exception.h +++ b/libEM/exception.h @@ -107,7 +107,7 @@ namespace EMAN { int linenum; string desc; string objname; - static string msg; + static string msg; // while not completely threadsafe, without this the error string doesn't persist long enough and Python exception string is garbage }; diff --git a/libpyEM/EMAN3tensor.py b/libpyEM/EMAN3tensor.py index 5b74ab62bf..0fdf07a547 100644 --- a/libpyEM/EMAN3tensor.py +++ b/libpyEM/EMAN3tensor.py @@ -67,6 +67,7 @@ class StackCache(): def __init__(self,filename,n): """Specify filename and number of images to be cached.""" + print("cache: ",filename) self.filename=filename self.fp=open(filename,"wb+") # erase file! @@ -123,8 +124,11 @@ def read(self,nlist): stack=[] for i in nlist: - self.fp.seek(self.locs[i]) - stack.append(tf.io.parse_tensor(self.fp.read(self.locs[i+1]-self.locs[i]),out_type=tf.complex64)) + try: + self.fp.seek(self.locs[i]) + stack.append(tf.io.parse_tensor(self.fp.read(self.locs[i+1]-self.locs[i]),out_type=tf.complex64)) + except: + raise Exception(f"Error reading cache {self.filename}: {i} -> {self.locs[i]}") self.locked=False ret=EMStack2D(tf.stack(stack)) @@ -462,6 +466,7 @@ def downsample(self,newsize): current stack is in real or Fourier space. This cannot be used to upsample (make images larger) and should not be used on rectangular images/volumes.""" + if newsize==self.shape[1]: return EMStack2D(self.tensor) # this won't copy, but since the tensor is constant should be ok? return EMStack2D(tf_downsample_2d(self.tensor,newsize)) # TODO: for now we're forcing this to be a tensor, probably better to leave it in the current format class Orientations(): @@ -522,10 +527,13 @@ def init_from_transforms(self,xformlist): return(tf.constant(tytx)) - def transforms(self): + def transforms(self,tytx=None): """converts the current orientations to a list of Transform objects""" - return [Transform({"type":"spinvec","v1":self._data[i][0],"v2":self._data[i][1],"v3":self._data[i][2]})] + if tytx is not None: + return [Transform({"type":"spinvec","v1":self._data[i][0],"v2":self._data[i][1],"v3":self._data[i][2],"tx":tytx[i][1],"ty":tytx[i][0]}) for i in range(len(self._data))] + + return [Transform({"type":"spinvec","v1":self._data[i][0],"v2":self._data[i][1],"v3":self._data[i][2]}) for i in range(len(self._data))] def to_mx2d(self,swapxy=False): """Returns the current set of orientations as a 2 x 3 x N matrix which will transform a set of 3-vectors to a set of @@ -665,12 +673,15 @@ def replicate(self,n=2,dev=0.01): dups=[self._data+tf.random.normal(self._data.shape,stddev=dev) for i in range(n)] self._data=tf.concat(dups,0) - def norm_filter(self,sig=0.5): - """Rescale the amplitudes so the maximum is 1, with amplitude below mean+sig*sigma""" + def norm_filter(self,sig=0.5,rad_downweight=-1): + """Rescale the amplitudes so the maximum is 1, with amplitude below mean+sig*sigma removed. rad_downweight, if >0 will apply a radial linear amplitude decay beyond the specified radius to the corner of the cube. eg - 0.5 will downweight the corners. Downweighting only works if Gaussian coordinate range follows the -0.5 - 0.5 standard range for the box. """ self.coerce_tensor() self._data=self._data*(1.0,1.0,1.0,1.0/tf.reduce_max(self._data[:,3])) # "normalize" amplitudes so max amplitude is scaled to 1.0, not sure how necessary this really is - thr=tf.math.reduce_mean(self._data[:,3])+sig*tf.math.reduce_std(self._data[:,3]) - self._data=tf.boolean_mask(self._data,self._data[:,3]>thr) # remove any gaussians with amplitude below threshold + if rad_downweight>0: + famp=self._data[:,3]*(1.0-tf.nn.relu(tf.math.reduce_euclidean_norm(self._data[:,:3],1)-rad_downweight)) + else: famp=self._data[:,3] + thr=tf.math.reduce_mean(famp)+sig*tf.math.reduce_std(famp) + self._data=tf.boolean_mask(self._data,famp>thr) # remove any gaussians with amplitude below threshold def project_simple(self,orts,boxsize,tytx=None): """Generates a tensor containing a simple 2-D projection (interpolated delta functions) of the set of Gaussians for each of N Orientations in orts. @@ -885,7 +896,7 @@ def tf_downsample_3d(imgs,newx,stack=False): # two possible approaches would be to add an extra dimension to rad_img to cover image number, and handle the scatter_nd as a single operation # or to try making use of DataSet. I started a DataSet implementation, but decided it added too much design complexity def tf_frc(ima,imb,avg=0,weight=1.0,minfreq=0): - """Computes the pairwise FRCs between two stacks of complex images. Returns a list of 1D FSC tensors or if avg!=0 + """Computes the pairwise FRCs between two stacks of complex images. imb may alternatively be a single image. Returns a list of 1D FSC tensors or if avg!=0 then the average of the first 'avg' values. If -1, averages through Nyquist. Weight permits a frequency based weight (only for avg>0): 1-2 will upweight low frequencies, 0-1 will upweight high frequencies""" if ima.dtype!=tf.complex64 or imb.dtype!=tf.complex64 : raise Exception("tf_frc requires FFTs") @@ -923,6 +934,8 @@ def tf_frc(ima,imb,avg=0,weight=1.0,minfreq=0): except: raise Exception(f"failed in FRC with sizes {ima.shape} {imb.shape} {imar.shape} {imbr.shape}") + if len(imbr.shape)==3: single=False + else: single=True frc=[] for i in range(nimg): zero=tf.zeros([nr]) @@ -933,8 +946,12 @@ def tf_frc(ima,imb,avg=0,weight=1.0,minfreq=0): aprd=tf.tensor_scatter_nd_add(zero,rad_img,imar[i]) aprd=tf.tensor_scatter_nd_add(aprd,rad_img,imai[i]) - bprd=tf.tensor_scatter_nd_add(zero,rad_img,imbr[i]) - bprd=tf.tensor_scatter_nd_add(bprd,rad_img,imbi[i]) + if single: + bprd=tf.tensor_scatter_nd_add(zero,rad_img,imbr) + bprd=tf.tensor_scatter_nd_add(bprd,rad_img,imbi) + else: + bprd=tf.tensor_scatter_nd_add(zero,rad_img,imbr[i]) + bprd=tf.tensor_scatter_nd_add(bprd,rad_img,imbi[i]) frc.append(cross/tf.sqrt(aprd*bprd)) diff --git a/libpyEM/qtgui/embrowser.py b/libpyEM/qtgui/embrowser.py index a4972749d8..0660fde8ee 100644 --- a/libpyEM/qtgui/embrowser.py +++ b/libpyEM/qtgui/embrowser.py @@ -63,8 +63,8 @@ def display_error(msg) : # This is a floating point number-finding regular expression # Documentation: https://regex101.com/r/68zUsE/4/ -renumfind = re.compile(r"(?:^|(?<=[^\d\w:.-]))[+-]?\d+\.*\d*(?:[eE][-+]?\d+|)(?=[^\d\w:.-]|$)") - +#renumfind = re.compile(r"(?:^|(?<=[^\d\w:.-]))[+-]?\d+\.*\d*(?:[eE][-+]?\d+|)(?=[^\d\w:.-]|$)") +renumfind = re.compile(r"-?\d*\.?\d+[eE]?-?\d*") # We need to sort ints and floats as themselves, not string, John Flanagan def safe_int(v) : """Performs a safe conversion from a string to an int. If a non int is presented we return the lowest possible value""" @@ -846,7 +846,9 @@ def name() : @staticmethod def isValid(path, header) : """Returns (size, n, dim) if the referenced path is a file of this type, None if not valid. The first 4k block of data from the file is provided as well to avoid unnecessary file access.""" - if not isprint(header) : return False # demand printable Ascii. FIXME: what about unicode ? +# if not isprint(header) : return False # demand printable Ascii. FIXME: what about unicode ? + try: s=header.decode("utf-8") + except: return False try : size = os.stat(path)[6] except : return False @@ -880,7 +882,10 @@ def name() : @staticmethod def isValid(path, header) : """Returns (size, n, dim) if the referenced path is a file of this type, None if not valid. The first 4k block of data from the file is provided as well to avoid unnecessary file access.""" - if not isprint(header) : return False # demand printable Ascii. FIXME: what about unicode ? + + try: s=header.decode("utf-8") + except: return False +# if not isprint(header) : return False # demand printable Ascii. FIXME: what about unicode ? if isinstance(header, bytes): header=header.decode("utf-8") if not "" in header.lower() : return False # For the moment, we demand an tag somewhere in the first 4k @@ -979,8 +984,7 @@ def name() : @staticmethod def isValid(path, header) : """Returns (size, n, dim) if the referenced path is a file of this type, None if not valid. The first 4k block of data from the file is provided as well to avoid unnecessary file access.""" - try: - if not isprint(header) : return False + try: s=header.decode("utf-8") except: return False # We need to try to count the columns in the file @@ -1930,7 +1934,9 @@ def isValid(path, header) : ext = os.path.basename(path).split('.')[-1] if ext not in proper_exts: return False - if not isprint(header) : return False # demand printable Ascii. FIXME: what about unicode ? + try: s=header.decode("utf-8") + except: return False +# if not isprint(header) : return False # demand printable Ascii. FIXME: what about unicode ? try : size = os.stat(path)[6] except : return False diff --git a/programs/e2iminfo.py b/programs/e2iminfo.py index 8e739260af..a7fe3ac3be 100755 --- a/programs/e2iminfo.py +++ b/programs/e2iminfo.py @@ -104,7 +104,7 @@ def main(): out.write("{:1.5f}\t{:1.1f}\t{:1.4f}\t{:1.5f}\t{:1.2f}\t{:1.1f}\t{:1.3f}\t{:1.2f}\n".format(v.defocus,v.bfactor,v.apix,v.dfdiff,v.dfang,v.voltage,v.cs,v.get_phase())) elif isinstance(v,Transform) : dct=v.get_params("eman") - out.write("{}\t{}\t{}\t{}\t{}\t{}\n".format(v["az"],v["alt"],v["phi"],v["tx"],v["ty"],v["tz"])) + out.write("{}\t{}\t{}\t{}\t{}\t{}\n".format(dct["az"],dct["alt"],dct["phi"],dct["tx"],dct["ty"],dct["tz"])) elif isinstance(v,list) or isinstance(v,tuple): out.write([str(i) for i in v].join("\t")+"\n") else: diff --git a/programs/e3make3d_gauss.py b/programs/e3make3d_gauss.py index 5bb49f4cee..1441a65cf9 100755 --- a/programs/e3make3d_gauss.py +++ b/programs/e3make3d_gauss.py @@ -70,7 +70,7 @@ def main(): if options.verbose: print(f"{nptcl} particles at {nxraw}^3") # definition of downsampling sequence for stages of refinement - # #ptcl, downsample, iter, frc weight, amp threshold, replicate, step coef + # 0) #ptcl, 1) downsample, 2) iter, 3) frc weight, 4) amp threshold, 5) replicate, 6) step coef # replication skipped in final stage if options.tomo: stages=[ @@ -85,12 +85,12 @@ def main(): else: stages=[ [500, 16,16,1.8,-3 ,1,.01, 2.0], - [500, 16,16,1.8, 0.5 ,4,.01, 1.0], - [1000, 32,16,1.5,0 ,1,.005,1.5], - [1000, 32,16,1.5,-1,3,.007,1.0], - [2500, 64,24,1.2,-1.5,3,.005,1.0], - [10000,256,24,1.0,-2 ,3,.002,0.75], - [25000,512,12,0.8,-2 ,1,.001,0.5] + [500, 16,16,1.8, 0 ,3,.01, 1.0], + [1000, 32,16,1.5, 0 ,2,.005,1.5], + [1000, 32,16,1.5,-1 ,3,.007,1.0], + [2500, 64,24,1.2,-1.5,2,.005,1.0], + [10000,256,24,1.0,-2 ,2,.002,1.0], + [25000,512,12,0.8,-2 ,1,.001,0.75] ] times=[time.time()] @@ -120,29 +120,15 @@ def main(): ptcls=[] for sn,stage in enumerate(stages): if options.verbose: print(f"Stage {sn} - {local_datetime()}:") -# -# # stage 1 - limit to ~1000 particles for initial low resolution work -# if options.verbose: print(f"\tReading Files {min(stage[0],nptcl)} ptcl") -# ptcls=EMStack2D(EMData.read_images(args[0],range(0,nptcl,max(1,nptcl//stage[0])))) -# orts,tytx=ptcls.orientations -# tytx/=nxraw -# ptclsf=ptcls.do_fft() -# -# if options.verbose: print(f"\tDownsampling {min(nxraw,stage[1])} px") -# if stage[1]0: gaus.replicate(stage[5],stage[6]) g2=len(gaus) @@ -181,12 +170,12 @@ def main(): out=open(options.gaussout,"w") for x,y,z,a in gaus.tensor: out.write(f"{x:1.5f}\t{y:1.5f}\t{z:1.5f}\t{a:1.3f}\n") - vol=gaus.volume(nxraw) + vol=gaus.volume(nxraw).emdata[0] vol["apix_x"]=apix vol["apix_y"]=apix vol["apix_z"]=apix #vol.emdata[0].process_inplace("filter.lowpass.gauss",{"cutoff_abs":options.volfilt}) - vol.write_images(options.volout) + vol.write_image(options.volout,0) times=np.array(times) diff --git a/programs/e3spa_refine_gauss.py b/programs/e3spa_refine_gauss.py index 47ceee17cf..20d6fe9bd3 100755 --- a/programs/e3spa_refine_gauss.py +++ b/programs/e3spa_refine_gauss.py @@ -46,6 +46,7 @@ def main(): parser.add_argument("--volfilt", type=float, help="Lowpass filter to apply to output volume, absolute, Nyquist=0.5", default=0.3) parser.add_argument("--initgauss",type=int,help="Gaussians in the first pass, scaled with stage, default=500", default=500) parser.add_argument("--savesteps", action="store_true",help="Save the gaussian parameters for each refinement step, for debugging and demos") + parser.add_argument("--fromscratch", action="store_true",help="Ignore orientations from input file and refine from scratch") parser.add_argument("--sym", type=str,help="symmetry. currently only support c and d", default="c1") parser.add_argument("--gpudev",type=int,help="GPU Device, default 0", default=0) parser.add_argument("--gpuram",type=int,help="Maximum GPU ram to allocate in MB, default=4096", default=4096) @@ -71,6 +72,7 @@ def main(): nptcl=EMUtil.get_image_count(args[0]) nxraw=EMData(args[0],0,True)["nx"] + nxrawm2=good_size_small(nxraw-2) apix=EMData(args[0],0,True)["apix_x"] if options.savesteps: @@ -79,36 +81,22 @@ def main(): if options.verbose: print(f"{nptcl} particles at {nxraw}^3") - # definition of downsampling sequence for stages of refinement - # 0)#ptcl, 1)downsample, 2)iter, 3)frc weight, 4)amp threshold, 5)replicate, 6)grad step coef, 7) FRC weighting 8)frc loc threshold (9 disables gradient) - # replication skipped in final stage. thresholds are mean+coef*std - stages=[ - [1000, 16,16,1.8, 0 ,2,.01, 2.0,9], - [1000, 16,16,1.8, 0 ,2,.01, 2.0,9], - [1000, 16,16,1.8,-1 ,2,.01, 2.0,-1], - [1000, 16,16,1.8,-1 ,2,.01, 2.0,-1], - [1000, 16,16,1.8,-1 ,1,.01, 2.0,-1], - [2000, 32,16,1.5,-.5 ,2,.005,1.5,-3], - [2000, 32,16,1.5,-1 ,3,.007,1.0,-2], - [5000, 64,24,1.2,-1.5,3,.005,1.0,-3], - [10000,256,24,1.0,-2 ,3,.002,0.75,-3], - [25000,512,12,1.0,-2 ,1,.001,0.5,-3.0] - ] - - for l in stages: l[1]=min(l[1],nxraw) # make sure we aren't "upsampling" times=[time.time()] # Cache initialization if options.verbose: print("Caching particle data") - downs=sorted(set([s[1] for s in stages])) +# downs=sorted(set([s[1] for s in stages])) + downs=sorted(set([min(i,nxrawm2) for i in (24,32,64,256,512)])) # note that 24 is also used in reseeding # caches={down:StackCache(f"tmp_{os.getpid()}_{down}.cache",nptcl) for down in downs} # dictionary keyed by box size caches={down:StackCache(f"{options.path}/tmp_{down}.cache",nptcl) for down in downs} # dictionary keyed by box size + fromscratch=options.fromscratch for i in range(0,nptcl,2500): if options.verbose>1: print(f"Caching {i}/{nptcl}") stk=EMStack2D(EMData.read_images(args[0],range(i,min(i+2500,nptcl)))) orts,tytx=stk.orientations - if orts is None: + if orts is None or fromscratch: + fromscratch=True tytx=np.zeros((stk.shape[0],2)) orts=rand.random((stk.shape[0],3))-0.5 else: tytx/=nxraw @@ -120,45 +108,77 @@ def main(): else: caches[down].write(stkf,i,orts,tytx) + # Reseed orientations for global search at low resolution + tstorts=[] + for x in np.arange(-0.5,0.5,0.04): + for y in np.arange(-0.5,0.5,0.04): + for z in np.arange(-0.5,0.5,0.04): + if hypot(x,y,z)<=0.5: tstorts.append((x,y,z)) + tst_orts=Orientations(np.array(tstorts)) + # Forces all of the caches to share the same orientation information so we can update them simultaneously below (FRCs not jointly cached!) for down in downs[1:]: caches[down].orts=caches[downs[0]].orts caches[down].tytx=caches[downs[0]].tytx + # definition of downsampling sequence for stages of refinement + # 0)#ptcl, 1)downsample, 2)iter, 3)frc weight, 4)amp threshold, 5)replicate, 6)replicate spread, 7) gradient scale 8)frc loc threshold (9 disables gradient) + # replication skipped in final stage. thresholds are mean+coef*std + if fromscratch: + print("Notice: refining from scratch without orientations") + stages=[ + [200, 24,16,1.8, 0 ,2,.05, 2.0,9], + [200, 24,16,1.8, 0 ,2,.05, 2.0,9], + [200, 24,16,1.8,-1 ,2,.01, 2.0,-1], + [5000, 24,16,1.8,-1 ,3,.1, 2.0,-1], + [5000, 24,16,1.8, 0 ,1,.01, 2.0,-1], + [5000, 32,16,1.5,-.5 ,2,.05,1.5,-3], + [5000, 32,16,1.5,-1 ,3,.007,1.0,-2], + [10000, 64,12,1.2,-1.5,3,.005,1.0,-3], + [10000, 64,12,1.0,-2 ,3,.002,0.75,-3], + [10000,256,12,1.2,-1.5,3,.005,1.0,-3], + [10000,256,12,1.0,-2 ,3,.002,0.75,-3], + [25000,512, 6,1.0,-2 ,1,.001,0.5,-3.0], + [25000,512, 6,1.0,-2 ,1,.001,0.5,-3.0] + ] + else: + stages=[ + [1000, 24,16,1.8,-3 ,1,.01, 2.0, 9], + [1000, 24,16,1.8, 0 ,2,.03, 1.5, 9], + [1000, 24,16,1.8,-1 ,1,.01, 1.5, -3], + [1000, 24,16,1.8, 0 ,2,.01, 1.5, -2], + [2000, 32,16,1.5, 0 ,2,.02,1.5, -3], + [2000, 32,16,1.5,-0.5,2,.01,1.25, -2], + [5000, 64,24,1.2,-1 ,2,.005,1.0, -3], + [10000,256,24,1.0,-1 ,2,.002,1.0,-3], + [25000,512,12,1.0,-2 ,1,.001,1.0, -3] + ] + + for l in stages: l[1]=min(l[1],nxrawm2) # make sure we aren't "upsampling" + + gaus=Gaussians() #Initialize Gaussians to random values with amplitudes over a narrow range rnd=tf.random.uniform((options.initgauss,4)) # specify the number of Gaussians to start with here rnd+=(-.5,-.5,-.5,10.0) gaus._data=rnd/(1.5,1.5,1.5,100.0) # amplitudes set to ~1.0, positions random within 2/3 box size + lsxin=LSXFile(args[0]) times.append(time.time()) ptcls=[] for sn,stage in enumerate(stages): if options.verbose: print(f"Stage {sn} - {local_datetime()}:") ccache=caches[stage[1]] -# -# # stage 1 - limit to ~1000 particles for initial low resolution work -# if options.verbose: print(f"\tReading Files {min(stage[0],nptcl)} ptcl") -# ptcls=EMStack2D(EMData.read_images(args[0],range(0,nptcl,max(1,nptcl//stage[0])))) -# orts,tytx=ptcls.orientations -# tytx/=nxraw -# ptclsf=ptcls.do_fft() -# -# if options.verbose: print(f"\tDownsampling {min(nxraw,stage[1])} px") -# if stage[1]0: + frcm=np.mean(lowfrc) + frcsg=np.std(lowfrc) + reseed_idx=np.where(frcs1: print(f"{i}/{len(ptcls)}") + ofrcs=tf_frc(seedprojsf.tensor,ptcls[i],-1) + maxort=tf.argmax(ofrcs) # best orientation for this particle + ccache.orts[i]=tst_orts[maxort] + #ccache.tytx[ii]=(0,0) # just keep the current center? + print(f"{nseeded} orts reseeded ({frcm+frcsg*stage[8]} thr) {local_datetime()}") + if stage[8]<9: if options.verbose: print(f"\tIterating orientations parms x{stage[2]} with frc weight {stage[3]}\n FRC\t\tort_grad\tcen_grad") fout=open(f"{options.path}/fscs.txt","w") @@ -212,39 +261,48 @@ def main(): # filter results and prepare for next stage g0=len(gaus) - gaus.norm_filter(sig=stage[4]) # remove gaussians below threshold + gaus.norm_filter(sig=stage[4],rad_downweight=0.33) # remove gaussians below threshold g1=len(gaus) if stage[5]>0: gaus.replicate(stage[5],stage[6]) # make copies of gaussians with local perturbation g2=len(gaus) - if stage[8]<9: - # reseed orientations for low FRCs - frcs=ccache.frcs # not ideal, stealing the actual list from the object, but good enough for now - frcm=np.mean(frcs[frcs<1.5]) - frcsg=np.std(frcs[frcs<1.5]) - nseeded=0 - for ii,f in enumerate(frcs): - if f {g1} -> {g2} gaussians {nseeded} orts reseeded ({frcm+frcsg*stage[8]} thr) {local_datetime()}") - - else: print(f"Stage {sn} complete: {g0} -> {g1} -> {g2} gaussians no orts reseeded {local_datetime()}") + print(f"Stage {sn} complete: {g0} -> {g1} -> {g2} gaussians no orts reseeded {local_datetime()}") + + + + + # frcs=ccache.frcs # not ideal, stealing the actual list from the object, but good enough for now + # frcm=np.mean(frcs[frcs<1.5]) + # frcsg=np.std(frcs[frcs<1.5]) + # nseeded=0 + # for ii,f in enumerate(frcs): + # if f