Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/microsoft/DMs
Browse files Browse the repository at this point in the history
  • Loading branch information
sarahalamdari committed Sep 9, 2023
2 parents 7a8f5ac + 4cec23b commit e739e8f
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 9 deletions.
3 changes: 1 addition & 2 deletions evodiff/conditional_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,7 @@ def main():
elif args.model_type == 'lr_ar_640M':
checkpoint = LR_AR_640M()
else:
print("Please select valid model, if you want to generate a random baseline add --random-baseline flag to any"
" model")
raise Exception("Please select either oa_dm_38M, oa_dm_640M, carp_640M, lr_ar_38M, or lr_ar_640M. You selected: ", args.model_type, ". If you want to generate a random baseline, add the --random-baseline flag to any model.")

model, collater, tokenizer, scheme = checkpoint
model.eval().cuda()
Expand Down
2 changes: 1 addition & 1 deletion evodiff/conditional_generation_msa.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def main():
mask_id = checkpoint[2].mask_idx
pad_id = checkpoint[2].padding_idx
else:
raise Exception("Please select either msa_or_ar_randsub, msa_oa_oar_maxsub, or esm_msa_1b baseline. You selected:", args.model_type)
raise Exception("Please select either msa_oa_dm_randsub, msa_oa_dm_maxsub, or esm_msa_1b baseline. You selected:", args.model_type)

model, collater, tokenizer, scheme = checkpoint
model.eval().cuda()
Expand Down
12 changes: 6 additions & 6 deletions evodiff/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ def main():
_ = torch.manual_seed(0)
np.random.seed(0)
parser = argparse.ArgumentParser()
parser.add_argument('--model-type', type=str, default='oa_ar_640M',
help='Choice of: carp_38M carp_640M esm1b_640M \
oa_ar_38M oa_ar_640M')
parser.add_argument('--model-type', type=str, default='oa_dm_640M',
help='Choice of: carp_38M carp_640M esm1b_650M \
oa_dm_38M oa_dm_640M lr_ar_38M lr_ar_640M d3pm_blosum_38M d3pm_blosum_640M d3pm_uniform_38M d3pm_uniform_640M')
parser.add_argument('-g', '--gpus', default=1, type=int,
help='number of gpus per node')
#parser.add_argument('out_fpath', type=str, nargs='?', default=os.getenv('PT_OUTPUT_DIR', '/tmp') + '/')
Expand All @@ -47,9 +47,9 @@ def main():
checkpoint = CARP_38M()
elif args.model_type=='carp_640M':
checkpoint = CARP_640M()
elif args.model_type=='oa_ar_38M':
elif args.model_type=='oa_dm_38M':
checkpoint = OA_DM_38M()
elif args.model_type=='oa_ar_640M':
elif args.model_type=='oa_dm_640M':
checkpoint = OA_DM_640M()
elif args.model_type=='lr_ar_38M':
checkpoint = LR_AR_38M()
Expand All @@ -68,7 +68,7 @@ def main():
checkpoint = D3PM_UNIFORM_640M(return_all=True)
d3pm=True
else:
print("Please select valid model")
raise Exception("Please select either carp_38M, carp_640M, esm1b_650M, oa_dm_38M, oa_dm_640M, lr_ar_38M, lr_ar_640M, d3pm_blosum_38M, d3pm_blosum_640M, d3pm_uniform_38M, or d3pm_uniform_640M. You selected:", args.model_type)

if d3pm:
model, collater, tokenizer, scheme, timestep, Q_bar, Q = checkpoint
Expand Down

0 comments on commit e739e8f

Please sign in to comment.