Skip to content

Commit

Permalink
fix(api): Reconnect WebSocket when resuming app from a paused state (#…
Browse files Browse the repository at this point in the history
…5567)

* fix(api): Reconnect WebSocket when resuming app from a paused state
  • Loading branch information
tyllark authored Oct 18, 2024
1 parent 6e6edab commit 8222a03
Show file tree
Hide file tree
Showing 11 changed files with 365 additions and 2 deletions.
6 changes: 5 additions & 1 deletion packages/api/amplify_api/lib/amplify_api.dart
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,8 @@ library amplify_api;

export 'package:amplify_api/src/api_plugin_impl.dart';
export 'package:amplify_api_dart/amplify_api_dart.dart'
hide AmplifyAPIDart, ConnectivityPlatform, ConnectivityStatus;
hide
AmplifyAPIDart,
ConnectivityPlatform,
ProcessLifeCycle,
ConnectivityStatus;
2 changes: 2 additions & 0 deletions packages/api/amplify_api/lib/src/api_plugin_impl.dart
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// SPDX-License-Identifier: Apache-2.0

import 'package:amplify_api/src/connectivity_plus_platform.dart';
import 'package:amplify_api/src/flutter_life_cycle.dart';
import 'package:amplify_api_dart/amplify_api_dart.dart';
import 'package:amplify_core/amplify_core.dart';

Expand All @@ -14,6 +15,7 @@ class AmplifyAPI extends AmplifyAPIDart with AWSDebuggable {
super.options,
}) : super(
connectivity: const ConnectivityPlusPlatform(),
processLifeCycle: FlutterLifeCycle(),
);

@override
Expand Down
42 changes: 42 additions & 0 deletions packages/api/amplify_api/lib/src/flutter_life_cycle.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

import 'dart:async';

import 'package:amplify_api_dart/amplify_api_dart.dart';
import 'package:flutter/widgets.dart';
import 'package:meta/meta.dart';

/// {@template amplify_api.flutter_life_cycle}
/// Creates a stream of [ProcessStatus] mapped from [AppLifecycleListener](https://api.flutter.dev/flutter/widgets/AppLifecycleListener-class.html).
/// {@endtemplate}
@internal
class FlutterLifeCycle extends ProcessLifeCycle {
/// {@macro amplify_api.flutter_life_cycle}
FlutterLifeCycle() {
AppLifecycleListener(
onStateChange: _onStateChange,
);
}

final _stateController =
StreamController<ProcessStatus>.broadcast(sync: true);

@override
Stream<ProcessStatus> get onStateChanged => _stateController.stream;

void _onStateChange(AppLifecycleState state) {
switch (state) {
case AppLifecycleState.detached:
_stateController.add(ProcessStatus.detached);
case AppLifecycleState.paused:
_stateController.add(ProcessStatus.paused);
case AppLifecycleState.hidden:
_stateController.add(ProcessStatus.hidden);
case AppLifecycleState.inactive:
_stateController.add(ProcessStatus.inactive);
case AppLifecycleState.resumed:
_stateController.add(ProcessStatus.resumed);
}
}
}
1 change: 1 addition & 0 deletions packages/api/amplify_api_dart/lib/amplify_api_dart.dart
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ export 'src/graphql/model_helpers/model_subscriptions.dart';

