Skip to content

Commit

Permalink
gmm improvements (I think)
Browse files Browse the repository at this point in the history
  • Loading branch information
sludtke42 committed Dec 10, 2024
1 parent 7028450 commit 643d287
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 40 deletions.
7 changes: 5 additions & 2 deletions libpyEM/EMAN3jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,10 @@ def write(self,stack,n0,ortss=None,tytxs=None):
try: self.orts[n0:n0+len(stack)]=ortss
except: self.orts[n0:n0+len(stack)]=ortss.numpy()
if tytxs is not None:
try: self.tytx[n0:n0+len(stack)]=tytxs
except: self.tytx[n0:n0+len(stack)]=tytxs.numpy()
try: self.tytx[n0:n0+len(stack),:2]=tytxs
except:
# print(tytxs,tytxs.shape)
self.tytx[n0:n0+len(stack)]=tytxs.numpy()

# we go through the images one at a time, serialze, and write to a file with a directory
self.fp.seek(self.cloc)
Expand Down Expand Up @@ -381,6 +383,7 @@ class EMStack2D(EMStack):
def set_data(self,imgs):
""" """
self._xforms=None
self._df=None
if imgs is None:
self._data=None
self._npy_list=None
Expand Down
8 changes: 4 additions & 4 deletions programs/e2gmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2147,13 +2147,13 @@ def new_neutral(self):
lsxs=None

# refine the neutral model against some real data in entropy training mode
er=run(f"e2gmm_refine.py --projs {self.gmm}/particles_subset.lst --npt {self.currun['ngauss']} --decoderentropy --npt {self.currun['ngauss']} --sym {sym} --maxboxsz {maxbox} --minressz {minboxp} --model {modelseg} --modelout {modelout} --niter 10 --nmid {self.currun['dim']} --evalmodel {self.gmm}/{self.currunkey}_model_projs.hdf --evalsize {self.jsparm['boxsize']} --decoderout {decoder} {conv} --ampreg 0.1 --sigmareg 1.0")
er=run(f"e2gmm_refine.py --projs {self.gmm}/particles_subset.lst --npt {self.currun['ngauss']} --decoderentropy --npt {self.currun['ngauss']} --sym {sym} --maxboxsz {maxbox} --minressz {minboxp} --model {modelseg} --modelout {modelout} --niter 10 --nmid {self.currun['dim']} --evalmodel {self.gmm}/{self.currunkey}_model_projs.hdf --evalsize {self.jsparm['boxsize']} --decoderout {decoder} {conv} --ampreg 0.05")
if er :
showerror("Error running e2gmm_refine, see console for details. GPU memory exhaustion is a common issue. Consider reducing the target resolution.")
return

# Now we train latent zero to the neutral conformation
er=run(f"e2gmm_refine.py --projs {self.gmm}/proj_in.hdf --decoderin {decoder} --sym {sym} --maxboxsz {maxbox} --minressz {minboxp} --model {modelseg} --modelout {modelout} --niter 20 --nmid {self.currun['dim']} --evalmodel {self.gmm}/{self.currunkey}_model_projs.hdf --evalsize {self.jsparm['boxsize']} --decoderout {decoder} {conv} --modelreg {self.currun['modelreg']} --ampreg 1.0")
er=run(f"e2gmm_refine.py --projs {self.gmm}/proj_in.hdf --decoderin {decoder} --sym {sym} --maxboxsz {maxbox} --minressz {minboxp} --model {modelseg} --modelout {modelout} --niter 20 --nmid {self.currun['dim']} --evalmodel {self.gmm}/{self.currunkey}_model_projs.hdf --evalsize {self.jsparm['boxsize']} --decoderout {decoder} {conv} --modelreg {self.currun['modelreg']} --ampreg 0.1")
if er :
showerror("Error running e2gmm_refine, see console for details. GPU memory exhaustion is a common issue. Consider reducing the target resolution.")
return
Expand Down Expand Up @@ -2234,13 +2234,13 @@ def new_neutral2(self):
lsxs=None

