hvoss-techfak commited on
Commit
8fd7f49
·
1 Parent(s): db1bb68

fixed bug with the collision objective. Created fast collision objective. Also added this to the original library.

Browse files
Files changed (3) hide show
  1. app.py +45 -101
  2. collision_objectives.py +83 -0
  3. requirements.txt +1 -1
app.py CHANGED
@@ -30,6 +30,10 @@ from jax_ik.objectives import (
30
  DistanceObjTraj,
31
  SphereCollisionPenaltyObjTraj,
32
  )
 
 
 
 
33
 
34
  def download_and_setup_files():
35
  os.makedirs("files", exist_ok=True)
@@ -78,13 +82,6 @@ class IKServer:
78
  self.max_num_steps_global = 20000
79
  self.max_learning_rate = 5.0
80
 
81
- # Caches for solver reuse (LRU)
82
- self.solver_cache = {}
83
- self.urdf_solver_cache = {}
84
- self.cache_access_order = []
85
- self.urdf_cache_access_order = []
86
- self.max_cache_size = 5
87
-
88
  # Animation buffers
89
  self.animation_frames_agent = []
90
  self.animation_frames_urdf = []
@@ -145,23 +142,6 @@ class IKServer:
145
  allowed = self.urdf_available_bones if is_urdf else self.available_bones
146
  return [b for b in bones if b in allowed]
147
 
148
- def _cache_key(self, bones, num_steps):
149
- return tuple(sorted(bones)) + (int(num_steps),)
150
-
151
- def _evict_lru(self, is_urdf=False):
152
- cache = self.urdf_solver_cache if is_urdf else self.solver_cache
153
- order = self.urdf_cache_access_order if is_urdf else self.cache_access_order
154
- if not order:
155
- return
156
- k = order.pop(0)
157
- cache.pop(k, None)
158
- gc.collect()
159
-
160
- def _invalidate_caches(self):
161
- self.solver_cache.clear(); self.urdf_solver_cache.clear()
162
- self.cache_access_order.clear(); self.urdf_cache_access_order.clear()
163
- gc.collect()
164
-
165
  def _create_solver(self, bones, is_urdf, num_steps):
166
  if is_urdf:
167
  return InverseKinematicsSolver(
@@ -188,29 +168,6 @@ class IKServer:
188
  compute_sdf=False,
189
  )
190
 
191
- def _get_solver(self, bones, is_urdf=False, num_steps=None):
192
- if num_steps is None:
193
- num_steps = self.args.num_steps
194
- num_steps = self._sanitize_num_steps(num_steps)
195
- key = self._cache_key(bones, num_steps)
196
- cache = self.urdf_solver_cache if is_urdf else self.solver_cache
197
- order = self.urdf_cache_access_order if is_urdf else self.cache_access_order
198
- if key in cache:
199
- if key in order: order.remove(key)
200
- order.append(key)
201
- return cache[key]
202
- if len(cache) >= self.max_cache_size:
203
- self._evict_lru(is_urdf)
204
- try:
205
- solver = self._create_solver(bones, is_urdf, num_steps)
206
- except Exception as e:
207
- logger.warning(f"Primary solver creation failure (is_urdf={is_urdf}): {e}; using defaults")
208
- bones = self.urdf_default_controlled_bones if is_urdf else self.default_controlled_bones
209
- solver = self._create_solver(bones, is_urdf, num_steps)
210
- cache[key] = solver
211
- order.append(key)
212
- return solver
213
-
214
  def _rebuild_solver_safe(self, is_urdf=False, force_defaults=False):
215
  try:
216
  if is_urdf:
@@ -232,8 +189,7 @@ class IKServer:
232
  return False
233
 
234
  def _attempt_solver_recovery(self, is_urdf=False):
235
- logger.warning("Attempting full solver recovery (invalidate caches + defaults)")
236
- self._invalidate_caches()
237
  if is_urdf:
238
  self.urdf_current_controlled_bones = self.urdf_default_controlled_bones.copy()
239
  return self._rebuild_solver_safe(True, force_defaults=True)
@@ -291,7 +247,8 @@ class IKServer:
291
  self.selectable_bones = [b for b in self.available_bones if b in self.bounds_dict]
292
  self.current_controlled_bones = self.default_controlled_bones.copy()
293
  self.current_end_effector = self.default_end_effector
