forked from HansZ8/RoboJuDo
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmujoco_env.py
More file actions
155 lines (118 loc) · 5.35 KB
/
mujoco_env.py
File metadata and controls
155 lines (118 loc) · 5.35 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import logging
import time
import mujoco
import mujoco_viewer
import numpy as np
from robojudo.environment import Environment, env_registry
from robojudo.environment.env_cfgs import MujocoEnvCfg
from robojudo.environment.utils.mujoco_viz import MujocoVisualizer
from robojudo.utils.util_func import quat_rotate_inverse_np, quatToEuler
logger = logging.getLogger(__name__)
@env_registry.register
class MujocoEnv(Environment):
cfg_env: MujocoEnvCfg
def __init__(self, cfg_env: MujocoEnvCfg, device="cpu"):
super().__init__(cfg_env=cfg_env, device=device)
self.sim_duration = cfg_env.sim_duration
self.sim_dt = cfg_env.sim_dt
self.sim_decimation = cfg_env.sim_decimation
self.control_dt = self.sim_dt * self.sim_decimation
self.model = mujoco.MjModel.from_xml_path(cfg_env.xml) # pyright: ignore[reportAttributeAccessIssue]
self.model.opt.timestep = self.sim_dt
self.data = mujoco.MjData(self.model) # pyright: ignore[reportAttributeAccessIssue]
# mujoco.mj_resetDataKeyframe(self.model, self.data, 0)
mujoco.mj_step(self.model, self.data) # pyright: ignore[reportAttributeAccessIssue]
self.viewer = mujoco_viewer.MujocoViewer(
self.model,
self.data,
width=1200,
height=900,
hide_menus=True,
diable_key_callbacks=True,
)
self.viewer.cam.distance = 3.0
self.viewer.cam.elevation = -10.0
self.viewer.cam.azimuth = 180.0
# self.viewer._paused = True
if cfg_env.visualize_extras:
self.visualizer = MujocoVisualizer(self.viewer)
else:
self.visualizer = None
self.last_time = time.time()
self.update() # get initial state
def reborn(self, init_qpos=None):
if init_qpos is not None:
self.data.qpos[0:7] = init_qpos
self.data.qvel[:] = 0.0
self.data.ctrl[:] = 0.0
else:
mujoco.mj_resetDataKeyframe(self.model, self.data, 0) # pyright: ignore[reportAttributeAccessIssue]
mujoco.mj_forward(self.model, self.data) # pyright: ignore[reportAttributeAccessIssue]
def reset(self):
if self.born_place_align: # TODO: merge
self.born_place_align = False # disable during reset
self.update()
self.born_place_align = True # enable after reset
self.set_born_place()
self.update()
def set_gains(self, stiffness, damping):
assert len(stiffness) == self.num_dofs and len(damping) == self.num_dofs
self.stiffness = np.asarray(stiffness)
self.damping = np.asarray(damping)
def self_check(self):
pass
def set_born_place(self, quat: np.ndarray | None = None, pos: np.ndarray | None = None):
quat_ = self.base_quat if quat is None else quat
pos_ = self.base_pos if pos is None else pos
super().set_born_place(quat_, pos_)
def update(self, simple=False): # TODO: clean sensors in xml
"""simple: only update dof pos & vel"""
dof_pos = self.data.qpos.astype(np.float32)[-self.num_dofs :]
dof_vel = self.data.qvel.astype(np.float32)[-self.num_dofs :]
self._dof_pos = dof_pos.copy()
self._dof_vel = dof_vel.copy()
if simple:
return
quat = self.data.qpos.astype(np.float32)[3:7][[1, 2, 3, 0]]
ang_vel = self.data.qvel.astype(np.float32)[3:6]
base_pos = self.data.qpos.astype(np.float32)[:3]
lin_vel = self.data.qvel.astype(np.float32)[0:3]
if self.born_place_align:
quat, base_pos = self.base_align.align_transform(quat, base_pos)
lin_vel = quat_rotate_inverse_np(quat, lin_vel)
rpy = quatToEuler(quat)
self._base_rpy = rpy.copy()
self._base_quat = quat.copy()
self._base_ang_vel = ang_vel.copy()
self._base_pos = base_pos.copy()
self._base_lin_vel = lin_vel.copy()
if self.update_with_fk:
fk_info = self.fk()
self._fk_info = fk_info.copy()
self._torso_ang_vel = fk_info[self._torso_name]["ang_vel"]
self._torso_quat = fk_info[self._torso_name]["quat"]
self._torso_pos = fk_info[self._torso_name]["pos"]
def step(self, pd_target, hand_pose=None):
assert len(pd_target) == self.num_dofs, "pd_target len should be num_dofs of env"
if hand_pose is not None:
logger.info("Hand pose-->", hand_pose)
self.viewer.cam.lookat = self.data.qpos.astype(np.float32)[:3]
if self.viewer.is_alive:
self.viewer.render()
for _ in range(self.sim_decimation):
torque = (pd_target - self.dof_pos) * self.stiffness - self.dof_vel * self.damping
torque = np.clip(torque, -self.torque_limits, self.torque_limits)
self.data.ctrl = torque
mujoco.mj_step(self.model, self.data) # pyright: ignore[reportAttributeAccessIssue]
self.update(simple=True)
self.update(simple=False)
def shutdown(self):
self.viewer.close()
if __name__ == "__main__":
from robojudo.config.g1.env.g1_mujuco_env_cfg import G1MujocoEnvCfg
mujoco_env = MujocoEnv(cfg_env=G1MujocoEnvCfg())
mujoco_env.viewer._paused = False
while True:
# mujoco_env.update()
mujoco_env.step(np.zeros(mujoco_env.num_dofs))
time.sleep(0.02)