-
Notifications
You must be signed in to change notification settings - Fork 54
/
Copy pathsessions_client.py
185 lines (154 loc) · 6.45 KB
/
sessions_client.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
import asyncio
from copy import deepcopy
from datetime import timedelta
from enum import IntEnum
from threading import Lock, Thread
from typing import Optional
from grpclib import Status
from grpclib.client import Channel
from grpclib.events import RecvTrailingMetadata, SendRequest, listen
from grpclib.exceptions import GRPCError, StreamTerminatedError
from grpclib.metadata import _MetadataLike
from viam import logging
from viam.proto.robot import RobotServiceStub, SendSessionHeartbeatRequest, StartSessionRequest, StartSessionResponse
from viam.rpc.dial import DialOptions, dial
LOGGER = logging.getLogger(__name__)
SESSION_METADATA_KEY = "viam-sid"
EXEMPT_METADATA_METHODS = frozenset(
[
"/grpc.reflection.v1alpha.ServerReflection/ServerReflectionInfo",
"/proto.rpc.webrtc.v1.SignalingService/Call",
"/proto.rpc.webrtc.v1.SignalingService/CallUpdate",
"/proto.rpc.webrtc.v1.SignalingService/OptionalWebRTCConfig",
"/proto.rpc.v1.AuthService/Authenticate",
"/viam.robot.v1.RobotService/ResourceNames",
"/viam.robot.v1.RobotService/ResourceRPCSubtypes",
"/viam.robot.v1.RobotService/StartSession",
"/viam.robot.v1.RobotService/SendSessionHeartbeat",
]
)
class _SupportedState(IntEnum):
UNKNOWN = 0
TRUE = 1
FALSE = 2
class SessionsClient:
"""
A Session allows a client to express that it is actively connected and
supports stopping actuating components when it's not.
"""
channel: Channel
client: RobotServiceStub
_address: str
_dial_options: DialOptions
_disabled: bool
_lock: Lock
_current_id: str
_heartbeat_interval: Optional[timedelta]
_supported: _SupportedState
_thread: Optional[Thread]
def __init__(self, channel: Channel, direct_dial_address: str, dial_options: Optional[DialOptions], *, disabled: bool = False):
self.channel = channel
self.client = RobotServiceStub(channel)
self._address = direct_dial_address
self._disabled = disabled
self._dial_options = deepcopy(dial_options) if dial_options is not None else DialOptions()
self._dial_options.disable_webrtc = True
self._lock = Lock()
self._current_id = ""
self._heartbeat_interval = None
self._supported = _SupportedState.UNKNOWN
self._thread = None
listen(self.channel, SendRequest, self._send_request)
listen(self.channel, RecvTrailingMetadata, self._recv_trailers)
def reset(self):
with self._lock:
self._reset()
def _reset(self):
LOGGER.debug("resetting session")
self._supported = _SupportedState.UNKNOWN
self._current_id = ""
self._heartbeat_interval = None
if self._thread is not None:
try:
self._thread.join(timeout=1)
except RuntimeError:
LOGGER.debug("failed to join session heartbeat thread")
self._thread = None
async def _send_request(self, event: SendRequest):
if self._disabled:
return
if event.method_name in EXEMPT_METADATA_METHODS:
return
event.metadata.update(await self.metadata)
async def _recv_trailers(self, event: RecvTrailingMetadata):
if event.status == Status.INVALID_ARGUMENT and event.status_message == "SESSION_EXPIRED":
LOGGER.debug("Session expired")
self.reset()
@property
async def metadata(self) -> _MetadataLike:
with self._lock:
if self._disabled or self._supported != _SupportedState.UNKNOWN:
return self._metadata
request = StartSessionRequest(resume=self._current_id)
try:
response: StartSessionResponse = await self.client.StartSession(request)
except GRPCError as error:
if error.status == Status.UNIMPLEMENTED:
with self._lock:
self._reset()
self._supported = _SupportedState.FALSE
return self._metadata
else:
raise
if response is None:
raise GRPCError(status=Status.INTERNAL, message="Expected response to start session")
if response.heartbeat_window is None:
raise GRPCError(status=Status.INTERNAL, message="Expected heartbeat window in response to start session")
with self._lock:
self._supported = _SupportedState.TRUE
self._heartbeat_interval = response.heartbeat_window.ToTimedelta()
self._current_id = response.id
# tick once to ensure heartbeats are supported
await self._heartbeat_tick(self.client)
with self._lock:
if self._thread is not None:
self._reset()
if self._supported == _SupportedState.TRUE:
# We send heartbeats faster than the interval window to
# ensure that we don't fall outside of it and expire the session.
wait = self._heartbeat_interval.total_seconds() / 5
self._thread = Thread(
name="heartbeat-thread",
target=asyncio.run,
args=(self._heartbeat_process(wait),),
daemon=True,
)
self._thread.start()
return self._metadata
async def _heartbeat_tick(self, client: RobotServiceStub):
with self._lock:
if not self._current_id:
LOGGER.debug("Failed to send heartbeat, session client reset")
return
request = SendSessionHeartbeatRequest(id=self._current_id)
try:
await client.SendSessionHeartbeat(request)
except (GRPCError, StreamTerminatedError):
LOGGER.debug("Heartbeat terminated", exc_info=True)
self.reset()
else:
LOGGER.debug("Sent heartbeat successfully")
async def _heartbeat_process(self, wait: float):
channel = await dial(address=self._address, options=self._dial_options)
client = RobotServiceStub(channel.channel)
while True:
with self._lock:
if self._supported != _SupportedState.TRUE:
return
await self._heartbeat_tick(client)
await asyncio.sleep(wait)
@property
def _metadata(self) -> _MetadataLike:
if self._supported == _SupportedState.TRUE and self._current_id != "":
return {SESSION_METADATA_KEY: self._current_id}
return {}