Source code for viam.components.input.client

import asyncio
from threading import Lock, RLock
from time import time
from typing import Any, Dict, List, Mapping, Optional

from google.protobuf.struct_pb2 import Struct
from grpclib import GRPCError, Status
from grpclib.client import Channel

import viam
from viam.errors import NotSupportedError
from viam.logging import getLogger
from viam.proto.common import DoCommandRequest, DoCommandResponse, Geometry
from viam.proto.component.inputcontroller import (
    GetControlsRequest,
    GetControlsResponse,
    GetEventsRequest,
    GetEventsResponse,
    InputControllerServiceStub,
    StreamEventsRequest,
    StreamEventsResponse,
    TriggerEventRequest,
)
from viam.resource.rpc_client_base import ReconfigurableResourceRPCClientBase, ResourceRPCClientBase
from viam.utils import ValueTypes, dict_to_struct, get_geometries, struct_to_dict

from .input import Control, ControlFunction, Controller, Event, EventType

LOGGER = getLogger(__name__)


[docs]class ControllerClient(Controller, ReconfigurableResourceRPCClientBase): """gRPC client for an Input Controller""" def __init__(self, name: str, channel: Channel): self.channel = channel self.client = InputControllerServiceStub(channel) self.callbacks: Dict[Control, Dict[EventType, Optional[ControlFunction]]] = {} self._lock = RLock() self._stream_lock = Lock() self._is_streaming = False self._is_stream_ready = False self._callback_extra: Struct = dict_to_struct({}) super().__init__(name)
[docs] async def get_controls( self, *, extra: Optional[Dict[str, Any]] = None, timeout: Optional[float] = None, **kwargs, ) -> List[Control]: md = kwargs.get("metadata", self.Metadata()).proto request = GetControlsRequest(controller=self.name, extra=dict_to_struct(extra)) response: GetControlsResponse = await self.client.GetControls(request, timeout=timeout, metadata=md) return [Control(control) for control in response.controls]
[docs] async def get_events( self, *, extra: Optional[Dict[str, Any]] = None, timeout: Optional[float] = None, **kwargs, ) -> Dict[Control, Event]: md = kwargs.get("metadata", self.Metadata()).proto request = GetEventsRequest(controller=self.name, extra=dict_to_struct(extra)) response: GetEventsResponse = await self.client.GetEvents(request, timeout=timeout, metadata=md) return {Control(event.control): Event.from_proto(event) for (event) in response.events}
[docs] def register_control_callback( self, control: Control, triggers: List[EventType], function: Optional[ControlFunction], extra: Optional[Dict[str, Any]] = None, **kwargs, ): md = kwargs.get("metadata", self.Metadata()) self._callback_extra = dict_to_struct(extra) with self._lock: callbacks = self.callbacks.get(control, {}) for trigger in triggers: if trigger == EventType.BUTTON_CHANGE: callbacks[EventType.BUTTON_PRESS] = function callbacks[EventType.BUTTON_RELEASE] = function else: callbacks[trigger] = function self.callbacks[control] = callbacks def handle_task_result(task: asyncio.Task): try: result = task.result() LOGGER.debug(f"Task {task.get_name()} returned with result {result}") except asyncio.CancelledError: pass except Exception: LOGGER.exception("Exception raised by task = %r", task) task = asyncio.create_task(self._stream_events(md), name=f"{viam._TASK_PREFIX}-input_stream_events") task.add_done_callback(handle_task_result)
[docs] def reset_channel(self, channel: Channel): super().reset_channel(channel) with self._lock: for control, callback in self.callbacks.items(): for event_type, func in callback.items(): self.register_control_callback(control, [event_type], func)
[docs] async def trigger_event( self, event: Event, *, extra: Optional[Dict[str, Any]] = None, timeout: Optional[float] = None, **kwargs, ): md = kwargs.get("metadata", self.Metadata()).proto request = TriggerEventRequest(controller=self.name, event=event.proto, extra=dict_to_struct(extra)) try: await self.client.TriggerEvent(request, timeout=timeout, metadata=md) except GRPCError as e: if e.status == Status.UNIMPLEMENTED and ("does not support triggering events" in e.message if e.message else False): raise NotSupportedError(f"Input controller named {self.name} does not support triggering events") raise e
async def _stream_events(self, metadata: ResourceRPCClientBase.Metadata): with self._stream_lock: if self._is_streaming: return self._is_streaming = True if not self.callbacks: return md = metadata.proto request = StreamEventsRequest(controller=self.name, events=[], extra=self._callback_extra) with self._lock: for control, callbacks in self.callbacks.items(): event = StreamEventsRequest.Events( control=control, events=[et for (et, func) in callbacks.items() if func is not None], cancelled_events=[et for (et, func) in callbacks.items() if func is None], ) request.events.append(event) try: async with self.client.StreamEvents.open(metadata=md) as stream: await stream.send_message(request, end=True) self._send_connection_status(True) reply: StreamEventsResponse async for reply in stream: event = reply.event self._execute_callback(Event.from_proto(event)) except Exception as e: LOGGER.error(e) finally: self._send_connection_status(False) with self._stream_lock: self._is_streaming = False self._is_stream_ready = False def _send_connection_status(self, connected: bool): for control in self.callbacks.keys(): event = Event(time=time(), event=EventType.CONNECT if connected else EventType.DISCONNECT, control=control, value=0) self._execute_callback(event) def _execute_callback(self, event: Event): try: callbacks = self.callbacks[event.control] callback = callbacks.get(event.event, None) all_callback = callbacks.get(EventType.ALL_EVENTS, None) except KeyError: return if callback is not None: callback(event) if all_callback is not None: all_callback(event)
[docs] async def do_command( self, command: Mapping[str, ValueTypes], *, timeout: Optional[float] = None, **kwargs, ) -> Mapping[str, ValueTypes]: md = kwargs.get("metadata", self.Metadata()).proto request = DoCommandRequest(name=self.name, command=dict_to_struct(command)) response: DoCommandResponse = await self.client.DoCommand(request, timeout=timeout, metadata=md) return struct_to_dict(response.result)
[docs] async def get_geometries(self, *, extra: Optional[Dict[str, Any]] = None, timeout: Optional[float] = None, **kwargs) -> List[Geometry]: md = kwargs.get("metadata", self.Metadata()) return await get_geometries(self.client, self.name, extra, timeout, md)