#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File : client.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 02/17/2020
#
# This file is part of Project Concepts.
# Distributed under terms of the MIT license.
import time
import os.path as osp
import tempfile
import threading
import collections
import functools
import contextlib
import warnings
from typing import Any, Optional, Union, Tuple, List, Dict
import numpy as np
import pybullet as p
import pybullet_data
import jacinle
import jacinle.io as io
from concepts.simulator.pybullet.world import BulletWorld
from concepts.utils.typing_utils import Vec3f, Vec4f
__all__ = ['BulletClient']
[docs]
class BulletP(object):
[docs]
def __init__(self, client_id=None):
self.client_id = client_id
[docs]
def set_client_id(self, client_id):
self.client_id = client_id
def __getattr__(self, item):
assert self.client_id is not None
func = getattr(p, item)
if callable(func):
return functools.partial(func, physicsClientId=self.client_id)
return func
[docs]
class MouseEvent(collections.namedtuple('_MouseEvent', ['eventType', 'mousePosX', 'mousePosY', 'buttonIndex', 'buttonState'])):
pass
[docs]
class BulletClient(object):
"""A wrapper for the pybullet client."""
DEFAULT_ENGINE_PARAMETERS = {'numSolverIterations': 10}
DEFAULT_FPS = 240
DEFAULT_GRAVITY = (0, 0, -9.8)
DEFAULT_ASSETS_ROOT = osp.join(osp.dirname(osp.dirname(osp.dirname(osp.abspath(__file__)))), 'assets')
[docs]
def __init__(
self,
assets_root: Optional[str] = None,
is_gui: bool = False,
*,
fps: Optional[int] = None,
render_fps: Optional[int] = None,
gravity: Optional[Union[Tuple[float], float]] = None,
connect: bool = True,
client_id: int = -1,
width: Optional[int] = 960,
height: Optional[int] = 960,
additional_title: Optional[str] = None,
save_video: Optional[str] = None,
enable_realtime_rendering: Optional[bool] = None,
enable_debug_gui: bool = False,
engine_parameters: Optional[Dict[str, Any]] = None,
):
"""Initialize the BulletClient.
Args:
assets_root: the root directory of the assets (by default it is the `assets` directory in the `concepts` package).
is_gui: whether to enable the GUI.
fps: the physics simulation FPS (default: 120).
render_fps: the rendering FPS (default: 120).
gravity: the gravity vector (default: (0, 0, -9.8)).
connect: whether to connect to the server immediately.
client_id: the client id to connect to. If this is set to -1, a new client id will be created.
save_video: the path to save the video.
width: the width of the window.
height: the height of the window.
additional_title: the additional title of the window.
enable_debug_gui: whether to enable the debug GUI.
enable_realtime_rendering: whether to enable realtime rendering (default: True if render_fps is set, otherwise False).
engine_parameters: additional engine parameters.
"""
if not is_gui:
render_fps = 0
self.is_gui = is_gui
self.fps = fps if fps is not None else type(self).DEFAULT_FPS
self.render_fps = render_fps if render_fps is not None else self.fps
if not self.is_gui:
self.render_fps = 0
self.gravity = canonize_gravity(gravity if gravity is not None else type(self).DEFAULT_GRAVITY)
self.engine_parameters = engine_parameters
self.client_id = None
self.assets_root = assets_root if assets_root is not None else type(self).DEFAULT_ASSETS_ROOT
self.save_video = save_video
self.additional_title = additional_title
self.enable_realtime_rendering = enable_realtime_rendering if enable_realtime_rendering is not None else self.render_fps > 0
self.enable_debug_gui = enable_debug_gui
self.w = BulletWorld()
self.p = BulletP()
self.width = width
self.height = height
self.debug_items = dict()
if client_id == -1:
if connect:
self.connect()
else:
self.client_id = client_id
self.w.set_client_id(self.client_id)
self.p.set_client_id(self.client_id)
self._nonphysics_step_callbacks = []
debug_items: Dict[str, Union[int, Tuple[int, ...]]]
"""The debug items that are added to the world. The key is the name of the item, and the value is the item id."""
@property
def world(self):
"""Alias for `self.w`."""
return self.w
[docs]
@contextlib.contextmanager
def with_fps(self, fps: Optional[int] = None, render_fps: Optional[int] = None, realtime_rendering: Optional[bool] = None):
current_fps, current_render_fps, current_realtime_rendering = self.fps, self.render_fps, self.enable_realtime_rendering
if realtime_rendering is None and render_fps is not None: # if render_fps is set, we assume realtime rendering is enabled.
realtime_rendering = True
if fps is not None:
self.fps = fps
if render_fps is not None:
self.render_fps = render_fps
elif fps is not None:
self.render_fps = fps
if realtime_rendering is not None:
self.enable_realtime_rendering = realtime_rendering
yield
self.fps, self.render_fps, self.enable_realtime_rendering = current_fps, current_render_fps, current_realtime_rendering
[docs]
def set_rendering_fps(self, render_fps: Optional[int] = None):
if render_fps is None:
self.render_fps = self.fps
else:
self.render_fps = render_fps
self.enable_realtime_rendering = True
[docs]
def set_enable_realtime_rendering(self, enable_realtime_rendering: Optional[bool] = None):
if enable_realtime_rendering is None:
self.enable_realtime_rendering = self.render_fps > 0
else:
self.enable_realtime_rendering = enable_realtime_rendering
[docs]
def connect(self, suppress_warnings: bool = True):
if suppress_warnings:
with jacinle.suppress_stdout():
self._connect()
else:
self._connect()
def _connect(self):
options = ''
if self.save_video:
options += f'--mp4="{self.save_video}" --mp4fps=60'
if self.width is not None:
options += ' --width={}'.format(self.width)
if self.height is not None:
options += ' --height={}'.format(self.height)
self.client_id = p.connect(p.GUI if self.is_gui else p.DIRECT, options=options)
# p.configureDebugVisualizer(p.COV_ENABLE_SINGLE_STEP_RENDERING, 0, physicsClientId=self.client_id)
if self.save_video:
p.configureDebugVisualizer(p.COV_ENABLE_SINGLE_STEP_RENDERING, 1, physicsClientId=self.client_id)
if self.is_gui and self.enable_realtime_rendering:
p.configureDebugVisualizer(p.COV_ENABLE_SINGLE_STEP_RENDERING, 1, physicsClientId=self.client_id)
# Disable the cache of the URDF files. This would allow us to load JIT URDF files.
p.setPhysicsEngineParameter(enableFileCaching=0, physicsClientId=self.client_id)
p.setAdditionalSearchPath(pybullet_data.getDataPath(), physicsClientId=self.client_id)
if self.engine_parameters is not None:
p.setPhysicsEngineParameter(physicsClientId=self.client_id, **self.engine_parameters)
else:
p.setPhysicsEngineParameter(physicsClientId=self.client_id, **type(self).DEFAULT_ENGINE_PARAMETERS)
p.configureDebugVisualizer(p.COV_ENABLE_MOUSE_PICKING, 0, physicsClientId=self.client_id)
# Disable the GUI (e.g., synthetic camera views and parameters).
p.configureDebugVisualizer(p.COV_ENABLE_GUI, self.enable_debug_gui, physicsClientId=self.client_id)
if self.enable_debug_gui:
p.configureDebugVisualizer(p.COV_ENABLE_RGB_BUFFER_PREVIEW, 1, physicsClientId=self.client_id)
p.configureDebugVisualizer(p.COV_ENABLE_DEPTH_BUFFER_PREVIEW, 1, physicsClientId=self.client_id)
p.configureDebugVisualizer(p.COV_ENABLE_SEGMENTATION_MARK_PREVIEW, 1, physicsClientId=self.client_id)
p.configureDebugVisualizer(p.COV_ENABLE_RENDERING, 1, physicsClientId=self.client_id)
if self.assets_root is not None:
file_io = p.loadPlugin('fileIOPlugin', physicsClientId=self.client_id)
if file_io >= 0:
p.executePluginCommand(file_io, textArgument=self.assets_root, intArgs=[p.AddFileIOAction], physicsClientId=self.client_id)
else:
raise RuntimeError('pybullet: cannot load FileIO!')
p.setAdditionalSearchPath(self.assets_root, physicsClientId=self.client_id)
# NB(Jiayuan Mao @ 10/04): also add the temp dir to the asset path so that we can load JIT URDF files.
p.setAdditionalSearchPath(tempfile.gettempdir(), physicsClientId=self.client_id)
p.setGravity(*self.gravity, physicsClientId=self.client_id)
p.setTimeStep(1.0 / self.fps, physicsClientId=self.client_id)
self.w.set_client_id(self.client_id)
self.p.set_client_id(self.client_id)
# Set the title of the window.
if self.additional_title is not None:
p.addUserDebugText(self.additional_title, [0, 0, 1], [0, 0, 0], parentObjectUniqueId=0, physicsClientId=self.client_id)
[docs]
def is_connected(self):
return p.isConnected(physicsClientId=self.client_id)
[docs]
def has_gui(self):
return p.getConnectionInfo(physicsClientId=self.client_id)['connectionMethod'] == p.GUI
[docs]
def disconnect(self):
p.disconnect(physicsClientId=self.client_id)
[docs]
def reset_world(self):
p.resetSimulation(physicsClientId=self.client_id)
p.setGravity(*self.gravity, physicsClientId=self.client_id)
p.setTimeStep(1.0 / self.fps, physicsClientId=self.client_id)
# Should also remember to reset the world record.
self.w = BulletWorld()
self.w.set_client_id(self.client_id)
[docs]
def set_rendering(self, enable: bool):
p.configureDebugVisualizer(p.COV_ENABLE_RENDERING, int(enable), physicsClientId=self.client_id)
[docs]
@contextlib.contextmanager
def disable_rendering(self, disable_rendering: bool = True, reset: bool = False, suppress_stdout: bool = False):
if reset:
self.reset_world()
with jacinle.cond_with(
jacinle.suppress_stdout(),
suppress_stdout
):
if disable_rendering:
p.configureDebugVisualizer(p.COV_ENABLE_RENDERING, 0, physicsClientId=self.client_id)
yield
if disable_rendering and self.is_connected():
p.configureDebugVisualizer(p.COV_ENABLE_RENDERING, 1, physicsClientId=self.client_id)
[docs]
@contextlib.contextmanager
def disable_stdout(self, activate: bool = True):
with jacinle.cond_with(
jacinle.suppress_stdout(),
activate
):
yield
[docs]
@contextlib.contextmanager
def disable_world_update(self):
"""Temporarily disable the world update. Specifically, when loading a new model, the world object `self.w` will not be updated.
This function also disables rendering of the pybullet debug renderer. Thus, this functionality is useful when loading a large number of models."""
warnings.warn('`disable_world_update` is deprecated. Use `disable_rendering` instead.', DeprecationWarning)
with self.disable_rendering(suppress_stdout=True):
yield
[docs]
def step(self, steps=1, realtime_rendering: Optional[bool] = None):
clock = None
actual_reatime_rendering = realtime_rendering if realtime_rendering is not None else self.enable_realtime_rendering
if actual_reatime_rendering:
if self.render_fps > 0:
clock = jacinle.Clock(1 / self.render_fps)
for i in range(steps):
p.stepSimulation(physicsClientId=self.client_id)
self._nonphysics_step()
if clock is not None:
clock.tick()
[docs]
def step_until_stable(self, max_steps: int = int(1e6), velocity_threshold: float = 1e-3, angular_velocity_threshold: float = 1e-3, joint_velocity_threshold: float = 1e-3):
for _ in range(max_steps):
p.stepSimulation(physicsClientId=self.client_id)
self._nonphysics_step()
if self.is_stable(velocity_threshold=velocity_threshold, angular_velocity_threshold=angular_velocity_threshold, joint_velocity_threshold=joint_velocity_threshold):
break
[docs]
def is_stable(self, velocity_threshold: float = 1e-3, angular_velocity_threshold: float = 1e-3, joint_velocity_threshold: float = 1e-3):
for body_id in self.world.body_names.int_to_string:
vel, ang_vel = p.getBaseVelocity(body_id, physicsClientId=self.client_id)
if any(abs(v) > velocity_threshold for v in vel):
return False
if any(abs(v) > angular_velocity_threshold for v in ang_vel):
return False
for body_id, joint_id in self.world.joint_names.int_to_string:
vel = p.getJointState(body_id, joint_id, physicsClientId=self.client_id)[1]
if isinstance(vel, (list, tuple)):
if any(abs(v) > joint_velocity_threshold for v in vel):
return False
else:
if abs(vel) > joint_velocity_threshold:
return False
return True
def _nonphysics_step(self):
for cb in self._nonphysics_step_callbacks:
cb()
[docs]
def add_nonphysics_step_callback(self, cb):
self._nonphysics_step_callbacks.append(cb)
[docs]
def remove_nonphysics_step_callback(self, cb):
self._nonphysics_step_callbacks.remove(cb)
[docs]
def load_urdf(self, xml_path, pos=(0, 0, 0), quat=(0, 0, 0, 1), body_name: Optional[str] = None, group: Optional[str] = '__UNSET__', static=False, scale: float = 1.0, rgba=None, notify_world_update=True) -> int:
xml_path = self._canonize_asset_path(xml_path)
pos, quat = canonize_default_pos_and_quat(pos, quat)
try:
ret = p.loadURDF(xml_path, pos, quat, useFixedBase=static, globalScaling=scale, physicsClientId=self.client_id, flags=p.URDF_USE_SELF_COLLISION)
except p.error as e:
raise RuntimeError('pybullet: cannot load URDF file: {}'.format(xml_path)) from e
if notify_world_update:
if group == '__UNSET__':
group = 'fixed' if static else 'rigid'
self.w.notify_update(ret, body_name=body_name, group=group)
if rgba is not None:
self.w.change_visual_color(ret, rgba=rgba)
return ret
[docs]
def load_urdf_template(self, xml_path: str, replaces: Dict[str, Any], pos=None, quat=None, **kwargs) -> int:
xml_path = self._canonize_asset_path(xml_path)
with open(xml_path) as f:
xml_content = f.read()
for k, v in sorted(replaces.items(), key=lambda x: len(x[0]), reverse=True):
if isinstance(v, (tuple, list)):
for i in range(len(v)):
xml_content = xml_content.replace(k + str(i), str(v[i]))
else:
xml_content = xml_content.replace(k, str(v))
with io.tempfile('w', '.xml') as f:
f.write(xml_content)
f.flush()
return self.load_urdf(f.name, pos=pos, quat=quat, **kwargs)
[docs]
def load_sdf(self, xml_path, scale=1.0, notify_world_update=True) -> int:
xml_path = self._canonize_asset_path(xml_path)
ret = p.loadSDF(xml_path, globalScaling=scale, physicsClientId=self.client_id)
if notify_world_update:
self.w.notify_update(ret)
return ret
[docs]
def load_mjcf(self, xml_path, pos=(0, 0, 0), quat=(0, 0, 0, 1), body_name=None, group='__UNSET__', static=False, notify_world_update=True) -> int:
xml_path = self._canonize_asset_path(xml_path)
pos, quat = canonize_default_pos_and_quat(pos, quat)
ret = p.loadMJCF(xml_path, pos, quat, useFixedBase=static, physicsClientId=self.client_id, flags=p.MJCF_COLORS_FROM_FILE)
if notify_world_update:
if group == '__UNSET__':
group = 'fixed' if static else 'rigid'
self.w.notify_update(ret, body_name=body_name, group=group)
return ret
[docs]
def loads_mjcf(self, xml_content, pos=None, quat=None, save_to=None, **kwargs) -> int:
if not isinstance(xml_content, str):
xml_content = io.dumps_xml(xml_content)
if save_to is not None:
with open(save_to, 'w') as f:
f.write(xml_content)
with io.tempfile('w', '.xml') as f:
f.write(xml_content)
f.flush()
return self.load_mjcf(f.name, pos=pos, quat=quat, **kwargs)
[docs]
def remove_body(self, body_id):
return p.removeBody(body_id, physicsClientId=self.client_id)
def _canonize_asset_path(self, path):
return path.replace('assets://', self.assets_root + '/')
[docs]
def get_mouse_events(self) -> List[MouseEvent]:
return list(MouseEvent(*event) for event in self.p.getMouseEvents())
[docs]
def update_viewer(self):
self.p.getMouseEvents()
[docs]
def update_viewer_twice(self):
self.update_viewer()
time.sleep(0.1)
self.update_viewer()
[docs]
def wait_for_duration(self, duration):
t0 = time.time()
while time.time() - t0 <= duration:
self.update_viewer()
[docs]
def wait_forever(self):
print(jacinle.colored('Entering the infinite loop. Press Ctrl+C to exit.', 'yellow'))
try:
while True:
self.update_viewer()
except KeyboardInterrupt:
print(jacinle.colored('Ctrl+C detected. Exiting...', 'yellow'))
pass
[docs]
def wait_for_user(self, message='Press enter to continue...'):
import platform
try:
message = jacinle.colored('Entering the infinite loop. Enter a command to continue. Enter ipdb to enter the debugger and exit to force quit. User message:', 'yellow') + '\n' + message
if self.has_gui() and platform.system() == 'Darwin':
# OS X doesn't multi-thread the OpenGL visualizer
rv = self._threaded_input(message)
else:
rv = input(message)
if rv.strip().lower() == 'ipdb':
import ipdb
ipdb.set_trace()
elif rv.strip().lower() == 'exit':
import sys
sys.exit(0)
return rv
except KeyboardInterrupt:
return None
[docs]
def timeout(self, duration: float):
for _ in range(int(duration * self.fps)):
yield
[docs]
def absolute_timeout(self, duration: float):
return jacinle.timeout(duration, fps=self.fps)
def _threaded_input(self, *args, **kwargs):
# OS X doesn't multi-thread the OpenGL visualizer
data = []
thread = threading.Thread(target=lambda: data.append(input(*args, **kwargs)), args=[])
thread.start()
try:
while thread.is_alive():
self.update_viewer()
finally:
thread.join()
return data[-1]
[docs]
def add_debug_line(self, start_pos, end_pos, color, name=None, life_time=0) -> int:
rv = p.addUserDebugLine(start_pos, end_pos, color, life_time, physicsClientId=self.client_id)
return self.register_debug_item(name, rv)
[docs]
def add_debug_text(self, text, pos, color, name=None, life_time=0) -> int:
rv = p.addUserDebugText(text, pos, color, life_time, physicsClientId=self.client_id)
return self.register_debug_item(name, rv)
[docs]
def add_debug_ray(self, start_pos, delta, color, length: float = 1.0,name=None, life_time=0) -> int:
rv = list()
rv.append(p.addUserDebugLine(start_pos, start_pos + np.asarray(delta) * length, color, life_time, physicsClientId=self.client_id))
rv.extend(self.add_debug_cube(start_pos, (0.05, 0.05, 0.05), color, life_time=life_time))
return self.register_debug_item(name, tuple(rv))
[docs]
def add_debug_cube(self, center, extent, color, name=None, life_time=0) -> Tuple[int, ...]:
rv = list()
center = np.asarray(center)
extent = np.asarray(extent)
min_point = center - extent / 2
max_point = center + extent / 2
edges = [
(min_point, min_point + np.array([extent[0], 0, 0])),
(min_point, min_point + np.array([0, extent[1], 0])),
(min_point, min_point + np.array([0, 0, extent[2]])),
(min_point + np.array([0, extent[1], 0]), min_point + np.array([extent[0], extent[1], 0])),
(min_point + np.array([0, extent[1], 0]), min_point + np.array([0, extent[1], extent[2]])),
(min_point + np.array([extent[0], 0, 0]), min_point + np.array([extent[0], extent[1], 0])),
(min_point + np.array([extent[0], 0, 0]), min_point + np.array([extent[0], 0, extent[2]])),
(min_point + np.array([0, 0, extent[2]]), min_point + np.array([0, extent[1], extent[2]])),
(min_point + np.array([0, 0, extent[2]]), min_point + np.array([extent[0], 0, extent[2]])),
(min_point + np.array([extent[0], extent[1], 0]), max_point),
(min_point + np.array([0, extent[1], extent[2]]), max_point),
(min_point + np.array([extent[0], 0, extent[2]]), max_point),
]
for edge in edges:
rv.append(self.add_debug_line(edge[0], edge[1], color, life_time=life_time))
return self.register_debug_item(name, tuple(rv))
[docs]
def register_debug_item(self, name: Optional[str], item_id: Union[int, Tuple[int, ...]]) -> Union[int, Tuple[int, ...]]:
if name is not None:
if name in self.debug_items:
self.remove_debug_item(name)
self.debug_items[name] = item_id
return item_id
[docs]
def remove_debug_item(self, item_id: Union[str, int, Tuple[int, ...]]):
if isinstance(item_id, str):
item_id = self.debug_items[item_id]
if isinstance(item_id, (tuple, list)):
for i in item_id:
p.removeUserDebugItem(i, physicsClientId=self.client_id)
else:
return p.removeUserDebugItem(item_id, physicsClientId=self.client_id)
[docs]
def add_debug_coordinate_system(self, pos, principle_axes, size: float = 0.1, name=None, life_time=0) -> Tuple[int, ...]:
items = list()
for i, axis in enumerate(principle_axes):
items.append(self.add_debug_line(pos, pos + size * axis, color=(i == 0, i == 1, i == 2), life_time=life_time))
rv = tuple(items)
return self.register_debug_item(name, rv)
[docs]
def canonize_gravity(gravity):
if isinstance(gravity, (int, float)):
return (0, 0, gravity)
else:
gravity = tuple(gravity)
assert len(gravity) == 3
return gravity
[docs]
def canonize_default_pos_and_quat(pos: Optional[Vec3f], quat: Optional[Vec4f]):
if pos is None:
pos = (0, 0, 0)
if quat is None:
quat = (0, 0, 0, 1)
return pos, quat