Skip to content

Commit

Permalink
Replace SIV.py
Browse files Browse the repository at this point in the history
  • Loading branch information
harryrf committed Aug 14, 2022
1 parent e6c4011 commit 7b2c17f
Showing 1 changed file with 62 additions and 81 deletions.
143 changes: 62 additions & 81 deletions SIV.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
#This code runs all models.
#NOTE: the novel part of the work with comments referencing the paper is
#in the Block section beginning at line 710.

#COMMANDS FOR ANALYSES FROM PAPER:

#MUST USE THIS VENV:
Expand All @@ -11,8 +7,10 @@
#for a in simglu ohio;do for x in 0 1 2 3 4 5 6 7 8 9 10 11;do python SIV.py $a $x bo noic;python SIV.py $a $x bo DIRECT; for b in bo SB SIVFIRST SIVLAST z nogate norestrict nodirb;do python SIV.py $a $x $b;done; done;done

#missingness analysis:
#for m in 0 .1 .2 .3 .4 .5; do python SIV.py simglu $x $m bo miss;python SIV.py simglu 9 $m miss;done;
#for s in 1 2 3 4 5;do for a in 9 0 1 2 3 4 5 6 7 8;do for b in z bo; do for c in 0 .1 .2 .3 .4 .5;do for d in miss noise ;do python SIV.py simglu $a $c $s $b $d; done;done;done;done;done

#non-cf ablations:
#for a in simglu ohio;do for x in 0 1 2 3 4 5 6 7 8 9 10 11;do for b in bo z ;do python SIV.py $a $x $b nocf;done; done;done

import os,datetime
import sys
Expand Down Expand Up @@ -78,7 +76,8 @@ def inarg(i,o):
nv=5
SIMGLU=True
missval=0

noiseval=0
seedval=1
subs=['adult#001','adult#002','adult#003','adult#004','adult#005','adult#006','adult#007','adult#008','adult#009','adult#010','child#001','child#002','child#003','child#004','child#005','child#006','child#007','child#008','child#009','child#010','adolescent#001','adolescent#002','adolescent#003','adolescent#004','adolescent#005','adolescent#006','adolescent#007','adolescent#008','adolescent#009','adolescent#010']
if int(sys.argv[2])>9:
print('No such subject')
Expand All @@ -88,7 +87,13 @@ def inarg(i,o):
if 'miss' in sys.argv:
missval=float(sys.argv[3])
outstr+='.miss'+sys.argv[3]

outstr+='.seed'+sys.argv[4]
seedval=int(sys.argv[4])
if 'noise' in sys.argv:
noiseval=float(sys.argv[3])
outstr+='.noise'+sys.argv[3]
outstr+='.seed'+sys.argv[4]
seedval=int(sys.argv[4])
#-----------------------------------------------------------------Data processing options ##################################


Expand Down Expand Up @@ -189,7 +194,10 @@ def inarg(i,o):
#-----------------------------------------------------------------Model development options



GIVMAINIC=False
nomainic,outstr=inarg('nomainic',outstr)
if not nomainic:
GIVMAINIC=True
RESTRICT=True
lrdec=1#float(sys.argv[3])
wddec=1#float(sys.argv[4])
Expand Down Expand Up @@ -229,7 +237,7 @@ def inarg(i,o):
if not BEATSONLY and not nogateresdir and not nodirb:
DIRECT=True


CLEANCARB,outstr=inarg('CLEANCARB2',outstr)
#################################### MAIN SECTION ############################################
def main():
maindir = os.getcwd()+'/'+outstr
Expand Down Expand Up @@ -709,47 +717,41 @@ class Block(nn.Module):

def __init__(self, units, device, backcast_length, forecast_length):
super(Block, self).__init__()
#set params.
self.backlen=backcast_length
self.forecast_length=forecast_length
self.input=nv
self.device = device

self.units=int(100*lstmsize)
self.bs=BATCHSIZE


#The encoder network, sigma.
#main encoder network
self.lstm=nn.LSTM(self.input,self.units, num_layers=lstmlay,batch_first=True,bidirectional=True).to(device)


#The intrinsic decoder network, theta.
#main decoder network
self.dec=nn.LSTM(self.units*2,self.units, num_layers=lstmlay,batch_first=True,bidirectional=True).to(device)

