Skip to content

Commit

Permalink
Improve support for spot scale-sets
Browse files Browse the repository at this point in the history
Add an interactive thread and a loop on that thread that checks for spot eviction.
  • Loading branch information
samtkaplan committed Feb 14, 2024
1 parent 3432d73 commit f8bd033
Show file tree
Hide file tree
Showing 4 changed files with 193 additions and 49 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@
Manifest.toml
deps/build.log
deps/usr
docs/build
docs/build
.vscode
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
HTTP = "cd3eb016-35fb-5094-929b-558a96fad6f3"
JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
JSONWebTokens = "9b8beb19-0777-58c6-920b-28f749fee4d3"
LibCURL = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21"
LibGit2 = "76f85450-5226-5b5a-8eaa-529ad045b433"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
Expand All @@ -29,6 +30,7 @@ CodecZlib = "0.7"
HTTP = "1"
JSON = "0.21"
JSONWebTokens = "1"
LibCURL = "0.6"
MPI = "0.14, 0.15, 0.16, 0.17, 0.18, 0.19, 0.20"
TOML = "1"
julia = "^1.4"
Expand Down
195 changes: 161 additions & 34 deletions src/AzManagers.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module AzManagers

using AzSessions, Base64, CodecZlib, Dates, Distributed, HTTP, JSON, JSONWebTokens, LibGit2, Logging, MPI, Pkg, Printf, Random, Serialization, Sockets, TOML
using AzSessions, Base64, CodecZlib, Dates, Distributed, HTTP, JSON, JSONWebTokens, LibCURL, LibGit2, Logging, MPI, Pkg, Printf, Random, Serialization, Sockets, TOML

function logerror(e, loglevel=Logging.Info)
io = IOBuffer()
Expand Down Expand Up @@ -185,6 +185,7 @@ mutable struct AzManager <: ClusterManager
pending_down::Dict{ScaleSet,Set{String}}
deleted::Dict{ScaleSet,Dict{String,DateTime}}
pruned::Dict{ScaleSet,Set{String}}
preempted::Dict{ScaleSet,Set{String}}
port::UInt16
server::Sockets.TCPServer
worker_socket::TCPSocket
Expand Down Expand Up @@ -213,6 +214,7 @@ function azmanager!(session, nretry, verbose, show_quota)
_manager.pending_down = Dict{ScaleSet,Set{String}}()
_manager.deleted = Dict{ScaleSet,Dict{String,DateTime}}()
_manager.pruned = Dict{ScaleSet,Set{String}}()
_manager.preempted = Dict{ScaleSet,Set{String}}()
_manager.scalesets = Dict{ScaleSet,Int}()
_manager.task_add = @async add_pending_connections()
_manager.task_process = @async process_pending_connections()
Expand Down Expand Up @@ -539,11 +541,41 @@ function process_pending_connections()
pids = addprocs(manager; sockets)
empty!(sockets)

@sync for pid in pids
for pid in pids
@async begin
wrkr = Distributed.map_pid_wrkr[pid]
if isdefined(wrkr, :config) && isdefined(wrkr.config, :userdata) && lowercase(get(wrkr.config.userdata, "priority", "")) == "spot"
remote_do(machine_prempt_loop, pid) # monitor for Azure spot evictions on each machine
try
remotecall_fetch(machine_preempt_loop, pid)
catch e
if isa(e, RemoteException) && isa(e.captured.ex, TaskFailedException) && isa(e.captured.ex.task.result.ex, SpotPreemptException)
ex = e.captured.ex.task.result.ex
notbefore = DateTime(ex.notbefore, dateformat"e, dd u yyyy HH:MM:SS \G\M\T")
@info "caught preempt exception for $(ex.clusterid), removing not before $notbefore UTC"
_now = now(UTC)
if notbefore > _now
@info "sleeping for $(notbefore - _now)"
sleep(notbefore - _now)
end
u = wrkr.config.userdata
try
scaleset = ScaleSet(u["subscriptionid"], u["resourcegroup"], u["scalesetname"])
add_instance_to_preempted_list(manager, scaleset, u["instanceid"])
catch e
@info "error adding instance to preempted list"
end

