Skip to content

Commit

Permalink
Fix ROCm backend API
Browse files Browse the repository at this point in the history
  • Loading branch information
michel2323 committed May 16, 2023
1 parent cd84343 commit bfaa546
Show file tree
Hide file tree
Showing 6 changed files with 7 additions and 8 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ExaTron"
uuid = "28b18bf8-76f9-41ea-81fa-0f922810b349"
authors = ["Youngdae Kim <youngdae@anl.gov>", "François Pacaud <fpacaud@anl.gov>", "Kibaek Kim <kimk@anl.gov>", "Michel Schanen <mschanen@anl.gov>"]
version = "3.0.0"
version = "3.0.1"

[deps]
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
Expand Down
6 changes: 3 additions & 3 deletions examples/admm/acopf_admm_gpu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ function get_generator_data(data::OPFData, ::CUDABackend)
return pgmin,pgmax,qgmin,qgmax,c2,c1,c0
end

function get_generator_data(data::OPFData, ::ROCDevice)
function get_generator_data(data::OPFData, ::ROCBackend)
ngen = length(data.generators)

pgmin = ROCArray{Float64}(undef, ngen)
Expand Down Expand Up @@ -143,7 +143,7 @@ function get_bus_data(data::OPFData, ::CUDABackend)
return cuFrStart,cuFrIdx,cuToStart,cuToIdx,cuGenStart,cuGenIdx,cuPd,cuQd
end

function get_bus_data(data::OPFData, ::ROCDevice)
function get_bus_data(data::OPFData, ::ROCBackend)
nbus = length(data.buses)

FrIdx = [l for b=1:nbus for l in data.FromLines[b]]
Expand Down Expand Up @@ -232,7 +232,7 @@ function get_branch_data(data::OPFData, device::CUDABackend)
cuYttR, cuYttI, cuYtfR, cuYtfI, cuFrBound, cuToBound
end

function get_branch_data(data::OPFData, device::ROCDevice)
function get_branch_data(data::OPFData, device::ROCBackend)
buses = data.buses
lines = data.lines
BusIdx = data.BusIdx
Expand Down
2 changes: 1 addition & 1 deletion examples/admm/environment.jl
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ AdmmEnv(opfdata::OPFData, device::CUDABackend, rho_pq, rho_va; options...) = Adm
opfdata, rho_pq, rho_va; device=device, options...
)

AdmmEnv(opfdata::OPFData, device::ROCDevice, rho_pq, rho_va; options...) = AdmmEnv{Float64, ROCArray{Float64, 1}, ROCArray{Int, 1}, ROCArray{Float64, 2}}(
AdmmEnv(opfdata::OPFData, device::ROCBackend, rho_pq, rho_va; options...) = AdmmEnv{Float64, ROCArray{Float64, 1}, ROCArray{Int, 1}, ROCArray{Float64, 2}}(
opfdata, rho_pq, rho_va; device=device, options...
)

Expand Down
2 changes: 1 addition & 1 deletion examples/opf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ max_iter = parse(Int, ARGS[3])
# Indicate which GPU device to use
device = CPU()
# device = CUDABackend()
# device = ROCDevice()
# device = ROCBackend()
# verbose = 0: No output
# verbose = 1: Final result metrics
# verbose = 2: Iteration output
Expand Down
1 change: 0 additions & 1 deletion test/KA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ if has_cuda_gpu()
AT = CuArray
elseif has_rocm_gpu()
# Set for crusher login node to avoid other users
AMDGPU.default_device!(AMDGPU.devices()[2])
device = AMDGPU.ROCBackend()
AT = ROCArray
else
Expand Down
2 changes: 1 addition & 1 deletion test/admmtest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,4 @@ function one_level_admm(case::String, device)
return nothing
end

# one_level_admm(CASE, ROCDevice())
# one_level_admm(CASE, ROCBackend())

0 comments on commit bfaa546

Please sign in to comment.