if BEATSONLY and DIRECT:
#for only direct input ablation.
self.dec=nn.LSTM(self.units*2+60,self.units, num_layers=lstmlay,batch_first=True,bidirectional=True).to(device)


if not BEATSONLY:

#DEPRICATED SECTION, IGNORE. (Left in code for reproducability purposes, randomness is effected when they are removed)
#SIV encoder network
self.lstmS=nn.LSTM(self.input,self.units, num_layers=2,batch_first=True,bidirectional=True).to(device)
self.decS=nn.LSTM(self.units*4,self.units*2, num_layers=2,batch_first=True,bidirectional=True).to(device)

self.lstmSC=nn.LSTM(self.input,self.units, num_layers=2,batch_first=True,bidirectional=True).to(device)
self.decSC=nn.LSTM(self.units*4,self.units*2, num_layers=2,batch_first=True,bidirectional=True).to(device)
self.linSC=nn.Linear(self.units *2, 1).to(device)

self.decS=nn.LSTM(self.units*4+30,self.units*2, num_layers=2,batch_first=True,bidirectional=True).to(device)
self.decSC=nn.LSTM(self.units*4+30,self.units*2, num_layers=2,batch_first=True,bidirectional=True).to(device)

#SIV decoder network (phi) one: for insulin
self.decS=nn.LSTM(self.units*2,self.units, num_layers=2,batch_first=True,bidirectional=True).to(device)
#SIV decoder network (phi) two: for carbs
self.decSC=nn.LSTM(self.units*2,self.units, num_layers=2,batch_first=True,bidirectional=True).to(device)
#Re-assign to accept shifted SIV signal (x') as input
if DIRECT:
self.decS=nn.LSTM(self.units*2+30,self.units, num_layers=2,batch_first=True,bidirectional=True).to(device)
self.decSC=nn.LSTM(self.units*2+30,self.units, num_layers=2,batch_first=True,bidirectional=True).to(device)

#output network FC
#output network
self.lin=nn.Linear(self.units *2, 1).to(device)


Expand All @@ -759,7 +761,7 @@ def forward(self, xt,xorig):



#Pad the input

x=xt.clone()
origbs=x.size()[0]
if origbs<self.bs:
Expand All @@ -770,133 +772,97 @@ def forward(self, xt,xorig):


if not BEATSONLY:
#Identify Batch position of inputs with Insulin and Carbs
bothinds=torch.sum(torch.sum(x[:,:,1:3].clone(),1),1)>0
inds=torch.sum(x[:,:,1],1)>0
Cinds=torch.sum(x[:,:,2],1)>0
xin=x.clone()
if not GIVMAINIC:
xin[:,:,1:3]=0

# if RECBACK:
# xin=xint.clone()
else:
xin=x.clone()

if DIRECT and BEATSONLY and not GIVMAINIC:
xin[:,:,1:3]=0

#calculate and reshape h_sigma
lstm_out, (h_0,c_0) = self.lstm(xin)
lstm_out=lstm_out[:,-1,:].view((500,1,-1))

#set up output matrix
outer=torch.zeros(self.bs,self.forecast_length).to(self.device)


#SET UP LSTM parameters VARIABLES
#for theta

hdec=(torch.zeros(2*lstmlay,self.bs,self.units)).to(self.device)#,
cdec = (torch.zeros(2*lstmlay,self.bs,self.units)).to(self.device)#
#for phi for Insulin network
hdecS=(torch.zeros(4,self.bs,self.units)).to(self.device)#,
cdecS = (torch.zeros(4,self.bs,self.units)).to(self.device)#,
#for phi for Carb network
hdecSC=(torch.zeros(4,self.bs,self.units)).to(self.device)#,
cdecSC = (torch.zeros(4,self.bs,self.units)).to(self.device)#,

#Set up SIV signal (x') to be input directly to the SIV decoders
#Signal is shifted (see Implemantation details)

if DIRECT:
sivDIR=torch.zeros((500,30,2)).to(self.device)#,
sivDIR[:,6:,:]=x[:,:,1:3].clone()
if not BEATSONLY:
sivDIRC=sivDIR[:,:,1:2].clone()
sivDIR=sivDIR[:,:,0:1].clone()

