Skip to content

Commit

Permalink
More work on gaussian refinement
Browse files Browse the repository at this point in the history
  • Loading branch information
sludtke42 committed May 31, 2024
1 parent 6fe71aa commit abf18f0
Show file tree
Hide file tree
Showing 6 changed files with 179 additions and 109 deletions.
2 changes: 1 addition & 1 deletion libEM/exception.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ namespace EMAN {
int linenum;
string desc;
string objname;
static string msg;
static string msg; // while not completely threadsafe, without this the error string doesn't persist long enough and Python exception string is garbage
};


Expand Down
39 changes: 28 additions & 11 deletions libpyEM/EMAN3tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ class StackCache():

def __init__(self,filename,n):
"""Specify filename and number of images to be cached."""
print("cache: ",filename)
self.filename=filename

self.fp=open(filename,"wb+") # erase file!
Expand Down Expand Up @@ -123,8 +124,11 @@ def read(self,nlist):

stack=[]
for i in nlist:
self.fp.seek(self.locs[i])
stack.append(tf.io.parse_tensor(self.fp.read(self.locs[i+1]-self.locs[i]),out_type=tf.complex64))
try:
self.fp.seek(self.locs[i])
stack.append(tf.io.parse_tensor(self.fp.read(self.locs[i+1]-self.locs[i]),out_type=tf.complex64))
except:
raise Exception(f"Error reading cache {self.filename}: {i} -> {self.locs[i]}")

self.locked=False
ret=EMStack2D(tf.stack(stack))
Expand Down Expand Up @@ -462,6 +466,7 @@ def downsample(self,newsize):
current stack is in real or Fourier space. This cannot be used to upsample (make images larger) and should
not be used on rectangular images/volumes."""

if newsize==self.shape[1]: return EMStack2D(self.tensor) # this won't copy, but since the tensor is constant should be ok?
return EMStack2D(tf_downsample_2d(self.tensor,newsize)) # TODO: for now we're forcing this to be a tensor, probably better to leave it in the current format

class Orientations():
Expand Down Expand Up @@ -522,10 +527,13 @@ def init_from_transforms(self,xformlist):

return(tf.constant(tytx))

def transforms(self):
def transforms(self,tytx=None):
"""converts the current orientations to a list of Transform objects"""

return [Transform({"type":"spinvec","v1":self._data[i][0],"v2":self._data[i][1],"v3":self._data[i][2]})]
if tytx is not None:
return [Transform({"type":"spinvec","v1":self._data[i][0],"v2":self._data[i][1],"v3":self._data[i][2],"tx":tytx[i][1],"ty":tytx[i][0]}) for i in range(len(self._data))]

return [Transform({"type":"spinvec","v1":self._data[i][0],"v2":self._data[i][1],"v3":self._data[i][2]}) for i in range(len(self._data))]

def to_mx2d(self,swapxy=False):
"""Returns the current set of orientations as a 2 x 3 x N matrix which will transform a set of 3-vectors to a set of
Expand Down Expand Up @@ -665,12 +673,15 @@ def replicate(self,n=2,dev=0.01):
dups=[self._data+tf.random.normal(self._data.shape,stddev=dev) for i in range(n)]
self._data=tf.concat(dups,0)

def norm_filter(self,sig=0.5):
"""Rescale the amplitudes so the maximum is 1, with amplitude below mean+sig*sigma"""
def norm_filter(self,sig=0.5,rad_downweight=-1):
"""Rescale the amplitudes so the maximum is 1, with amplitude below mean+sig*sigma removed. rad_downweight, if >0 will apply a radial linear amplitude decay beyond the specified radius to the corner of the cube. eg - 0.5 will downweight the corners. Downweighting only works if Gaussian coordinate range follows the -0.5 - 0.5 standard range for the box. """
self.coerce_tensor()
self._data=self._data*(1.0,1.0,1.0,1.0/tf.reduce_max(self._data[:,3])) # "normalize" amplitudes so max amplitude is scaled to 1.0, not sure how necessary this really is
thr=tf.math.reduce_mean(self._data[:,3])+sig*tf.math.reduce_std(self._data[:,3])
self._data=tf.boolean_mask(self._data,self._data[:,3]>thr) # remove any gaussians with amplitude below threshold
if rad_downweight>0:
famp=self._data[:,3]*(1.0-tf.nn.relu(tf.math.reduce_euclidean_norm(self._data[:,:3],1)-rad_downweight))
else: famp=self._data[:,3]
thr=tf.math.reduce_mean(famp)+sig*tf.math.reduce_std(famp)
self._data=tf.boolean_mask(self._data,famp>thr) # remove any gaussians with amplitude below threshold

