Skip to content

Commit

Permalink
Move retry/catch to a mixin
Browse files Browse the repository at this point in the history
  • Loading branch information
agrare committed Oct 8, 2024
1 parent 6a4041c commit 6ffb6e9
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 52 deletions.
1 change: 1 addition & 0 deletions lib/floe.rb
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
require_relative "floe/workflow/states/non_terminal_mixin"
require_relative "floe/workflow/states/parallel"
require_relative "floe/workflow/states/pass"
require_relative "floe/workflow/states/retry_catch_mixin"
require_relative "floe/workflow/states/succeed"
require_relative "floe/workflow/states/task"
require_relative "floe/workflow/states/wait"
Expand Down
20 changes: 16 additions & 4 deletions lib/floe/workflow/states/map.rb
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@

require_relative "input_output_mixin"
require_relative "non_terminal_mixin"
require_relative "retry_catch_mixin"

module Floe
class Workflow
module States
class Map < Floe::Workflow::State
include InputOutputMixin
include NonTerminalMixin
include RetryCatchMixin

attr_reader :end, :next, :parameters, :input_path, :output_path, :result_path,
:result_selector, :retry, :catch, :item_processor, :items_path,
Expand Down Expand Up @@ -58,8 +60,13 @@ def start(context)
end

def finish(context)
result = each_item_processor(context).map(&:output)
context.output = process_output(context, result)
if failed?(context)
error = parse_error(context)
retry_state!(context, error) || catch_error!(context, error) || fail_workflow!(context, error)
else
result = each_item_processor(context).map(&:output)
context.output = process_output(context, result)
end
super
end

Expand Down Expand Up @@ -105,10 +112,11 @@ def failed?(context)

# Some have failed, check the tolerated_failure thresholds to see if
# we should fail the whole state.
num_failed = contexts.select(&:failed?).count
num_failed = contexts.count(&:failed?)
return false if tolerated_failure_count && num_failed < tolerated_failure_count
return false if tolerated_failure_percentage && (100 * num_failed / contexts.count.to_f) < tolerated_failure_percentage
return true

true
end

private
Expand All @@ -128,6 +136,10 @@ def step_nonblock!(context)
end
end

def parse_error(context)
each_item_processor(context).detect(&:failed?)&.output&.dig("Error")
end

def validate_state!(workflow)
validate_state_next!(workflow)
end
Expand Down
57 changes: 57 additions & 0 deletions lib/floe/workflow/states/retry_catch_mixin.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# frozen_string_literal: true

module Floe
class Workflow
module States
module RetryCatchMixin
def find_retrier(error)
self.retry.detect { |r| r.match_error?(error) }
end

def find_catcher(error)
self.catch.detect { |c| c.match_error?(error) }
end

def retry_state!(context, error)
retrier = find_retrier(error["Error"]) if error
return if retrier.nil?

# If a different retrier is hit reset the context
if !context["State"].key?("RetryCount") || context["State"]["Retrier"] != retrier.error_equals
context["State"]["RetryCount"] = 0
context["State"]["Retrier"] = retrier.error_equals
end

context["State"]["RetryCount"] += 1

return if context["State"]["RetryCount"] > retrier.max_attempts

wait_until!(context, :seconds => retrier.sleep_duration(context["State"]["RetryCount"]))
context.next_state = context.state_name
context.output = error
logger.info("Running state: [#{long_name}] with input [#{context.json_input}] got error[#{context.json_output}]...Retry - delay: #{wait_until(context)}")
true
end

def catch_error!(context, error)
catcher = find_catcher(error["Error"]) if error
return if catcher.nil?

context.next_state = catcher.next
context.output = catcher.result_path.set(context.input, error)
logger.info("Running state: [#{long_name}] with input [#{context.json_input}]...CatchError - next state: [#{context.next_state}] output: [#{context.json_output}]")

true
end

def fail_workflow!(context, error)
# next_state is nil, and will be set to nil again in super
# keeping in here for completeness
context.next_state = nil
context.output = error
logger.error("Running state: [#{long_name}] with input [#{context.json_input}]...Complete workflow - output: [#{context.json_output}]")
end
end
end
end
end
53 changes: 5 additions & 48 deletions lib/floe/workflow/states/task.rb
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
# frozen_string_literal: true

require_relative "input_output_mixin"
require_relative "non_terminal_mixin"
require_relative "retry_catch_mixin"

module Floe
class Workflow
module States
class Task < Floe::Workflow::State
include InputOutputMixin
include NonTerminalMixin
include RetryCatchMixin

attr_reader :credentials, :end, :heartbeat_seconds, :next, :parameters,
:result_selector, :resource, :timeout_seconds, :retry, :catch,
Expand Down Expand Up @@ -82,54 +87,6 @@ def success?(context)
runner.success?(context.state["RunnerContext"])
end

def find_retrier(error)
self.retry.detect { |r| r.match_error?(error) }
end

def find_catcher(error)
self.catch.detect { |c| c.match_error?(error) }
end

def retry_state!(context, error)
retrier = find_retrier(error["Error"]) if error
return if retrier.nil?

# If a different retrier is hit reset the context
if !context["State"].key?("RetryCount") || context["State"]["Retrier"] != retrier.error_equals
context["State"]["RetryCount"] = 0
context["State"]["Retrier"] = retrier.error_equals
end

context["State"]["RetryCount"] += 1

return if context["State"]["RetryCount"] > retrier.max_attempts

wait_until!(context, :seconds => retrier.sleep_duration(context["State"]["RetryCount"]))
context.next_state = context.state_name
context.output = error
logger.info("Running state: [#{long_name}] with input [#{context.json_input}] got error[#{context.json_output}]...Retry - delay: #{wait_until(context)}")
true
end

def catch_error!(context, error)
catcher = find_catcher(error["Error"]) if error
return if catcher.nil?

context.next_state = catcher.next
context.output = catcher.result_path.set(context.input, error)
logger.info("Running state: [#{long_name}] with input [#{context.json_input}]...CatchError - next state: [#{context.next_state}] output: [#{context.json_output}]")

true
end

def fail_workflow!(context, error)
# next_state is nil, and will be set to nil again in super
# keeping in here for completeness
context.next_state = nil
context.output = error
logger.error("Running state: [#{long_name}] with input [#{context.json_input}]...Complete workflow - output: [#{context.json_output}]")
end

def parse_error(output)
return if output.nil?
return output if output.kind_of?(Hash)
Expand Down

0 comments on commit 6ffb6e9

Please sign in to comment.