Skip to content

Commit

Permalink
Unrolled build for rust-lang#135581
Browse files Browse the repository at this point in the history
Rollup merge of rust-lang#135581 - EnzymeAD:refactor-codgencx, r=oli-obk

Separate Builder methods from tcx

As part of the autodiff upstreaming we noticed, that it would be nice to have various builder methods available without the TypeContext, which prevents the normal CodegenCx to be passed around between threads.
We introduce a SimpleCx which just owns the llvm module and llvm context, to encapsulate them.
The previous CodegenCx now implements deref and forwards access to the llvm module or context to it's SimpleCx sub-struct. This gives us a bit more flexibility, because now we can pass (or construct) the SimpleCx in locations where we don't have enough information to construct a CodegenCx, or are not able to pass it around due to the tcx lifetimes (and it not implementing send/sync).

This also introduces an SBuilder, similar to the SimpleCx. The SBuilder uses a SimpleCx, whereas the existing Builder uses the larger CodegenCx. I will push updates to make  implementations generic (where possible) to be implemented once and work for either of the two. I'll also clean up the leftover code.

`call` is a bit tricky, because it requires a tcx, I probably need to duplicate it after all.

Tracking:

- rust-lang#124509
  • Loading branch information
rust-timer authored Jan 25, 2025
2 parents 1e9b017 + 386c233 commit 829a110
Show file tree
Hide file tree
Showing 10 changed files with 239 additions and 56 deletions.
1 change: 0 additions & 1 deletion compiler/rustc_codegen_gcc/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,6 @@ impl WriteBackendMethods for GccCodegenBackend {
}
fn autodiff(
_cgcx: &CodegenContext<Self>,
_tcx: TyCtxt<'_>,
_module: &ModuleCodegen<Self::Module>,
_diff_fncs: Vec<AutoDiffItem>,
_config: &ModuleConfig,
Expand Down
150 changes: 138 additions & 12 deletions compiler/rustc_codegen_llvm/src/builder.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::borrow::Cow;
use std::borrow::{Borrow, Cow};
use std::ops::Deref;
use std::{iter, ptr};

Expand Down Expand Up @@ -31,27 +31,135 @@ use tracing::{debug, instrument};
use crate::abi::FnAbiLlvmExt;
use crate::attributes;
use crate::common::Funclet;
use crate::context::CodegenCx;
use crate::context::{CodegenCx, SimpleCx};
use crate::llvm::{self, AtomicOrdering, AtomicRmwBinOp, BasicBlock, False, True};
use crate::type_::Type;
use crate::type_of::LayoutLlvmExt;
use crate::value::Value;

// All Builders must have an llfn associated with them
#[must_use]
pub(crate) struct Builder<'a, 'll, 'tcx> {
pub(crate) struct GenericBuilder<'a, 'll, CX: Borrow<SimpleCx<'ll>>> {
pub llbuilder: &'ll mut llvm::Builder<'ll>,
pub cx: &'a CodegenCx<'ll, 'tcx>,
pub cx: &'a CX,
}

impl Drop for Builder<'_, '_, '_> {
pub(crate) type SBuilder<'a, 'll> = GenericBuilder<'a, 'll, SimpleCx<'ll>>;
pub(crate) type Builder<'a, 'll, 'tcx> = GenericBuilder<'a, 'll, CodegenCx<'ll, 'tcx>>;

impl<'a, 'll, CX: Borrow<SimpleCx<'ll>>> Drop for GenericBuilder<'a, 'll, CX> {
fn drop(&mut self) {
unsafe {
llvm::LLVMDisposeBuilder(&mut *(self.llbuilder as *mut _));
}
}
}