def project_simple(self,orts,boxsize,tytx=None):
"""Generates a tensor containing a simple 2-D projection (interpolated delta functions) of the set of Gaussians for each of N Orientations in orts.
Expand Down Expand Up @@ -885,7 +896,7 @@ def tf_downsample_3d(imgs,newx,stack=False):
# two possible approaches would be to add an extra dimension to rad_img to cover image number, and handle the scatter_nd as a single operation
# or to try making use of DataSet. I started a DataSet implementation, but decided it added too much design complexity
def tf_frc(ima,imb,avg=0,weight=1.0,minfreq=0):
"""Computes the pairwise FRCs between two stacks of complex images. Returns a list of 1D FSC tensors or if avg!=0
"""Computes the pairwise FRCs between two stacks of complex images. imb may alternatively be a single image. Returns a list of 1D FSC tensors or if avg!=0
then the average of the first 'avg' values. If -1, averages through Nyquist. Weight permits a frequency based weight
(only for avg>0): 1-2 will upweight low frequencies, 0-1 will upweight high frequencies"""
if ima.dtype!=tf.complex64 or imb.dtype!=tf.complex64 : raise Exception("tf_frc requires FFTs")
Expand Down Expand Up @@ -923,6 +934,8 @@ def tf_frc(ima,imb,avg=0,weight=1.0,minfreq=0):
except:
raise Exception(f"failed in FRC with sizes {ima.shape} {imb.shape} {imar.shape} {imbr.shape}")

if len(imbr.shape)==3: single=False
else: single=True
frc=[]
for i in range(nimg):
zero=tf.zeros([nr])
Expand All @@ -933,8 +946,12 @@ def tf_frc(ima,imb,avg=0,weight=1.0,minfreq=0):
aprd=tf.tensor_scatter_nd_add(zero,rad_img,imar[i])
aprd=tf.tensor_scatter_nd_add(aprd,rad_img,imai[i])

bprd=tf.tensor_scatter_nd_add(zero,rad_img,imbr[i])
bprd=tf.tensor_scatter_nd_add(bprd,rad_img,imbi[i])
if single:
bprd=tf.tensor_scatter_nd_add(zero,rad_img,imbr)
bprd=tf.tensor_scatter_nd_add(bprd,rad_img,imbi)
else:
bprd=tf.tensor_scatter_nd_add(zero,rad_img,imbr[i])
bprd=tf.tensor_scatter_nd_add(bprd,rad_img,imbi[i])

frc.append(cross/tf.sqrt(aprd*bprd))

Expand Down
20 changes: 13 additions & 7 deletions libpyEM/qtgui/embrowser.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ def display_error(msg) :

# This is a floating point number-finding regular expression
# Documentation: https://regex101.com/r/68zUsE/4/
renumfind = re.compile(r"(?:^|(?<=[^\d\w:.-]))[+-]?\d+\.*\d*(?:[eE][-+]?\d+|)(?=[^\d\w:.-]|$)")

#renumfind = re.compile(r"(?:^|(?<=[^\d\w:.-]))[+-]?\d+\.*\d*(?:[eE][-+]?\d+|)(?=[^\d\w:.-]|$)")
renumfind = re.compile(r"-?\d*\.?\d+[eE]?-?\d*")
# We need to sort ints and floats as themselves, not string, John Flanagan
def safe_int(v) :
"""Performs a safe conversion from a string to an int. If a non int is presented we return the lowest possible value"""
Expand Down Expand Up @@ -846,7 +846,9 @@ def name() :
@staticmethod
def isValid(path, header) :
"""Returns (size, n, dim) if the referenced path is a file of this type, None if not valid. The first 4k block of data from the file is provided as well to avoid unnecessary file access."""
if not isprint(header) : return False # demand printable Ascii. FIXME: what about unicode ?
# if not isprint(header) : return False # demand printable Ascii. FIXME: what about unicode ?
try: s=header.decode("utf-8")
except: return False

