diff --git a/evodiff/conditional_generation.py b/evodiff/conditional_generation.py index 6931823..a7bcba2 100644 --- a/evodiff/conditional_generation.py +++ b/evodiff/conditional_generation.py @@ -73,8 +73,7 @@ def main(): elif args.model_type == 'lr_ar_640M': checkpoint = LR_AR_640M() else: - print("Please select valid model, if you want to generate a random baseline add --random-baseline flag to any" - " model") + raise Exception("Please select either oa_dm_38M, oa_dm_640M, carp_640M, lr_ar_38M, or lr_ar_640M. You selected: ", args.model_type, ". If you want to generate a random baseline, add the --random-baseline flag to any model.") model, collater, tokenizer, scheme = checkpoint model.eval().cuda() diff --git a/evodiff/conditional_generation_msa.py b/evodiff/conditional_generation_msa.py index 4f29c7f..39911da 100644 --- a/evodiff/conditional_generation_msa.py +++ b/evodiff/conditional_generation_msa.py @@ -72,7 +72,7 @@ def main(): mask_id = checkpoint[2].mask_idx pad_id = checkpoint[2].padding_idx else: - raise Exception("Please select either msa_or_ar_randsub, msa_oa_oar_maxsub, or esm_msa_1b baseline. You selected:", args.model_type) + raise Exception("Please select either msa_oa_dm_randsub, msa_oa_dm_maxsub, or esm_msa_1b baseline. You selected:", args.model_type) model, collater, tokenizer, scheme = checkpoint model.eval().cuda() diff --git a/evodiff/generate.py b/evodiff/generate.py index 96dee7e..306061f 100644 --- a/evodiff/generate.py +++ b/evodiff/generate.py @@ -21,9 +21,9 @@ def main(): _ = torch.manual_seed(0) np.random.seed(0) parser = argparse.ArgumentParser() - parser.add_argument('--model-type', type=str, default='oa_ar_640M', - help='Choice of: carp_38M carp_640M esm1b_640M \ - oa_ar_38M oa_ar_640M') + parser.add_argument('--model-type', type=str, default='oa_dm_640M', + help='Choice of: carp_38M carp_640M esm1b_650M \ + oa_dm_38M oa_dm_640M lr_ar_38M lr_ar_640M d3pm_blosum_38M d3pm_blosum_640M d3pm_uniform_38M d3pm_uniform_640M') parser.add_argument('-g', '--gpus', default=1, type=int, help='number of gpus per node') #parser.add_argument('out_fpath', type=str, nargs='?', default=os.getenv('PT_OUTPUT_DIR', '/tmp') + '/') @@ -47,9 +47,9 @@ def main(): checkpoint = CARP_38M() elif args.model_type=='carp_640M': checkpoint = CARP_640M() - elif args.model_type=='oa_ar_38M': + elif args.model_type=='oa_dm_38M': checkpoint = OA_DM_38M() - elif args.model_type=='oa_ar_640M': + elif args.model_type=='oa_dm_640M': checkpoint = OA_DM_640M() elif args.model_type=='lr_ar_38M': checkpoint = LR_AR_38M() @@ -68,7 +68,7 @@ def main(): checkpoint = D3PM_UNIFORM_640M(return_all=True) d3pm=True else: - print("Please select valid model") + raise Exception("Please select either carp_38M, carp_640M, esm1b_650M, oa_dm_38M, oa_dm_640M, lr_ar_38M, lr_ar_640M, d3pm_blosum_38M, d3pm_blosum_640M, d3pm_uniform_38M, or d3pm_uniform_640M. You selected:", args.model_type) if d3pm: model, collater, tokenizer, scheme, timestep, Q_bar, Q = checkpoint