Source code for concepts.simulator.pybullet.components.multi_robot_controller

#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : multi_robot_controller.py
# Author : Jiayuan Mao
# Email  : maojiayuan@gmail.com
# Date   : 07/8/2024
#
# This file is part of Project Concepts.
# Distributed under terms of the MIT license.

from typing import Optional, Sequence, NamedTuple

import numpy as np

from jacinle.utils.enum import JacEnum
from concepts.simulator.pybullet.components.component_base import BulletComponent
from concepts.simulator.pybullet.components.robot_base import BulletArmRobotBase
from concepts.math.interpolation_utils import gen_linear_spline, get_next_target_linear_spline


class _ControlCommand(NamedTuple):
    func: callable
    kwargs: dict


[docs] class RecordedControlCommandType(JacEnum): BEGIN_SYNC_CONTEXT = 'begin_sync_context' END_SYNC_CONTEXT = 'end_sync_context' DO = 'do' DO_SYNCHRONIZED_QPOS_TRAJECTORIES = 'do_synchronized_qpos_trajectories'
[docs] class RecordedControlDoCommand(NamedTuple): robot_index: int action_name: str kwargs: dict
[docs] class RecordedControlDoSynchronizedQposTrajectoriesCommand(NamedTuple): trajectories: Sequence[Sequence[np.ndarray]] step_size: float gains: float atol: float timeout: float
[docs] class RecordedControlCommand(NamedTuple): type: RecordedControlCommandType payload: Optional[NamedTuple] = None
[docs] class MultiRobotController(BulletComponent):
[docs] def __init__(self, robots: Sequence[BulletArmRobotBase]): assert len(robots) > 0 super().__init__(robots[0].client) self.robots = robots self.current_ctx = None self.recording_enabled = False self.recorded_commands = list()
ACTION_NAME_MAPPING = { 'move_qpos': 'move_qpos_set_control', 'move_qpos_path_v2': 'move_qpos_path_v2_set_control', 'move_cartesian_trajectory': 'move_cartesian_trajectory_set_control', 'open_gripper_free': 'open_gripper_free_set_control', 'close_gripper_free': 'close_gripper_free_set_control', 'grasp': 'grasp_set_control', }
[docs] def enable_recording(self): self.recording_enabled = True
[docs] def disable_recoding(self): self.recording_enabled = False
[docs] def get_concat_qpos(self): return np.concatenate([robot.get_qpos() for robot in self.robots])
[docs] def make_sync_context(self): return MultiRobotControllerContext(self)
[docs] def do(self, robot_index: int, action_name: str, **kwargs) -> _ControlCommand: assert 0 <= robot_index < len(self.robots) if self.recording_enabled: self.recorded_commands.append(RecordedControlCommand( RecordedControlCommandType.DO, RecordedControlDoCommand(robot_index, action_name, kwargs) )) cmd = _ControlCommand( getattr(self.robots[robot_index], self.ACTION_NAME_MAPPING[action_name]), kwargs ) if self.current_ctx is not None: if robot_index not in self.current_ctx.commands: self.current_ctx.commands[robot_index] = list() self.current_ctx.commands[robot_index].append(cmd) return cmd
[docs] def do_synchronized_qpos_trajectories( self, trajectories: Sequence[Sequence[np.ndarray]], step_size: float = 1, gains: float = 0.3, atol: float = 0.03, timeout: float = 20, verbose: bool = False ): assert len(trajectories) == len(self.robots) > 0 nr_length = len(trajectories[0]) for i in range(len(trajectories)): assert len(trajectories[i]) == nr_length if self.recording_enabled: self.recorded_commands.append(RecordedControlCommand( RecordedControlCommandType.DO_SYNCHRONIZED_QPOS_TRAJECTORIES, RecordedControlDoSynchronizedQposTrajectoriesCommand(trajectories, step_size, gains, atol, timeout) )) MultiRobotMoveTrajectory2(self, self.robots, trajectories).move( step_size=step_size, gains=gains, atol=atol, timeout=timeout, verbose=verbose )
[docs] def stable_reset(self, nr_steps=10): robot_qposes = [robot.get_qpos() for robot in self.robots] for i in range(nr_steps): for j, robot in enumerate(self.robots): robot.set_full_hold_position_control(robot_qposes[j]) self.client.step()
[docs] def replay(self, commands): assert self.recording_enabled is False, 'Replay is not allowed when recording is enabled.' for cmd in commands: if cmd.type is RecordedControlCommandType.DO: robot_index = cmd.payload.robot_index action_name = cmd.payload.action_name kwargs = cmd.payload.kwargs self.do(robot_index, action_name, **kwargs) elif cmd.type is RecordedControlCommandType.DO_SYNCHRONIZED_QPOS_TRAJECTORIES: trajectories = cmd.payload.trajectories step_size = cmd.payload.step_size gains = cmd.payload.gains atol = cmd.payload.atol timeout = cmd.payload.timeout self.do_synchronized_qpos_trajectories(trajectories, step_size=step_size, gains=gains, atol=atol, timeout=timeout) elif cmd.type is RecordedControlCommandType.BEGIN_SYNC_CONTEXT: self.make_sync_context().begin() elif cmd.type is RecordedControlCommandType.END_SYNC_CONTEXT: self.current_ctx.end() else: raise ValueError(f'Unknown command type: {cmd.type}')
[docs] class MultiRobotControllerContext(object):
[docs] def __init__(self, controller: MultiRobotController): self.controller = controller self.commands = dict()
[docs] def begin(self): self.controller.current_ctx = self self.commands = dict() if self.controller.recording_enabled: self.controller.recorded_commands.append(RecordedControlCommand(RecordedControlCommandType.BEGIN_SYNC_CONTEXT)) return self
def __enter__(self): return self.begin()
[docs] def end(self): if self.controller.recording_enabled: self.controller.recorded_commands.append(RecordedControlCommand(RecordedControlCommandType.END_SYNC_CONTEXT)) self.run_commands() self.controller.current_ctx = None
def __exit__(self, exc_type, exc_val, exc_tb): self.end()
[docs] def run_commands(self): current_iterators = dict() current_iterator_indices = dict() for robot_index, cmds in self.commands.items(): current_iterator_indices[robot_index] = 0 current_iterators[robot_index] = cmds[0].func(**cmds[0].kwargs) dones = [False] * len(self.controller.robots) qposes = [None] * len(self.controller.robots) for i in range(len(dones)): if i not in self.commands: dones[i] = True qposes[i] = self.controller.robots[i].get_qpos() if all(dones): return timestep = 0 while True: timestep += 1 for i, done in enumerate(dones): # print(f'{timestep=}:: {i=} {done=}') if done: self.controller.robots[i].set_full_hold_position_control(qposes[i]) else: try: next(current_iterators[i]) except StopIteration: if current_iterator_indices[i] + 1 < len(self.commands[i]): current_iterator_indices[i] += 1 cmd = self.commands[i][current_iterator_indices[i]] current_iterators[i] = cmd.func(**cmd.kwargs) else: dones[i] = True qposes[i] = self.controller.robots[i].get_qpos() self.controller.client.step() if all(dones): break
[docs] class MultiRobotMoveTrajectory2(object):
[docs] def __init__(self, controller: MultiRobotController, robots: Sequence[BulletArmRobotBase], qpos_trajectories: Sequence[Sequence[np.ndarray]]): self.controller = controller self.robots = robots self.qpos_trajectories = qpos_trajectories self.concat_qpos_trajectories = self._dedup_qpos_trajectory(np.concatenate([t for t in qpos_trajectories], axis=1)) # (nr_steps, nr_joints * nr_robots) self.q_start_indices = np.cumsum([0] + [len(t[0]) for t in qpos_trajectories])[:-1] self.q_lengths = [len(t[0]) for t in qpos_trajectories]
@property def client(self): return self.robots[0].client def _dedup_qpos_trajectory(self, qpos_trajectory): qpos_trajectory = np.array(qpos_trajectory) qpos_trajectory_dedup = list() last_qpos = None for qpos in qpos_trajectory: if qpos is None: continue if last_qpos is None or np.linalg.norm(qpos - last_qpos, ord=2) > 0.01: qpos_trajectory_dedup.append(qpos) last_qpos = qpos qpos_trajectory = np.array(qpos_trajectory_dedup) return qpos_trajectory
[docs] def set_control( self, step_size: float = 1, gains: float = 0.3, atol: float = 0.03, timeout: float = 20, verbose: bool = False, ): # spl = gen_cubic_spline(qpos_trajectory) spl = gen_linear_spline(self.concat_qpos_trajectories) prev_qpos = None prev_qpos_not_moving = 0 next_id = None for _ in self.client.timeout(timeout): current_qpos = self.controller.get_concat_qpos() # next_target = get_next_target_cubic_spline(spl, current_qpos, step_size, qpos_trajectory) next_id, next_target = get_next_target_linear_spline( spl, current_qpos, step_size, minimum_x=next_id - step_size + 0.2 if next_id is not None else None ) last_norm = np.linalg.norm(self.concat_qpos_trajectories[-1] - current_qpos, ord=1) if verbose: print('last_norm', last_norm) if prev_qpos is not None: last_moving_dist = np.linalg.norm(prev_qpos - current_qpos, ord=1) if last_moving_dist < 0.001: prev_qpos_not_moving += 1 else: prev_qpos_not_moving = 0 if prev_qpos_not_moving > 10: if last_norm < atol * 10: return True else: print('Not moving for a long time (10 steps).') return False prev_qpos = current_qpos if last_norm < atol: return True for i, robot in enumerate(self.robots): robot.set_arm_joint_position_control( next_target[self.q_start_indices[i]:self.q_start_indices[i] + self.q_lengths[i]], gains=gains, set_gripper_control=True ) yield return False
[docs] def move(self, step_size: float = 1, gains: float = 0.3, atol: float = 0.03, timeout: float = 20, verbose: bool = False): try: for _ in self.set_control(step_size=step_size, gains=gains, atol=atol, timeout=timeout, verbose=verbose): self.client.step() return True except StopIteration as e: return e.value