Skip to content

Commit

Permalink
cleanup explain()
Browse files Browse the repository at this point in the history
  • Loading branch information
martinju committed Jan 21, 2025
1 parent a4749f5 commit e64a184
Showing 1 changed file with 10 additions and 20 deletions.
30 changes: 10 additions & 20 deletions python/shaprpy/explain.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,46 +252,36 @@ def explain(
# Setting globals to simplify the loop
converged = rinternal.rx2('iter_list')[iter-1].rx2('converged')[0]

# rinternal.rx2['timing_list'].rx2['postprocess_res'] = base.Sys_time()
rinternal.rx2['timing_list'] = ro.ListVector({**dict(rinternal.rx2['timing_list'].items()), 'postprocess_res': base.Sys_time()})


# Add the current timing_list to the iter_timing_list
#iter_timing_list = list(rinternal.rx2['iter_timing_list'])
#iter_timing_list.append(rinternal.rx2['timing_list'])
#rinternal.rx2['iter_timing_list'] = ro.ListVector(iter_timing_list)

# rinternal.rx2['iter_timing_list'].rx2[iter] = rinternal.rx2['timing_list']
rinternal.rx2['iter_timing_list'] = ro.ListVector({**dict(rinternal.rx2['iter_timing_list'].items()), f'element_{iter}': rinternal.rx2['timing_list']})

iter += 1

#rinternal.rx2['main_timing_list'].rx2['main_computation'] = base.Sys_time()
rinternal.rx2['main_timing_list'] = ro.ListVector({**dict(rinternal.rx2['main_timing_list'].items()), 'main_computation': base.Sys_time()})

# Rerun after convergence to get the same output format as for the non-iterative approach
routput = shapr.finalize_explanation(rinternal)

#rinternal.rx2['main_timing_list'].rx2['finalize_explanation'] = base.Sys_time()
rinternal.rx2['main_timing_list'] = ro.ListVector({**dict(rinternal.rx2['main_timing_list'].items()), 'finalize_explanation': base.Sys_time()})


routput.rx2['timing'] = shapr.compute_time(rinternal)

# Some cleanup when doing testing
#testing = rinternal.rx2('parameters').rx2('testing')[0]
#if base.isTRUE(testing):
# routput = shapr.testing_cleanup(routput)
testing = rinternal.rx2('parameters').rx2('testing')[0]
if testing:
routput = shapr.testing_cleanup(routput)

# Convert R objects to Python objects
shapley_values_est = r2py(base.as_data_frame(routput.rx2('shapley_values_est')))
shapley_values_sd = r2py(base.as_data_frame(routput.rx2('shapley_values_sd')))
pred_explain = r2py(routput.rx2('pred_explain'))
shapley_values_est = recurse_r_tree(routput.rx2('shapley_values_est'))
shapley_values_sd = recurse_r_tree(routput.rx2('shapley_values_sd'))
pred_explain = recurse_r_tree(routput.rx2('pred_explain'))
MSEv = recurse_r_tree(routput.rx2('MSEv'))
iterative_results = recurse_r_tree(routput.rx2('iterative_results'))
#saving_path = StrVector(routput.rx2['saving_path']) # NOt sure why this is not working
saving_path = StrVector(rinternal.rx2['parameters'].rx2['output_args'].rx2['saving_path'])[0]
#internal = recurse_r_tree(routput.rx2('rinternal')) # Currently get an error with NULL elements here
saving_path = recurse_r_tree(routput.rx2['saving_path'])
internal = recurse_r_tree(routput.rx2['internal'])
timing = recurse_r_tree(routput.rx2['timing'])

return {
Expand All @@ -301,8 +291,8 @@ def explain(
"MSEv": MSEv,
"iterative_results": iterative_results,
"saving_path": saving_path,
"internal": rinternal,
"timing": timing
"internal": internal,
"timing": timing,
}


Expand Down

0 comments on commit e64a184

Please sign in to comment.