Skip to content

Commit 6fd5d59

Browse files
committed
ROCm WIP
1 parent 62a015d commit 6fd5d59

File tree

6 files changed

+297
-9
lines changed

6 files changed

+297
-9
lines changed

Cargo.lock

Lines changed: 1 addition & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

shared/subspace-proof-of-space-gpu/Cargo.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ include = [
1616
blst = { version = "0.3.13", optional = true }
1717
rust-kzg-blst = { git = "https://github.com/grandinetech/rust-kzg", rev = "6c8fcc623df3d7e8c0f30951a49bfea764f90bf4", default-features = false, optional = true }
1818
# TODO: This is `rocm` branch, it is needed for ROCm support
19-
#sppark = { git = "https://github.com/dot-asm/sppark", rev = "8eeafe0f6cc0ca8211b1be93922df1b5a118bbd2", optional = true }
20-
sppark = { version = "0.1.8", optional = true }
19+
sppark = { git = "https://github.com/dot-asm/sppark", rev = "fe1237fe9eabb8aeb48a21af4d439fb4ac4f5d5d", optional = true }
20+
#sppark = { version = "0.1.8", optional = true }
2121
subspace-core-primitives = { version = "0.1.0", path = "../../crates/subspace-core-primitives", default-features = false, optional = true }
2222

2323
[dev-dependencies]
@@ -31,7 +31,7 @@ cc = "1.1.15"
3131
[features]
3232
# Only Volta+ architectures are supported (GeForce RTX 20xx consumer GPUs and newer)
3333
cuda = ["_gpu"]
34-
# TODO: ROCm can't be enabled at the same time as `cuda` feature at the moment and is not exposed on library level
34+
# TODO: ROCm can't be enabled at the same time as `cuda` feature at the moment
3535
rocm = ["_gpu"]
3636
# Internal feature, shouldn't be used directly
3737
_gpu = [

shared/subspace-proof-of-space-gpu/build.rs

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,12 @@ fn main() {
2121
hipcc.compiler(env::var("HIPCC").unwrap_or("hipcc".to_string()));
2222
hipcc.cpp(true);
2323
if cfg!(debug_assertions) {
24-
hipcc.opt_level(1);
24+
hipcc.opt_level(2);
2525
}
26-
hipcc.flag("--offload-arch=native,gfx1100,gfx1030,gfx942,gfx90a,gfx908");
27-
// 6 corresponds to the number of offload-arch
28-
hipcc.flag("-parallel-jobs=6");
26+
// hipcc.flag("--offload-arch=gfx1100,gfx1101,gfx1102,gfx1103,gfx1030,gfx942,gfx90a,gfx908");
27+
hipcc.flag("--offload-arch=gfx1102");
28+
hipcc.flag_if_supported("-parallel-jobs=16");
29+
// hipcc.flag("--offload-device-only");
2930
// This controls how error strings get handled in the FFI. When defined error strings get
3031
// returned from the FFI, and Rust must then free them. When not defined error strings are
3132
// not returned.
@@ -35,6 +36,9 @@ fn main() {
3536
hipcc.flag("-include").flag("util/cuda2hip.hpp");
3637
}
3738
hipcc.file("src/subspace_api.cu").compile("subspace_rocm");
39+
40+
// Doesn't link otherwise
41+
println!("cargo::rustc-link-lib=amdhip64");
3842
}
3943

4044
if cfg!(feature = "cuda") {
@@ -57,5 +61,6 @@ fn main() {
5761
nvcc.file("src/subspace_api.cu").compile("subspace_cuda");
5862
}
5963

64+
println!("cargo::rerun-if-changed=src");
6065
println!("cargo::rerun-if-env-changed=CXXFLAGS");
6166
}
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,4 @@
11
#[cfg(feature = "cuda")]
22
pub mod cuda;
3+
#[cfg(feature = "rocm")]
4+
pub mod rocm;
Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
// Copyright Supranational LLC
2+
// Licensed under the Apache License, Version 2.0, see LICENSE for details.
3+
// SPDX-License-Identifier: Apache-2.0
4+
5+
#[cfg(test)]
6+
mod tests;
7+
8+
use rust_kzg_blst::types::fr::FsFr;
9+
use std::ops::DerefMut;
10+
use subspace_core_primitives::crypto::Scalar;
11+
use subspace_core_primitives::{PosProof, PosSeed, Record};
12+
13+
extern "C" {
14+
/// # Returns
15+
/// * `usize` - The number of available GPUs.
16+
fn gpu_count() -> usize;
17+
18+
/// # Parameters
19+
/// * `k: The size parameter for the table.
20+
/// * `seed: A pointer to the seed data.
21+
/// * `lg_record_size: The logarithm of the record size.
22+
/// * `challenge_index: A mutable pointer to store the index of the challenge.
23+
/// * `record: A pointer to the record data.
24+
/// * `chunks_scratch: A mutable pointer to a scratch space for chunk data.
25+
/// * `proof_count: A mutable pointer to store the count of proofs.
26+
/// * `source_record_chunks: A mutable pointer to the source record chunks.
27+
/// * `parity_record_chunks: A mutable pointer to the parity record chunks.
28+
/// * `gpu_id: The ID of the GPU to use.
29+
///
30+
/// # Returns
31+
/// * `sppark::Error` - An error code indicating the result of the operation.
32+
///
33+
/// # Assumptions
34+
/// * `seed` must be a valid pointer to a 32-byte.
35+
/// * `record` must be a valid pointer to the record data (`*const Record`), with a length of `1 << lg_record_size`.
36+
/// * `source_record_chunks` and `parity_record_chunks` must be valid mutable pointers to `Scalar` elements, each with a length of `1 << lg_record_size`.
37+
/// * `chunks_scratch` must be a valid mutable pointer where up to `challenges_count` 32-byte chunks of GPU-calculated data will be written.
38+
/// * `gpu_id` must be a valid identifier of an available GPU. The available GPUs can be determined by using the `gpu_count` function.
39+
fn generate_and_encode_pospace_dispatch(
40+
k: u32,
41+
seed: *const [u8; 32],
42+
lg_record_size: u32,
43+
challenge_index: *mut u32,
44+
record: *const [u8; 32],
45+
chunks_scratch: *mut [u8; 32],
46+
proof_count: *mut u32,
47+
parity_record_chunks: *mut FsFr,
48+
gpu_id: i32,
49+
) -> sppark::Error;
50+
}
51+
52+
/// Returns [`RocmDevice`] for each available device
53+
pub fn rocm_devices() -> Vec<RocmDevice> {
54+
let num_devices = unsafe { gpu_count() };
55+
56+
(0i32..)
57+
.take(num_devices)
58+
.map(|gpu_id| RocmDevice { gpu_id })
59+
.collect()
60+
}
61+
62+
/// Wrapper data structure encapsulating a single ROCm-capable device
63+
#[derive(Debug)]
64+
pub struct RocmDevice {
65+
gpu_id: i32,
66+
}
67+
68+
impl RocmDevice {
69+
/// ROCm device ID
70+
pub fn id(&self) -> i32 {
71+
self.gpu_id
72+
}
73+
74+
/// Generates and encodes PoSpace on the GPU.
75+
///
76+
/// This function performs the generation and encoding of PoSpace
77+
/// on a GPU. It uses the specified parameters to perform the computations and
78+
/// ensures that errors are properly handled by returning a `Result` type.
79+
///
80+
/// # Parameters
81+
///
82+
/// ## Input
83+
///
84+
/// - `k`: The size parameter for the table.
85+
/// - `seed`: A 32-byte seed used for the table generation process.
86+
/// - `record`: A slice of bytes (`&[u8]`). These records are the data on which the proof of space will be generated.
87+
/// - `gpu_id`: ID of the GPU to use. This parameter specifies which GPU to use for the computation.
88+
///
89+
/// ## Output
90+
///
91+
/// - `source_record_chunks`: A mutable vector of original data chunks of type FsFr, each 32 bytes in size.
92+
/// - `parity_record_chunks`: A mutable vector of parity chunks derived from the source, each 32 bytes in size.
93+
/// - `proof_count`: A mutable reference to the proof count. This value will be updated with the number of proofs generated.
94+
/// - `chunks_scratch`: A mutable vector used to store the processed chunks. This vector holds the final results after combining record chunks and proof hashes.
95+
/// - `challenge_index`: A mutable vector used to map the challenges to specific parts of the data.
96+
pub fn generate_and_encode_pospace(
97+
&self,
98+
seed: &PosSeed,
99+
record: &mut Record,
100+
encoded_chunks_used_output: impl ExactSizeIterator<Item = impl DerefMut<Target = bool>>,
101+
) -> Result<(), String> {
102+
let record_len = Record::NUM_CHUNKS;
103+
let challenge_len = Record::NUM_S_BUCKETS;
104+
let lg_record_size = record_len.ilog2();
105+
106+
if challenge_len > u32::MAX as usize {
107+
return Err(String::from("challenge_len is too large to fit in u32"));
108+
}
109+
110+
let mut proof_count = 0u32;
111+
let mut chunks_scratch_gpu = Vec::<[u8; Scalar::FULL_BYTES]>::with_capacity(challenge_len);
112+
let mut challenge_index_gpu = Vec::<u32>::with_capacity(challenge_len);
113+
let mut parity_record_chunks = Vec::<Scalar>::with_capacity(Record::NUM_CHUNKS);
114+
115+
let error = unsafe {
116+
generate_and_encode_pospace_dispatch(
117+
u32::from(PosProof::K),
118+
&**seed,
119+
lg_record_size,
120+
challenge_index_gpu.as_mut_ptr(),
121+
record.as_ptr(),
122+
chunks_scratch_gpu.as_mut_ptr(),
123+
&mut proof_count,
124+
Scalar::slice_mut_to_repr(&mut parity_record_chunks).as_mut_ptr(),
125+
self.gpu_id,
126+
)
127+
};
128+
129+
if error.code != 0 {
130+
return Err(error.to_string());
131+
}
132+
133+
let proof_count = proof_count as usize;
134+
unsafe {
135+
chunks_scratch_gpu.set_len(proof_count);
136+
challenge_index_gpu.set_len(proof_count);
137+
parity_record_chunks.set_len(Record::NUM_CHUNKS);
138+
}
139+
140+
let mut encoded_chunks_used = vec![false; challenge_len];
141+
let source_record_chunks = record.to_vec();
142+
143+
let mut chunks_scratch = challenge_index_gpu
144+
.into_iter()
145+
.zip(chunks_scratch_gpu)
146+
.collect::<Vec<_>>();
147+
148+
chunks_scratch
149+
.sort_unstable_by(|(a_out_index, _), (b_out_index, _)| a_out_index.cmp(b_out_index));
150+
151+
// We don't need all the proofs
152+
chunks_scratch.truncate(proof_count.min(Record::NUM_CHUNKS));
153+
154+
for (out_index, _chunk) in &chunks_scratch {
155+
encoded_chunks_used[*out_index as usize] = true;
156+
}
157+
158+
encoded_chunks_used_output
159+
.zip(&encoded_chunks_used)
160+
.for_each(|(mut output, input)| *output = *input);
161+
162+
record
163+
.iter_mut()
164+
.zip(
165+
chunks_scratch
166+
.into_iter()
167+
.map(|(_out_index, chunk)| chunk)
168+
.chain(
169+
source_record_chunks
170+
.into_iter()
171+
.zip(parity_record_chunks)
172+
.flat_map(|(a, b)| [a, b.to_bytes()])
173+
.zip(encoded_chunks_used.iter())
174+
// Skip chunks that were used previously
175+
.filter_map(|(record_chunk, encoded_chunk_used)| {
176+
if *encoded_chunk_used {
177+
None
178+
} else {
179+
Some(record_chunk)
180+
}
181+
}),
182+
),
183+
)
184+
.for_each(|(output_chunk, input_chunk)| {
185+
*output_chunk = input_chunk;
186+
});
187+
188+
Ok(())
189+
}
190+
}
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
use crate::rocm::rocm_devices;
2+
use std::num::NonZeroUsize;
3+
use std::slice;
4+
use subspace_core_primitives::crypto::{blake3_254_hash_to_scalar, blake3_hash};
5+
use subspace_core_primitives::{HistorySize, PieceOffset, Record, SectorId};
6+
use subspace_erasure_coding::ErasureCoding;
7+
use subspace_farmer_components::plotting::{CpuRecordsEncoder, RecordsEncoder};
8+
use subspace_farmer_components::sector::SectorContentsMap;
9+
use subspace_proof_of_space::chia::ChiaTable;
10+
use subspace_proof_of_space::Table;
11+
12+
type PosTable = ChiaTable;
13+
14+
#[test]
15+
fn basic() {
16+
let rocm_device = rocm_devices()
17+
.into_iter()
18+
.next()
19+
.expect("Need ROCm device to run this test");
20+
21+
let mut table_generator = PosTable::generator();
22+
let erasure_coding = ErasureCoding::new(
23+
NonZeroUsize::new(Record::NUM_S_BUCKETS.next_power_of_two().ilog2() as usize)
24+
.expect("Not zero; qed"),
25+
)
26+
.unwrap();
27+
let global_mutex = Default::default();
28+
let mut cpu_records_encoder = CpuRecordsEncoder::<PosTable>::new(
29+
slice::from_mut(&mut table_generator),
30+
&erasure_coding,
31+
&global_mutex,
32+
);
33+
34+
let sector_id = SectorId::new(blake3_hash(b"hello"), 500);
35+
let history_size = HistorySize::ONE;
36+
let mut record = Record::new_boxed();
37+
record.iter_mut().enumerate().for_each(|(index, chunk)| {
38+
*chunk = blake3_254_hash_to_scalar(&index.to_le_bytes()).to_bytes()
39+
});
40+
41+
let mut cpu_encoded_records = Record::new_zero_vec(2);
42+
for cpu_encoded_record in &mut cpu_encoded_records {
43+
cpu_encoded_record.clone_from(&record);
44+
}
45+
let cpu_sector_contents_map = cpu_records_encoder
46+
.encode_records(
47+
&sector_id,
48+
&mut cpu_encoded_records,
49+
history_size,
50+
&Default::default(),
51+
)
52+
.unwrap();
53+
54+
println!("a");
55+
56+
let mut gpu_encoded_records = Record::new_zero_vec(2);
57+
for gpu_encoded_record in &mut gpu_encoded_records {
58+
gpu_encoded_record.clone_from(&record);
59+
}
60+
let mut gpu_sector_contents_map = SectorContentsMap::new(2);
61+
rocm_device
62+
.generate_and_encode_pospace(
63+
&sector_id.derive_evaluation_seed(PieceOffset::ZERO, history_size),
64+
&mut gpu_encoded_records[0],
65+
gpu_sector_contents_map
66+
.iter_record_bitfields_mut()
67+
.next()
68+
.unwrap()
69+
.iter_mut(),
70+
)
71+
.unwrap();
72+
println!("b");
73+
rocm_device
74+
.generate_and_encode_pospace(
75+
&sector_id.derive_evaluation_seed(PieceOffset::ONE, history_size),
76+
&mut gpu_encoded_records[1],
77+
gpu_sector_contents_map
78+
.iter_record_bitfields_mut()
79+
.nth(1)
80+
.unwrap()
81+
.iter_mut(),
82+
)
83+
.unwrap();
84+
println!("c");
85+
86+
assert_eq!(
87+
cpu_sector_contents_map.iter_record_bitfields()[0],
88+
gpu_sector_contents_map.iter_record_bitfields()[0]
89+
);
90+
assert!(cpu_sector_contents_map == gpu_sector_contents_map);
91+
assert!(cpu_encoded_records == gpu_encoded_records);
92+
}

0 commit comments

Comments
 (0)