@@ -470,7 +470,7 @@ CIRGenModule::getOrCreateStaticVarDecl(const VarDecl &D,
470470 Name = getStaticDeclName (*this , D);
471471
472472 mlir::Type LTy = getTypes ().convertTypeForMem (Ty);
473- cir::AddressSpaceAttr AS =
473+ cir::AddressSpaceAttr actualAS =
474474 builder.getAddrSpaceAttr (getGlobalVarAddressSpace (&D));
475475
476476 // OpenCL variables in local address space and CUDA shared
@@ -482,8 +482,9 @@ CIRGenModule::getOrCreateStaticVarDecl(const VarDecl &D,
482482 !D.hasAttr <CUDASharedAttr>())
483483 Init = builder.getZeroInitAttr (convertType (Ty));
484484
485- cir::GlobalOp GV = builder.createVersionedGlobal (
486- getModule (), getLoc (D.getLocation ()), Name, LTy, false , Linkage, AS);
485+ cir::GlobalOp GV =
486+ builder.createVersionedGlobal (getModule (), getLoc (D.getLocation ()), Name,
487+ LTy, false , Linkage, actualAS);
487488 // TODO(cir): infer visibility from linkage in global op builder.
488489 GV.setVisibility (getMLIRVisibilityFromCIRLinkage (Linkage));
489490 GV.setInitialValueAttr (Init);
@@ -497,14 +498,15 @@ CIRGenModule::getOrCreateStaticVarDecl(const VarDecl &D,
497498
498499 setGVProperties (GV, &D);
499500
500- // OG checks if the expected address space, denoted by the type, is the
501- // same as the actual address space indicated by attributes. If they aren't
502- // the same, an addrspacecast is emitted when this variable is accessed.
503- // In CIR however, cir.get_global alreadys carries that information in
504- // !cir.ptr type - if this global is in OpenCL local address space, then its
505- // type would be !cir.ptr<..., addrspace(offload_local)>. Therefore we don't
506- // need an explicit address space cast in CIR: they will get emitted when
507- // lowering to LLVM IR.
501+ // OG checks whether the expected address space (AS), denoted by
502+ // __attributes__((addrspace(n))), is the same as the actual AS indicated by
503+ // other attributes (such as __device__ in CUDA). If they aren't the same, an
504+ // addrspacecast is emitted when this variable is accessed, which means we
505+ // need it in this function. In CIR however, since we access globals by
506+ // `cir.get_global`, we won't emit a cast for GlobalOp here. Instead, we
507+ // record the AST, and create a CastOp in
508+ // `CIRGenBaseBuilder::createGetGlobal`.
509+ GV.setAstAttr (cir::ASTVarDeclAttr::get (&getMLIRContext (), &D));
508510
509511 // Ensure that the static local gets initialized by making sure the parent
510512 // function gets emitted eventually.
@@ -617,7 +619,10 @@ void CIRGenFunction::emitStaticVarDecl(const VarDecl &D,
617619 // TODO(cir): we should have a way to represent global ops as values without
618620 // having to emit a get global op. Sometimes these emissions are not used.
619621 auto addr = getBuilder ().createGetGlobal (globalOp);
620- auto getAddrOp = mlir::cast<cir::GetGlobalOp>(addr.getDefiningOp ());
622+ auto definingOp = addr.getDefiningOp ();
623+ bool hasCast = isa<cir::CastOp>(definingOp);
624+ auto getAddrOp = mlir::cast<cir::GetGlobalOp>(
625+ hasCast ? definingOp->getOperand (0 ).getDefiningOp () : definingOp);
621626
622627 CharUnits alignment = getContext ().getDeclAlign (&D);
623628
@@ -633,7 +638,7 @@ void CIRGenFunction::emitStaticVarDecl(const VarDecl &D,
633638 llvm_unreachable (" VLAs are NYI" );
634639
635640 // Save the type in case adding the initializer forces a type change.
636- auto expectedType = addr.getType ();
641+ auto expectedType = cast<cir::PointerType>( addr.getType () );
637642
638643 auto var = globalOp;
639644
@@ -678,7 +683,25 @@ void CIRGenFunction::emitStaticVarDecl(const VarDecl &D,
678683 //
679684 // FIXME: It is really dangerous to store this in the map; if anyone
680685 // RAUW's the GV uses of this constant will be invalid.
681- auto castedAddr = builder.createBitcast (getAddrOp.getAddr (), expectedType);
686+ mlir::Value castedAddr;
687+ if (!hasCast)
688+ castedAddr = builder.createBitcast (getAddrOp.getAddr (), expectedType);
689+ else {
690+ // If there is an extra CastOp from createGetGlobal, we need to remove the
691+ // existing addrspacecast, then supply a bitcast and a new addrspacecast:
692+ // %1 = cir.get_global @addr
693+ // %2 = cir.cast(addrspacecast, %1) <--- remove
694+ // %2 = cir.cast(bitcast, %1) <--- insert
695+ // %3 = cir.cast(addrspacecast, %2) <--- insert
696+ definingOp->erase ();
697+
698+ auto expectedTypeWithAS = cir::PointerType::get (
699+ expectedType.getPointee (), getAddrOp.getType ().getAddrSpace ());
700+ auto converted =
701+ builder.createBitcast (getAddrOp.getAddr (), expectedTypeWithAS);
702+ castedAddr = builder.createAddrSpaceCast (converted, expectedType);
703+ }
704+
682705 LocalDeclMap.find (&D)->second = Address (castedAddr, elemTy, alignment);
683706 CGM.setStaticLocalDeclAddress (&D, var);
684707
0 commit comments