diff --git a/Cargo.toml b/Cargo.toml index ec839acac..ef74bacad 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -127,8 +127,8 @@ lto = "thin" #ceno_crypto_primitives = { path = "../ceno-patch/crypto-primitives", package = "ceno_crypto_primitives" } #ceno_syscall = { path = "../ceno-patch/syscall", package = "ceno_syscall" } -# [patch."https://github.com/scroll-tech/ceno-gpu-mock.git"] -# ceno_gpu = { path = "../ceno-gpu/cuda_hal", package = "cuda_hal", default-features = false, features=["bb31"] } +#[patch."https://github.com/scroll-tech/ceno-gpu-mock.git"] +#ceno_gpu = { path = "../ceno-gpu/cuda_hal", package = "cuda_hal", default-features = false, features = ["bb31"] } #[patch."https://github.com/scroll-tech/gkr-backend"] #ff_ext = { path = "../gkr-backend/crates/ff_ext", package = "ff_ext" } diff --git a/ceno_zkvm/src/e2e.rs b/ceno_zkvm/src/e2e.rs index 8fb599f27..1625fb106 100644 --- a/ceno_zkvm/src/e2e.rs +++ b/ceno_zkvm/src/e2e.rs @@ -1865,7 +1865,7 @@ fn create_proofs_streaming< { let mut proofs = Vec::new(); let mut proof_err = None; - let mut rx = rx; + let rx = rx; while let Ok((zkvm_witness, shard_ctx, pi)) = rx.recv() { if is_mock_proving { MockProver::assert_satisfied_full( diff --git a/ceno_zkvm/src/scheme/gpu/mod.rs b/ceno_zkvm/src/scheme/gpu/mod.rs index 40bc1976c..a9c8b6f4f 100644 --- a/ceno_zkvm/src/scheme/gpu/mod.rs +++ b/ceno_zkvm/src/scheme/gpu/mod.rs @@ -10,19 +10,37 @@ use crate::{ }, structs::{ComposedConstrainSystem, PointAndEval, TowerProofs}, }; +use ceno_gpu::bb31::{CudaHalBB31, GpuPolynomial}; +use either::Either; use ff_ext::ExtensionField; use gkr_iop::{ - gkr::{self, Evaluation, GKRProof, GKRProverOutput, layer::LayerWitness}, - gpu::{GpuBackend, GpuProver}, + gkr::{ + self, Evaluation, GKRProof, GKRProverOutput, + layer::{LayerWitness, gpu::utils::extract_mle_relationships_from_monomial_terms}, + }, + gpu::{GpuBackend, GpuProver, gpu_prover::BB31Ext}, hal::ProverBackend, }; use itertools::{Itertools, chain}; use mpcs::{Point, PolynomialCommitmentScheme}; -use multilinear_extensions::{mle::MultilinearExtension, util::ceil_log2}; -use std::{collections::BTreeMap, sync::Arc}; +use multilinear_extensions::{ + Expression, ToExpr, + mle::{FieldType, IntoMLE, MultilinearExtension}, + util::ceil_log2, + virtual_poly::{build_eq_x_r_vec, eq_eval}, +}; +use rayon::iter::{ + IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator, + IntoParallelRefMutIterator, ParallelIterator, +}; +use std::{ + collections::BTreeMap, + iter::{once, repeat_n}, + sync::Arc, +}; use sumcheck::{ macros::{entered_span, exit_span}, - structs::IOPProverMessage, + structs::{IOPProof, IOPProverMessage}, util::optimal_sumcheck_threads, }; use transcript::{BasicTranscript, Transcript}; @@ -33,15 +51,23 @@ use tracing::info_span; #[cfg(feature = "gpu")] use gkr_iop::gpu::gpu_prover::*; +mod util; +use util::{ + WitnessRegistry, batch_mles_take_half, expect_basic_transcript, hal_to_backend_error, + mle_filter_even_odd_batch, mle_host_to_gpu, read_septic_value_from_gpu, symbolic_from_mle, +}; pub struct GpuTowerProver; use crate::{ - scheme::{constants::NUM_FANIN, cpu::CpuEccProver}, + scheme::{ + constants::{NUM_FANIN, SEPTIC_EXTENSION_DEGREE}, + septic_curve::SepticPoint, + }, structs::EccQuarkProof, }; use gkr_iop::{ - gpu::{ArcMultilinearExtensionGpu, MultilinearExtensionGpu}, - selector::SelectorContext, + gpu::{ArcMultilinearExtensionGpu, BB31Base, MultilinearExtensionGpu}, + selector::{SelectorContext, SelectorType}, }; // Extract out_evals from GPU-built tower witnesses @@ -106,7 +132,7 @@ impl> TraceCommitter() != std::any::TypeId::of::() { - panic!("GPU backend only supports Goldilocks base field"); + panic!("GPU backend only supports BabyBear base field"); } let span = entered_span!("[gpu] init pp", profiling_2 = true); @@ -217,6 +243,7 @@ impl> TraceCommitter( composed_cs: &ComposedConstrainSystem, input: &ProofInput<'_, GpuBackend>>, @@ -320,46 +347,43 @@ fn build_tower_witness_gpu<'buf, E: ExtensionField>( last_layer }) .collect::>() + } else if lk_denominator_last_layer.is_empty() { + vec![] } else { - if lk_denominator_last_layer.is_empty() { - vec![] - } else { - // Case when numerator is empty - create shared ones_buffer and use views - // This saves memory by having all p1, p2 polynomials reference the same buffer - let nv = lk_denominator_last_layer[0][0].num_vars(); + // Case when numerator is empty - create shared ones_buffer and use views + // This saves memory by having all p1, p2 polynomials reference the same buffer + let nv = lk_denominator_last_layer[0][0].num_vars(); - // Create one shared ones_buffer as Owned (can be 'static) - let ones_poly = GpuPolynomialExt::new_with_scalar(&cuda_hal.inner, nv, BB31Ext::ONE) - .map_err(|e| format!("Failed to create shared ones_buffer: {:?}", e)) - .unwrap(); - // SAFETY: Owned buffer can be safely treated as 'static - let ones_poly_static: GpuPolynomialExt<'static> = - unsafe { std::mem::transmute(ones_poly) }; - ones_buffer.push(ones_poly_static); + // Create one shared ones_buffer as Owned (can be 'static) + let ones_poly = GpuPolynomialExt::new_with_scalar(&cuda_hal.inner, nv, BB31Ext::ONE) + .map_err(|e| format!("Failed to create shared ones_buffer: {:?}", e)) + .unwrap(); + // SAFETY: Owned buffer can be safely treated as 'static + let ones_poly_static: GpuPolynomialExt<'static> = unsafe { std::mem::transmute(ones_poly) }; + ones_buffer.push(ones_poly_static); - // Get reference from storage to ensure proper lifetime - let ones_poly_ref = ones_buffer.last().unwrap(); - let mle_len_bytes = ones_poly_ref.evaluations().len() * std::mem::size_of::(); + // Get reference from storage to ensure proper lifetime + let ones_poly_ref = ones_buffer.last().unwrap(); + let mle_len_bytes = ones_poly_ref.evaluations().len() * std::mem::size_of::(); - // Create views referencing the shared ones_buffer for each tower's p1, p2 - lk_denominator_last_layer - .into_iter() - .map(|lk_d_chunks| { - // Create views of ones_buffer for p1 and p2 - let p1_view = ones_poly_ref.evaluations().as_slice_range(0..mle_len_bytes); - let p2_view = ones_poly_ref.evaluations().as_slice_range(0..mle_len_bytes); - let p1_gpu = GpuPolynomialExt::new(BufferImpl::new_from_view(p1_view), nv); - let p2_gpu = GpuPolynomialExt::new(BufferImpl::new_from_view(p2_view), nv); - // SAFETY: views from 'static buffer can be 'static - let p1_gpu: GpuPolynomialExt<'static> = unsafe { std::mem::transmute(p1_gpu) }; - let p2_gpu: GpuPolynomialExt<'static> = unsafe { std::mem::transmute(p2_gpu) }; - // Use [p1, p2, q1, q2] format for the last layer - let mut last_layer = vec![p1_gpu, p2_gpu]; - last_layer.extend(lk_d_chunks); - last_layer - }) - .collect::>() - } + // Create views referencing the shared ones_buffer for each tower's p1, p2 + lk_denominator_last_layer + .into_iter() + .map(|lk_d_chunks| { + // Create views of ones_buffer for p1 and p2 + let p1_view = ones_poly_ref.evaluations().as_slice_range(0..mle_len_bytes); + let p2_view = ones_poly_ref.evaluations().as_slice_range(0..mle_len_bytes); + let p1_gpu = GpuPolynomialExt::new(BufferImpl::new_from_view(p1_view), nv); + let p2_gpu = GpuPolynomialExt::new(BufferImpl::new_from_view(p2_view), nv); + // SAFETY: views from 'static buffer can be 'static + let p1_gpu: GpuPolynomialExt<'static> = unsafe { std::mem::transmute(p1_gpu) }; + let p2_gpu: GpuPolynomialExt<'static> = unsafe { std::mem::transmute(p2_gpu) }; + // Use [p1, p2, q1, q2] format for the last layer + let mut last_layer = vec![p1_gpu, p2_gpu]; + last_layer.extend(lk_d_chunks); + last_layer + }) + .collect::>() }; if !logup_last_layers.is_empty() { let first_layer = &logup_last_layers[0]; @@ -518,7 +542,7 @@ impl> TowerProver() != std::any::TypeId::of::() { - panic!("GPU backend only supports Goldilocks base field"); + panic!("GPU backend only supports BabyBear base field"); } // Calculate r_set_len directly from constraint system @@ -731,20 +755,217 @@ impl> EccQuarkProver>>, transcript: &mut impl Transcript, ) -> Result, ZKVMError> { - // TODO implement GPU version of `create_ecc_proof` - let xs = xs.iter().map(|mle| mle.inner_to_mle().into()).collect_vec(); - let ys = ys.iter().map(|mle| mle.inner_to_mle().into()).collect_vec(); - let invs = invs + assert_eq!(xs.len(), SEPTIC_EXTENSION_DEGREE); + assert_eq!(ys.len(), SEPTIC_EXTENSION_DEGREE); + + let n = xs[0].mle.num_vars() - 1; + tracing::debug!( + "Creating EC Summation Quark proof with {} points in {n} variables", + num_instances + ); + + let out_rt = transcript.sample_and_append_vec(b"ecc", n); + + // expression with add (3 zero constraints), bypass (2 zero constraints), export (2 zero constraints) + let alpha_pows = transcript.sample_and_append_challenge_pows( + SEPTIC_EXTENSION_DEGREE * 3 + SEPTIC_EXTENSION_DEGREE * 2 + SEPTIC_EXTENSION_DEGREE * 2, + b"ecc_alpha", + ); + let mut alpha_pows_iter = alpha_pows.iter(); + + let sel_add = SelectorType::QuarkBinaryTreeLessThan(0.into()); + let sel_add_ctx = SelectorContext { + offset: 0, + num_instances, + num_vars: n, + }; + let sel_add_mle: MultilinearExtension<'_, E> = + sel_add.compute(&out_rt, &sel_add_ctx).unwrap(); + + // the final sum is located at [1,...,1,0] (in big-endian) + let last_evaluation_index = (1 << n) - 2; + let lsi_on_hypercube = once(E::ZERO).chain(repeat_n(E::ONE, n - 1)).collect_vec(); + let mut sel_export = (0..(1 << n)) + .into_par_iter() + .map(|_| E::ZERO) + .collect::>(); + sel_export[last_evaluation_index] = eq_eval(&out_rt, lsi_on_hypercube.as_slice()); + let sel_export_mle = sel_export.into_mle(); + + // we construct sel_bypass witness here + // verifier can derive it via `sel_bypass = eq - sel_add - sel_last_onehot` + let mut sel_bypass_mle: Vec = build_eq_x_r_vec(&out_rt); + match sel_add_mle.evaluations() { + FieldType::Ext(sel_add_mle) => sel_add_mle + .par_iter() + .zip(sel_bypass_mle.par_iter_mut()) + .for_each(|(sel_add, sel_bypass)| { + if *sel_add != E::ZERO { + *sel_bypass = E::ZERO; + } + }), + _ => unreachable!(), + } + *sel_bypass_mle.last_mut().unwrap() = E::ZERO; + let sel_bypass_mle = sel_bypass_mle.into_mle(); + + let cuda_hal = get_cuda_hal().map_err(hal_to_backend_error)?; + let sel_add_gpu = mle_host_to_gpu(&cuda_hal, &sel_add_mle); + let sel_bypass_gpu = mle_host_to_gpu(&cuda_hal, &sel_bypass_mle); + let sel_export_gpu = mle_host_to_gpu(&cuda_hal, &sel_export_mle); + let split_batches = mle_filter_even_odd_batch::( + &cuda_hal, + &[(&xs, false), (&xs, true), (&ys, false), (&ys, true)], + )?; + let mut split_iter = split_batches.into_iter(); + let x0_gpu = split_iter.next().unwrap_or_default(); + let x1_gpu = split_iter.next().unwrap_or_default(); + let y0_gpu = split_iter.next().unwrap_or_default(); + let y1_gpu = split_iter.next().unwrap_or_default(); + + // build x[1,b], y[1,b], s[1,b] + let x3_gpu = batch_mles_take_half::(&xs, 1)?; + let y3_gpu = batch_mles_take_half::(&ys, 1)?; + let s_gpu = batch_mles_take_half::(&invs, 1)?; + + let mut registry: WitnessRegistry<'a, E> = WitnessRegistry::default(); + let sel_add_expr = registry.register(sel_add_gpu); + let sel_bypass_expr = registry.register(sel_bypass_gpu); + let sel_export_expr = registry.register(sel_export_gpu); + + let s = symbolic_from_mle(&mut registry, &s_gpu); + let x0 = symbolic_from_mle(&mut registry, &x0_gpu); + let y0 = symbolic_from_mle(&mut registry, &y0_gpu); + let x1 = symbolic_from_mle(&mut registry, &x1_gpu); + let y1 = symbolic_from_mle(&mut registry, &y1_gpu); + let x3 = symbolic_from_mle(&mut registry, &x3_gpu); + let y3 = symbolic_from_mle(&mut registry, &y3_gpu); + + let mut exprs_add = vec![]; + let mut exprs_bypass = vec![]; + // affine addition + // zerocheck: 0 = s[1,b] * (x[b,0] - x[b,1]) - (y[b,0] - y[b,1]) with b != (1,...,1) + exprs_add.extend( + (s.clone() * (&x0 - &x1) - (&y0 - &y1)) + .to_exprs() + .into_iter() + .zip_eq(alpha_pows_iter.by_ref().take(SEPTIC_EXTENSION_DEGREE)) + .map(|(e, alpha)| e * Expression::Constant(Either::Right(*alpha))), + ); + + // zerocheck: 0 = s[1,b]^2 - x[b,0] - x[b,1] - x[1,b] with b != (1,...,1) + exprs_add.extend( + ((&s * &s) - &x0 - &x1 - &x3) + .to_exprs() + .into_iter() + .zip_eq(alpha_pows_iter.by_ref().take(SEPTIC_EXTENSION_DEGREE)) + .map(|(e, alpha)| e * Expression::Constant(Either::Right(*alpha))), + ); + // zerocheck: 0 = s[1,b] * (x[b,0] - x[1,b]) - (y[b,0] + y[1,b]) with b != (1,...,1) + exprs_add.extend( + (s.clone() * (&x0 - &x3) - (&y0 + &y3)) + .to_exprs() + .into_iter() + .zip_eq(alpha_pows_iter.by_ref().take(SEPTIC_EXTENSION_DEGREE)) + .map(|(e, alpha)| e * Expression::Constant(Either::Right(*alpha))), + ); + + let exprs_add = exprs_add.into_iter().sum::>() * sel_add_expr; + + // deal with bypass + // 0 = (x[1,b] - x[b,0]) + exprs_bypass.extend( + (&x3 - &x0) + .to_exprs() + .into_iter() + .zip_eq(alpha_pows_iter.by_ref().take(SEPTIC_EXTENSION_DEGREE)) + .map(|(e, alpha)| e * Expression::Constant(Either::Right(*alpha))), + ); + // 0 = (y[1,b] - y[b,0]) + exprs_bypass.extend( + (&y3 - &y0) + .to_exprs() + .into_iter() + .zip_eq(alpha_pows_iter.by_ref().take(SEPTIC_EXTENSION_DEGREE)) + .map(|(e, alpha)| e * Expression::Constant(Either::Right(*alpha))), + ); + + // export x[1,...,1,0], y[1,...,1,0] for final result (using big-endian notation) + let xp_gpu = batch_mles_take_half::(&xs, 1)?; + let yp_gpu = batch_mles_take_half::(&ys, 1)?; + let final_sum_x = read_septic_value_from_gpu(&xp_gpu, last_evaluation_index)?; + let final_sum_y = read_septic_value_from_gpu(&yp_gpu, last_evaluation_index)?; + // 0 = sel_export * (x[1,b] - final_sum.x) + // 0 = sel_export * (y[1,b] - final_sum.y) + let export_expr = + x3.0.iter() + .zip_eq(final_sum_x.0.iter()) + .chain(y3.0.iter().zip_eq(final_sum_y.0.iter())) + .map(|(x, final_x)| x - final_x.expr()) + .zip_eq(alpha_pows_iter.by_ref().take(SEPTIC_EXTENSION_DEGREE * 2)) + .map(|(e, alpha)| e * Expression::Constant(Either::Right(*alpha))) + .sum::>() + * sel_export_expr; + + let exprs_bypass = exprs_bypass.into_iter().sum::>() * sel_bypass_expr; + + let zero_expr = exprs_add + exprs_bypass + export_expr; + + let monomial_terms = zero_expr.get_monomial_terms(); + let gpu_refs = registry.gpu_refs(); + let (term_coefficients, mle_indices_per_term, mle_size_info) = + extract_mle_relationships_from_monomial_terms(&monomial_terms, &gpu_refs, &[], &[]); + let max_degree = mle_indices_per_term .iter() - .map(|mle| mle.inner_to_mle().into()) - .collect_vec(); - Ok(CpuEccProver::create_ecc_proof( + .map(|indices| indices.len()) + .max() + .unwrap_or(0); + let term_coefficients_gl64: Vec = + unsafe { std::mem::transmute(term_coefficients.clone()) }; + let gpu_refs_gl64: Vec<&MultilinearExtensionGpu> = + unsafe { std::mem::transmute(gpu_refs) }; + let gpu_field_refs = gpu_refs_gl64.iter().map(|mle| &mle.mle).collect_vec(); + + let basic_transcript = expect_basic_transcript(transcript); + let (proof_gpu, evals_gpu, challenges_gpu) = cuda_hal + .prove_generic_sumcheck_gpu( + gpu_field_refs, + &mle_size_info, + &term_coefficients_gl64, + &mle_indices_per_term, + n, + max_degree, + basic_transcript, + ) + .map_err(|e| hal_to_backend_error(format!("GPU sumcheck failed: {e:?}")))?; + + drop(cuda_hal); + + let proof_gpu_e: IOPProof = unsafe { std::mem::transmute(proof_gpu) }; + let evals_gpu_e: Vec> = unsafe { std::mem::transmute(evals_gpu) }; + let mut evals = Vec::new(); + for chunk in evals_gpu_e { + evals.extend(chunk); + } + let rt: Point = unsafe { + std::mem::transmute::, Vec>( + challenges_gpu.iter().map(|c| c.elements).collect(), + ) + }; + + // 3 for sel_add, sel_bypass, sel_export + // 7 for x[rt,0], x[rt,1], y[rt,0], y[rt,1], x[1,rt], y[1,rt], s[1,rt] + assert_eq!(evals.len(), 3 + SEPTIC_EXTENSION_DEGREE * 7); + let final_sum = SepticPoint::from_affine(final_sum_x.clone(), final_sum_y.clone()); + assert_eq!(proof_gpu_e.extract_sum(), E::ZERO); + + Ok(EccQuarkProof { + zerocheck_proof: proof_gpu_e, num_instances, - xs, - ys, - invs, - transcript, - )) + evals, + rt, + sum: final_sum, + }) } } @@ -760,7 +981,7 @@ impl> OpeningProver + 'static), ) -> PCS::Proof { if std::any::TypeId::of::() != std::any::TypeId::of::() { - panic!("GPU backend only supports Goldilocks base field"); + panic!("GPU backend only supports BabyBear base field"); } let mut rounds = vec![]; @@ -822,7 +1043,7 @@ impl> OpeningProver() == std::any::TypeId::of::() { + if std::any::TypeId::of::() == std::any::TypeId::of::() { let transcript_any = transcript as &mut dyn std::any::Any; let basic_transcript = transcript_any .downcast_mut::>() @@ -838,9 +1059,8 @@ impl> OpeningProver>( + transcript: &mut T, +) -> &mut BasicTranscript { + let actual = std::any::type_name::(); + let expected = std::any::type_name::>(); + assert_eq!( + actual, expected, + "GPU backend requires BasicTranscript; got {actual}" + ); + unsafe { &mut *(transcript as *mut T as *mut BasicTranscript) } +} + +pub fn read_septic_value_from_gpu<'a, E: ExtensionField>( + polys: &[Arc>], + index: usize, +) -> Result, ZKVMError> { + let coords = polys + .iter() + .map(|poly| read_base_value_from_gpu(poly, index)) + .collect::, _>>()?; + Ok(coords.into()) +} + +fn read_base_value_from_gpu<'a, E: ExtensionField>( + poly: &Arc>, + index: usize, +) -> Result { + match &poly.mle { + GpuFieldType::Base(base_poly) => { + let buffer = base_poly.evaluations(); + let raw = buffer + .get(index) + .map_err(|e| hal_to_backend_error(format!("failed to read GPU buffer: {e:?}")))?; + let canonical = raw.as_canonical_u32(); + Ok(E::BaseField::from_canonical_u32(canonical)) + } + GpuFieldType::Ext(_) => Err(hal_to_backend_error( + "expected base-field polynomial for final-sum extraction", + )), + GpuFieldType::Unreachable => { + Err(hal_to_backend_error("unreachable GPU polynomial variant")) + } + } +} + +pub fn batch_mles_take_half<'a, E: ExtensionField>( + polynomials: &[Arc>], + chunk_index: usize, +) -> Result>>, ZKVMError> { + if polynomials.is_empty() { + return Ok(Vec::new()); + } + + debug_assert!( + chunk_index < 2, + "only two chunks are supported when splitting in half" + ); + debug_assert_eq!( + TypeId::of::(), + TypeId::of::(), + "GPU backend only supports BabyBear base field" + ); + + polynomials + .iter() + .map(|poly| { + let gpu_poly = match &poly.mle { + GpuFieldType::Base(base) => base + .as_view_chunk(2) + .into_iter() + .nth(chunk_index) + .expect("chunk index must be valid"), + GpuFieldType::Ext(_) => { + return Err(hal_to_backend_error( + "expected base-field polynomial for EC witness splitting", + )); + } + GpuFieldType::Unreachable => { + return Err(hal_to_backend_error("unreachable GPU polynomial variant")); + } + }; + let gpu_mle = MultilinearExtensionGpu::from_ceno_gpu_base(gpu_poly); + Ok(Arc::new(gpu_mle)) + }) + .collect() +} + +pub fn symbolic_from_mle<'a, E: ExtensionField>( + registry: &mut WitnessRegistry<'a, E>, + polys: &[Arc>], +) -> SymbolicSepticExtension { + SymbolicSepticExtension::new( + polys + .iter() + .cloned() + .map(|poly| registry.register(poly)) + .collect(), + ) +} + +#[derive(Default)] +pub struct WitnessRegistry<'a, E: ExtensionField> { + gpu_mles: Vec>>, +} + +impl<'a, E: ExtensionField> WitnessRegistry<'a, E> { + pub fn register(&mut self, mle: Arc>) -> Expression { + let idx_u16 = u16::try_from(self.gpu_mles.len()) + .expect("witness identifier overflow in EC sum quark"); + self.gpu_mles.push(mle); + Expression::WitIn(idx_u16 as WitnessId) + } + + pub fn gpu_refs(&self) -> Vec<&MultilinearExtensionGpu<'a, E>> { + self.gpu_mles.iter().map(|arc| arc.as_ref()).collect() + } +} + +pub fn hal_to_backend_error(message: impl Into) -> ZKVMError { + ZKVMError::BackendError(BackendError::CircuitError(message.into().into_boxed_str())) +} + +pub fn mle_host_to_gpu<'a, E: ExtensionField>( + cuda_hal: &CudaHalBB31, + mle: &MultilinearExtension<'a, E>, +) -> Arc> { + if TypeId::of::() != TypeId::of::() { + panic!("GPU backend only supports BabyBear base field"); + } + let gpu = MultilinearExtensionGpu::from_ceno(cuda_hal, mle); + Arc::new(unsafe { + std::mem::transmute::, MultilinearExtensionGpu<'static, E>>( + gpu, + ) + }) +} + +pub fn mle_filter_even_odd_batch<'a, E: ExtensionField>( + cuda_hal: &CudaHalBB31, + requests: &[(&[Arc>], bool)], +) -> Result>>>, ZKVMError> { + if requests.iter().all(|(polys, _)| polys.is_empty()) { + return Ok(vec![Vec::new(); requests.len()]); + } + + debug_assert_eq!( + TypeId::of::(), + TypeId::of::(), + "GPU backend only supports Babybear base field" + ); + + let mut flattened_refs = Vec::new(); + let mut flags = Vec::new(); + let mut result_num_vars = Vec::new(); + + let expected_len = requests + .first() + .map(|(polys, _)| polys.len()) + .unwrap_or_default(); + assert!( + requests + .iter() + .all(|(polys, _)| polys.len() == expected_len), + "all filter requests must contain the same number of MLEs" + ); + + for (polys, flag) in requests { + for poly in *polys { + let num_vars = poly + .mle + .num_vars() + .checked_sub(1) + .expect("polynomial must have at least one variable"); + result_num_vars.push(num_vars); + flattened_refs.push(&poly.mle); + flags.push(*flag); + } + } + + if flattened_refs.is_empty() { + return Ok(vec![Vec::new(); requests.len()]); + } + + let stride = 1usize << result_num_vars[0]; + assert!( + flattened_refs + .iter() + .zip(result_num_vars.iter()) + .all(|(poly, vars)| poly.num_vars() == vars + 1), + "all MLEs must share the same number of variables before filtering" + ); + + let mut output_buffers = flattened_refs + .iter() + .map(|_| { + cuda_hal + .alloc_elems_on_device(stride, false) + .map_err(|e| hal_to_backend_error(format!("failed to allocate GPU buffer: {e:?}"))) + }) + .collect::, _>>()?; + + let flattened_refs: Vec<&GpuFieldType<'static>> = + unsafe { std::mem::transmute(flattened_refs) }; + + filter_mle_even_odd_batch::< + CudaHalBB31, + BB31Ext, + BB31Base, + GpuFieldType<'static>, + GpuPolynomial<'static>, + >(cuda_hal, flattened_refs, &flags, &mut output_buffers) + .map_err(|e| hal_to_backend_error(format!("GPU filter kernel failed: {e:?}")))?; + + let mut outputs = Vec::with_capacity(requests.len()); + let mut idx = 0; + for _ in requests { + let mut segment = Vec::with_capacity(expected_len); + for _ in 0..expected_len { + let buf = output_buffers + .get(idx) + .expect("missing buffer for filter result") + .clone(); + let num_vars = result_num_vars[idx]; + let gpu_poly = GpuPolynomial::new(buf, num_vars); + let gpu_mle = MultilinearExtensionGpu::from_ceno_gpu_base(gpu_poly); + let gpu_mle_static = unsafe { + std::mem::transmute::< + MultilinearExtensionGpu<'_, E>, + MultilinearExtensionGpu<'static, E>, + >(gpu_mle) + }; + segment.push(Arc::new(gpu_mle_static)); + idx += 1; + } + outputs.push(segment); + } + + Ok(outputs) +} diff --git a/gkr_iop/src/gpu/mod.rs b/gkr_iop/src/gpu/mod.rs index 103dd83a2..b2743b77b 100644 --- a/gkr_iop/src/gpu/mod.rs +++ b/gkr_iop/src/gpu/mod.rs @@ -370,7 +370,7 @@ impl> let (coeffs, indices, size_info) = extract_mle_relationships_from_monomial_terms( expr, &layer_wits.iter().map(|mle| mle.as_ref()).collect_vec(), - &pub_io_evals, + pub_io_evals, challenges, ); let coeffs_gl64: Vec = unsafe { std::mem::transmute(coeffs) };