294
- self.solver = self._get_solver(self.current_controlled_bones, is_urdf=False, num_steps=self.current_num_steps)
 
295
  self.initial_rotations = np.zeros(len(self.solver.controlled_bones) * 3, dtype=np.float32)
296
  self.best_angles = self.initial_rotations.copy()
297
  self.mesh_data = load_mesh_data_from_gltf(self.args.gltf_file, self.solver.fk_solver)
@@ -314,7 +271,7 @@ class IKServer:
314
  self.urdf_selectable_bones = list(self.urdf_available_bones)
315
  self.urdf_current_controlled_bones = self.urdf_default_controlled_bones.copy()
316
  self.urdf_current_end_effector = self.urdf_default_end_effector
317
- self.urdf_solver = self._get_solver(self.urdf_current_controlled_bones, is_urdf=True, num_steps=self.urdf_current_num_steps)
318
  self.urdf_initial_rotations = np.zeros(len(self.urdf_solver.controlled_bones) * 3, dtype=np.float32)
319
  self.urdf_best_angles = self.urdf_initial_rotations.copy()
320
  self.urdf_mesh_data = load_mesh_data_from_urdf(self.urdf_file, self.urdf_solver.fk_solver)
@@ -328,7 +285,9 @@ class IKServer:
328
  use_head=True,
329
  weight=1.0,
330
  )
331
- self.collision_obj = SphereCollisionPenaltyObjTraj(
 
 
332
  {"center":[0.1,0.0,0.35],"radius":0.1},
333
  min_clearance=0.0,
334
  weight=1.0,
@@ -345,7 +304,9 @@ class IKServer:
345
  use_head=True,
346
  weight=1.0,
347
  )