/// Network connectivity util not needed by consumers of Flutter package amplify_api.
export 'src/graphql/web_socket/types/connectivity_platform.dart';
export 'src/graphql/web_socket/types/process_life_cycle.dart';
9 changes: 8 additions & 1 deletion packages/api/amplify_api_dart/lib/src/api_plugin_impl.dart
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import 'package:amplify_api_dart/src/graphql/web_socket/blocs/web_socket_bloc.da
import 'package:amplify_api_dart/src/graphql/web_socket/services/web_socket_service.dart';
import 'package:amplify_api_dart/src/graphql/web_socket/state/web_socket_state.dart';
import 'package:amplify_api_dart/src/graphql/web_socket/types/connectivity_platform.dart';
import 'package:amplify_api_dart/src/graphql/web_socket/types/process_life_cycle.dart';
import 'package:amplify_api_dart/src/util/amplify_api_config.dart';
import 'package:amplify_api_dart/src/util/amplify_authorization_rest_client.dart';
import 'package:amplify_core/amplify_core.dart';
Expand All @@ -30,8 +31,10 @@ class AmplifyAPIDart extends APIPluginInterface with AWSDebuggable {
AmplifyAPIDart({
APIPluginOptions options = const APIPluginOptions(),
ConnectivityPlatform connectivity = const ConnectivityPlatform(),
ProcessLifeCycle processLifeCycle = const ProcessLifeCycle(),
}) : _options = options,
_connectivity = connectivity {
_connectivity = connectivity,
_processLifeCycle = processLifeCycle {
_options.authProviders.forEach(registerAuthProvider);
}

Expand All @@ -43,6 +46,9 @@ class AmplifyAPIDart extends APIPluginInterface with AWSDebuggable {
/// Creates a stream representing network connectivity at the hardware level.
final ConnectivityPlatform _connectivity;

/// Creates a stream representing the process life cycle state.
final ProcessLifeCycle _processLifeCycle;

/// A map of the keys from the Amplify API config with auth modes to HTTP clients
/// to use for requests to that endpoint/auth mode. e.g. { "myEndpoint.AWS_IAM": AWSHttpClient}
final Map<String, AWSHttpClient> _clientPool = {};
Expand Down Expand Up @@ -277,6 +283,7 @@ class AmplifyAPIDart extends APIPluginInterface with AWSDebuggable {
wsService: AmplifyWebSocketService(),
subscriptionOptions: _options.subscriptionOptions,
connectivity: _connectivity,
processLifeCycle: _processLifeCycle,
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import 'package:amplify_api_dart/src/graphql/web_socket/services/web_socket_serv
import 'package:amplify_api_dart/src/graphql/web_socket/state/web_socket_state.dart';
import 'package:amplify_api_dart/src/graphql/web_socket/state/ws_subscriptions_state.dart';
import 'package:amplify_api_dart/src/graphql/web_socket/types/connectivity_platform.dart';
import 'package:amplify_api_dart/src/graphql/web_socket/types/process_life_cycle.dart';
import 'package:amplify_api_dart/src/graphql/web_socket/types/subscriptions_event.dart';
import 'package:amplify_api_dart/src/graphql/web_socket/types/web_socket_types.dart';
import 'package:amplify_core/amplify_core.dart' hide SubscriptionEvent;
Expand All @@ -33,8 +34,10 @@ class WebSocketBloc with AWSDebuggable, AmplifyLoggerMixin {
required WebSocketService wsService,
required GraphQLSubscriptionOptions subscriptionOptions,
required ConnectivityPlatform connectivity,
required ProcessLifeCycle processLifeCycle,
AWSHttpClient? pollClientOverride,
}) : _connectivity = connectivity,
_processLifeCycle = processLifeCycle,
_pollClient = pollClientOverride ?? AWSHttpClient() {
final subBlocs = <String, SubscriptionBloc<Object?>>{};

Expand All @@ -49,6 +52,7 @@ class WebSocketBloc with AWSDebuggable, AmplifyLoggerMixin {
);
final blocStream = _wsEventStream.asyncExpand(_eventTransformer);
_networkSubscription = _getConnectivityStream();
_processLifeCycleSubscription = _getProcessLifecycleStream();
_stateSubscription = blocStream.listen(_emit);
add(const InitEvent());
}
Expand Down Expand Up @@ -81,10 +85,14 @@ class WebSocketBloc with AWSDebuggable, AmplifyLoggerMixin {
late final Stream<WebSocketEvent> _wsEventStream = _wsEventController.stream;
late final StreamSubscription<WebSocketState> _stateSubscription;
late final StreamSubscription<ConnectivityStatus> _networkSubscription;
late final StreamSubscription<ProcessStatus> _processLifeCycleSubscription;

/// Creates a stream representing network connectivity at the hardware level.
final ConnectivityPlatform _connectivity;

/// Creates a stream representing the process life cycle state.
final ProcessLifeCycle _processLifeCycle;

/// The underlying event stream, used only in testing.
@visibleForTesting
Stream<WebSocketEvent> get wsEventStream => _wsEventStream;
Expand Down Expand Up @@ -164,6 +172,8 @@ class WebSocketBloc with AWSDebuggable, AmplifyLoggerMixin {
yield* _networkLoss();
} else if (event is NetworkFoundEvent) {
yield* _networkFound();
} else if (event is ProcessResumeEvent) {
yield* _processResumed();
} else if (event is PollSuccessEvent) {
yield* _pollSuccess();
} else if (event is PollFailedEvent) {
Expand Down Expand Up @@ -328,6 +338,16 @@ class WebSocketBloc with AWSDebuggable, AmplifyLoggerMixin {
yield* const Stream.empty();
}

Stream<WebSocketState> _processResumed() async* {
final state = _currentState;
if (state is ConnectedState) {
yield state.reconnecting(networkState: NetworkState.disconnected);
add(const ReconnectEvent());
}
// TODO(dnys1): Yield broken on web debug build.
yield* const Stream.empty();
}

/// Handle successful polls
Stream<WebSocketState> _pollSuccess() async* {
// TODO(dnys1): Yield broken on web debug build.
Expand Down Expand Up @@ -467,6 +487,7 @@ class WebSocketBloc with AWSDebuggable, AmplifyLoggerMixin {
await Future.wait<void>([
// TODO(equartey): https://github.com/fluttercommunity/plus_plugins/issues/1382
if (!isWindows()) _networkSubscription.cancel(),
_processLifeCycleSubscription.cancel(),
Future.value(_pollClient.close()),
_stateSubscription.cancel(),
_wsEventController.close(),
Expand Down Expand Up @@ -507,6 +528,41 @@ class WebSocketBloc with AWSDebuggable, AmplifyLoggerMixin {
);
}

/// Process life cycle stream monitors when the process resumes from a paused state.
StreamSubscription<ProcessStatus> _getProcessLifecycleStream() {
var prev = ProcessStatus.detached;
return _processLifeCycle.onStateChanged.listen(
(state) {
if (_isResuming(state, prev)) {
// ignore: invalid_use_of_internal_member
if (!WebSocketOptions.autoReconnect) {
_shutdownWithException(
const NetworkException(
'Unable to recover network connection, web socket will close.',
recoverySuggestion: 'Avoid pausing the process.',
),
StackTrace.current,
);
} else {
add(const ProcessResumeEvent());
}
}

prev = state;
},
onError: (Object e, StackTrace st) =>
logger.error('Error in process life cycle stream $e, $st'),
);
}

bool _isResuming(ProcessStatus current, ProcessStatus previous) {
if (previous != ProcessStatus.paused) return false;

return current == ProcessStatus.hidden ||
current == ProcessStatus.inactive ||
current == ProcessStatus.resumed;
}

Future<void> _poll() async {
try {
final res = await _sendPollRequest();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

/// Possible process life cycle states
enum ProcessStatus {
/// Engine is running without a view.
detached,

/// Application is not visible to the user or responding to user input.
paused,

/// All views of an application are hidden.
hidden,

/// A view of the application is visible, but none have input.
inactive,

/// Default running mode.
resumed,
}

/// {@template amplify_api_dart.process_life_cycle}
/// Used to create a stream representing the process life cycle state.
///
/// The generated stream is empty.
/// {@endtemplate}
class ProcessLifeCycle {
/// {@macro amplify_api_dart.process_life_cycle}
const ProcessLifeCycle();

/// Generates a new stream of [ProcessStatus].
Stream<ProcessStatus> get onStateChanged => const Stream.empty();
}
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,19 @@ class NetworkLossEvent extends NetworkEvent {
String get runtimeTypeName => 'NetworkLossEvent';
}

/// Discrete class for when the process is resumed
/// Triggers when AppLifecycleListener detects the process has been resumed.
class ProcessResumeEvent extends WebSocketEvent {
/// Create a process resumed event
const ProcessResumeEvent();

@override
String get runtimeTypeName => 'ProcessResumeEvent';

@override
Map<String, Object?> toJson() => const {};
}

/// Triggers when a successful ping to AppSync is made
class PollSuccessEvent extends WebSocketEvent {
/// Create a successful Poll event
Expand Down
4 changes: 4 additions & 0 deletions packages/api/amplify_api_dart/test/graphql_test.dart
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ void main() {
'payload': {'data': mockSubscriptionData},
};

mockProcessLifeCycleController = StreamController<ProcessStatus>();
mockWebSocketService = MockWebSocketService();
const subscriptionOptions = GraphQLSubscriptionOptions(
pollInterval: Duration(seconds: 1),
Expand All @@ -292,6 +293,7 @@ void main() {
subscriptionOptions: subscriptionOptions,
pollClientOverride: mockClient.client,
connectivity: const ConnectivityPlatform(),
processLifeCycle: const MockProcessLifeCycle(),
);

sendMockConnectionAck(mockWebSocketBloc!, mockWebSocketService!);
Expand Down Expand Up @@ -599,6 +601,7 @@ void main() {
});

test('should have correct state flow during a failure', () async {
mockProcessLifeCycleController = StreamController<ProcessStatus>();
mockWebSocketService = MockWebSocketService();
const subscriptionOptions = GraphQLSubscriptionOptions(
pollInterval: Duration(seconds: 1),
Expand All @@ -613,6 +616,7 @@ void main() {
subscriptionOptions: subscriptionOptions,
pollClientOverride: mockClient.client,
connectivity: const ConnectivityPlatform(),
processLifeCycle: const MockProcessLifeCycle(),
);

final blocReady = Completer<void>();
Expand Down
11 changes: 11 additions & 0 deletions packages/api/amplify_api_dart/test/util.dart
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import 'package:amplify_api_dart/src/graphql/web_socket/blocs/web_socket_bloc.da
import 'package:amplify_api_dart/src/graphql/web_socket/services/web_socket_service.dart';
import 'package:amplify_api_dart/src/graphql/web_socket/state/web_socket_state.dart';
import 'package:amplify_api_dart/src/graphql/web_socket/types/connectivity_platform.dart';
import 'package:amplify_api_dart/src/graphql/web_socket/types/process_life_cycle.dart';
import 'package:amplify_api_dart/src/graphql/web_socket/types/web_socket_types.dart';
import 'package:amplify_core/amplify_core.dart';
import 'package:amplify_core/src/config/amplify_outputs/data/data_outputs.dart';
Expand Down Expand Up @@ -329,6 +330,16 @@ class MockConnectivity extends ConnectivityPlatform {
mockNetworkStreamController.stream;
}

late StreamController<ProcessStatus> mockProcessLifeCycleController;

class MockProcessLifeCycle extends ProcessLifeCycle {
const MockProcessLifeCycle();

@override
Stream<ProcessStatus> get onStateChanged =>
mockProcessLifeCycleController.stream;
}

/// Ensures a query predicate converts to JSON correctly.
void testQueryPredicateTranslation(
QueryPredicate? queryPredicate,
Expand Down
Loading

0 comments on commit 8222a03

Please sign in to comment.