Skip to content

Commit

Permalink
load weights only false
Browse files Browse the repository at this point in the history
  • Loading branch information
alanzty committed Dec 10, 2024
1 parent a9888e8 commit 6f851fe
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions evals/eval_gs_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def calc_all_features_mf(model_name, model, tokenizer, doc_meta_list, preprocess


def load_model(model_name, pretrained):
model, _, preprocess = open_clip.create_model_and_transforms(model_name, pretrained=pretrained)
model, _, preprocess = open_clip.create_model_and_transforms(model_name, pretrained=pretrained, load_weights_only=False)
model = model.to('cuda')
model.eval()
tokenizer = open_clip.get_tokenizer(model_name)
Expand Down Expand Up @@ -283,7 +283,7 @@ def run_eval(argv):
max_context_length = open_clip.factory._MODEL_CONFIGS[args.model_name]['text_cfg']['context_length']
else:
max_context_length = 77
else:
elif not args.model_name.startswith('hf-hub:'):
open_clip.factory._MODEL_CONFIGS[args.model_name]['text_cfg']['context_length'] = max_context_length
args.context_length = max_context_length

Expand Down Expand Up @@ -319,11 +319,14 @@ def run_eval(argv):
df_test[query_key] += df_test[col] + "_{!@#~}_"

logging.info(df_test)
if args.weight_key:
assert args.weight_key in df_test.columns
if (args.weight_key in df_test.columns) and len(df_test[args.weight_key].unique()) > 1:
df_test[args.weight_key] = (((df_test[args.weight_key] - df_test[args.weight_key].min()) / (df_test[args.weight_key].max() - df_test[args.weight_key].min())) * 99 + 1).astype(int)
else:
args.weight_key = "score"
df_test[args.weight_key] = 1
assert df_test[args.weight_key].min() >= 1

# get the test queries and gt_results if it is there.
if os.path.exists(args.gt_results_path):
Expand Down

0 comments on commit 6f851fe

Please sign in to comment.