try : size = os.stat(path)[6]
except : return False
Expand Down Expand Up @@ -880,7 +882,10 @@ def name() :
@staticmethod
def isValid(path, header) :
"""Returns (size, n, dim) if the referenced path is a file of this type, None if not valid. The first 4k block of data from the file is provided as well to avoid unnecessary file access."""
if not isprint(header) : return False # demand printable Ascii. FIXME: what about unicode ?

try: s=header.decode("utf-8")
except: return False
# if not isprint(header) : return False # demand printable Ascii. FIXME: what about unicode ?
if isinstance(header, bytes):
header=header.decode("utf-8")
if not "<html>" in header.lower() : return False # For the moment, we demand an <html> tag somewhere in the first 4k
Expand Down Expand Up @@ -979,8 +984,7 @@ def name() :
@staticmethod
def isValid(path, header) :
"""Returns (size, n, dim) if the referenced path is a file of this type, None if not valid. The first 4k block of data from the file is provided as well to avoid unnecessary file access."""
try:
if not isprint(header) : return False
try: s=header.decode("utf-8")
except: return False

# We need to try to count the columns in the file
Expand Down Expand Up @@ -1930,7 +1934,9 @@ def isValid(path, header) :
ext = os.path.basename(path).split('.')[-1]
if ext not in proper_exts: return False

if not isprint(header) : return False # demand printable Ascii. FIXME: what about unicode ?
try: s=header.decode("utf-8")
except: return False
# if not isprint(header) : return False # demand printable Ascii. FIXME: what about unicode ?

try : size = os.stat(path)[6]
except : return False
Expand Down
2 changes: 1 addition & 1 deletion programs/e2iminfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def main():
out.write("{:1.5f}\t{:1.1f}\t{:1.4f}\t{:1.5f}\t{:1.2f}\t{:1.1f}\t{:1.3f}\t{:1.2f}\n".format(v.defocus,v.bfactor,v.apix,v.dfdiff,v.dfang,v.voltage,v.cs,v.get_phase()))
elif isinstance(v,Transform) :
dct=v.get_params("eman")
out.write("{}\t{}\t{}\t{}\t{}\t{}\n".format(v["az"],v["alt"],v["phi"],v["tx"],v["ty"],v["tz"]))
out.write("{}\t{}\t{}\t{}\t{}\t{}\n".format(dct["az"],dct["alt"],dct["phi"],dct["tx"],dct["ty"],dct["tz"]))
elif isinstance(v,list) or isinstance(v,tuple):
out.write([str(i) for i in v].join("\t")+"\n")
else:
Expand Down
51 changes: 20 additions & 31 deletions programs/e3make3d_gauss.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def main():
if options.verbose: print(f"{nptcl} particles at {nxraw}^3")

