Source code for concepts.simulator.pybullet.ikfast.ikfast_common

#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : ikfast_common.py
# Author : Jiayuan Mao
# Email  : maojiayuan@gmail.com
# Date   : 12/26/2022
#
# This file is part of HACL-PyTorch.
# Distributed under terms of the MIT license.

import itertools
from typing import Optional, Iterable, Tuple, List

import random
import numpy as np
import numpy.random as npr

from jacinle.logging import get_logger

from concepts.utils.rotationlib import mat2quat, quat2mat

from concepts.simulator.pybullet.world import BulletWorld
from concepts.simulator.pybullet.rotation_utils import quat_conjugate, quat_mul

logger = get_logger(__file__)


[docs]class IKFastWrapperBase(object):
[docs] def __init__( self, module, joint_ids: List[int], free_joint_ids: List[int] = tuple(), joints_lower: np.ndarray = None, joints_upper: np.ndarray = None, use_xyzw: bool = True, # PyBullet uses xyzw. max_attempts: int = 1000, fix_free_joint_positions: bool = False, shuffle_solutions: bool = False, sort_closest_solution: bool = False, ): self.module = module self.joint_ids = joint_ids self.free_joint_ids = free_joint_ids self.use_xyzw = use_xyzw self.max_attempts = max_attempts self.joints_lower = joints_lower self.joints_upper = joints_upper self.free_joints_lower = list() self.free_joints_upper = list() for i, joint_id in enumerate(self.joint_ids): if joint_id in free_joint_ids: self.free_joints_lower.append(joints_lower[i]) self.free_joints_upper.append(joints_upper[i]) self.free_joints_lower = np.array(self.free_joints_lower) self.free_joints_upper = np.array(self.free_joints_upper) self.fix_free_joint_positions = fix_free_joint_positions self.initial_free_joint_positions = self.get_current_free_joint_positions() self.shuffle_solutions = shuffle_solutions self.sort_closest_solution = sort_closest_solution
[docs] def get_current_joint_positions(self) -> np.ndarray: raise NotImplementedError()
[docs] def get_current_free_joint_positions(self) -> np.ndarray: raise NotImplementedError()
[docs] def fk(self, qpos: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: pos, mat = self.module.get_fk(list(qpos)) quat = mat2quat(mat) if self.use_xyzw: return pos, quat[[1, 2, 3, 0]] return pos, quat
[docs] def ik_internal(self, pos: np.ndarray, quat: np.ndarray, sampled: Optional[np.ndarray] = None) -> List[np.ndarray]: if self.use_xyzw: quat = quat[[3, 0, 1, 2]] mat = quat2mat(quat) if sampled is None: solutions = self.module.get_ik(mat.tolist(), pos.tolist()) else: solutions = self.module.get_ik(mat.tolist(), pos.tolist(), list(sampled)) if solutions is None: return list() return [np.array(solution) for solution in solutions]
[docs] def gen_ik(self, pos: np.ndarray, quat: np.ndarray, last_qpos: Optional[np.ndarray], max_attempts: Optional[int] = None, max_distance: float = float('inf'), verbose: bool = False) -> Iterable[np.ndarray]: if last_qpos is None: current_joint_positions = self.get_current_joint_positions() current_free_joint_positions = self.get_current_free_joint_positions() else: current_joint_positions = last_qpos current_free_joint_positions = [last_qpos[i] for i, joint_id in enumerate(self.joint_ids) if joint_id in self.free_joint_ids] if self.fix_free_joint_positions: generator = [self.initial_free_joint_positions] else: generator = itertools.chain([current_free_joint_positions], gen_uniform_sample_joints(self.free_joints_lower, self.free_joints_upper)) if max_attempts is not None: generator = itertools.islice(generator, max_attempts) else: generator = itertools.islice(generator, self.max_attempts) succeeded = False for sampled in generator: solutions = self.ik_internal(pos, quat, sampled) if self.shuffle_solutions: random.shuffle(solutions) sorted_solutions = list() for solution in solutions: # print('Checking solution: ', solution, 'lower', self.joints_lower, 'upper', self.joints_upper, check_joint_limits(solution, self.joints_lower, self.joints_upper)) if check_joint_limits(solution, self.joints_lower, self.joints_upper): if distance_fn(solution, current_joint_positions) < max_distance: succeeded = True # fk_pos, fk_quat = self.fk(solution.tolist()) # print('query (inside): ', pos, quat, 'solution: ', solution, 'fk', fk_pos, fk_quat, 'fk_diff', np.linalg.norm(fk_pos - pos), quat_mul(quat_conjugate(fk_quat), quat)[3]) sorted_solutions.append(solution) elif verbose: print(f'IK solution is too far from current joint positions: {solution} vs {current_joint_positions}') if self.sort_closest_solution: sorted_solutions.sort(key=lambda qpos: distance_fn(qpos, current_joint_positions)) yield from sorted_solutions if not succeeded and max_attempts is None: logger.warning(f'Failed to find IK solution for {pos} {quat} after {self.max_attempts} attempts.')
[docs]class IKFastWrapper(IKFastWrapperBase):
[docs] def __init__( self, world: BulletWorld, module, body_id, joint_ids: List[int], free_joint_ids: List[int] = tuple(), use_xyzw: bool = True, # PyBullet uses xyzw. max_attempts: int = 1000, fix_free_joint_positions: bool = False, shuffle_solutions: bool = False, sort_closest_solution: bool = False ): self.world = world self.module = module self.body_id = body_id joint_info = [self.world.get_joint_info_by_id(self.body_id, joint_id) for joint_id in joint_ids] joints_lower = np.array([info.joint_lower_limit for info in joint_info]) joints_upper = np.array([info.joint_upper_limit for info in joint_info]) super().__init__( module, joint_ids, free_joint_ids, joints_lower, joints_upper, use_xyzw, max_attempts, fix_free_joint_positions, shuffle_solutions, sort_closest_solution )
# assert len(self.free_joint_ids) + 6 == len(self.joint_ids)
[docs] def get_current_joint_positions(self) -> np.ndarray: return np.array([self.world.get_joint_state_by_id(self.body_id, joint_id).position for joint_id in self.joint_ids])
[docs] def get_current_free_joint_positions(self) -> np.ndarray: return np.array([self.world.get_joint_state_by_id(self.body_id, joint_id).position for joint_id in self.free_joint_ids])
[docs]def check_joint_limits(qpos: np.ndarray, lower_limits: np.ndarray, upper_limits: np.ndarray) -> bool: return np.all(np.logical_and(qpos >= lower_limits, qpos <= upper_limits))
[docs]def uniform_sample_joints(lower_limits: np.ndarray, upper_limits: np.ndarray) -> np.ndarray: return np.array([npr.uniform(lower, upper) for lower, upper in zip(lower_limits, upper_limits)])
[docs]def gen_uniform_sample_joints(lower_limits: np.ndarray, upper_limits: np.ndarray) -> Iterable[np.ndarray]: while True: yield uniform_sample_joints(lower_limits, upper_limits)
[docs]def random_select_solution(solutions: List[np.ndarray]) -> np.ndarray: return random.choice(solutions)
[docs]def distance_fn(qpos1: np.ndarray, qpos2: np.ndarray) -> float: return np.linalg.norm(np.array(qpos1) - np.array(qpos2), ord=2)
[docs]def closest_select_solution(solutions: List[np.ndarray], current_qpos: np.ndarray) -> np.ndarray: return min(solutions, key=lambda qpos: distance_fn(qpos, current_qpos))