Skip to content

Commit

Permalink
average build time plot
Browse files Browse the repository at this point in the history
  • Loading branch information
divyegala committed Dec 14, 2023
1 parent 39d2970 commit 0e9f1cc
Showing 1 changed file with 29 additions and 34 deletions.
63 changes: 29 additions & 34 deletions python/raft-ann-bench/src/raft-ann-bench/plot/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,25 +228,15 @@ def inv_fun(x):
def create_plot_build(
build_results, search_results, linestyles, fn_out, dataset
):
qps_80 = [-1] * len(linestyles)
bt_80 = [0] * len(linestyles)
i_80 = [-1] * len(linestyles)

qps_85 = [-1] * len(linestyles)
bt_85 = [0] * len(linestyles)
i_85 = [-1] * len(linestyles)

qps_90 = [-1] * len(linestyles)
bt_90 = [0] * len(linestyles)
i_90 = [-1] * len(linestyles)

qps_95 = [-1] * len(linestyles)
bt_95 = [0] * len(linestyles)
i_95 = [-1] * len(linestyles)

qps_99 = [-1] * len(linestyles)
bt_99 = [0] * len(linestyles)
i_99 = [-1] * len(linestyles)

data = OrderedDict()
colors = OrderedDict()
Expand All @@ -259,32 +249,38 @@ def mean_y(algo):

for pos, algo in enumerate(sorted(search_results.keys(), key=mean_y)):
points = np.array(search_results[algo], dtype=object)
# x is recall, ls is algo_name, idxs is index_name
xs = points[:, 2]
ys = points[:, 3]
ls = points[:, 0]
idxs = points[:, 1]
# x is recall, y is qps, ls is algo_name, idxs is index_name

len_80, len_85, len_90, len_95, len_99 = 0, 0, 0, 0, 0
for i in range(len(xs)):
if xs[i] >= 0.80 and xs[i] < 0.85 and ys[i] > qps_80[pos]:
qps_80[pos] = ys[i]
bt_80[pos] = build_results[(ls[i], idxs[i])][0][2]
i_80[pos] = idxs[i]
elif xs[i] >= 0.85 and xs[i] < 0.9 and ys[i] > qps_85[pos]:
qps_85[pos] = ys[i]
bt_85[pos] = build_results[(ls[i], idxs[i])][0][2]
i_85[pos] = idxs[i]
elif xs[i] >= 0.9 and xs[i] < 0.95 and ys[i] > qps_90[pos]:
qps_90[pos] = ys[i]
bt_90[pos] = build_results[(ls[i], idxs[i])][0][2]
i_90[pos] = idxs[i]
elif xs[i] >= 0.95 and xs[i] < 0.99 and ys[i] > qps_95[pos]:
qps_95[pos] = ys[i]
bt_95[pos] = build_results[(ls[i], idxs[i])][0][2]
i_95[pos] = idxs[i]
elif xs[i] >= 0.99 and ys[i] > qps_99[pos]:
qps_99[pos] = ys[i]
bt_99[pos] = build_results[(ls[i], idxs[i])][0][2]
i_99[pos] = idxs[i]
if xs[i] >= 0.80 and xs[i] < 0.85:
bt_80[pos] = bt_80[pos] + build_results[(ls[i], idxs[i])][0][2]
len_80 = len_80 + 1
elif xs[i] >= 0.85 and xs[i] < 0.9:
bt_85[pos] = bt_85[pos] + build_results[(ls[i], idxs[i])][0][2]
len_85 = len_85 + 1
elif xs[i] >= 0.9 and xs[i] < 0.95:
bt_90[pos] = bt_90[pos] + build_results[(ls[i], idxs[i])][0][2]
len_90 = len_90 + 1
elif xs[i] >= 0.95 and xs[i] < 0.99:
bt_95[pos] = bt_95[pos] + build_results[(ls[i], idxs[i])][0][2]
len_95 = len_95 + 1
elif xs[i] >= 0.99:
bt_99[pos] = bt_99[pos] + build_results[(ls[i], idxs[i])][0][2]
len_99 = len_99 + 1
if len_80 > 0:
bt_80[pos] = bt_80[pos] / len_80
if len_85 > 0:
bt_85[pos] = bt_85[pos] / len_85
if len_90 > 0:
bt_90[pos] = bt_90[pos] / len_90
if len_95 > 0:
bt_95[pos] = bt_95[pos] / len_95
if len_99 > 0:
bt_99[pos] = bt_99[pos] / len_99
data[algo] = [
bt_80[pos],
bt_85[pos],
Expand All @@ -309,7 +305,7 @@ def mean_y(algo):
ax = df.plot.bar(rot=0, color=colors)
fig = ax.get_figure()
print(f"writing build output to {fn_out}")
plt.title("Build Time for Highest QPS")
plt.title("Average Build Time for Recall Range")
plt.suptitle(f"{dataset}")
plt.ylabel("Build Time (s)")
fig.savefig(fn_out)
Expand Down Expand Up @@ -403,7 +399,6 @@ def load_all_results(
filter_k_bs = []
for result_filename in result_files:
filename_split = result_filename.split(",")
print(filename_split)
if (
int(filename_split[-3][1:]) == k
and int(filename_split[-2][2:]) == batch_size
Expand Down

0 comments on commit 0e9f1cc

Please sign in to comment.