impl<'a, 'll> SBuilder<'a, 'll> {
fn call(
&mut self,
llty: &'ll Type,
llfn: &'ll Value,
args: &[&'ll Value],
funclet: Option<&Funclet<'ll>>,
) -> &'ll Value {
debug!("call {:?} with args ({:?})", llfn, args);

let args = self.check_call("call", llty, llfn, args);
let funclet_bundle = funclet.map(|funclet| funclet.bundle());
let mut bundles: SmallVec<[_; 2]> = SmallVec::new();
if let Some(funclet_bundle) = funclet_bundle {
bundles.push(funclet_bundle);
}

let call = unsafe {
llvm::LLVMBuildCallWithOperandBundles(
self.llbuilder,
llty,
llfn,
args.as_ptr() as *const &llvm::Value,
args.len() as c_uint,
bundles.as_ptr(),
bundles.len() as c_uint,
c"".as_ptr(),
)
};
call
}

fn with_scx(scx: &'a SimpleCx<'ll>) -> Self {
// Create a fresh builder from the simple context.
let llbuilder = unsafe { llvm::LLVMCreateBuilderInContext(scx.llcx) };
SBuilder { llbuilder, cx: scx }
}
}
impl<'a, 'll, CX: Borrow<SimpleCx<'ll>>> GenericBuilder<'a, 'll, CX> {
pub(crate) fn bitcast(&mut self, val: &'ll Value, dest_ty: &'ll Type) -> &'ll Value {
unsafe { llvm::LLVMBuildBitCast(self.llbuilder, val, dest_ty, UNNAMED) }
}

fn ret_void(&mut self) {
unsafe {
llvm::LLVMBuildRetVoid(self.llbuilder);
}
}

fn ret(&mut self, v: &'ll Value) {
unsafe {
llvm::LLVMBuildRet(self.llbuilder, v);
}
}
}
impl<'a, 'll> SBuilder<'a, 'll> {
fn build(cx: &'a SimpleCx<'ll>, llbb: &'ll BasicBlock) -> SBuilder<'a, 'll> {
let bx = SBuilder::with_scx(cx);
unsafe {
llvm::LLVMPositionBuilderAtEnd(bx.llbuilder, llbb);
}
bx
}

fn check_call<'b>(
&mut self,
typ: &str,
fn_ty: &'ll Type,
llfn: &'ll Value,
args: &'b [&'ll Value],
) -> Cow<'b, [&'ll Value]> {
assert!(
self.cx.type_kind(fn_ty) == TypeKind::Function,
"builder::{typ} not passed a function, but {fn_ty:?}"
);

let param_tys = self.cx.func_params_types(fn_ty);

let all_args_match = iter::zip(&param_tys, args.iter().map(|&v| self.cx.val_ty(v)))
.all(|(expected_ty, actual_ty)| *expected_ty == actual_ty);

if all_args_match {
return Cow::Borrowed(args);
}

let casted_args: Vec<_> = iter::zip(param_tys, args)
.enumerate()
.map(|(i, (expected_ty, &actual_val))| {
let actual_ty = self.cx.val_ty(actual_val);
if expected_ty != actual_ty {
debug!(
"type mismatch in function call of {:?}. \
Expected {:?} for param {}, got {:?}; injecting bitcast",
llfn, expected_ty, i, actual_ty
);
self.bitcast(actual_val, expected_ty)
} else {
actual_val
}
})
.collect();

Cow::Owned(casted_args)
}
}

/// Empty string, to be used where LLVM expects an instruction name, indicating
/// that the instruction is to be left unnamed (i.e. numbered, in textual IR).
// FIXME(eddyb) pass `&CStr` directly to FFI once it's a thin pointer.
Expand Down Expand Up @@ -1222,6 +1330,14 @@ impl<'ll> StaticBuilderMethods for Builder<'_, 'll, '_> {
}

impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {
fn build(cx: &'a CodegenCx<'ll, 'tcx>, llbb: &'ll BasicBlock) -> Builder<'a, 'll, 'tcx> {
let bx = Builder::with_cx(cx);
unsafe {
llvm::LLVMPositionBuilderAtEnd(bx.llbuilder, llbb);
}
bx
}