#Loop through horizon
for f in range(self.forecast_length):
if DIRECT and BEATSONLY:
#Do direct input for baseline (Only Dec. SIV Input ablation)
lstm_out=torch.cat((lstm_out.view((500,1,-1)),sivDIR.clone().view(500,1,60) /.7*torch.mean(lstm_out.view((500,1,-1)),2).view(500,1,1) ),2)
for i in range(0,self.backlen+5):
sivDIR[:,i,:]=sivDIR[:,i+1,:]


#INTRINSIC DECODER (theta)
#calculate h_theta at this timepoint
lstm_outxx, (hdec,cdec) = self.dec(lstm_out,(hdec,cdec))
lstm_outx=lstm_outxx.clone()
#assign output value from FC network- will be over written if not baseline.
outer[:,f]=self.lin(lstm_outx.clone()[:,0,:]).view(-1)

#SIV DECODERS (phi)
if not BEATSONLY:


# SIV DECODER 1-INSULIN
#calculate h_phi
if DIRECT:
#input with x'
lstm_outS, (hdecS,cdecS) = self.decS(torch.cat((lstm_out,sivDIR.clone().view(500,1,30) /.7*torch.mean(lstm_out.view((500,1,-1)),2).view(500,1,1) ),2),(hdecS,cdecS))
#perform shift.
for i in range(0,self.backlen+5):
sivDIR[:,i,:]=sivDIR[:,i+1,:]
else:
#input without x' for ablation.
lstm_outS, (hdecS,cdecS) = self.decS(lstm_out,(hdecS,cdecS))

lstmotemp=lstm_outx.clone()
#No gate ablation
if NOGATE:
if RESTRICT:
lstmotemp[:,0,:]=lstm_outx[:,0,:].clone()-F.relu(lstm_outS[:,0,:self.units*2].clone())
else:
lstmotemp[:,0,:]=lstm_outx[:,0,:].clone()+lstm_outS[:,0,:self.units*2].clone()
else:
if not RESTRICT:
#update h_theta with Insulin effect- not restricted ablation
lstmotemp[inds,0,:]=lstm_outx[inds,0,:].clone()-lstm_outS[inds,0,:self.units*2].clone()
else:
#update h_theta with Insulin effect- restricted
lstmotemp[inds,0,:]=lstm_outx[inds,0,:].clone()-F.relu(lstm_outS[inds,0,:self.units*2].clone())
#update main hidden state variable
lstm_outx=lstmotemp.clone()


# SIV DECODER 1-Carbs
#calculate h_phi
if DIRECT:
#input with x'
lstm_outSC, (hdecSC,cdecSC) = self.decSC(torch.cat((lstm_out,sivDIRC.clone().view(500,1,30) /.7*torch.mean(lstm_out.view((500,1,-1)),2).view(500,1,1) ),2),(hdecSC,cdecSC))
#perform shift.
for i in range(0,self.backlen+5):
sivDIRC[:,i,:]=sivDIRC[:,i+1,:]
else:
#input without x' for ablation.
lstm_outSC, (hdecSC,cdecSC) = self.decSC(lstm_out.view((500,1,-1)),(hdecSC,cdecSC))
lstmotemp=lstm_outx.clone()
#No gate ablation
if NOGATE:
if RESTRICT:
lstmotemp[:,0,:]=lstm_outx[:,0,:].clone()+F.relu(lstm_outSC[:,0,:self.units*2].clone())
else:
lstmotemp[:,0,:]=lstm_outx[:,0,:].clone()+lstm_outSC[:,0,:self.units*2].clone()
else:
if not RESTRICT:
#update h_theta with Carb effect- not restricted
lstmotemp[Cinds,0,:]=lstm_outx[Cinds,0,:].clone()+lstm_outSC[Cinds,0,:self.units*2].clone()
else:
#update h_theta with Carb effect- restricted
lstmotemp[Cinds,0,:]=lstm_outx[Cinds,0,:].clone()+F.relu(lstm_outSC[Cinds,0,:self.units*2].clone())
#update main hidden state variable
lstm_outx=lstmotemp.clone()
#reassign output
outer[:,f]=self.lin(lstm_outx.clone()[:,0,:]).view(-1)

#not used
lstm_outST=lstm_outS.clone()
lstm_outSCT=lstm_outSC.clone()

