Skip to content

Commit 866523c

Browse files
committed
Updates to examples and plotting script
1 parent 9e36d9c commit 866523c

File tree

3 files changed

+64
-22
lines changed

3 files changed

+64
-22
lines changed

examples/benchmark.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@ fn benchmark_fft(n: usize) {
1010
let mut imags = vec![0.0; big_n];
1111
gen_random_signal(&mut reals, &mut imags);
1212

13-
let now = std::time::Instant::now();
13+
// let now = std::time::Instant::now();
1414
fft_dif(&mut reals, &mut imags);
15-
let elapsed = now.elapsed().as_micros();
16-
println!("{elapsed}");
15+
// let elapsed = now.elapsed().as_micros();
16+
// println!("{elapsed}");
1717
}
1818

1919
fn main() {

examples/profile.rs

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,19 @@
1+
use std::env;
2+
use std::str::FromStr;
3+
14
use phastft::fft_dif;
25

3-
fn bm_fft(num_qubits: usize) {
6+
fn benchmark_fft(num_qubits: usize) {
47
let n = 1 << num_qubits;
58
let mut reals: Vec<f64> = (1..=n).map(f64::from).collect();
69
let mut imags: Vec<f64> = (1..=n).map(f64::from).collect();
7-
810
fft_dif(&mut reals, &mut imags);
911
}
1012

1113
fn main() {
12-
bm_fft(25);
14+
let args: Vec<String> = env::args().collect();
15+
assert_eq!(args.len(), 2, "Usage {} <n>", args[0]);
16+
17+
let n = usize::from_str(&args[1]).unwrap();
18+
benchmark_fft(n);
1319
}

scripts/benchmark_plots.py

Lines changed: 52 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,17 @@
99

1010
plt.style.use("fivethirtyeight")
1111

12+
13+
SYMBOLS = {
14+
'customary' : ('B', 'K', 'M', 'G', 'T', 'P', 'E', 'Z', 'Y'),
15+
'customary_ext' : ('byte', 'kilo', 'mega', 'giga', 'tera', 'peta', 'exa',
16+
'zetta', 'iotta'),
17+
'iec' : ('Bi', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi', 'Yi'),
18+
'iec_ext' : ('byte', 'kibi', 'mebi', 'gibi', 'tebi', 'pebi', 'exbi',
19+
'zebi', 'yobi'),
20+
}
21+
22+
1223
def read_file(filepath: str) -> list[int]:
1324
y = []
1425

@@ -21,12 +32,12 @@ def read_file(filepath: str) -> list[int]:
2132

2233

2334
def get_figure_of_interest(vals: list[int]) -> float:
24-
return np.mean(vals)
35+
return np.median(vals)
2536

2637

2738
def build_and_clean_data(root_benchmark_dir: str, *names) -> defaultdict[str, list]:
2839
libs = ("rustfft", "phastft")
29-
n_range = range(4, 27)
40+
n_range = range(12, 30)
3041

3142
data = defaultdict(list)
3243

@@ -39,31 +50,56 @@ def build_and_clean_data(root_benchmark_dir: str, *names) -> defaultdict[str, li
3950
return data
4051

4152
def plot_lines(data: defaultdict[str, list]) -> None:
42-
index = list(range(4, 27))
53+
index = [bytes2human(2**n * (128 / 8)) for n in range(12, 30)]
4354
plt.figure()
4455

45-
print(len(data["phastft"]))
56+
y0 = np.asarray(data["phastft"])
57+
y1 = np.asarray(data["rustfft"])
58+
y0 = y1/y0
4659

4760
df = pd.DataFrame(
4861
{
49-
"PhastFT": data["phastft"],
50-
"RustFFT": data["rustfft"],
62+
"PhastFT": y0,
63+
"RustFFT": np.ones(len(index)),
5164
},
5265
index=index,
5366
)
5467

5568
df.plot(kind='bar', linewidth=3, rot=0)
56-
plt.xticks(fontsize=8)
57-
plt.xlabel("size")
58-
plt.ylabel("time (us)")
59-
plt.yscale("log")
60-
plt.show()
61-
# plt.tight_layout(pad=0.0)
62-
# plt.savefig("benchmarks_bar_plot.png", dpi=600)
69+
70+
plt.xticks(fontsize=9, rotation=-45)
71+
plt.yticks(fontsize=9)
72+
plt.xlabel("size of input")
73+
plt.ylabel("speedup (relative to RustFFT)")
74+
plt.legend(loc='best')
75+
plt.tight_layout()
76+
plt.savefig("benchmarks_bar_plot.png", dpi=600)
77+
78+
79+
# Source: https://stackoverflow.com/a/1094933
80+
def bytes2human(n, format='%(value).1f %(symbol)s', symbols='customary'):
81+
"""
82+
Convert n bytes into a human-readable string based on format.
83+
symbols can be either "customary", "customary_ext", "iec" or "iec_ext",
84+
see: https://goo.gl/kTQMs
85+
"""
86+
n = int(n)
87+
if n < 0:
88+
raise ValueError("n < 0")
89+
symbols = SYMBOLS[symbols]
90+
prefix = {}
91+
for i, s in enumerate(symbols[1:]):
92+
prefix[s] = 1 << (i+1)*10
93+
for symbol in reversed(symbols[1:]):
94+
if n >= prefix[symbol]:
95+
value = float(n) / prefix[symbol]
96+
return format % locals()
97+
return format % dict(symbol=symbols[0], value=n)
98+
6399

64100

65101
if __name__ == "__main__":
66102
# y = read_file("benchmark-data.2024.02.02.11-02-07/phastft/size_16")
67-
data = build_and_clean_data("benchmark-data.2024.02.02.11-27-33")
68-
print(data)
69-
# plot_lines(data)
103+
data = build_and_clean_data("benchmark-data.2024.02.02.11-43-10")
104+
# print(data)
105+
plot_lines(data)

0 commit comments

Comments
 (0)