Skip to content

Commit

Permalink
attribution-wrong-gen-loc-end-coord-fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
Deepro Banerjee committed May 24, 2023
1 parent 02b5429 commit b6af358
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions src/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,16 +413,16 @@ def eval_model(args, dataset_split="test"):
# non linear model interpretation
if args.model_name != "linear":
attributions, approximation_error = integrated_gradients.attribute(batch_dict['x_data'].float(), target=0, internal_batch_size=args.test_batch_size, return_convergence_delta=True, n_steps=500)
attributions = torch.sum(attributions, 0).cpu().numpy()
attributions = torch.sum(attributions, 1).cpu().numpy()
if attribution_array is None:
attribution_array = attributions
genomic_loc = batch_dict["genome_loc"]
genomic_loc = np.concatenate((np.array(genomic_loc[0]).reshape(-1,1), genomic_loc[1].cpu().numpy().reshape(-1,1), genomic_loc[1].cpu().numpy().reshape(-1,1)), axis=1)
genomic_loc = np.concatenate((np.array(genomic_loc[0]).reshape(-1,1), genomic_loc[1].cpu().numpy().reshape(-1,1), genomic_loc[2].cpu().numpy().reshape(-1,1)), axis=1)
genomic_loc_array = genomic_loc
else:
attribution_array = np.concatenate((attribution_array, attributions), axis=0)
genomic_loc = batch_dict["genome_loc"]
genomic_loc = np.concatenate((np.array(genomic_loc[0]).reshape(-1,1), genomic_loc[1].cpu().numpy().reshape(-1,1), genomic_loc[1].cpu().numpy().reshape(-1,1)), axis=1)
genomic_loc = np.concatenate((np.array(genomic_loc[0]).reshape(-1,1), genomic_loc[1].cpu().numpy().reshape(-1,1), genomic_loc[2].cpu().numpy().reshape(-1,1)), axis=1)
genomic_loc_array = np.concatenate((genomic_loc_array, genomic_loc), axis=0)

# update test bar
Expand Down

0 comments on commit b6af358

Please sign in to comment.