#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File : ikfast_common.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 12/26/2022
#
# This file is part of HACL-PyTorch.
# Distributed under terms of the MIT license.
import itertools
from typing import Optional, Iterable, Tuple, List
import random
import numpy as np
import numpy.random as npr
from jacinle.logging import get_logger
from concepts.utils.rotationlib import mat2quat, quat2mat
from concepts.simulator.pybullet.world import BulletWorld
from concepts.simulator.pybullet.rotation_utils import quat_conjugate, quat_mul
logger = get_logger(__file__)
[docs]class IKFastWrapperBase(object):
[docs] def __init__(
self, module,
joint_ids: List[int], free_joint_ids: List[int] = tuple(),
joints_lower: np.ndarray = None, joints_upper: np.ndarray = None,
use_xyzw: bool = True, # PyBullet uses xyzw.
max_attempts: int = 1000,
fix_free_joint_positions: bool = False,
shuffle_solutions: bool = False,
sort_closest_solution: bool = False,
):
self.module = module
self.joint_ids = joint_ids
self.free_joint_ids = free_joint_ids
self.use_xyzw = use_xyzw
self.max_attempts = max_attempts
self.joints_lower = joints_lower
self.joints_upper = joints_upper
self.free_joints_lower = list()
self.free_joints_upper = list()
for i, joint_id in enumerate(self.joint_ids):
if joint_id in free_joint_ids:
self.free_joints_lower.append(joints_lower[i])
self.free_joints_upper.append(joints_upper[i])
self.free_joints_lower = np.array(self.free_joints_lower)
self.free_joints_upper = np.array(self.free_joints_upper)
self.fix_free_joint_positions = fix_free_joint_positions
self.initial_free_joint_positions = self.get_current_free_joint_positions()
self.shuffle_solutions = shuffle_solutions
self.sort_closest_solution = sort_closest_solution
[docs] def get_current_joint_positions(self) -> np.ndarray:
raise NotImplementedError()
[docs] def get_current_free_joint_positions(self) -> np.ndarray:
raise NotImplementedError()
[docs] def fk(self, qpos: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
pos, mat = self.module.get_fk(list(qpos))
quat = mat2quat(mat)
if self.use_xyzw:
return pos, quat[[1, 2, 3, 0]]
return pos, quat
[docs] def ik_internal(self, pos: np.ndarray, quat: np.ndarray, sampled: Optional[np.ndarray] = None) -> List[np.ndarray]:
if self.use_xyzw:
quat = quat[[3, 0, 1, 2]]
mat = quat2mat(quat)
if sampled is None:
solutions = self.module.get_ik(mat.tolist(), pos.tolist())
else:
solutions = self.module.get_ik(mat.tolist(), pos.tolist(), list(sampled))
if solutions is None:
return list()
return [np.array(solution) for solution in solutions]
[docs] def gen_ik(self, pos: np.ndarray, quat: np.ndarray, last_qpos: Optional[np.ndarray], max_attempts: Optional[int] = None, max_distance: float = float('inf'), verbose: bool = False) -> Iterable[np.ndarray]:
if last_qpos is None:
current_joint_positions = self.get_current_joint_positions()
current_free_joint_positions = self.get_current_free_joint_positions()
else:
current_joint_positions = last_qpos
current_free_joint_positions = [last_qpos[i] for i, joint_id in enumerate(self.joint_ids) if joint_id in self.free_joint_ids]
if self.fix_free_joint_positions:
generator = [self.initial_free_joint_positions]
else:
generator = itertools.chain([current_free_joint_positions], gen_uniform_sample_joints(self.free_joints_lower, self.free_joints_upper))
if max_attempts is not None:
generator = itertools.islice(generator, max_attempts)
else:
generator = itertools.islice(generator, self.max_attempts)
succeeded = False
for sampled in generator:
solutions = self.ik_internal(pos, quat, sampled)
if self.shuffle_solutions:
random.shuffle(solutions)
sorted_solutions = list()
for solution in solutions:
# print('Checking solution: ', solution, 'lower', self.joints_lower, 'upper', self.joints_upper, check_joint_limits(solution, self.joints_lower, self.joints_upper))
if check_joint_limits(solution, self.joints_lower, self.joints_upper):
if distance_fn(solution, current_joint_positions) < max_distance:
succeeded = True
# fk_pos, fk_quat = self.fk(solution.tolist())
# print('query (inside): ', pos, quat, 'solution: ', solution, 'fk', fk_pos, fk_quat, 'fk_diff', np.linalg.norm(fk_pos - pos), quat_mul(quat_conjugate(fk_quat), quat)[3])
sorted_solutions.append(solution)
elif verbose:
print(f'IK solution is too far from current joint positions: {solution} vs {current_joint_positions}')
if self.sort_closest_solution:
sorted_solutions.sort(key=lambda qpos: distance_fn(qpos, current_joint_positions))
yield from sorted_solutions
if not succeeded and max_attempts is None:
logger.warning(f'Failed to find IK solution for {pos} {quat} after {self.max_attempts} attempts.')
[docs]class IKFastWrapper(IKFastWrapperBase):
[docs] def __init__(
self,
world: BulletWorld, module,
body_id, joint_ids: List[int], free_joint_ids: List[int] = tuple(),
use_xyzw: bool = True, # PyBullet uses xyzw.
max_attempts: int = 1000,
fix_free_joint_positions: bool = False,
shuffle_solutions: bool = False,
sort_closest_solution: bool = False
):
self.world = world
self.module = module
self.body_id = body_id
joint_info = [self.world.get_joint_info_by_id(self.body_id, joint_id) for joint_id in joint_ids]
joints_lower = np.array([info.joint_lower_limit for info in joint_info])
joints_upper = np.array([info.joint_upper_limit for info in joint_info])
super().__init__(
module, joint_ids, free_joint_ids,
joints_lower, joints_upper,
use_xyzw, max_attempts,
fix_free_joint_positions, shuffle_solutions, sort_closest_solution
)
# assert len(self.free_joint_ids) + 6 == len(self.joint_ids)
[docs] def get_current_joint_positions(self) -> np.ndarray:
return np.array([self.world.get_joint_state_by_id(self.body_id, joint_id).position for joint_id in self.joint_ids])
[docs] def get_current_free_joint_positions(self) -> np.ndarray:
return np.array([self.world.get_joint_state_by_id(self.body_id, joint_id).position for joint_id in self.free_joint_ids])
[docs]def check_joint_limits(qpos: np.ndarray, lower_limits: np.ndarray, upper_limits: np.ndarray) -> bool:
return np.all(np.logical_and(qpos >= lower_limits, qpos <= upper_limits))
[docs]def random_select_solution(solutions: List[np.ndarray]) -> np.ndarray:
return random.choice(solutions)
[docs]def distance_fn(qpos1: np.ndarray, qpos2: np.ndarray) -> float:
return np.linalg.norm(np.array(qpos1) - np.array(qpos2), ord=2)
[docs]def closest_select_solution(solutions: List[np.ndarray], current_qpos: np.ndarray) -> np.ndarray:
return min(solutions, key=lambda qpos: distance_fn(qpos, current_qpos))