Source code for viam.sessions_client

import asyncio
import importlib
import pkgutil
from copy import deepcopy
from datetime import timedelta
from enum import IntEnum
from threading import Lock, Thread
from typing import MutableMapping, Optional

from grpclib import Status
from grpclib.client import Channel
from grpclib.events import RecvTrailingMetadata, SendRequest, listen
from grpclib.exceptions import GRPCError, StreamTerminatedError
from grpclib.metadata import _MetadataLike

from viam import logging
from viam.gen.common.v1.common_pb2 import safety_heartbeat_monitored
from viam.proto.robot import RobotServiceStub, SendSessionHeartbeatRequest, StartSessionRequest, StartSessionResponse
from viam.rpc.dial import DialOptions, dial

LOGGER = logging.getLogger(__name__)
SESSION_METADATA_KEY = "viam-sid"


class _SupportedState(IntEnum):
    UNKNOWN = 0
    TRUE = 1
    FALSE = 2


[docs]class SessionsClient: """ A Session allows a client to express that it is actively connected and supports stopping actuating components when it's not. """ channel: Channel client: RobotServiceStub _address: str _dial_options: DialOptions _disabled: bool _lock: Lock _current_id: str _heartbeat_interval: Optional[timedelta] _supported: _SupportedState _thread: Optional[Thread] _HEARTBEAT_MONITORED_METHODS: MutableMapping[str, bool] = {} def __init__(self, channel: Channel, direct_dial_address: str, dial_options: Optional[DialOptions], *, disabled: bool = False): self.channel = channel self.client = RobotServiceStub(channel) self._address = direct_dial_address self._disabled = disabled self._dial_options = deepcopy(dial_options) if dial_options is not None else DialOptions() self._dial_options.disable_webrtc = True self._lock = Lock() self._current_id = "" self._heartbeat_interval = None self._supported = _SupportedState.UNKNOWN self._thread = None listen(self.channel, SendRequest, self._send_request) listen(self.channel, RecvTrailingMetadata, self._recv_trailers)
[docs] def reset(self): with self._lock: self._reset()
def _reset(self): LOGGER.debug("resetting session") self._supported = _SupportedState.UNKNOWN self._current_id = "" self._heartbeat_interval = None if self._thread is not None: try: self._thread.join(timeout=1) except RuntimeError: LOGGER.debug("failed to join session heartbeat thread") self._thread = None async def _send_request(self, event: SendRequest): if self._disabled: return if not self._is_safety_heartbeat_monitored(event.method_name): return event.metadata.update(await self.metadata) async def _recv_trailers(self, event: RecvTrailingMetadata): if event.status == Status.INVALID_ARGUMENT and event.status_message == "SESSION_EXPIRED": LOGGER.debug("Session expired") self.reset() @property async def metadata(self) -> _MetadataLike: with self._lock: if self._disabled or self._supported != _SupportedState.UNKNOWN: return self._metadata request = StartSessionRequest(resume=self._current_id) try: response: StartSessionResponse = await self.client.StartSession(request) except GRPCError as error: if error.status == Status.UNIMPLEMENTED: with self._lock: self._reset() self._supported = _SupportedState.FALSE return self._metadata else: raise if response is None: raise GRPCError(status=Status.INTERNAL, message="Expected response to start session") if response.heartbeat_window is None: raise GRPCError(status=Status.INTERNAL, message="Expected heartbeat window in response to start session") with self._lock: self._supported = _SupportedState.TRUE self._heartbeat_interval = response.heartbeat_window.ToTimedelta() self._current_id = response.id # tick once to ensure heartbeats are supported await self._heartbeat_tick(self.client) with self._lock: if self._thread is not None: self._reset() if self._supported == _SupportedState.TRUE: # We send heartbeats faster than the interval window to # ensure that we don't fall outside of it and expire the session. wait = self._heartbeat_interval.total_seconds() / 5 self._thread = Thread( name="heartbeat-thread", target=asyncio.run, args=(self._heartbeat_process(wait),), daemon=True, ) self._thread.start() return self._metadata async def _heartbeat_tick(self, client: RobotServiceStub): with self._lock: if not self._current_id: LOGGER.debug("Failed to send heartbeat, session client reset") return request = SendSessionHeartbeatRequest(id=self._current_id) try: await client.SendSessionHeartbeat(request) except (GRPCError, StreamTerminatedError): LOGGER.debug("Heartbeat terminated", exc_info=True) self.reset() else: LOGGER.debug("Sent heartbeat successfully") async def _heartbeat_process(self, wait: float): channel = await dial(address=self._address, options=self._dial_options) client = RobotServiceStub(channel.channel) while True: with self._lock: if self._supported != _SupportedState.TRUE: return await self._heartbeat_tick(client) await asyncio.sleep(wait) @property def _metadata(self) -> _MetadataLike: if self._supported == _SupportedState.TRUE and self._current_id != "": return {SESSION_METADATA_KEY: self._current_id} return {} def _is_safety_heartbeat_monitored(self, method: str) -> bool: if method in self._HEARTBEAT_MONITORED_METHODS: return self._HEARTBEAT_MONITORED_METHODS[method] parts = method.split("/") if len(parts) != 3: self._HEARTBEAT_MONITORED_METHODS[method] = False return False service_path = parts[1] method_name = parts[2] parts = service_path.split(".") if len(parts) < 5: self._HEARTBEAT_MONITORED_METHODS[method] = False return False if parts[0] != "viam": self._HEARTBEAT_MONITORED_METHODS[method] = False return False resource_type = parts[1] resource_subtype = parts[2] version = parts[3] service_name = parts[4] try: module = importlib.import_module(f"viam.gen.{resource_type}.{resource_subtype}.{version}") submods = pkgutil.iter_modules(module.__path__) for mod in submods: if "_pb2" in mod.name: submod = getattr(module, mod.name) DESCRIPTOR = getattr(submod, "DESCRIPTOR") for service in DESCRIPTOR.services_by_name.values(): if service.name == service_name: for method_actual in service.methods: if method_actual.name == method_name: options = method_actual.GetOptions() if options.HasExtension(safety_heartbeat_monitored): is_monitored = options.Extensions[safety_heartbeat_monitored] self._HEARTBEAT_MONITORED_METHODS[method] = is_monitored return is_monitored self._HEARTBEAT_MONITORED_METHODS[method] = False return False self._HEARTBEAT_MONITORED_METHODS[method] = False return False except Exception: self._HEARTBEAT_MONITORED_METHODS[method] = False return False