#reassign for next loop

lstm_out=lstm_outx.clone()


#recover orignal batch size.

outer=outer[:origbs,:]


Expand Down Expand Up @@ -932,7 +898,6 @@ def __init__(self,device,backcast_length,forecast_length):
def forward(self, x,target):

xorig=x.clone()
#Perform carry forward transform:
if CF:
for f in range(1,x.shape[1]):
x[:,f,1:3]+=x[:,f-1,1:3]
Expand Down Expand Up @@ -961,22 +926,22 @@ def makedata(totallength,sub):

a=a[:datalen]
ll=len(a)
if missval>0:
np.random.seed(SEED)
if missval>0 or noiseval>0:
np.random.seed(seedval)
for f in range(ll):
ff=a[f][:,:3]
ff[:,0]/=SCALEVAL
ff[:,1]/=SCALEBOL
ff[:,2]/=SCALECARB
if missval>0:
if missval>0 or noiseval>0:

for i in range(ff.shape[0]):
if not ff[i,2]==0 and not np.isnan(ff[i,2]):
temppp=np.random.uniform()
if temppp<missval:
if temppp<missval and missval>0:
ff[i,2]=0
else:
ff[i,2]=ff[i,2]*(1-missval+np.random.uniform()*missval*2)
if noiseval>0:
ff[i,2]=ff[i,2]*(1-noiseval+np.random.uniform()*noiseval*2)

t=np.arange(288)
ff=np.concatenate((ff,np.zeros([288,2])),1)
Expand Down Expand Up @@ -1014,15 +979,20 @@ def makedata(totallength,sub):
b=np.asarray(a['basal'])
d=np.asarray(a['dose'])
c=np.asarray(a['carbs'])

g[np.isnan(g)]=0
c[np.isnan(c)]=0
d[np.isnan(d)]=0

if DOTHESCALE:
d/=SCALECARB
c/=SCALEBOL


if CLEANCARB:
for i in range(len(c)):
if c[i]>0:
if len(d[i:i+36][d[i:i+36]>0])!=1 or np.max(g[i:i+36])>180/SCALEVAL or np.min(g[i:i+36])>=70/SCALEVAL:
c[i]=999



fing=np.asarray(a['finger'])/400.0
Expand Down Expand Up @@ -1063,7 +1033,7 @@ def makedata(totallength,sub):
d=np.asarray(a['dose'])
c=np.asarray(a['carbs'])


g[np.isnan(g)]=0
c[np.isnan(c)]=0
d[np.isnan(d)]=0

Expand All @@ -1072,6 +1042,11 @@ def makedata(totallength,sub):
c/=SCALEBOL


if CLEANCARB:
for i in range(len(c)):
if c[i]>0:
if len(d[i:i+36][d[i:i+36]>0])!=1 or np.max(g[i:i+36])>180/SCALEVAL or np.min(g[i:i+36])>=70/SCALEVAL:
c[i]=999


fing=np.asarray(a['finger'])/400.0
Expand Down Expand Up @@ -1140,7 +1115,9 @@ def get_x_y(ii,ic):
l[np.isnan(l)]=0
learn[:,0]=l
l[learn[:,0]==0]=0

if CLEANCARB:
if np.max(learn[:,2])>100:
return np.asarray([]),None,False
see=temp[i+backcast_length:i+backcast_length+forecast_length,0]
if TRAINICONLY or ic==2:
if np.sum(learn[:,1])+np.sum(learn[:,2])==0:
Expand All @@ -1154,6 +1131,7 @@ def get_x_y(ii,ic):
return np.asarray([]),None,False
if np.sum(learn[:,0])==0:
return np.asarray([]),None,False

return learn,see,False


Expand Down Expand Up @@ -1242,6 +1220,9 @@ def get_x_y(i):
# learn[:,0]=gaus(learn[:,0],1)
# learn[:,0][origlearn[:,0]==0]=0
see=temp[i+backcast_length:i+backcast_length+forecast_length,0]
if CLEANCARB:
if np.max(learn[:,2])>100:
return np.asarray([]),None,False
if TESTICONLY:
if np.sum(learn[:,1])+np.sum(learn[:,2])==0:
return np.asarray([]),None,False
Expand Down

0 comments on commit 7b2c17f

Please sign in to comment.