Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Commit
·
8fd7f49
1
Parent(s):
db1bb68
fixed bug with the collision objective. Created fast collision objective. Also added this to the original library.
Browse files- app.py +45 -101
- collision_objectives.py +83 -0
- requirements.txt +1 -1
app.py
CHANGED
@@ -30,6 +30,10 @@ from jax_ik.objectives import (
|
|
30 |
DistanceObjTraj,
|
31 |
SphereCollisionPenaltyObjTraj,
|
32 |
)
|
|
|
|
|
|
|
|
|
33 |
|
34 |
def download_and_setup_files():
|
35 |
os.makedirs("files", exist_ok=True)
|
@@ -78,13 +82,6 @@ class IKServer:
|
|
78 |
self.max_num_steps_global = 20000
|
79 |
self.max_learning_rate = 5.0
|
80 |
|
81 |
-
# Caches for solver reuse (LRU)
|
82 |
-
self.solver_cache = {}
|
83 |
-
self.urdf_solver_cache = {}
|
84 |
-
self.cache_access_order = []
|
85 |
-
self.urdf_cache_access_order = []
|
86 |
-
self.max_cache_size = 5
|
87 |
-
|
88 |
# Animation buffers
|
89 |
self.animation_frames_agent = []
|
90 |
self.animation_frames_urdf = []
|
@@ -145,23 +142,6 @@ class IKServer:
|
|
145 |
allowed = self.urdf_available_bones if is_urdf else self.available_bones
|
146 |
return [b for b in bones if b in allowed]
|
147 |
|
148 |
-
def _cache_key(self, bones, num_steps):
|
149 |
-
return tuple(sorted(bones)) + (int(num_steps),)
|
150 |
-
|
151 |
-
def _evict_lru(self, is_urdf=False):
|
152 |
-
cache = self.urdf_solver_cache if is_urdf else self.solver_cache
|
153 |
-
order = self.urdf_cache_access_order if is_urdf else self.cache_access_order
|
154 |
-
if not order:
|
155 |
-
return
|
156 |
-
k = order.pop(0)
|
157 |
-
cache.pop(k, None)
|
158 |
-
gc.collect()
|
159 |
-
|
160 |
-
def _invalidate_caches(self):
|
161 |
-
self.solver_cache.clear(); self.urdf_solver_cache.clear()
|
162 |
-
self.cache_access_order.clear(); self.urdf_cache_access_order.clear()
|
163 |
-
gc.collect()
|
164 |
-
|
165 |
def _create_solver(self, bones, is_urdf, num_steps):
|
166 |
if is_urdf:
|
167 |
return InverseKinematicsSolver(
|
@@ -188,29 +168,6 @@ class IKServer:
|
|
188 |
compute_sdf=False,
|
189 |
)
|
190 |
|
191 |
-
def _get_solver(self, bones, is_urdf=False, num_steps=None):
|
192 |
-
if num_steps is None:
|
193 |
-
num_steps = self.args.num_steps
|
194 |
-
num_steps = self._sanitize_num_steps(num_steps)
|
195 |
-
key = self._cache_key(bones, num_steps)
|
196 |
-
cache = self.urdf_solver_cache if is_urdf else self.solver_cache
|
197 |
-
order = self.urdf_cache_access_order if is_urdf else self.cache_access_order
|
198 |
-
if key in cache:
|
199 |
-
if key in order: order.remove(key)
|
200 |
-
order.append(key)
|
201 |
-
return cache[key]
|
202 |
-
if len(cache) >= self.max_cache_size:
|
203 |
-
self._evict_lru(is_urdf)
|
204 |
-
try:
|
205 |
-
solver = self._create_solver(bones, is_urdf, num_steps)
|
206 |
-
except Exception as e:
|
207 |
-
logger.warning(f"Primary solver creation failure (is_urdf={is_urdf}): {e}; using defaults")
|
208 |
-
bones = self.urdf_default_controlled_bones if is_urdf else self.default_controlled_bones
|
209 |
-
solver = self._create_solver(bones, is_urdf, num_steps)
|
210 |
-
cache[key] = solver
|
211 |
-
order.append(key)
|
212 |
-
return solver
|
213 |
-
|
214 |
def _rebuild_solver_safe(self, is_urdf=False, force_defaults=False):
|
215 |
try:
|
216 |
if is_urdf:
|
@@ -232,8 +189,7 @@ class IKServer:
|
|
232 |
return False
|
233 |
|
234 |
def _attempt_solver_recovery(self, is_urdf=False):
|
235 |
-
logger.warning("Attempting full solver recovery (
|
236 |
-
self._invalidate_caches()
|
237 |
if is_urdf:
|
238 |
self.urdf_current_controlled_bones = self.urdf_default_controlled_bones.copy()
|
239 |
return self._rebuild_solver_safe(True, force_defaults=True)
|
@@ -291,7 +247,8 @@ class IKServer:
|
|
291 |
self.selectable_bones = [b for b in self.available_bones if b in self.bounds_dict]
|
292 |
self.current_controlled_bones = self.default_controlled_bones.copy()
|
293 |
self.current_end_effector = self.default_end_effector
|
294 |
-
|
|
|
295 |
self.initial_rotations = np.zeros(len(self.solver.controlled_bones) * 3, dtype=np.float32)
|
296 |
self.best_angles = self.initial_rotations.copy()
|
297 |
self.mesh_data = load_mesh_data_from_gltf(self.args.gltf_file, self.solver.fk_solver)
|
@@ -314,7 +271,7 @@ class IKServer:
|
|
314 |
self.urdf_selectable_bones = list(self.urdf_available_bones)
|
315 |
self.urdf_current_controlled_bones = self.urdf_default_controlled_bones.copy()
|
316 |
self.urdf_current_end_effector = self.urdf_default_end_effector
|
317 |
-
self.urdf_solver = self.
|
318 |
self.urdf_initial_rotations = np.zeros(len(self.urdf_solver.controlled_bones) * 3, dtype=np.float32)
|
319 |
self.urdf_best_angles = self.urdf_initial_rotations.copy()
|
320 |
self.urdf_mesh_data = load_mesh_data_from_urdf(self.urdf_file, self.urdf_solver.fk_solver)
|
@@ -328,7 +285,9 @@ class IKServer:
|
|
328 |
use_head=True,
|
329 |
weight=1.0,
|
330 |
)
|
331 |
-
|
|
|
|
|
332 |
{"center":[0.1,0.0,0.35],"radius":0.1},
|
333 |
min_clearance=0.0,
|
334 |
weight=1.0,
|
@@ -345,7 +304,9 @@ class IKServer:
|
|
345 |
use_head=True,
|
346 |
weight=1.0,
|
347 |
)
|
348 |
-
|
|
|
|
|
349 |
{"center":[0.2,0.0,0.35],"radius":0.1},
|
350 |
min_clearance=0.0,
|
351 |
weight=1.0,
|
@@ -359,28 +320,23 @@ class IKServer:
|
|
359 |
tgt = np.array(payload.get("target",[0.0,0.2,0.35]))
|
360 |
self.distance_obj.update_params({"bone_name": self.current_end_effector, "target_points": tgt, "weight": float(payload.get("distance_weight",1.0))})
|
361 |
# collision updates (weight, center, radius, min_clearance)
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
# still update weight / min_clearance so they can be changed next enable
|
375 |
-
mc = payload.get("collision_min_clearance")
|
376 |
-
upd = {"weight": float(payload.get("collision_weight",1.0))}
|
377 |
-
if mc is not None:
|
378 |
-
upd["min_clearance"] = float(mc)
|
379 |
-
self.collision_obj.update_params(upd)
|
380 |
subpoints = self._safe_int(payload.get("subpoints",1), 1, 1, 100)
|
381 |
mandatory, optional = [], []
|
382 |
if payload.get("distance_enabled", True): mandatory.append(self.distance_obj)
|
383 |
-
|
|
|
384 |
if payload.get("bone_zero_enabled", True):
|
385 |
optional.append(BoneZeroRotationObj(weight=float(payload.get("bone_zero_weight",0.05))))
|
386 |
if payload.get("derivative_enabled", True) and subpoints > 1:
|
@@ -421,27 +377,22 @@ class IKServer:
|
|
421 |
def _build_urdf_objectives(self, payload):
|
422 |
tgt = np.array(payload.get("target",[0.3,0.3,0.35]))
|
423 |
self.urdf_distance_obj.update_params({"bone_name": self.urdf_current_end_effector, "target_points": tgt, "weight": float(payload.get("distance_weight",1.0))})
|
424 |
-
|
425 |
-
|
426 |
-
|
427 |
-
|
428 |
-
|
429 |
-
|
430 |
-
|
431 |
-
|
432 |
-
|
433 |
-
|
434 |
-
|
435 |
-
else:
|
436 |
-
mc = payload.get("collision_min_clearance")
|
437 |
-
upd = {"weight": float(payload.get("collision_weight",1.0))}
|
438 |
-
if mc is not None:
|
439 |
-
upd["min_clearance"] = float(mc)
|
440 |
-
self.urdf_collision_obj.update_params(upd)
|
441 |
subpoints = self._safe_int(payload.get("subpoints",1),1,1,100)
|
442 |
mandatory, optional = [], []
|
443 |
if payload.get("distance_enabled", True): mandatory.append(self.urdf_distance_obj)
|
444 |
-
|
|
|
445 |
if payload.get("bone_zero_enabled", True):
|
446 |
optional.append(BoneZeroRotationObj(weight=float(payload.get("bone_zero_weight",0.05))))
|
447 |
if payload.get("derivative_enabled", True) and subpoints > 1:
|
@@ -481,7 +432,7 @@ class IKServer:
|
|
481 |
self.current_controlled_bones = bones
|
482 |
self.current_end_effector = eff
|
483 |
self.current_num_steps = int(num_steps)
|
484 |
-
self.solver = self.
|
485 |
self.initial_rotations = np.zeros(len(self.solver.controlled_bones)*3, dtype=np.float32)
|
486 |
self.best_angles = self.initial_rotations.copy()
|
487 |
self._setup_agent_objectives()
|
@@ -507,7 +458,7 @@ class IKServer:
|
|
507 |
self.urdf_current_controlled_bones = bones
|
508 |
self.urdf_current_end_effector = eff
|
509 |
self.urdf_current_num_steps = int(num_steps)
|
510 |
-
self.urdf_solver = self.
|
511 |
self.urdf_initial_rotations = np.zeros(len(self.urdf_solver.controlled_bones)*3, dtype=np.float32)
|
512 |
self.urdf_best_angles = self.urdf_initial_rotations.copy()
|
513 |
self._setup_urdf_objectives()
|
@@ -639,11 +590,7 @@ class IKServer:
|
|
639 |
return
|
640 |
try:
|
641 |
gc.collect()
|
642 |
-
#
|
643 |
-
rss = self.process.memory_info().rss
|
644 |
-
if rss > 16 * 1024**3:
|
645 |
-
logger.warning("High memory usage detected; invalidating caches")
|
646 |
-
self._invalidate_caches()
|
647 |
except Exception:
|
648 |
pass
|
649 |
self.last_cleanup_time = now
|
@@ -757,10 +704,9 @@ class IKServer:
|
|
757 |
@self.app.post("/reset")
|
758 |
def reset():
|
759 |
with self.solve_lock:
|
760 |
-
self._invalidate_caches()
|
761 |
self._attempt_solver_recovery(False)
|
762 |
self._attempt_solver_recovery(True)
|
763 |
-
return {"status":"ok","message":"solvers
|
764 |
|
765 |
@self.app.get("/health")
|
766 |
def health():
|
@@ -769,8 +715,6 @@ class IKServer:
|
|
769 |
"status":"ok",
|
770 |
"agent_frames": len(self.animation_frames_agent),
|
771 |
"urdf_frames": len(self.animation_frames_urdf),
|
772 |
-
"cache_agent": len(self.solver_cache),
|
773 |
-
"cache_urdf": len(self.urdf_solver_cache),
|
774 |
"agent_fail_count": self.agent_fail_count,
|
775 |
"urdf_fail_count": self.urdf_fail_count,
|
776 |
"last_agent_error": self.last_agent_error,
|
|
|
30 |
DistanceObjTraj,
|
31 |
SphereCollisionPenaltyObjTraj,
|
32 |
)
|
33 |
+
try:
|
34 |
+
from collision_objectives import FastSphereCollisionPenaltyObjTraj as _FastColl
|
35 |
+
except Exception: # fallback if file missing
|
36 |
+
_FastColl = SphereCollisionPenaltyObjTraj
|
37 |
|
38 |
def download_and_setup_files():
|
39 |
os.makedirs("files", exist_ok=True)
|
|
|
82 |
self.max_num_steps_global = 20000
|
83 |
self.max_learning_rate = 5.0
|
84 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
# Animation buffers
|
86 |
self.animation_frames_agent = []
|
87 |
self.animation_frames_urdf = []
|
|
|
142 |
allowed = self.urdf_available_bones if is_urdf else self.available_bones
|
143 |
return [b for b in bones if b in allowed]
|
144 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
def _create_solver(self, bones, is_urdf, num_steps):
|
146 |
if is_urdf:
|
147 |
return InverseKinematicsSolver(
|
|
|
168 |
compute_sdf=False,
|
169 |
)
|
170 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
171 |
def _rebuild_solver_safe(self, is_urdf=False, force_defaults=False):
|
172 |
try:
|
173 |
if is_urdf:
|
|
|
189 |
return False
|
190 |
|
191 |
def _attempt_solver_recovery(self, is_urdf=False):
|
192 |
+
logger.warning("Attempting full solver recovery (reset to defaults)")
|
|
|
193 |
if is_urdf:
|
194 |
self.urdf_current_controlled_bones = self.urdf_default_controlled_bones.copy()
|
195 |
return self._rebuild_solver_safe(True, force_defaults=True)
|
|
|
247 |
self.selectable_bones = [b for b in self.available_bones if b in self.bounds_dict]
|
248 |
self.current_controlled_bones = self.default_controlled_bones.copy()
|
249 |
self.current_end_effector = self.default_end_effector
|
250 |
+
# Direct solver creation (no cache)
|
251 |
+
self.solver = self._create_solver(self.current_controlled_bones, False, self.current_num_steps)
|
252 |
self.initial_rotations = np.zeros(len(self.solver.controlled_bones) * 3, dtype=np.float32)
|
253 |
self.best_angles = self.initial_rotations.copy()
|
254 |
self.mesh_data = load_mesh_data_from_gltf(self.args.gltf_file, self.solver.fk_solver)
|
|
|
271 |
self.urdf_selectable_bones = list(self.urdf_available_bones)
|
272 |
self.urdf_current_controlled_bones = self.urdf_default_controlled_bones.copy()
|
273 |
self.urdf_current_end_effector = self.urdf_default_end_effector
|
274 |
+
self.urdf_solver = self._create_solver(self.urdf_current_controlled_bones, True, self.urdf_current_num_steps)
|
275 |
self.urdf_initial_rotations = np.zeros(len(self.urdf_solver.controlled_bones) * 3, dtype=np.float32)
|
276 |
self.urdf_best_angles = self.urdf_initial_rotations.copy()
|
277 |
self.urdf_mesh_data = load_mesh_data_from_urdf(self.urdf_file, self.urdf_solver.fk_solver)
|
|
|
285 |
use_head=True,
|
286 |
weight=1.0,
|
287 |
)
|
288 |
+
use_fast = os.environ.get("USE_FAST_COLLISION", "1") != "0"
|
289 |
+
CollCls = _FastColl if use_fast else SphereCollisionPenaltyObjTraj
|
290 |
+
self.collision_obj = CollCls(
|
291 |
{"center":[0.1,0.0,0.35],"radius":0.1},
|
292 |
min_clearance=0.0,
|
293 |
weight=1.0,
|
|
|
304 |
use_head=True,
|
305 |
weight=1.0,
|
306 |
)
|
307 |
+
use_fast = os.environ.get("USE_FAST_COLLISION", "1") != "0"
|
308 |
+
CollCls = _FastColl if use_fast else SphereCollisionPenaltyObjTraj
|
309 |
+
self.urdf_collision_obj = CollCls(
|
310 |
{"center":[0.2,0.0,0.35],"radius":0.1},
|
311 |
min_clearance=0.0,
|
312 |
weight=1.0,
|
|
|
320 |
tgt = np.array(payload.get("target",[0.0,0.2,0.35]))
|
321 |
self.distance_obj.update_params({"bone_name": self.current_end_effector, "target_points": tgt, "weight": float(payload.get("distance_weight",1.0))})
|
322 |
# collision updates (weight, center, radius, min_clearance)
|
323 |
+
collision_enabled = payload.get("collision_enabled", False)
|
324 |
+
coll_update = {"weight": float(payload.get("collision_weight",1.0)) if collision_enabled else 0.0}
|
325 |
+
center = payload.get("collision_center")
|
326 |
+
if isinstance(center, (list, tuple)) and len(center)==3:
|
327 |
+
coll_update["center"] = center
|
328 |
+
radius = payload.get("collision_radius")
|
329 |
+
if radius is not None:
|
330 |
+
coll_update["radius"] = float(radius)
|
331 |
+
if "collision_min_clearance" in payload:
|
332 |
+
coll_update["min_clearance"] = float(payload.get("collision_min_clearance", 0.0))
|
333 |
+
# Always update params so compiled graph stays stable
|
334 |
+
self.collision_obj.update_params(coll_update)
|
|
|
|
|
|
|
|
|
|
|
|
|
335 |
subpoints = self._safe_int(payload.get("subpoints",1), 1, 1, 100)
|
336 |
mandatory, optional = [], []
|
337 |
if payload.get("distance_enabled", True): mandatory.append(self.distance_obj)
|
338 |
+
# Always include collision objective to avoid JIT retrace when toggling
|
339 |
+
optional.append(self.collision_obj)
|
340 |
if payload.get("bone_zero_enabled", True):
|
341 |
optional.append(BoneZeroRotationObj(weight=float(payload.get("bone_zero_weight",0.05))))
|
342 |
if payload.get("derivative_enabled", True) and subpoints > 1:
|
|
|
377 |
def _build_urdf_objectives(self, payload):
|
378 |
tgt = np.array(payload.get("target",[0.3,0.3,0.35]))
|
379 |
self.urdf_distance_obj.update_params({"bone_name": self.urdf_current_end_effector, "target_points": tgt, "weight": float(payload.get("distance_weight",1.0))})
|
380 |
+
collision_enabled = payload.get("collision_enabled", False)
|
381 |
+
coll_update = {"weight": float(payload.get("collision_weight",1.0)) if collision_enabled else 0.0}
|
382 |
+
center = payload.get("collision_center")
|
383 |
+
if isinstance(center, (list, tuple)) and len(center)==3:
|
384 |
+
coll_update["center"] = center
|
385 |
+
radius = payload.get("collision_radius")
|
386 |
+
if radius is not None:
|
387 |
+
coll_update["radius"] = float(radius)
|
388 |
+
if "collision_min_clearance" in payload:
|
389 |
+
coll_update["min_clearance"] = float(payload.get("collision_min_clearance", 0.0))
|
390 |
+
self.urdf_collision_obj.update_params(coll_update)
|
|
|
|
|
|
|
|
|
|
|
|
|
391 |
subpoints = self._safe_int(payload.get("subpoints",1),1,1,100)
|
392 |
mandatory, optional = [], []
|
393 |
if payload.get("distance_enabled", True): mandatory.append(self.urdf_distance_obj)
|
394 |
+
# Always include collision objective with weight possibly zero
|
395 |
+
optional.append(self.urdf_collision_obj)
|
396 |
if payload.get("bone_zero_enabled", True):
|
397 |
optional.append(BoneZeroRotationObj(weight=float(payload.get("bone_zero_weight",0.05))))
|
398 |
if payload.get("derivative_enabled", True) and subpoints > 1:
|
|
|
432 |
self.current_controlled_bones = bones
|
433 |
self.current_end_effector = eff
|
434 |
self.current_num_steps = int(num_steps)
|
435 |
+
self.solver = self._create_solver(bones, False, self.current_num_steps)
|
436 |
self.initial_rotations = np.zeros(len(self.solver.controlled_bones)*3, dtype=np.float32)
|
437 |
self.best_angles = self.initial_rotations.copy()
|
438 |
self._setup_agent_objectives()
|
|
|
458 |
self.urdf_current_controlled_bones = bones
|
459 |
self.urdf_current_end_effector = eff
|
460 |
self.urdf_current_num_steps = int(num_steps)
|
461 |
+
self.urdf_solver = self._create_solver(bones, True, self.urdf_current_num_steps)
|
462 |
self.urdf_initial_rotations = np.zeros(len(self.urdf_solver.controlled_bones)*3, dtype=np.float32)
|
463 |
self.urdf_best_angles = self.urdf_initial_rotations.copy()
|
464 |
self._setup_urdf_objectives()
|
|
|
590 |
return
|
591 |
try:
|
592 |
gc.collect()
|
593 |
+
# Removed cache invalidation logic (no caches anymore)
|
|
|
|
|
|
|
|
|
594 |
except Exception:
|
595 |
pass
|
596 |
self.last_cleanup_time = now
|
|
|
704 |
@self.app.post("/reset")
|
705 |
def reset():
|
706 |
with self.solve_lock:
|
|
|
707 |
self._attempt_solver_recovery(False)
|
708 |
self._attempt_solver_recovery(True)
|
709 |
+
return {"status":"ok","message":"solvers reset"}
|
710 |
|
711 |
@self.app.get("/health")
|
712 |
def health():
|
|
|
715 |
"status":"ok",
|
716 |
"agent_frames": len(self.animation_frames_agent),
|
717 |
"urdf_frames": len(self.animation_frames_urdf),
|
|
|
|
|
718 |
"agent_fail_count": self.agent_fail_count,
|
719 |
"urdf_fail_count": self.urdf_fail_count,
|
720 |
"last_agent_error": self.last_agent_error,
|
collision_objectives.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import jax
|
2 |
+
import jax.numpy as jnp
|
3 |
+
import numpy as np
|
4 |
+
from jax.tree_util import register_pytree_node_class
|
5 |
+
from jax_ik.objectives import ObjectiveFunction
|
6 |
+
|
7 |
+
@register_pytree_node_class
|
8 |
+
class FastSphereCollisionPenaltyObjTraj(ObjectiveFunction):
|
9 |
+
"""Vectorized sphere collision penalty over bone segments.
|
10 |
+
|
11 |
+
Weight stored as Python float (static aux) to avoid tracer-to-Python
|
12 |
+
concretization when objective always present with varying weight.
|
13 |
+
"""
|
14 |
+
def __init__(self, sphere_collider: dict, weight: float = 1.0, min_clearance: float = 0.05, segment_radius: float = 0.02):
|
15 |
+
self.center = jnp.asarray(sphere_collider["center"], jnp.float32)
|
16 |
+
self.radius = jnp.asarray(sphere_collider["radius"], jnp.float32)
|
17 |
+
self.min_clearance = jnp.asarray(min_clearance, jnp.float32)
|
18 |
+
self.segment_radius = jnp.asarray(segment_radius, jnp.float32)
|
19 |
+
self.weight = float(weight)
|
20 |
+
|
21 |
+
# pytree impl ------------------------------------------------------------
|
22 |
+
def tree_flatten(self):
|
23 |
+
# weight treated as static (aux) so changing it may retrace but avoids concretization errors
|
24 |
+
return (self.center, self.radius, self.min_clearance, self.segment_radius), (self.weight,)
|
25 |
+
|
26 |
+
@classmethod
|
27 |
+
def tree_unflatten(cls, aux, leaves):
|
28 |
+
(weight,) = aux
|
29 |
+
c, r, mc, sr = leaves
|
30 |
+
return cls(dict(center=c, radius=r), weight, mc, sr)
|
31 |
+
|
32 |
+
# API --------------------------------------------------------------------
|
33 |
+
def update_params(self, p: dict) -> None:
|
34 |
+
if "sphere_collider" in p:
|
35 |
+
collider = p["sphere_collider"]
|
36 |
+
if "center" in collider:
|
37 |
+
self.center = jnp.asarray(collider["center"], jnp.float32)
|
38 |
+
if "radius" in collider:
|
39 |
+
self.radius = jnp.asarray(collider["radius"], jnp.float32)
|
40 |
+
if "center" in p:
|
41 |
+
self.center = jnp.asarray(p["center"], jnp.float32)
|
42 |
+
if "radius" in p:
|
43 |
+
self.radius = jnp.asarray(p["radius"], jnp.float32)
|
44 |
+
if "min_clearance" in p:
|
45 |
+
self.min_clearance = jnp.asarray(p["min_clearance"], jnp.float32)
|
46 |
+
if "segment_radius" in p:
|
47 |
+
self.segment_radius = jnp.asarray(p["segment_radius"], jnp.float32)
|
48 |
+
if "weight" in p:
|
49 |
+
self.weight = float(p["weight"])
|
50 |
+
|
51 |
+
def get_params(self) -> dict:
|
52 |
+
return dict(
|
53 |
+
sphere_collider=dict(center=np.asarray(self.center).tolist(), radius=float(self.radius)),
|
54 |
+
min_clearance=float(self.min_clearance),
|
55 |
+
segment_radius=float(self.segment_radius),
|
56 |
+
weight=float(self.weight),
|
57 |
+
)
|
58 |
+
|
59 |
+
# core -------------------------------------------------------------------
|
60 |
+
def _penalty_single(self, cfg, fk_solver) -> jnp.ndarray:
|
61 |
+
fk = fk_solver.compute_fk_from_angles(cfg) # (N,4,4)
|
62 |
+
heads = fk[:, :3, 3] # (N,3)
|
63 |
+
parents = jnp.asarray(fk_solver.parent_list, jnp.int32) # (N,)
|
64 |
+
seg_mask = (parents >= 0).astype(jnp.float32) # (N,)
|
65 |
+
safe_parent_indices = jnp.where(parents >= 0, parents, 0)
|
66 |
+
p_head = heads[safe_parent_indices]
|
67 |
+
c_head = heads
|
68 |
+
v = c_head - p_head
|
69 |
+
dot_vv = jnp.sum(v * v, axis=1) + 1e-6
|
70 |
+
eff_rad = self.radius + self.min_clearance + self.segment_radius
|
71 |
+
vc = self.center - p_head
|
72 |
+
t = jnp.clip(jnp.sum(vc * v, axis=1) / dot_vv, 0.0, 1.0)
|
73 |
+
closest = p_head + t[:, None] * v
|
74 |
+
dist = jnp.linalg.norm(self.center - closest, axis=1)
|
75 |
+
penetration = jnp.maximum(0.0, eff_rad - dist)
|
76 |
+
return jnp.sum((penetration ** 2) * seg_mask)
|
77 |
+
|
78 |
+
def __call__(self, X: jnp.ndarray, fk_solver) -> jnp.ndarray:
|
79 |
+
if X.ndim == 1:
|
80 |
+
loss = self._penalty_single(X, fk_solver)
|
81 |
+
else:
|
82 |
+
loss = jnp.mean(jax.vmap(lambda c: self._penalty_single(c, fk_solver))(X))
|
83 |
+
return loss * jnp.float32(self.weight)
|
requirements.txt
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
jax-ik==0.1.
|
2 |
wandb
|
3 |
pillow
|
4 |
scipy
|
|
|
1 |
+
jax-ik==0.1.6
|
2 |
wandb
|
3 |
pillow
|
4 |
scipy
|