348
- self.urdf_collision_obj = SphereCollisionPenaltyObjTraj(
 
 
349
  {"center":[0.2,0.0,0.35],"radius":0.1},
350
  min_clearance=0.0,
351
  weight=1.0,
@@ -359,28 +320,23 @@ class IKServer:
359
  tgt = np.array(payload.get("target",[0.0,0.2,0.35]))
360
  self.distance_obj.update_params({"bone_name": self.current_end_effector, "target_points": tgt, "weight": float(payload.get("distance_weight",1.0))})
361
  # collision updates (weight, center, radius, min_clearance)
362
- if payload.get("collision_enabled", False):
363
- coll_update = {"weight": float(payload.get("collision_weight",1.0))}
364
- center = payload.get("collision_center")
365
- if isinstance(center, (list, tuple)) and len(center)==3:
366
- coll_update["center"] = center
367
- radius = payload.get("collision_radius")
368
- if radius is not None:
369
- coll_update["radius"] = float(radius)
370
- if "collision_min_clearance" in payload:
371
- coll_update["min_clearance"] = float(payload.get("collision_min_clearance", 0.0))
372
- self.collision_obj.update_params(coll_update)
373
- else:
374
- # still update weight / min_clearance so they can be changed next enable
375
- mc = payload.get("collision_min_clearance")
376
- upd = {"weight": float(payload.get("collision_weight",1.0))}
377
- if mc is not None:
378
- upd["min_clearance"] = float(mc)
379
- self.collision_obj.update_params(upd)
380
  subpoints = self._safe_int(payload.get("subpoints",1), 1, 1, 100)
381
  mandatory, optional = [], []
382
  if payload.get("distance_enabled", True): mandatory.append(self.distance_obj)
383
- if payload.get("collision_enabled", False): optional.append(self.collision_obj)
 
384
  if payload.get("bone_zero_enabled", True):
385
  optional.append(BoneZeroRotationObj(weight=float(payload.get("bone_zero_weight",0.05))))
386
  if payload.get("derivative_enabled", True) and subpoints > 1:
@@ -421,27 +377,22 @@ class IKServer:
421
  def _build_urdf_objectives(self, payload):
422
  tgt = np.array(payload.get("target",[0.3,0.3,0.35]))
423
  self.urdf_distance_obj.update_params({"bone_name": self.urdf_current_end_effector, "target_points": tgt, "weight": float(payload.get("distance_weight",1.0))})
424
- if payload.get("collision_enabled", False):
425
- coll_update = {"weight": float(payload.get("collision_weight",1.0))}
426
- center = payload.get("collision_center")
427
- if isinstance(center, (list, tuple)) and len(center)==3:
428
- coll_update["center"] = center
429
- radius = payload.get("collision_radius")
430
- if radius is not None:
431
- coll_update["radius"] = float(radius)
432
- if "collision_min_clearance" in payload:
433
- coll_update["min_clearance"] = float(payload.get("collision_min_clearance", 0.0))
434
- self.urdf_collision_obj.update_params(coll_update)
435
- else:
436
- mc = payload.get("collision_min_clearance")
437
- upd = {"weight": float(payload.get("collision_weight",1.0))}
438
- if mc is not None:
439
- upd["min_clearance"] = float(mc)
440
- self.urdf_collision_obj.update_params(upd)
441
  subpoints = self._safe_int(payload.get("subpoints",1),1,1,100)
442
  mandatory, optional = [], []
443
  if payload.get("distance_enabled", True): mandatory.append(self.urdf_distance_obj)
444
- if payload.get("collision_enabled", False): optional.append(self.urdf_collision_obj)
 
445
  if payload.get("bone_zero_enabled", True):
446
  optional.append(BoneZeroRotationObj(weight=float(payload.get("bone_zero_weight",0.05))))
447
  if payload.get("derivative_enabled", True) and subpoints > 1:
@@ -481,7 +432,7 @@ class IKServer:
481
  self.current_controlled_bones = bones
482
  self.current_end_effector = eff
483
  self.current_num_steps = int(num_steps)
484
- self.solver = self._get_solver(bones, False, self.current_num_steps)
485
  self.initial_rotations = np.zeros(len(self.solver.controlled_bones)*3, dtype=np.float32)
486
  self.best_angles = self.initial_rotations.copy()
487
  self._setup_agent_objectives()
@@ -507,7 +458,7 @@ class IKServer:
507
  self.urdf_current_controlled_bones = bones
508
  self.urdf_current_end_effector = eff
509
  self.urdf_current_num_steps = int(num_steps)
510
- self.urdf_solver = self._get_solver(bones, True, self.urdf_current_num_steps)
511
  self.urdf_initial_rotations = np.zeros(len(self.urdf_solver.controlled_bones)*3, dtype=np.float32)
512
  self.urdf_best_angles = self.urdf_initial_rotations.copy()
513
  self._setup_urdf_objectives()
@@ -639,11 +590,7 @@ class IKServer:
639
  return
640
  try:
641
  gc.collect()
642
- # Memory guard: if RSS > 2GB, clear caches
643
- rss = self.process.memory_info().rss
644
- if rss > 16 * 1024**3:
645
- logger.warning("High memory usage detected; invalidating caches")
646
- self._invalidate_caches()
647
  except Exception:
648
  pass
649
  self.last_cleanup_time = now
@@ -757,10 +704,9 @@ class IKServer:
757
  @self.app.post("/reset")
758
  def reset():
759
  with self.solve_lock:
760
- self._invalidate_caches()
761
  self._attempt_solver_recovery(False)
762
  self._attempt_solver_recovery(True)
763
- return {"status":"ok","message":"solvers & caches reset"}
764
 
765
  @self.app.get("/health")
766
  def health():
@@ -769,8 +715,6 @@ class IKServer:
769
  "status":"ok",
770
  "agent_frames": len(self.animation_frames_agent),
771
  "urdf_frames": len(self.animation_frames_urdf),
772
- "cache_agent": len(self.solver_cache),
773
- "cache_urdf": len(self.urdf_solver_cache),
774
  "agent_fail_count": self.agent_fail_count,
775
  "urdf_fail_count": self.urdf_fail_count,
776
  "last_agent_error": self.last_agent_error,
 
30
  DistanceObjTraj,
31
  SphereCollisionPenaltyObjTraj,
32
  )
33
+ try:
34
+ from collision_objectives import FastSphereCollisionPenaltyObjTraj as _FastColl
35
+ except Exception: # fallback if file missing
36
+ _FastColl = SphereCollisionPenaltyObjTraj
37
 
38
  def download_and_setup_files():
39
  os.makedirs("files", exist_ok=True)
 
82
  self.max_num_steps_global = 20000
83
  self.max_learning_rate = 5.0
84
 
 
 
 
 
 
 
 
85
  # Animation buffers
86
  self.animation_frames_agent = []
87
  self.animation_frames_urdf = []
 