try
@info "Removing preempted worker $pid from the Julia cluster"
# We can't use rmprocs here since the worker process might already be gone due to preemption. The
# worker process would usually do the following two lines (see the Distributed.message_handler_loop function)
Distributed.set_worker_state(Distributed.map_pid_wrkr[pid], Distributed.W_TERMINATED)
Distributed.deregister_worker(pid)
catch
@warn "unable to remove $pid"
end
end
end
end
end
end
Expand Down Expand Up @@ -807,6 +839,22 @@ function add_instance_to_pruned_list(manager::AzManager, scaleset::ScaleSet, ins
nothing
end

function add_instance_to_preempted_list(manager::AzManager, scaleset::ScaleSet, instanceid)
if haskey(manager.preempted, scaleset)
@debug "pushing worker with id=$instanceid onto preempted"
push!(manager.preempted[scaleset], string(instanceid))
else
@debug "creating preempted vector for id=$instanceid"
manager.preempted[scaleset] = Set{String}([string(instanceid)])
end
end

function ispreempted(manager::AzManager, config::WorkerConfig)
u = config.userdata
scaleset = ScaleSet(u["subscriptionid"], u["resourcegroup"], u["scalesetname"])
string(u["instanceid"]) get(manager.preempted, scaleset, Set{String}())
end

function add_instance_to_deleted_list(manager::AzManager, scaleset::ScaleSet, instanceid)
if haskey(manager.deleted, scaleset)
@debug "pushing worker with id=$instanceid onto deleted"
Expand All @@ -820,6 +868,12 @@ end

function Distributed.kill(manager::AzManager, id::Int, config::WorkerConfig)
@debug "kill for id=$id"

if ispreempted(manager, config)
@debug "kill on id=$id because it was preempted"
return nothing
end

try
remote_do(exit, id)
catch
Expand Down Expand Up @@ -933,10 +987,58 @@ function Distributed.manage(manager::AzManager, id::Integer, config::WorkerConfi
end
end

#=
Use libCURL because HTTP forces the request to run, partially, on a thread in the default thread-pool
where-as, we would like to run requests to the scaleset metadata server on the interactive thrad-pool.
=#
mutable struct CurlDataStruct
body::Vector{UInt8}
currentsize::Csize_t
end

function curl_get_write_callback(curlbuf::Ptr{Cchar}, size::Csize_t, nmemb::Csize_t, datavoid::Ptr{Cvoid})
datastruct = unsafe_pointer_to_objref(datavoid)::CurlDataStruct

n = size*nmemb
newsize = datastruct.currentsize + n
resize!(datastruct.body, newsize)

_data = pointer(datastruct.body, datastruct.currentsize+1)
@ccall memcpy(_data::Ptr{Cvoid}, curlbuf::Ptr{Cvoid}, n::Csize_t)::Ptr{Cvoid}
datastruct.currentsize = newsize
return n
end

function curl_get_metadata(url)
datastruct = CurlDataStruct(UInt8[], 0)

headers = C_NULL
headers = curl_slist_append(headers, "Metadata: true")

curl = curl_easy_init()
curl_easy_setopt(curl, CURLOPT_URL, url)
curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers)
curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, @cfunction(curl_get_write_callback, Csize_t, (Ptr{Cchar}, Csize_t, Csize_t, Ptr{Cvoid})))
curl_easy_setopt(curl, CURLOPT_WRITEDATA, pointer_from_objref(datastruct))

curl_easy_perform(curl)

http_code = Array{Clong}(undef, 1)
curl_easy_getinfo(curl, CURLINFO_RESPONSE_CODE, http_code)
if http_code[1] > 200
error("Azure metaadata service return $(http_code[1]) response.")
end

