import asyncio
import importlib
import pkgutil
from copy import deepcopy
from datetime import timedelta
from enum import IntEnum
from threading import Lock, Thread
from typing import MutableMapping, Optional
from grpclib import Status
from grpclib.client import Channel
from grpclib.events import RecvTrailingMetadata, SendRequest, listen
from grpclib.exceptions import GRPCError, StreamTerminatedError
from grpclib.metadata import _MetadataLike
from viam import logging
from viam.gen.common.v1.common_pb2 import safety_heartbeat_monitored
from viam.proto.robot import RobotServiceStub, SendSessionHeartbeatRequest, StartSessionRequest, StartSessionResponse
from viam.rpc.dial import DialOptions, dial
LOGGER = logging.getLogger(__name__)
SESSION_METADATA_KEY = "viam-sid"
class _SupportedState(IntEnum):
UNKNOWN = 0
TRUE = 1
FALSE = 2
[docs]class SessionsClient:
"""
A Session allows a client to express that it is actively connected and
supports stopping actuating components when it's not.
"""
channel: Channel
client: RobotServiceStub
_address: str
_dial_options: DialOptions
_disabled: bool
_lock: Lock
_current_id: str
_heartbeat_interval: Optional[timedelta]
_supported: _SupportedState
_thread: Optional[Thread]
_HEARTBEAT_MONITORED_METHODS: MutableMapping[str, bool] = {}
def __init__(self, channel: Channel, direct_dial_address: str, dial_options: Optional[DialOptions], *, disabled: bool = False):
self.channel = channel
self.client = RobotServiceStub(channel)
self._address = direct_dial_address
self._disabled = disabled
self._dial_options = deepcopy(dial_options) if dial_options is not None else DialOptions()
self._dial_options.disable_webrtc = True
self._lock = Lock()
self._current_id = ""
self._heartbeat_interval = None
self._supported = _SupportedState.UNKNOWN
self._thread = None
listen(self.channel, SendRequest, self._send_request)
listen(self.channel, RecvTrailingMetadata, self._recv_trailers)
[docs] def reset(self):
with self._lock:
self._reset()
def _reset(self):
LOGGER.debug("resetting session")
self._supported = _SupportedState.UNKNOWN
self._current_id = ""
self._heartbeat_interval = None
if self._thread is not None:
try:
self._thread.join(timeout=1)
except RuntimeError:
LOGGER.debug("failed to join session heartbeat thread")
self._thread = None
async def _send_request(self, event: SendRequest):
if self._disabled:
return
if not self._is_safety_heartbeat_monitored(event.method_name):
return
event.metadata.update(await self.metadata)
async def _recv_trailers(self, event: RecvTrailingMetadata):
if event.status == Status.INVALID_ARGUMENT and event.status_message == "SESSION_EXPIRED":
LOGGER.debug("Session expired")
self.reset()
@property
async def metadata(self) -> _MetadataLike:
with self._lock:
if self._disabled or self._supported != _SupportedState.UNKNOWN:
return self._metadata
request = StartSessionRequest(resume=self._current_id)
try:
response: StartSessionResponse = await self.client.StartSession(request)
except GRPCError as error:
if error.status == Status.UNIMPLEMENTED:
with self._lock:
self._reset()
self._supported = _SupportedState.FALSE
return self._metadata
else:
raise
if response is None:
raise GRPCError(status=Status.INTERNAL, message="Expected response to start session")
if response.heartbeat_window is None:
raise GRPCError(status=Status.INTERNAL, message="Expected heartbeat window in response to start session")
with self._lock:
self._supported = _SupportedState.TRUE
self._heartbeat_interval = response.heartbeat_window.ToTimedelta()
self._current_id = response.id
# tick once to ensure heartbeats are supported
await self._heartbeat_tick(self.client)
with self._lock:
if self._thread is not None:
self._reset()
if self._supported == _SupportedState.TRUE:
# We send heartbeats faster than the interval window to
# ensure that we don't fall outside of it and expire the session.
wait = self._heartbeat_interval.total_seconds() / 5
self._thread = Thread(
name="heartbeat-thread",
target=asyncio.run,
args=(self._heartbeat_process(wait),),
daemon=True,
)
self._thread.start()
return self._metadata
async def _heartbeat_tick(self, client: RobotServiceStub):
with self._lock:
if not self._current_id:
LOGGER.debug("Failed to send heartbeat, session client reset")
return
request = SendSessionHeartbeatRequest(id=self._current_id)
try:
await client.SendSessionHeartbeat(request)
except (GRPCError, StreamTerminatedError):
LOGGER.debug("Heartbeat terminated", exc_info=True)
self.reset()
else:
LOGGER.debug("Sent heartbeat successfully")
async def _heartbeat_process(self, wait: float):
channel = await dial(address=self._address, options=self._dial_options)
client = RobotServiceStub(channel.channel)
while True:
with self._lock:
if self._supported != _SupportedState.TRUE:
return
await self._heartbeat_tick(client)
await asyncio.sleep(wait)
@property
def _metadata(self) -> _MetadataLike:
if self._supported == _SupportedState.TRUE and self._current_id != "":
return {SESSION_METADATA_KEY: self._current_id}
return {}
def _is_safety_heartbeat_monitored(self, method: str) -> bool:
if method in self._HEARTBEAT_MONITORED_METHODS:
return self._HEARTBEAT_MONITORED_METHODS[method]
parts = method.split("/")
if len(parts) != 3:
self._HEARTBEAT_MONITORED_METHODS[method] = False
return False
service_path = parts[1]
method_name = parts[2]
parts = service_path.split(".")
if len(parts) < 5:
self._HEARTBEAT_MONITORED_METHODS[method] = False
return False
if parts[0] != "viam":
self._HEARTBEAT_MONITORED_METHODS[method] = False
return False
resource_type = parts[1]
resource_subtype = parts[2]
version = parts[3]
service_name = parts[4]
try:
module = importlib.import_module(f"viam.gen.{resource_type}.{resource_subtype}.{version}")
submods = pkgutil.iter_modules(module.__path__)
for mod in submods:
if "_pb2" in mod.name:
submod = getattr(module, mod.name)
DESCRIPTOR = getattr(submod, "DESCRIPTOR")
for service in DESCRIPTOR.services_by_name.values():
if service.name == service_name:
for method_actual in service.methods:
if method_actual.name == method_name:
options = method_actual.GetOptions()
if options.HasExtension(safety_heartbeat_monitored):
is_monitored = options.Extensions[safety_heartbeat_monitored]
self._HEARTBEAT_MONITORED_METHODS[method] = is_monitored
return is_monitored
self._HEARTBEAT_MONITORED_METHODS[method] = False
return False
self._HEARTBEAT_MONITORED_METHODS[method] = False
return False
except Exception:
self._HEARTBEAT_MONITORED_METHODS[method] = False
return False