142
  allowed = self.urdf_available_bones if is_urdf else self.available_bones
143
  return [b for b in bones if b in allowed]
144
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  def _create_solver(self, bones, is_urdf, num_steps):
146
  if is_urdf:
147
  return InverseKinematicsSolver(
 
168
  compute_sdf=False,
169
  )
170
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  def _rebuild_solver_safe(self, is_urdf=False, force_defaults=False):
172
  try:
173
  if is_urdf:
 
189
  return False
190
 
191
  def _attempt_solver_recovery(self, is_urdf=False):
192
+ logger.warning("Attempting full solver recovery (reset to defaults)")
 
193
  if is_urdf:
194
  self.urdf_current_controlled_bones = self.urdf_default_controlled_bones.copy()
195
  return self._rebuild_solver_safe(True, force_defaults=True)
 
247
  self.selectable_bones = [b for b in self.available_bones if b in self.bounds_dict]
248
  self.current_controlled_bones = self.default_controlled_bones.copy()
249
  self.current_end_effector = self.default_end_effector
250
+ # Direct solver creation (no cache)
251
+ self.solver = self._create_solver(self.current_controlled_bones, False, self.current_num_steps)
252
  self.initial_rotations = np.zeros(len(self.solver.controlled_bones) * 3, dtype=np.float32)
253
  self.best_angles = self.initial_rotations.copy()
254
  self.mesh_data = load_mesh_data_from_gltf(self.args.gltf_file, self.solver.fk_solver)
 
271
  self.urdf_selectable_bones = list(self.urdf_available_bones)
272
  self.urdf_current_controlled_bones = self.urdf_default_controlled_bones.copy()
273
  self.urdf_current_end_effector = self.urdf_default_end_effector
274
+ self.urdf_solver = self._create_solver(self.urdf_current_controlled_bones, True, self.urdf_current_num_steps)
275
  self.urdf_initial_rotations = np.zeros(len(self.urdf_solver.controlled_bones) * 3, dtype=np.float32)
276
  self.urdf_best_angles = self.urdf_initial_rotations.copy()
277
  self.urdf_mesh_data = load_mesh_data_from_urdf(self.urdf_file, self.urdf_solver.fk_solver)
 
285
  use_head=True,
286
  weight=1.0,
287
  )
288
+ use_fast = os.environ.get("USE_FAST_COLLISION", "1") != "0"
289
+ CollCls = _FastColl if use_fast else SphereCollisionPenaltyObjTraj
290
+ self.collision_obj = CollCls(
291
  {"center":[0.1,0.0,0.35],"radius":0.1},
292
  min_clearance=0.0,
293
  weight=1.0,
 
304
  use_head=True,
305
  weight=1.0,
306
  )