# definition of downsampling sequence for stages of refinement
# #ptcl, downsample, iter, frc weight, amp threshold, replicate, step coef
# 0) #ptcl, 1) downsample, 2) iter, 3) frc weight, 4) amp threshold, 5) replicate, 6) step coef
# replication skipped in final stage
if options.tomo:
stages=[
Expand All @@ -85,12 +85,12 @@ def main():
else:
stages=[
[500, 16,16,1.8,-3 ,1,.01, 2.0],
[500, 16,16,1.8, 0.5 ,4,.01, 1.0],
[1000, 32,16,1.5,0 ,1,.005,1.5],
[1000, 32,16,1.5,-1,3,.007,1.0],
[2500, 64,24,1.2,-1.5,3,.005,1.0],
[10000,256,24,1.0,-2 ,3,.002,0.75],
[25000,512,12,0.8,-2 ,1,.001,0.5]
[500, 16,16,1.8, 0 ,3,.01, 1.0],
[1000, 32,16,1.5, 0 ,2,.005,1.5],
[1000, 32,16,1.5,-1 ,3,.007,1.0],
[2500, 64,24,1.2,-1.5,2,.005,1.0],
[10000,256,24,1.0,-2 ,2,.002,1.0],
[25000,512,12,0.8,-2 ,1,.001,0.75]
]

times=[time.time()]
Expand Down Expand Up @@ -120,29 +120,15 @@ def main():
ptcls=[]
for sn,stage in enumerate(stages):
if options.verbose: print(f"Stage {sn} - {local_datetime()}:")
#
# # stage 1 - limit to ~1000 particles for initial low resolution work
# if options.verbose: print(f"\tReading Files {min(stage[0],nptcl)} ptcl")
# ptcls=EMStack2D(EMData.read_images(args[0],range(0,nptcl,max(1,nptcl//stage[0]))))
# orts,tytx=ptcls.orientations
# tytx/=nxraw
# ptclsf=ptcls.do_fft()
#
# if options.verbose: print(f"\tDownsampling {min(nxraw,stage[1])} px")
# if stage[1]<nxraw: ptclsfds=ptclsf.downsample(stage[1]) # downsample specifies the final size, not the amount of downsampling
# else: ptclsfds=ptclsf # if true size is smaller than stage size, don't downsample, obviously
# ny=stage[1]
# ptcls=None # free resouces since we're committed to re-reading files for now
# ptclsf=None
# # ny=ptclsfds.shape[1]

nliststg=range(sn,nptcl,max(1,nptcl//stage[0])) # all of the particles to use in the current stage, sn start gives some stochasticity

# print(ptclsfds.shape,tytx.shape)
# nliststg=range(sn,nptcl,max(1,nptcl//stage[0])) # all of the particles to use in the current stage, sn start gives some stochasticity

if options.verbose: print(f"\tIterating x{stage[2]} with frc weight {stage[3]}\n FRC\t\tshift_grad\tamp_grad")
lqual=-1.0
rstep=1.0
for i in range(stage[2]): # training epochs
for j in range(0,len(nliststg),500): # compute the gradient step piecewise due to memory limitations, 1000 particles at a time
nliststg=range(sn+i,nptcl,max(1,nptcl//stage[0])) # all of the particles to use in the current epoch in the current stage, sn+i provides stochasticity
for j in range(0,len(nliststg),500): # compute the gradient step piecewise due to memory limitations, 500 particles at a time
ptclsfds,orts,tytx=caches[stage[1]].read(nliststg[j:j+500])
step0,qual0,shift0,sca0=gradient_step(gaus,ptclsfds,orts,tytx,stage[3],stage[7])
if j==0:
Expand All @@ -153,14 +139,16 @@ def main():
shift+=shift0
sca+=sca0
norm=len(nliststg)//500+1
step/=norm
qual/=norm
if qual<lqual: rstep/=2.0 # if we start falling or oscillating we reduce the step within the epoch
step*=rstep/norm
shift/=norm
sca/=norm
gaus.add_tensor(step)
lqual=qual
if options.savesteps: from_numpy(gaus.numpy).write_image("steps.hdf",-1)

print(f"{i}: {qual:1.4f}\t{shift:1.4f}\t\t{sca:1.4f}")
print(f"{i}: {qual:1.4f}\t{shift:1.4f}\t\t{sca:1.4f}\t{rstep:1.4f}")

# if options.savesteps:
# vol=gaus.volume(nxraw)
Expand All @@ -169,7 +157,8 @@ def main():

# filter results and prepare for stage 2
g0=len(gaus)
gaus.norm_filter(sig=stage[4])
if options.tomo: gaus.norm_filter(sig=stage[4]) # gaussians outside the box may be important!
else: gaus.norm_filter(sig=stage[4],rad_downweight=0.33)
g1=len(gaus)
if stage[5]>0: gaus.replicate(stage[5],stage[6])
g2=len(gaus)
Expand All @@ -181,12 +170,12 @@ def main():
out=open(options.gaussout,"w")
for x,y,z,a in gaus.tensor: out.write(f"{x:1.5f}\t{y:1.5f}\t{z:1.5f}\t{a:1.3f}\n")

vol=gaus.volume(nxraw)
vol=gaus.volume(nxraw).emdata[0]
vol["apix_x"]=apix
vol["apix_y"]=apix
vol["apix_z"]=apix
#vol.emdata[0].process_inplace("filter.lowpass.gauss",{"cutoff_abs":options.volfilt})
vol.write_images(options.volout)
vol.write_image(options.volout,0)


times=np.array(times)
Expand Down
Loading

0 comments on commit abf18f0

Please sign in to comment.