Source code for concepts.dm.crow.interfaces.controller_interface
#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File : controller_interface.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 03/16/2024
#
# This file is part of Project Concepts.
# Distributed under terms of the MIT license.
"""A controller interface connects the controller commands output by a policy or a planner to the robot simulation/physical system.
Here, we distinguish between the simulation interface and the physical interface by whether they support state save/restore.
"""
import contextlib
from typing import Optional, Tuple, Dict, Callable
from concepts.dsl.tensor_value import TensorValue
from concepts.dm.crow.controller import CrowControllerApplier
from concepts.dm.crow.crow_domain import CrowState
from concepts.dm.crow.executors.crow_executor import CrowExecutor
__all__ = ['CrowControllerExecutionError', 'CrowControllerInterfaceBase', 'CrowSimulationControllerInterface', 'CrowPhysicalControllerInterface']
[docs]
class CrowControllerExecutionError(Exception):
pass
[docs]
class CrowControllerInterfaceBase(object):
"""The base class for all controller interfaces.
The convention of the controller interface is that it takes a controller name and a list of arguments, and then
calls the corresponding controller function with the arguments. If the execution fails, it should raise an exception.
"""
[docs]
def __init__(self, executor: Optional[CrowExecutor] = None):
self._executor = executor
self._controllers = dict()
@property
def executor(self) -> Optional[CrowExecutor]:
return self._executor
@property
def controllers(self) -> Dict[str, Callable]:
return self._controllers
[docs]
def register_controller(self, name: str, function: Callable):
self.controllers[name] = function
return self
[docs]
def step(self, action: CrowControllerApplier, **kwargs) -> None:
return self.step_internal(action.name, *action.arguments, **kwargs)
[docs]
def step_without_error(self, action: CrowControllerApplier, **kwargs) -> bool:
try:
self.step(action, **kwargs)
except CrowControllerExecutionError:
return False
return True
[docs]
def step_internal(self, name: str, *args, **kwargs) -> None:
if name not in self.controllers:
raise ValueError(f"Controller {name} not found.")
args = [arg.item() if isinstance(arg, TensorValue) and arg.dtype.is_pyobj_value_type else arg for arg in args]
return self.controllers[name](*args, **kwargs)
[docs]
class CrowSimulationControllerInterface(CrowControllerInterfaceBase):
[docs]
def __init__(self, executor: Optional[CrowExecutor] = None):
super().__init__(executor)
self._action_counter = 0
[docs]
def step_with_saved_state(self, action: CrowControllerApplier, **kwargs) -> Tuple[bool, int]:
"""Step with saved state. If the execution fails, return False and the state identifier.
Args:
action: the action to take.
Returns:
bool: whether the execution is successful.
int: the state identifier.
"""
state_identifier = self.save_state(**kwargs)
try:
self.step(action, **kwargs)
except CrowControllerExecutionError:
return False, state_identifier
return True, state_identifier
[docs]
def step_internal(self, name: str, *args, **kwargs) -> None:
try:
return super().step_internal(name, *args, **kwargs)
finally:
self.increment_action_counter()
[docs]
def reset_action_counter(self):
self._action_counter = 0
[docs]
def get_action_counter(self) -> int:
return self._action_counter
[docs]
def increment_action_counter(self):
self._action_counter += 1
[docs]
def save_state(self, **kwargs) -> int:
raise NotImplementedError
[docs]
def restore_state(self, state_identifier: int, **kwargs):
raise NotImplementedError
[docs]
def get_crow_state(self) -> CrowState:
"""Get the state of the simulation interface."""
raise NotImplementedError()
[docs]
@contextlib.contextmanager
def restore_context(self, verbose: bool = False, **kwargs):
state_identifier = self.save_state()
action_counter = self._action_counter
try:
yield
finally:
self.restore_state(state_identifier)
self._action_counter = action_counter
[docs]
class CrowPhysicalControllerInterface(CrowControllerInterfaceBase):
pass