307
+ use_fast = os.environ.get("USE_FAST_COLLISION", "1") != "0"
308
+ CollCls = _FastColl if use_fast else SphereCollisionPenaltyObjTraj
309
+ self.urdf_collision_obj = CollCls(
310
  {"center":[0.2,0.0,0.35],"radius":0.1},
311
  min_clearance=0.0,
312
  weight=1.0,
 
320
  tgt = np.array(payload.get("target",[0.0,0.2,0.35]))
321
  self.distance_obj.update_params({"bone_name": self.current_end_effector, "target_points": tgt, "weight": float(payload.get("distance_weight",1.0))})
322
  # collision updates (weight, center, radius, min_clearance)
323
+ collision_enabled = payload.get("collision_enabled", False)
324
+ coll_update = {"weight": float(payload.get("collision_weight",1.0)) if collision_enabled else 0.0}
325
+ center = payload.get("collision_center")
326
+ if isinstance(center, (list, tuple)) and len(center)==3:
327
+ coll_update["center"] = center
328
+ radius = payload.get("collision_radius")
329
+ if radius is not None:
330
+ coll_update["radius"] = float(radius)
331
+ if "collision_min_clearance" in payload:
332
+ coll_update["min_clearance"] = float(payload.get("collision_min_clearance", 0.0))
333
+ # Always update params so compiled graph stays stable
334
+ self.collision_obj.update_params(coll_update)
 
 
 
 
 
 
335
  subpoints = self._safe_int(payload.get("subpoints",1), 1, 1, 100)
336
  mandatory, optional = [], []
337
  if payload.get("distance_enabled", True): mandatory.append(self.distance_obj)
338
+ # Always include collision objective to avoid JIT retrace when toggling
339
+ optional.append(self.collision_obj)
340
  if payload.get("bone_zero_enabled", True):
341
  optional.append(BoneZeroRotationObj(weight=float(payload.get("bone_zero_weight",0.05))))
342
  if payload.get("derivative_enabled", True) and subpoints > 1:
 
377
  def _build_urdf_objectives(self, payload):
378
  tgt = np.array(payload.get("target",[0.3,0.3,0.35]))
379
  self.urdf_distance_obj.update_params({"bone_name": self.urdf_current_end_effector, "target_points": tgt, "weight": float(payload.get("distance_weight",1.0))})
380
+ collision_enabled = payload.get("collision_enabled", False)
381
+ coll_update = {"weight": float(payload.get("collision_weight",1.0)) if collision_enabled else 0.0}
382
+ center = payload.get("collision_center")
383
+ if isinstance(center, (list, tuple)) and len(center)==3:
384
+ coll_update["center"] = center
385
+ radius = payload.get("collision_radius")
386
+ if radius is not None:
387
+ coll_update["radius"] = float(radius)
388
+ if "collision_min_clearance" in payload:
389
+ coll_update["min_clearance"] = float(payload.get("collision_min_clearance", 0.0))
390
+ self.urdf_collision_obj.update_params(coll_update)
 
 
 
 
 
 
391
  subpoints = self._safe_int(payload.get("subpoints",1),1,1,100)
392
  mandatory, optional = [], []
393
  if payload.get("distance_enabled", True): mandatory.append(self.urdf_distance_obj)
394
+ # Always include collision objective with weight possibly zero
395
+ optional.append(self.urdf_collision_obj)
396
  if payload.get("bone_zero_enabled", True):
397
  optional.append(BoneZeroRotationObj(weight=float(payload.get("bone_zero_weight",0.05))))
398
  if payload.get("derivative_enabled", True) and subpoints > 1:
 
432
  self.current_controlled_bones = bones
433
  self.current_end_effector = eff
434
  self.current_num_steps = int(num_steps)
435
+ self.solver = self._create_solver(bones, False, self.current_num_steps)
436
  self.initial_rotations = np.zeros(len(self.solver.controlled_bones)*3, dtype=np.float32)
437
  self.best_angles = self.initial_rotations.copy()
438
  self._setup_agent_objectives()
 
458
  self.urdf_current_controlled_bones = bones
459
  self.urdf_current_end_effector = eff
460
  self.urdf_current_num_steps = int(num_steps)
461
+ self.urdf_solver = self._create_solver(bones, True, self.urdf_current_num_steps)
462
  self.urdf_initial_rotations = np.zeros(len(self.urdf_solver.controlled_bones)*3, dtype=np.float32)
463
  self.urdf_best_angles = self.urdf_initial_rotations.copy()
464
  self._setup_urdf_objectives()
 
590
  return
591
  try:
592
  gc.collect()
593
+ # Removed cache invalidation logic (no caches anymore)
 
 
 
 
594
  except Exception:
595
  pass
596
  self.last_cleanup_time = now
 
704
  @self.app.post("/reset")
705
  def reset():
706
  with self.solve_lock:
 
707
  self._attempt_solver_recovery(False)
708
  self._attempt_solver_recovery(True)
709
+ return {"status":"ok","message":"solvers reset"}
710
 
711
  @self.app.get("/health")
712
  def health():
 
715
  "status":"ok",
716
  "agent_frames": len(self.animation_frames_agent),
717
  "urdf_frames": len(self.animation_frames_urdf),
 
 
718
  "agent_fail_count": self.agent_fail_count,
719
  "urdf_fail_count": self.urdf_fail_count,
720
  "last_agent_error": self.last_agent_error,
collision_objectives.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import jax
2
+ import jax.numpy as jnp
3
+ import numpy as np
4
+ from jax.tree_util import register_pytree_node_class
5
+ from jax_ik.objectives import ObjectiveFunction
6
+
7
+ @register_pytree_node_class
8
+ class FastSphereCollisionPenaltyObjTraj(ObjectiveFunction):
9
+ """Vectorized sphere collision penalty over bone segments.
10
+
11
+ Weight stored as Python float (static aux) to avoid tracer-to-Python
12
+ concretization when objective always present with varying weight.
13
+ """
14
+ def __init__(self, sphere_collider: dict, weight: float = 1.0, min_clearance: float = 0.05, segment_radius: float = 0.02):
15
+ self.center = jnp.asarray(sphere_collider["center"], jnp.float32)
16
+ self.radius = jnp.asarray(sphere_collider["radius"], jnp.float32)
17
+ self.min_clearance = jnp.asarray(min_clearance, jnp.float32)
18
+ self.segment_radius = jnp.asarray(segment_radius, jnp.float32)
19
+ self.weight = float(weight)
20
+
21
+ # pytree impl ------------------------------------------------------------
22
+ def tree_flatten(self):
23
+ # weight treated as static (aux) so changing it may retrace but avoids concretization errors
24
+ return (self.center, self.radius, self.min_clearance, self.segment_radius), (self.weight,)
25
+
26
+ @classmethod
27
+ def tree_unflatten(cls, aux, leaves):
28
+ (weight,) = aux
29
+ c, r, mc, sr = leaves
30
+ return cls(dict(center=c, radius=r), weight, mc, sr)
31
+
32
+ # API --------------------------------------------------------------------
33
+ def update_params(self, p: dict) -> None:
34
+ if "sphere_collider" in p:
35
+ collider = p["sphere_collider"]
36
+ if "center" in collider:
37
+ self.center = jnp.asarray(collider["center"], jnp.float32)
38
+ if "radius" in collider:
39
+ self.radius = jnp.asarray(collider["radius"], jnp.float32)
40
+ if "center" in p:
41
+ self.center = jnp.asarray(p["center"], jnp.float32)
42
+ if "radius" in p:
43
+ self.radius = jnp.asarray(p["radius"], jnp.float32)
44
+ if "min_clearance" in p:
45
+ self.min_clearance = jnp.asarray(p["min_clearance"], jnp.float32)
46
+ if "segment_radius" in p:
47
+ self.segment_radius = jnp.asarray(p["segment_radius"], jnp.float32)
48
+ if "weight" in p:
49
+ self.weight = float(p["weight"])
50
+
51
+ def get_params(self) -> dict:
52
+ return dict(
53
+ sphere_collider=dict(center=np.asarray(self.center).tolist(), radius=float(self.radius)),
54
+ min_clearance=float(self.min_clearance),
55
+ segment_radius=float(self.segment_radius),
56
+ weight=float(self.weight),
57
+ )
58
+
59
+ # core -------------------------------------------------------------------
60
+ def _penalty_single(self, cfg, fk_solver) -> jnp.ndarray:
61
+ fk = fk_solver.compute_fk_from_angles(cfg) # (N,4,4)
62
+ heads = fk[:, :3, 3] # (N,3)
63
+ parents = jnp.asarray(fk_solver.parent_list, jnp.int32) # (N,)
64
+ seg_mask = (parents >= 0).astype(jnp.float32) # (N,)
65
+ safe_parent_indices = jnp.where(parents >= 0, parents, 0)
66
+ p_head = heads[safe_parent_indices]
67
+ c_head = heads
68
+ v = c_head - p_head
69
+ dot_vv = jnp.sum(v * v, axis=1) + 1e-6
70
+ eff_rad = self.radius + self.min_clearance + self.segment_radius
71
+ vc = self.center - p_head
72
+ t = jnp.clip(jnp.sum(vc * v, axis=1) / dot_vv, 0.0, 1.0)
73
+ closest = p_head + t[:, None] * v
74
+ dist = jnp.linalg.norm(self.center - closest, axis=1)
75
+ penetration = jnp.maximum(0.0, eff_rad - dist)
76
+ return jnp.sum((penetration ** 2) * seg_mask)
77
+
78
+ def __call__(self, X: jnp.ndarray, fk_solver) -> jnp.ndarray:
79
+ if X.ndim == 1:
80
+ loss = self._penalty_single(X, fk_solver)
81
+ else:
82
+ loss = jnp.mean(jax.vmap(lambda c: self._penalty_single(c, fk_solver))(X))
83
+ return loss * jnp.float32(self.weight)
requirements.txt CHANGED
@@ -1,4 +1,4 @@
1
- jax-ik==0.1.5
2
  wandb
3
  pillow
4
  scipy
 
1
+ jax-ik==0.1.6
2
  wandb
3
  pillow
4
  scipy