# refine the neutral model against some real data in entropy training mode
er=run(f"e2gmm_refine_point.py --projs {self.gmm}/particles_subset.lst --decoderentropy --npt {self.currun['ngauss']} --sym {sym} --maxboxsz {maxbox} --minressz {minboxp} --model {modelseg} --modelout {modelout} --niter 20 --nmid {self.currun['dim']} --evalmodel {self.gmm}/{self.currunkey}_model_projs.hdf --evalsize {self.jsparm['boxsize']} --decoderout {decoder} {conv} --ampreg 0.05 --ptclsclip {self.jsparm['boxsize']}")
er=run(f"e2gmm_refine_point.py --projs {self.gmm}/particles_subset.lst --decoderentropy --npt {self.currun['ngauss']} --sym {sym} --maxboxsz {maxbox} --minressz {minboxp} --model {modelseg} --modelout {modelout} --niter 20 --nmid {self.currun['dim']} --evalmodel {self.gmm}/{self.currunkey}_model_projs.hdf --evalsize {self.jsparm['boxsize']} --decoderout {decoder} {conv} --ampreg 0.02 --ptclsclip {self.jsparm['boxsize']}")
if er :
showerror("Error running e2gmm_refine, see console for details. GPU memory exhaustion is a common issue. Consider reducing the target resolution.")
return

# Now we train latent zero to the neutral conformation
er=run(f"e2gmm_refine_point.py --projs {self.gmm}/proj_in.hdf --decoderin {decoder} --sym {sym} --maxboxsz {maxbox} --minressz {minboxp} --model {modelseg} --modelout {modelout} --niter 20 --nmid {self.currun['dim']} --evalmodel {self.gmm}/{self.currunkey}_model_projs.hdf --evalsize {self.jsparm['boxsize']} --decoderout {decoder} {conv} --modelreg {self.currun['modelreg']} --ampreg 1.0 --ptclsclip {self.jsparm['boxsize']}")
er=run(f"e2gmm_refine_point.py --projs {self.gmm}/proj_in.hdf --decoderin {decoder} --sym {sym} --maxboxsz {maxbox} --minressz {minboxp} --model {modelseg} --modelout {modelout} --niter 40 --nmid {self.currun['dim']} --evalmodel {self.gmm}/{self.currunkey}_model_projs.hdf --evalsize {self.jsparm['boxsize']} --decoderout {decoder} {conv} --modelreg {self.currun['modelreg']} --ampreg 0.05 --ptclsclip {self.jsparm['boxsize']}")
#er=run(f"e2gmm_refine_point.py --projs {self.gmm}/proj_in.hdf --sym {sym} --maxboxsz {maxbox} --model {modelseg} --modelout {modelout} --niter 20 --nmid {self.currun['dim']} --evalmodel {self.gmm}/{self.currunkey}_model_projs.hdf --evalsize {self.jsparm['boxsize']} --decoderout {decoder} {conv} --modelreg {self.currun['modelreg']} --ampreg 1.0 --ndense -1 --ptclsclip {self.jsparm['boxsize']}")
if er :
showerror("Error running e2gmm_refine, see console for details. GPU memory exhaustion is a common issue. Consider reducing the target resolution.")
Expand Down
75 changes: 41 additions & 34 deletions programs/e2gmm_refine_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def main():
parser.add_argument("--learnrate", type=float,help="learning rate for model training only. Default is 1e-4. ", default=1e-4)
# parser.add_argument("--sigmareg", type=float,help="regularizer for the sigma of gaussian width. Larger value means all Gaussian functions will have essentially the same width. Smaller value may help compensating local resolution difference.", default=.5)
parser.add_argument("--modelreg", type=float,help="regularizer for for Gaussian positions based on the starting model, ie the result will be biased towards the starting model when training the decoder (0-1 typ). Default 0", default=0)
parser.add_argument("--ampreg", type=float,help="regularizer for the Gaussian amplitudes in the first 1/2 of the iterations. Large values will encourage all Gaussians to have similar amplitudes. default = 0", default=0)
parser.add_argument("--ampreg", type=float,help="regularizer for Gaussian amplitudes. Large values will encourage all Gaussians towards 1.0 or -0.2. default = 0", default=0)
parser.add_argument("--niter", type=int,help="number of iterations", default=32)
parser.add_argument("--npts", type=int,help="number of points to initialize. ", default=-1)
parser.add_argument("--batchsz", type=int,help="batch size", default=128)
Expand Down Expand Up @@ -316,7 +316,7 @@ def main():
chunkn=nptcl