curl_easy_cleanup(curl)

datastruct
end

function get_instanceid()
local r
try
_r = HTTP.request("GET", "http://169.254.169.254/metadata/instance/compute?api-version=2021-02-01", ["Metadata"=>"true"])
# _r = HTTP.request("GET", "http://169.254.169.254/metadata/instance/compute?api-version=2021-02-01", ["Metadata"=>"true"])
_r = curl_get_metadata("http://169.254.169.254/metadata/instance/compute?api-version=2021-02-01")
r = JSON.parse(String(_r.body))
catch
r = Dict()
Expand All @@ -945,60 +1047,72 @@ function get_instanceid()
end

"""
preempted([id=myid()|id="instanceid"])
ispreempted,notbefore = preempted([id=myid()|id="instanceid"])
Check to see if the machine `id::Int` has received an Azure spot preempt message. Returns
true if a preempt message is received and false otherwise.
(true, notbefore) if a preempt message is received and (false,"") otherwise. `notbefore`
is the date/time before which the machine is guaranteed to still exist.
"""
function preempted(instanceid::AbstractString="", clusterid::Int=0)
function preempted(instanceid::AbstractString, clusterid::Int)
isempty(instanceid) && (instanceid = get_instanceid())
clusterid == 0 && (clusterid = myid())
local _r
try
tic = time()
_r = HTTP.request("GET", "http://169.254.169.254/metadata/scheduledevents?api-version=2020-07-01", ["Metadata"=>"true"])
# _r = HTTP.request("GET", "http://169.254.169.254/metadata/scheduledevents?api-version=2020-07-01", ["Metadata"=>"true"])
_r = curl_get_metadata("http://169.254.169.254/metadata/scheduledevents?api-version=2020-07-01")
if time() - tic > 55 # 55 seconds, simply because it is less that 60, and 60 seconds is the eviction notice.
@warn "$(now()), took longer than 55 seconds to query the meta-data server for scheduled events (elapsed time=$(time() - tic))."
@debug "$(now()), took longer than 55 seconds to query the meta-data server for scheduled events (elapsed time=$(time() - tic))."
end
catch
@warn "unable to get scheduledevents."
return false
return false, ""
end
r = JSON.parse(String(_r.body))
for event in get(r, "Events", [])
if get(event, "EventType", "") == "Preempt" && instanceid get(event, "Resources", [])
@warn "Machine with id $clusterid ($instanceid) is being pre-empted" now(Dates.UTC) event["NotBefore"] event["EventType"] event["EventSource"]
return true
return true, event["NotBefore"]
end
end
return false
return false, ""
end
preempted(id::Int) = remotecall_fetch(preempted, id)

function _machine_preempt_loop(pid, clusterid)
instanceid = ""
while true
instanceid = get_instanceid()
instanceid == "" || break
sleep(1)
end
while true
if AzManagers.preempted(instanceid, clusterid)
# self-destruct button, Distributed should see that the process is exited and update the cluster book-keeping.
@info "self-destruct, killing pid=$pid"
run(`kill -9 $pid`)
exit()
break
end
sleep(1)
macro spawn_interactive(ex::Expr)
if VERSION > v"1.9"
esc(:(Threads.@spawn :interactive $ex))
else
esc(:(Threads.@spawn $ex))
end
end

function machine_prempt_loop()
project = dirname(Pkg.project().path)
pid = getpid()
id = myid()
open(`julia --project=$project -e "using AzManagers; AzManagers._machine_preempt_loop($pid, $id)"`)
struct SpotPreemptException <: Exception
instanceid::String
clusterid::Int
notbefore::String
end
Base.showerror(io::IO, e::SpotPreemptException) = print(io, "spot preemption on process '$(e.clusterid)' ($(e.instanceid)), not before '$(e.notbefore)'")

