Source code for viam.sessions_client

import asyncio
from copy import deepcopy
from datetime import timedelta
from enum import IntEnum
from threading import Lock, Thread
from typing import Optional

from grpclib import Status
from grpclib.client import Channel
from grpclib.events import RecvTrailingMetadata, SendRequest, listen
from grpclib.exceptions import GRPCError, StreamTerminatedError
from grpclib.metadata import _MetadataLike

from viam import logging
from viam.proto.robot import RobotServiceStub, SendSessionHeartbeatRequest, StartSessionRequest, StartSessionResponse
from viam.rpc.dial import DialOptions, dial

LOGGER = logging.getLogger(__name__)
SESSION_METADATA_KEY = "viam-sid"

EXEMPT_METADATA_METHODS = frozenset(
    [
        "/grpc.reflection.v1alpha.ServerReflection/ServerReflectionInfo",
        "/proto.rpc.webrtc.v1.SignalingService/Call",
        "/proto.rpc.webrtc.v1.SignalingService/CallUpdate",
        "/proto.rpc.webrtc.v1.SignalingService/OptionalWebRTCConfig",
        "/proto.rpc.v1.AuthService/Authenticate",
        "/viam.robot.v1.RobotService/ResourceNames",
        "/viam.robot.v1.RobotService/ResourceRPCSubtypes",
        "/viam.robot.v1.RobotService/StartSession",
        "/viam.robot.v1.RobotService/SendSessionHeartbeat",
    ]
)


class _SupportedState(IntEnum):
    UNKNOWN = 0
    TRUE = 1
    FALSE = 2


[docs]class SessionsClient: """ A Session allows a client to express that it is actively connected and supports stopping actuating components when it's not. """ channel: Channel client: RobotServiceStub _address: str _dial_options: DialOptions _disabled: bool _lock: Lock _current_id: str _heartbeat_interval: Optional[timedelta] _supported: _SupportedState _thread: Optional[Thread] def __init__(self, channel: Channel, direct_dial_address: str, dial_options: Optional[DialOptions], *, disabled: bool = False): self.channel = channel self.client = RobotServiceStub(channel) self._address = direct_dial_address self._disabled = disabled self._dial_options = deepcopy(dial_options) if dial_options is not None else DialOptions() self._dial_options.disable_webrtc = True self._lock = Lock() self._current_id = "" self._heartbeat_interval = None self._supported = _SupportedState.UNKNOWN self._thread = None listen(self.channel, SendRequest, self._send_request) listen(self.channel, RecvTrailingMetadata, self._recv_trailers)
[docs] def reset(self): with self._lock: self._reset()
def _reset(self): LOGGER.debug("resetting session") self._supported = _SupportedState.UNKNOWN self._current_id = "" self._heartbeat_interval = None if self._thread is not None: try: self._thread.join(timeout=1) except RuntimeError: LOGGER.debug("failed to join session heartbeat thread") self._thread = None async def _send_request(self, event: SendRequest): if self._disabled: return if event.method_name in EXEMPT_METADATA_METHODS: return event.metadata.update(await self.metadata) async def _recv_trailers(self, event: RecvTrailingMetadata): if event.status == Status.INVALID_ARGUMENT and event.status_message == "SESSION_EXPIRED": LOGGER.debug("Session expired") self.reset() @property async def metadata(self) -> _MetadataLike: with self._lock: if self._disabled or self._supported != _SupportedState.UNKNOWN: return self._metadata request = StartSessionRequest(resume=self._current_id) try: response: StartSessionResponse = await self.client.StartSession(request) except GRPCError as error: if error.status == Status.UNIMPLEMENTED: with self._lock: self._reset() self._supported = _SupportedState.FALSE return self._metadata else: raise if response is None: raise GRPCError(status=Status.INTERNAL, message="Expected response to start session") if response.heartbeat_window is None: raise GRPCError(status=Status.INTERNAL, message="Expected heartbeat window in response to start session") with self._lock: self._supported = _SupportedState.TRUE self._heartbeat_interval = response.heartbeat_window.ToTimedelta() self._current_id = response.id # tick once to ensure heartbeats are supported await self._heartbeat_tick(self.client) with self._lock: if self._thread is not None: self._reset() if self._supported == _SupportedState.TRUE: # We send heartbeats faster than the interval window to # ensure that we don't fall outside of it and expire the session. wait = self._heartbeat_interval.total_seconds() / 5 self._thread = Thread( name="heartbeat-thread", target=asyncio.run, args=(self._heartbeat_process(wait),), daemon=True, ) self._thread.start() return self._metadata async def _heartbeat_tick(self, client: RobotServiceStub): with self._lock: if not self._current_id: LOGGER.debug("Failed to send heartbeat, session client reset") return request = SendSessionHeartbeatRequest(id=self._current_id) try: await client.SendSessionHeartbeat(request) except (GRPCError, StreamTerminatedError): LOGGER.debug("Heartbeat terminated", exc_info=True) self.reset() else: LOGGER.debug("Sent heartbeat successfully") async def _heartbeat_process(self, wait: float): channel = await dial(address=self._address, options=self._dial_options) client = RobotServiceStub(channel.channel) while True: with self._lock: if self._supported != _SupportedState.TRUE: return await self._heartbeat_tick(client) await asyncio.sleep(wait) @property def _metadata(self) -> _MetadataLike: if self._supported == _SupportedState.TRUE and self._current_id != "": return {SESSION_METADATA_KEY: self._current_id} return {}