fn with_cx(cx: &'a CodegenCx<'ll, 'tcx>) -> Self {
// Create a fresh builder from the crate context.
let llbuilder = unsafe { llvm::LLVMCreateBuilderInContext(cx.llcx) };
Expand All @@ -1231,13 +1347,16 @@ impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {
pub(crate) fn llfn(&self) -> &'ll Value {
unsafe { llvm::LLVMGetBasicBlockParent(self.llbb()) }
}
}

impl<'a, 'll, CX: Borrow<SimpleCx<'ll>>> GenericBuilder<'a, 'll, CX> {
fn position_at_start(&mut self, llbb: &'ll BasicBlock) {
unsafe {
llvm::LLVMRustPositionBuilderAtStart(self.llbuilder, llbb);
}
}

}
impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {
fn align_metadata(&mut self, load: &'ll Value, align: Align) {
unsafe {
let md = [llvm::LLVMValueAsMetadata(self.cx.const_u64(align.bytes()))];
Expand All @@ -1259,7 +1378,8 @@ impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {
self.set_metadata(inst, llvm::MD_unpredictable, md);
}
}

}
impl<'a, 'll, CX: Borrow<SimpleCx<'ll>>> GenericBuilder<'a, 'll, CX> {
pub(crate) fn minnum(&mut self, lhs: &'ll Value, rhs: &'ll Value) -> &'ll Value {
unsafe { llvm::LLVMRustBuildMinNum(self.llbuilder, lhs, rhs) }
}
Expand Down Expand Up @@ -1360,7 +1480,9 @@ impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {
let ret = unsafe { llvm::LLVMBuildCatchRet(self.llbuilder, funclet.cleanuppad(), unwind) };
ret.expect("LLVM does not have support for catchret")
}
}

impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {
fn check_call<'b>(
&mut self,
typ: &str,
Expand Down Expand Up @@ -1401,11 +1523,13 @@ impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {

Cow::Owned(casted_args)
}

}
impl<'a, 'll, CX: Borrow<SimpleCx<'ll>>> GenericBuilder<'a, 'll, CX> {
pub(crate) fn va_arg(&mut self, list: &'ll Value, ty: &'ll Type) -> &'ll Value {
unsafe { llvm::LLVMBuildVAArg(self.llbuilder, list, ty, UNNAMED) }
}

}
impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {
pub(crate) fn call_intrinsic(&mut self, intrinsic: &str, args: &[&'ll Value]) -> &'ll Value {
let (ty, f) = self.cx.get_intrinsic(intrinsic);
self.call(ty, None, None, f, args, None, None)
Expand All @@ -1423,7 +1547,8 @@ impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {

self.call_intrinsic(intrinsic, &[self.cx.const_u64(size), ptr]);
}

}
impl<'a, 'll, CX: Borrow<SimpleCx<'ll>>> GenericBuilder<'a, 'll, CX> {
pub(crate) fn phi(
&mut self,
ty: &'ll Type,
Expand All @@ -1443,7 +1568,8 @@ impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {
llvm::LLVMAddIncoming(phi, &val, &bb, 1 as c_uint);
}
}

}
impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {
fn fptoint_sat(&mut self, signed: bool, val: &'ll Value, dest_ty: &'ll Type) -> &'ll Value {
let src_ty = self.cx.val_ty(val);
let (float_ty, int_ty, vector_length) = if self.cx.type_kind(src_ty) == TypeKind::Vector {
Expand Down
25 changes: 11 additions & 14 deletions compiler/rustc_codegen_llvm/src/builder/autodiff.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,19 @@ use std::ptr;
use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, AutoDiffItem, DiffActivity, DiffMode};
use rustc_codegen_ssa::ModuleCodegen;
use rustc_codegen_ssa::back::write::ModuleConfig;
use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods};
use rustc_errors::FatalError;
use rustc_middle::ty::TyCtxt;
use rustc_session::config::Lto;
use tracing::{debug, trace};

use crate::back::write::{llvm_err, llvm_optimize};
use crate::builder::Builder;
use crate::declare::declare_raw_fn;
use crate::builder::SBuilder;
use crate::context::SimpleCx;
use crate::declare::declare_simple_fn;
use crate::errors::LlvmError;
use crate::llvm::AttributePlace::Function;
use crate::llvm::{Metadata, True};
use crate::value::Value;
use crate::{CodegenContext, LlvmCodegenBackend, ModuleLlvm, attributes, context, llvm};
use crate::{CodegenContext, LlvmCodegenBackend, ModuleLlvm, attributes, llvm};

fn get_params(fnc: &Value) -> Vec<&Value> {
unsafe {
Expand All @@ -38,8 +37,8 @@ fn get_params(fnc: &Value) -> Vec<&Value> {
/// [^1]: <https://enzyme.mit.edu/getting_started/CallingConvention/>
// FIXME(ZuseZ4): `outer_fn` should include upstream safety checks to
// cover some assumptions of enzyme/autodiff, which could lead to UB otherwise.
fn generate_enzyme_call<'ll, 'tcx>(
cx: &context::CodegenCx<'ll, 'tcx>,
fn generate_enzyme_call<'ll>(
cx: &SimpleCx<'ll>,
fn_to_diff: &'ll Value,
outer_fn: &'ll Value,
attrs: AutoDiffAttrs,
Expand Down Expand Up @@ -112,7 +111,7 @@ fn generate_enzyme_call<'ll, 'tcx>(
//FIXME(ZuseZ4): the CC/Addr/Vis values are best effort guesses, we should look at tests and
// think a bit more about what should go here.
let cc = llvm::LLVMGetFunctionCallConv(outer_fn);
let ad_fn = declare_raw_fn(
let ad_fn = declare_simple_fn(
cx,
&ad_name,
llvm::CallConv::try_from(cc).expect("invalid callconv"),
Expand All @@ -132,7 +131,7 @@ fn generate_enzyme_call<'ll, 'tcx>(
llvm::LLVMRustEraseInstFromParent(br);

let last_inst = llvm::LLVMRustGetLastInstruction(entry).unwrap();
let mut builder = Builder::build(cx, entry);
let mut builder = SBuilder::build(cx, entry);

let num_args = llvm::LLVMCountParams(&fn_to_diff);
let mut args = Vec::with_capacity(num_args as usize + 1);
Expand Down Expand Up @@ -236,7 +235,7 @@ fn generate_enzyme_call<'ll, 'tcx>(
}
}

let call = builder.call(enzyme_ty, None, None, ad_fn, &args, None, None);
let call = builder.call(enzyme_ty, ad_fn, &args, None);

// This part is a bit iffy. LLVM requires that a call to an inlineable function has some
// metadata attachted to it, but we just created this code oota. Given that the
Expand Down Expand Up @@ -274,10 +273,9 @@ fn generate_enzyme_call<'ll, 'tcx>(
}
}

pub(crate) fn differentiate<'ll, 'tcx>(
pub(crate) fn differentiate<'ll>(
module: &'ll ModuleCodegen<ModuleLlvm>,
cgcx: &CodegenContext<LlvmCodegenBackend>,
tcx: TyCtxt<'tcx>,
diff_items: Vec<AutoDiffItem>,
config: &ModuleConfig,
) -> Result<(), FatalError> {
Expand All @@ -286,8 +284,7 @@ pub(crate) fn differentiate<'ll, 'tcx>(
}

let diag_handler = cgcx.create_dcx();
let (_, cgus) = tcx.collect_and_partition_mono_items(());
let cx = context::CodegenCx::new(tcx, &cgus.first().unwrap(), &module.module_llvm);
let cx = SimpleCx { llmod: module.module_llvm.llmod(), llcx: module.module_llvm.llcx };

// Before dumping the module, we want all the TypeTrees to become part of the module.
for item in diff_items.iter() {
Expand Down
Loading

0 comments on commit 829a110

Please sign in to comment.