function machine_preempt_loop()
if VERSION >= v"1.9" && Threads.nthreads(:interactive) > 0
tsk = @spawn_interactive begin
instanceid = get_instanceid()
clusterid = myid()
@debug "starting preempt loop on $clusterid, $instanceid"

while true
ispreempted, notbefore = preempted(instanceid, clusterid)
if ispreempted
# pid=1 will catch this exception, and remove the worker from the Julia cluster.
throw(SpotPreemptException(instanceid, clusterid, notbefore))
end
sleep(1)
end
end
fetch(tsk)
else
@warn "AzManagers is not running the preempt loop for pid=$(myid()) since it requires at least one interactive thread on worker machines."
end
end

function azure_physical_name(keyval="PhysicalHostName")
Expand Down Expand Up @@ -1828,6 +1942,19 @@ function buildstartupscript_cluster(manager::AzManager, spot::Bool, ppi::Int, mp

juliaenvstring = remote_julia_environment_name == "" ? "" : """using Pkg; Pkg.activate(joinpath(Pkg.envdir(), "$remote_julia_environment_name")); """

# if spot is true, then ensure at least one interactive thread on workers so that one can check for spot evictions periodically.
if spot && VERSION >= v"1.9"
_julia_num_threads = split(julia_num_threads, ',')
julia_num_threads_default = length(_julia_num_threads) > 0 ? parse(Int, _julia_num_threads[1]) : 1
julia_num_threads_interactive = length(_julia_num_threads) > 1 ? parse(Int, _julia_num_threads[2]) : 0

if julia_num_threads_interactive == 0
@debug "Augmenting 'julia_num_threads' option with an interactive thread so it can be used on workers for spot-event polling."
julia_num_threads_interactive = 1
end
julia_num_threads = nthreads_filter("$julia_num_threads_default,$julia_num_threads_interactive")
end

_exeflags = isempty(exeflags) ? "-t $julia_num_threads" : "$exeflags -t $julia_num_threads"

if mpi_ranks_per_worker == 0
Expand Down
42 changes: 28 additions & 14 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,19 @@ end
end
rmprocs(workers())

group = "test$(randstring('a':'z',4))"
julia_num_threads = VERSION >= v"1.9" ? "2,0" : "2"
addprocs(templatename, 1; waitfor = true, group, session, julia_num_threads, spot = true)

@test remotecall_fetch(Threads.nthreads, workers()[1]) == 2

if VERSION >= v"1.9"
if workers()[1] != 1
@test remotecall_fetch(Threads.nthreads, workers()[1], :interactive) == 1
end
end
rmprocs(workers())

group = "test$(randstring('a':'z',4))"
julia_num_threads = VERSION >= v"1.9" ? "3,2" : "3"
addprocs(templatename, 1; waitfor = true, group, session, julia_num_threads, spot=true)
Expand All @@ -132,24 +145,25 @@ end
rmprocs(workers())
end

if VERSION >= v"1.9"
@testset "spot eviction" begin
group = "test$(randstring('a':'z',4))"
julia_num_threads = "2,1"
addprocs(templatename, 2; waitfor = true, group, session, julia_num_threads, spot = true)

@testset "spot eviction" begin
group = "test$(randstring('a':'z',4))"
addprocs(templatename, 2; waitfor = true, group, session, spot = true)

sleep(90)
AzManagers.simulate_spot_eviction(workers()[1])
AzManagers.simulate_spot_eviction(workers()[1])

tic = time()
while time() - tic < 300
if nprocs() < 3
@info "cluster responded to spot eviction in $(time() - tic) seconds"
break
tic = time()
while time() - tic < 300
if nprocs() < 3
@info "cluster responded to spot eviction in $(time() - tic) seconds"
break
end
sleep(10)
end
sleep(10)
@test nprocs() < 3
rmprocs(workers())
end
@test nprocs() < 3
rmprocs(workers())
end

@testset "environment, addproc" begin
Expand Down

0 comments on commit f8bd033

Please sign in to comment.