trainset=tf.data.Dataset.from_tensor_slices((dcpx[0], dcpx[1], xfsnp))
trainset=trainset.batch(bsz,drop_remainder=True)
trainset=trainset.batch(bsz,drop_remainder=False)
allscr, allgrds=calc_gradient(trainset, pts, params, options )
# allscr, allgrds=calc_gqual(trainset, pts, params, options )

Expand Down Expand Up @@ -375,7 +375,7 @@ def main():
#### actual training
ptclidx=allscr>-1
trainset=tf.data.Dataset.from_tensor_slices((allgrds[ptclidx], dcpx[0][ptclidx], dcpx[1][ptclidx], xfsnp[ptclidx]))
trainset=trainset.batch(bsz,drop_remainder=True)
trainset=trainset.batch(bsz,drop_remainder=False)

train_heterg(trainset, pts, encode_model, decode_model, params, options,grps)

Expand Down Expand Up @@ -750,25 +750,27 @@ def build_encoder(ninp,nmid,grps=None):
grps, if provided is a dictionary of group number keys with a count of gaussians as the value. The sum of the values should equal nmid."""
l2=tf.keras.regularizers.l2(1e-3)
l1=tf.keras.regularizers.l1(1e-3)
binitsig=tf.keras.initializers.Constant(-0.5)
# binit=tf.keras.initializers.RandomNormal(0,1e-2)
# kinit=tf.keras.initializers.HeNormal()
binit="random_normal"
kinit="he_normal"
leaky=tf.keras.layers.LeakyReLU(0.1)

if grps is None:
print(f"Encoder (no groups) {max(ninp//2,nmid*8)},{max(ninp//4,nmid*4)},{max(ninp//8,nmid*4)},{max(ninp//16,nmid*4)},{max(ninp//32,nmid*2)}")
print(f"Encoder (no groups) {max(ninp//2,nmid*8)},{max(ninp//4,nmid*8)},{max(ninp//8,nmid*4)},{max(ninp//16,nmid*4)},{max(ninp//32,nmid*2)}")
layers=[
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(max(ninp//2,nmid*8), activation="leaky_relu", kernel_initializer=kinit, kernel_regularizer=l2,use_bias=True,bias_initializer=binit),
tf.keras.layers.Dense(max(ninp//2,nmid*16), activation=leaky, kernel_initializer=kinit, kernel_regularizer=l2,bias_initializer=binit,use_bias=True),
# tf.keras.layers.Dropout(.2),
tf.keras.layers.Dense(max(ninp//4,nmid*8), activation="leaky_relu", kernel_initializer=kinit, kernel_regularizer=l2,use_bias=True),
tf.keras.layers.Dropout(.4),
tf.keras.layers.Dense(max(ninp//8,nmid*4), activation="leaky_relu", kernel_initializer=kinit, kernel_regularizer=l2,use_bias=True),
tf.keras.layers.Dropout(.4),
tf.keras.layers.Dense(max(ninp//16,nmid*4), activation="leaky_relu", kernel_initializer=kinit, kernel_regularizer=l2,use_bias=True),
tf.keras.layers.Dropout(.4),
tf.keras.layers.Dense(max(ninp//32,nmid*4), activation="leaky_relu", kernel_initializer=kinit, kernel_regularizer=l2,use_bias=True),
tf.keras.layers.Dense(nmid, kernel_regularizer=l2, kernel_initializer=kinit,use_bias=True),
tf.keras.layers.Dense(max(ninp//4,nmid*16), activation=leaky, kernel_initializer=kinit, kernel_regularizer=l2,bias_initializer=binit,use_bias=True),
# tf.keras.layers.Dropout(.2),
tf.keras.layers.Dense(max(ninp//8,nmid*8), activation=leaky, kernel_initializer=kinit, kernel_regularizer=l2,bias_initializer=binit,use_bias=True),
# tf.keras.layers.Dropout(.2),
tf.keras.layers.Dense(max(ninp//16,nmid*8), activation=leaky, kernel_initializer=kinit, kernel_regularizer=l2,bias_initializer=binit,use_bias=True),
# tf.keras.layers.Dropout(.2),
tf.keras.layers.Dense(max(ninp//32,nmid*8), activation=leaky, kernel_initializer=kinit, kernel_regularizer=l2,bias_initializer=binit,use_bias=True),
tf.keras.layers.Dense(nmid, activation=leaky,kernel_regularizer=l2,bias_initializer=binit, kernel_initializer=kinit,use_bias=True),
]

encode_model=tf.keras.Sequential(layers)
Expand All @@ -782,16 +784,16 @@ def build_encoder(ninp,nmid,grps=None):
in2s=[]
t=0
for i in ngrp:
in2s.append(tf.keras.layers.Dense(i, activation="leaky_relu", kernel_initializer=kinit, kernel_regularizer=l2,use_bias=True,bias_initializer=binit)(in1[:,t:t+i]))
in2s.append(tf.keras.layers.Dense(i, activation=leaky, kernel_initializer=kinit, kernel_regularizer=l2,use_bias=True,bias_initializer=binit)(in1[:,t:t+i]))
t+=i
# Add Dropout here?
# drop=[tf.keras.layers.Dropout(0.3)(in2s[i]) for i in range(len(ngrp))]
mid=[tf.keras.layers.Dense(max(ngrp[i]//2,latpergrp*8), activation="leaky_relu", kernel_initializer=kinit, kernel_regularizer=l2,use_bias=True)(in2s[i]) for i in range(len(ngrp))]
mid=[tf.keras.layers.Dense(max(ngrp[i]//2,latpergrp*16), activation=leaky, kernel_initializer=kinit, kernel_regularizer=l2,bias_initializer=binit,use_bias=True)(in2s[i]) for i in range(len(ngrp))]
# drop=[tf.keras.layers.Dropout(0.25)(mid[i]) for i in range(len(ngrp))]
mid2=[tf.keras.layers.Dense(max(ngrp[i]//4,latpergrp*8), activation="leaky_relu", kernel_initializer=kinit, kernel_regularizer=l2,use_bias=True)(mid[i]) for i in range(len(ngrp))]
mid3=[tf.keras.layers.Dense(max(ngrp[i]//8,latpergrp*4), activation="leaky_relu", kernel_initializer=kinit, kernel_regularizer=l2,use_bias=True)(mid2[i]) for i in range(len(ngrp))]
mid4=[tf.keras.layers.Dense(max(ngrp[i]//16,latpergrp*4), activation="leaky_relu", kernel_initializer=kinit, kernel_regularizer=l2,use_bias=True)(mid3[i]) for i in range(len(ngrp))]
outs=[tf.keras.layers.Dense(latpergrp, kernel_regularizer=l2, kernel_initializer=kinit,use_bias=True)(mid4[i]) for i in range(len(ngrp))]
mid2=[tf.keras.layers.Dense(max(ngrp[i]//4,latpergrp*16), activation=leaky, kernel_initializer=kinit, kernel_regularizer=l2,bias_initializer=binit,use_bias=True)(mid[i]) for i in range(len(ngrp))]
mid3=[tf.keras.layers.Dense(max(ngrp[i]//8,latpergrp*8), activation=leaky, kernel_initializer=kinit, kernel_regularizer=l2,bias_initializer=binit,use_bias=True)(mid2[i]) for i in range(len(ngrp))]
mid4=[tf.keras.layers.Dense(max(ngrp[i]//16,latpergrp*8), activation=leaky, kernel_initializer=kinit, kernel_regularizer=l2,bias_initializer=binit,use_bias=True)(mid3[i]) for i in range(len(ngrp))]
outs=[tf.keras.layers.Dense(latpergrp, kernel_regularizer=l2, activation=leaky,kernel_initializer=kinit,bias_initializer=binitsig, use_bias=True)(mid4[i]) for i in range(len(ngrp))]
out=tf.keras.layers.Concatenate()(outs)
encode_model=tf.keras.Model(inputs=in1,outputs=out)
# print(in1,in2s,mid,mid2,outs,out)
Expand Down Expand Up @@ -819,25 +821,28 @@ def build_decoder(nmid, pt ):
# kinit=tf.keras.initializers.HeNormal()
binit="random_normal"
kinit="he_normal"
leaky=tf.keras.layers.LeakyReLU(0.1)
binitsig=tf.keras.initializers.Constant(-0.5)
l2=tf.keras.regularizers.l2(1e-3)
l1=tf.keras.regularizers.l1(1e-3)
# layer_output=tf.keras.layers.Dense(nout*4, kernel_initializer=kinit, activation="sigmoid",use_bias=True,kernel_constraint=Localize4())
layer_output=tf.keras.layers.Dense(nout*4, kernel_initializer=binit, activation="sigmoid",use_bias=True)
# layer_output=tf.keras.layers.Dense(nout*4, kernel_initializer=binit, activation="leaky_relu",use_bias=True)
layer_output=tf.keras.layers.Dense(nout*4, activation="sigmoid",kernel_initializer=kinit,use_bias=True)

# print(f"Decoder {max(nout//32,nmid)} {max(nout//8,nmid)} {max(nout//2,nmid)}")
layers=[
#tf.keras.layers.Dense(nmid*2,activation="leaky_relu",use_bias=True,bias_initializer=kinit,kernel_constraint=Localize1()),
#tf.keras.layers.Dense(nmid*4,activation="leaky_relu",use_bias=True,kernel_constraint=Localize2()),
#tf.keras.layers.Dense(nmid*8,activation="leaky_relu",use_bias=True,kernel_constraint=Localize3()),
tf.keras.layers.Dense(max(nmid*4,nout//64),activation="leaky_relu",kernel_initializer=kinit,use_bias=True,bias_initializer=binit),
tf.keras.layers.Dense(max(nmid*8,nout//32),activation="leaky_relu",kernel_initializer=kinit,use_bias=True),
tf.keras.layers.Dense(min(nmid*8,nout//16),activation="leaky_relu",kernel_initializer=kinit,use_bias=True),
tf.keras.layers.Dropout(.4),
tf.keras.layers.Dense(min(nmid*16,nout//8),activation="leaky_relu",kernel_initializer=kinit,use_bias=True),
tf.keras.layers.Dropout(.4),
tf.keras.layers.Dense(nout//8,activation="leaky_relu",kernel_initializer=kinit,use_bias=True),
tf.keras.layers.Dropout(.25),
tf.keras.layers.Dense(nout//4,activation="leaky_relu",kernel_initializer=kinit,use_bias=True),
tf.keras.layers.Dense(max(nmid*4,nout//64),activation=leaky,kernel_initializer=kinit,kernel_regularizer=l2,bias_initializer=binit, use_bias=True),
tf.keras.layers.Dense(max(nmid*8,nout//32),activation=leaky,kernel_initializer=kinit,kernel_regularizer=l2,bias_initializer=binit,use_bias=True),
tf.keras.layers.Dense(min(nmid*8,nout//16),activation=leaky,kernel_initializer=kinit,kernel_regularizer=l2,bias_initializer=binit,use_bias=True),
# tf.keras.layers.Dropout(.2),
tf.keras.layers.Dense(min(nmid*16,nout//8),activation=leaky,kernel_initializer=kinit,kernel_regularizer=l2,bias_initializer=binit,use_bias=True),
# tf.keras.layers.Dropout(.2),
tf.keras.layers.Dense(nout//8,activation=leaky,kernel_initializer=kinit,kernel_regularizer=l2,bias_initializer=binit,use_bias=True),
# tf.keras.layers.Dropout(.2),
tf.keras.layers.Dense(nout//4,activation=leaky,kernel_initializer=kinit,kernel_regularizer=l2,bias_initializer=binit,use_bias=True),
# tf.keras.layers.BatchNormalization(),
layer_output,
tf.keras.layers.Reshape((nout,4))
Expand Down Expand Up @@ -877,7 +882,7 @@ def train_decoder(gen_model, trainset, params, options, pts=None):
nbatch=0
for pjr,pji,xf in trainset: nbatch+=1

if options.decoderentropy: confs=tf.random.uniform((nbatch,xf.shape[0],options.nmid),minval=-0.05, maxval=0.05)
if options.decoderentropy: confs=tf.random.uniform((nbatch,xf.shape[0],options.nmid),minval=-0.005, maxval=0.005)
else: confs=tf.zeros((nbatch,xf.shape[0],options.nmid), dtype=floattype)

for itr in range(options.niter):
Expand All @@ -899,7 +904,8 @@ def train_decoder(gen_model, trainset, params, options, pts=None):
if options.modelreg>0:
#print(tf.reduce_sum(pout[0,:,:3]*pts[:,:3]),tf.reduce_sum((pout[0,:,:3]-pts[:,:3])**2),len(pts))
l+=tf.reduce_sum((pout[0,:,:3]-pts[:,:3])**2)/len(pts)*options.modelreg*20.0 # factor of 20 is a rough calibration relative to the dynamic training
if itr<options.niter//2: l+=std[3]*options.ampreg*options.ampreg
#print(pout.shape,pout[0])
if options.ampreg>0 and itr<options.niter-2: l+=tf.reduce_mean(tf.math.minimum(tf.math.abs(pout[:,:,3]-0.9),tf.math.abs(pout[:,:,3]+0.1)))*options.ampreg
# print(std)

cost.append(loss)
Expand Down Expand Up @@ -1028,7 +1034,7 @@ def coarse_align(dcpx, pts, options):
def refine_align(dcpx, xfsnp, pts, options, lr=1e-3):
nsample=dcpx[0].shape[0]
trainset=tf.data.Dataset.from_tensor_slices((dcpx[0], dcpx[1], xfsnp))
trainset=trainset.batch(options.batchsz,drop_remainder=True)
trainset=trainset.batch(options.batchsz,drop_remainder=False)
nbatch=nsample//options.batchsz

opt=tf.keras.optimizers.Adam(learning_rate=lr)
Expand Down Expand Up @@ -1171,9 +1177,9 @@ def train_heterg(trainset, pts, encode_model, decode_model, params, options,grps
pas=tf.constant(np.array([pas[0],pas[0],pas[0],pas[1]], dtype=floattype))

## initialize optimizer
# opt=tf.keras.optimizers.Adam(learning_rate=options.learnrate)
opt=tf.keras.optimizers.Adam(learning_rate=options.learnrate)
# opt=tf.keras.optimizers.experimental.AdamW(learning_rate=options.learnrate)
opt=tf.keras.optimizers.Adamax(learning_rate=0.02)
# opt=tf.keras.optimizers.Adamax(learning_rate=0.02)
# opt=tf.keras.optimizers.Adadelta(learning_rate=0.1)
# opt=tf.keras.optimizers.Lion()
wts=encode_model.trainable_variables + decode_model.trainable_variables
Expand All @@ -1189,6 +1195,7 @@ def train_heterg(trainset, pts, encode_model, decode_model, params, options,grps
else: ngrps=0

if ngrps<=1:
# if False:
focusepoch=min(6,max(1,(options.niter//2)//options.nmid) ) # number of epochs before adding additional middle layer coordinates
trainmidmask=tf.constant([[1.0 if j*focusepoch<=i else 0.0 for j in range(options.nmid)] for i in range(options.niter)], dtype=floattype)

Expand Down

0 comments on commit 643d287

Please sign in to comment.