aswerdlow commited on
Commit
3a60a49
·
1 Parent(s): 2f30910

Fixed demo instructions & yaml config

Browse files
README.md CHANGED
@@ -45,19 +45,13 @@ See [TRAIN.md](docs/TRAIN.md) for training commands.
45
 
46
  ## Inference
47
 
48
- <!-- Inference demo for **TODO**.
49
- ```
50
- TODO
51
- ``` -->
52
- <!-- <img src="docs/todo.png" width="1000"> -->
53
-
54
-
55
  Interactive demo:
 
 
 
 
 
56
  ```
57
- python demo/server.py
58
- python demo/client_simple_fasthtml.py
59
- ```
60
-
61
 
62
  ## Training
63
 
@@ -71,11 +65,12 @@ See [EVAL.md](docs/EVAL.md) for details.
71
  ### Citation
72
  To cite our work, please use the following:
73
  ```
74
- @article{TODO,
75
- title={TODO},
76
- author={TODO},
77
- journal={arXiv preprint arXiv:TODO},
78
- year={TODO}
 
79
  }
80
  ```
81
 
 
45
 
46
  ## Inference
47
 
 
 
 
 
 
 
 
48
  Interactive demo:
49
+ ```bash
50
+ mkdir -p ./ckpts/unidisc_interleaved
51
+ huggingface-cli download aswerdlow/unidisc_interleaved --local-dir ./ckpts/unidisc_interleaved
52
+ uv run demo/server.py experiments='[large_scale_train,large_scale_train_high_res_interleaved,eval_unified,large_scale_high_res_interleaved_inference]' trainer.load_from_state_dict="./ckpts/unidisc_interleaved/unidisc_interleaved.pt"
53
+ uv run demo/client.py
54
  ```
 
 
 
 
55
 
56
  ## Training
57
 
 
65
  ### Citation
66
  To cite our work, please use the following:
67
  ```
68
+ @article{swerdlow2025unidisc,
69
+ title = {Unified Multimodal Discrete Diffusion},
70
+ author = {Swerdlow, Alexander and Prabhudesai, Mihir and Gandhi, Siddharth and Pathak, Deepak and Fragkiadaki, Katerina},
71
+ journal = {arXiv preprint arXiv:2503.20853},
72
+ year = {2025},
73
+ doi = {10.48550/arXiv.2503.20853},
74
  }
75
  ```
76
 
configs/config.yaml CHANGED
@@ -293,103 +293,6 @@ hydra:
293
  subdir: ${hydra.job.id}
294
  job:
295
  chdir: true
296
- # launcher:
297
- # name: ${get_slurm_name:}
298
- # # See https://hydra.cc/docs/configure_hydra/workdir/
299
- # submitit_folder: ${hydra.sweep.dir}/%j
300
- # nodes: ${nodes} # Number of nodes. This value is *per* node
301
- # mem_gb: ${eval:'${mem_per_gpu} * ${trainer.devices}'} # 40GB per gpu. This value is *per* node
302
- # gpus_per_node: ${trainer.devices}
303
- # partition: ${partition}
304
- # constraint: ${constraint}
305
- # exclude: ${exclude_nodes:}
306
-
307
- # timeout_min: ${timeout_min}
308
- # max_num_timeout: 12 # Num requeue exlcuding pre-emptions
309
- # comment: aswerdlo
310
- # stderr_to_stdout: true
311
-
312
- # # Be careful with changing anything below.
313
- # # see: https://github.com/stas00/ml-engineering/tree/master/training/fault-tolerance#approach-b2-choosing-which-process-to-send-the-signal-to
314
- # # see: https://github.com/huggingface/accelerate/issues/1918
315
-
316
- # # The accelerate launcher w/1 initial process and then spawn 1 per GPU
317
- # tasks_per_node: 1
318
- # cpus_per_task: ${eval:'${cpus_per_gpu} * ${trainer.devices}'}
319
- # python: |
320
- # bash -c "torchrun --nnodes $SLURM_NNODES --nproc_per_node $SLURM_GPUS_PER_NODE --role \$(hostname -s|tr -dc '0-9'): --node_rank \$SLURM_PROCID --max-restarts=2 --rdzv_id $RANDOM --rdzv_backend c10d --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \
321
-
322
- # # python: "${getpythoncmd:}"
323
- # # tasks_per_node: ${devices}
324
- # # cpus_per_task: 8
325
- # # python: 'python'
326
-
327
- # python_suffix: ' --dummy-arg $SLURM_JOB_ID" &'
328
- # signal: 'B:USR2@360'
329
- # post_srun_commands:
330
- # - ''
331
- # - wait
332
-
333
- # srun_args:
334
- # - '--jobid $SLURM_JOB_ID'
335
-
336
- # setup:
337
- # - |
338
- # export MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
339
- # export MASTER_PORT=$(( ($SLURM_JOB_ID % 20001) + 30000 ))
340
- # export NUM_PROCESSES=$((SLURM_NNODES * SLURM_GPUS_PER_NODE))
341
- # export NCCL_DEBUG=INFO
342
- # export NCCL_NSOCKS_PERTHREAD=4
343
- # export NCCL_SOCKET_NTHREADS=2
344
- # export OMP_NUM_THREADS=2
345
- # export PYTHONUNBUFFERED=1
346
- # export STDOUT_PATH=$(scontrol show job $SLURM_JOB_ID | grep -oP "StdOut=\K[^ ]+")
347
- # export LOCAL_JOB_FOLDER=$(dirname $STDOUT_PATH)
348
- # export NCCL_TOPO_DUMP_FILE="$LOCAL_JOB_FOLDER/nccl_topo.xml"
349
- # if [ -n "$SLURM_RESTART_COUNT" ]; then
350
- # export RESTART_COUNT=$SLURM_RESTART_COUNT
351
- # else
352
- # export RESTART_COUNT=0
353
- # fi
354
- # export MAIN_LOG_PATH="$LOCAL_JOB_FOLDER/log_$RESTART_COUNT.txt"
355
-
356
- # mkdir -p $LOCAL_JOB_FOLDER
357
- # printenv > "$LOCAL_JOB_FOLDER"/env_"$SLURM_LOCALID_$RESTART_COUNT.txt"
358
-
359
- # echo "ibstatus: $(ibstatus)"
360
- # echo "ibdev2netdev: $(ibdev2netdev)"
361
- # echo "rdma device: $(rdma link)"
362
- # echo "environment: $(env | grep NCCL)"
363
- # echo "NUM_PROCESSES: $NUM_PROCESSES, SLURM_NNODES: $SLURM_NNODES SLURM_GPUS_PER_NODE: $SLURM_GPUS_PER_NODE"
364
- # echo "NODE_ID: $SLURM_NODEID, SLURM_PROCID: $SLURM_PROCID, MASTER_ADDR: $MASTER_ADDR, MASTER_PORT: $MASTER_PORT"
365
- # echo "PWD: $PWD, LOCAL_JOB_FOLDER: $LOCAL_JOB_FOLDER, MAIN_LOG_PATH: $MAIN_LOG_PATH"
366
-
367
- # trap 'echo "SIGUSR2 received for $SLURM_JOB_ID"; \
368
- # if [ -n "$SLURM_ARRAY_JOB_ID" ]; then echo "SLURM_ARRAY_JOB_ID: $SLURM_ARRAY_JOB_ID"; fi; \
369
- # if [ -n "$SLURM_ARRAY_TASK_ID" ]; then echo "SLURM_ARRAY_TASK_ID: $SLURM_ARRAY_TASK_ID"; fi; \
370
- # # ps auxww | grep $USER; \
371
- # pid=$(pgrep -u $USER -f "python.*(accelerate|torchrun|deepspeed|distributed\.run).*dummy-arg $SLURM_JOB_ID"); \
372
- # echo "Found parent PIDs: $pid"; \
373
- # for p in $pid; do \
374
- # echo "Parent PID has cmd: $(ps -p $p -o cmd=)"; \
375
- # children=$(pgrep -P $p); \
376
- # echo "Children: $children"; \
377
- # if [ -n "$children" ]; then \
378
- # for child in $children; do \
379
- # ppid=$(ps -o ppid= -p $child | tr -d " ")
380
- # if [ "$ppid" -eq "$p" ]; then
381
- # echo "Killing direct child process: PID $child with cmd: $(ps -p $child -o cmd=)"
382
- # kill -USR2 $child &
383
- # else
384
- # echo "Skipping non-direct child process: PID $child with PPID $ppid"
385
- # fi
386
- # done; \
387
- # echo "Sent kill signals to children of $p"; \
388
- # else \
389
- # echo "No children found for $p"; \
390
- # fi; \
391
- # done; \
392
- # wait;' SIGUSR2
393
 
394
  checkpointing:
395
  # Use custom `save_dir` if, e.g., saving to S3 bucket, otherwise leave this parameter as is
@@ -447,5 +350,8 @@ data:
447
  add_image_gen_tokens: false
448
  use_slow_tokenizer: false
449
  add_image_token: false
 
 
 
450
 
451
  dummyarg: null
 
293
  subdir: ${hydra.job.id}
294
  job:
295
  chdir: true
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
296
 
297
  checkpointing:
298
  # Use custom `save_dir` if, e.g., saving to S3 bucket, otherwise leave this parameter as is
 
350
  add_image_gen_tokens: false
351
  use_slow_tokenizer: false
352
  add_image_token: false
353
+ train: "unset_dataset"
354
+ val: "unset_dataset"
355
+ tokenizer_name_or_path: "NousResearch/Llama-2-7b-hf"
356
 
357
  dummyarg: null
configs/config_empty.yaml DELETED
@@ -1,8 +0,0 @@
1
- defaults:
2
- - _self_
3
- - /model: small
4
- - /experiments: []
5
-
6
- # from omegaconf import OmegaConf
7
- # with open("config.yaml", "w") as fp:
8
- # OmegaConf.save(config=config, f=fp.name)
 
 
 
 
 
 
 
 
 
configs/slurm_example.yaml ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This is an example slurm launcher config that should be added to the main config.yaml file under the hydra section. This cannot be run directly.
2
+ hydra:
3
+ launcher:
4
+ name: ${get_slurm_name:}
5
+ # See https://hydra.cc/docs/configure_hydra/workdir/
6
+ submitit_folder: ${hydra.sweep.dir}/%j
7
+ nodes: ${nodes} # Number of nodes. This value is *per* node
8
+ mem_gb: ${eval:'${mem_per_gpu} * ${trainer.devices}'} # 40GB per gpu. This value is *per* node
9
+ gpus_per_node: ${trainer.devices}
10
+ partition: ${partition}
11
+ constraint: ${constraint}
12
+ exclude: ${exclude_nodes:}
13
+
14
+ timeout_min: ${timeout_min}
15
+ max_num_timeout: 12 # Num requeue exlcuding pre-emptions
16
+ comment: aswerdlo
17
+ stderr_to_stdout: true
18
+
19
+ # Be careful with changing anything below.
20
+ # see: https://github.com/stas00/ml-engineering/tree/master/training/fault-tolerance#approach-b2-choosing-which-process-to-send-the-signal-to
21
+ # see: https://github.com/huggingface/accelerate/issues/1918
22
+
23
+ # The accelerate launcher w/1 initial process and then spawn 1 per GPU
24
+ tasks_per_node: 1
25
+ cpus_per_task: ${eval:'${cpus_per_gpu} * ${trainer.devices}'}
26
+ python: |
27
+ bash -c "torchrun --nnodes $SLURM_NNODES --nproc_per_node $SLURM_GPUS_PER_NODE --role \$(hostname -s|tr -dc '0-9'): --node_rank \$SLURM_PROCID --max-restarts=2 --rdzv_id $RANDOM --rdzv_backend c10d --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \
28
+
29
+ python_suffix: ' --dummy-arg $SLURM_JOB_ID" &'
30
+ signal: 'B:USR2@360'
31
+ post_srun_commands:
32
+ - ''
33
+ - wait
34
+
35
+ srun_args:
36
+ - '--jobid $SLURM_JOB_ID'
37
+
38
+ setup:
39
+ - |
40
+ export MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
41
+ export MASTER_PORT=$(( ($SLURM_JOB_ID % 20001) + 30000 ))
42
+ export NUM_PROCESSES=$((SLURM_NNODES * SLURM_GPUS_PER_NODE))
43
+ export NCCL_DEBUG=INFO
44
+ export NCCL_NSOCKS_PERTHREAD=4
45
+ export NCCL_SOCKET_NTHREADS=2
46
+ export OMP_NUM_THREADS=2
47
+ export PYTHONUNBUFFERED=1
48
+ export STDOUT_PATH=$(scontrol show job $SLURM_JOB_ID | grep -oP "StdOut=\K[^ ]+")
49
+ export LOCAL_JOB_FOLDER=$(dirname $STDOUT_PATH)
50
+ export NCCL_TOPO_DUMP_FILE="$LOCAL_JOB_FOLDER/nccl_topo.xml"
51
+ if [ -n "$SLURM_RESTART_COUNT" ]; then
52
+ export RESTART_COUNT=$SLURM_RESTART_COUNT
53
+ else
54
+ export RESTART_COUNT=0
55
+ fi
56
+ export MAIN_LOG_PATH="$LOCAL_JOB_FOLDER/log_$RESTART_COUNT.txt"
57
+
58
+ mkdir -p $LOCAL_JOB_FOLDER
59
+ printenv > "$LOCAL_JOB_FOLDER"/env_"$SLURM_LOCALID_$RESTART_COUNT.txt"
60
+
61
+ echo "ibstatus: $(ibstatus)"
62
+ echo "ibdev2netdev: $(ibdev2netdev)"
63
+ echo "rdma device: $(rdma link)"
64
+ echo "environment: $(env | grep NCCL)"
65
+ echo "NUM_PROCESSES: $NUM_PROCESSES, SLURM_NNODES: $SLURM_NNODES SLURM_GPUS_PER_NODE: $SLURM_GPUS_PER_NODE"
66
+ echo "NODE_ID: $SLURM_NODEID, SLURM_PROCID: $SLURM_PROCID, MASTER_ADDR: $MASTER_ADDR, MASTER_PORT: $MASTER_PORT"
67
+ echo "PWD: $PWD, LOCAL_JOB_FOLDER: $LOCAL_JOB_FOLDER, MAIN_LOG_PATH: $MAIN_LOG_PATH"
68
+
69
+ trap 'echo "SIGUSR2 received for $SLURM_JOB_ID"; \
70
+ if [ -n "$SLURM_ARRAY_JOB_ID" ]; then echo "SLURM_ARRAY_JOB_ID: $SLURM_ARRAY_JOB_ID"; fi; \
71
+ if [ -n "$SLURM_ARRAY_TASK_ID" ]; then echo "SLURM_ARRAY_TASK_ID: $SLURM_ARRAY_TASK_ID"; fi; \
72
+ # ps auxww | grep $USER; \
73
+ pid=$(pgrep -u $USER -f "python.*(accelerate|torchrun|deepspeed|distributed\.run).*dummy-arg $SLURM_JOB_ID"); \
74
+ echo "Found parent PIDs: $pid"; \
75
+ for p in $pid; do \
76
+ echo "Parent PID has cmd: $(ps -p $p -o cmd=)"; \
77
+ children=$(pgrep -P $p); \
78
+ echo "Children: $children"; \
79
+ if [ -n "$children" ]; then \
80
+ for child in $children; do \
81
+ ppid=$(ps -o ppid= -p $child | tr -d " ")
82
+ if [ "$ppid" -eq "$p" ]; then
83
+ echo "Killing direct child process: PID $child with cmd: $(ps -p $child -o cmd=)"
84
+ kill -USR2 $child &
85
+ else
86
+ echo "Skipping non-direct child process: PID $child with PPID $ppid"
87
+ fi
88
+ done; \
89
+ echo "Sent kill signals to children of $p"; \
90
+ else \
91
+ echo "No children found for $p"; \
92
+ fi; \
93
+ done; \
94
+ wait;' SIGUSR2
demo/assets/boat.jpg ADDED

Git LFS Details

  • SHA256: 76b5ab9ce3c9fb282d3eb53f812d0afd4f972fb8b2b6d8ce771022fbda928f39
  • Pointer size: 131 Bytes
  • Size of remote file: 274 kB
demo/assets/building.jpg ADDED

Git LFS Details

  • SHA256: c3b8fe94b65f17ea90b6b158c7e78ef80155aeaae341de3f76ca00f8fb763eb9
  • Pointer size: 130 Bytes
  • Size of remote file: 14.2 kB
demo/assets/dog.jpg ADDED

Git LFS Details

  • SHA256: 030ca382b90b831bdaa1b52db905dd1ae98beef57c7d1504e85ca1ffc2b5f23f
  • Pointer size: 129 Bytes
  • Size of remote file: 9.89 kB
demo/assets/dog_grass.jpg ADDED

Git LFS Details

  • SHA256: 2d9e610f8a7dee65e894ad964cd5a75707053475d076c23b82b5adab5f7adf1e
  • Pointer size: 131 Bytes
  • Size of remote file: 108 kB
demo/assets/mountain.jpg ADDED

Git LFS Details

  • SHA256: 14cbc3df4f8c9b4b0681fdda773cadb2867379116370eceb988ebf5482d4279b
  • Pointer size: 130 Bytes
  • Size of remote file: 10 kB
demo/assets/pickup.jpg ADDED

Git LFS Details

  • SHA256: 30daa8b78d8eee141ba9691f36e25d69d54627378e75200b750ac21de20766c7
  • Pointer size: 131 Bytes
  • Size of remote file: 582 kB
demo/assets/tajmahal.jpg ADDED

Git LFS Details

  • SHA256: c5d233d65f537bff66cd7e523d5ba3f2b1fdfc30c6302c2e05498c52e0b97258
  • Pointer size: 131 Bytes
  • Size of remote file: 423 kB
demo/assets/venice.jpg ADDED

Git LFS Details

  • SHA256: f929ae5c19233571960d81d4ba36aebbb75acb1a07d4ea61c2d581c4d826eeba
  • Pointer size: 131 Bytes
  • Size of remote file: 487 kB
demo/client.py CHANGED
@@ -571,16 +571,17 @@ def post(
571
  port: int | None = 8001,
572
  reward_models: str | None = "False"
573
  ):
574
- messages = []
575
  if user_input:
576
- messages.append({"type": "text", "text": user_input})
577
 
 
578
  current_image = None
579
  if uploaded_file is not None and uploaded_file.filename != "No image":
580
  current_image = process(Image.open(io.BytesIO(uploaded_file.file.read())), int(resolution))
581
  img_data = encode_image(current_image)["url"]
582
 
583
- messages.append({
584
  "type": "image_url",
585
  "image_url": {"url": img_data},
586
  "is_mask": False
@@ -589,12 +590,15 @@ def post(
589
  if mask_data is not None and len(mask_data) > 0:
590
  mask_array = get_boolean_mask(mask_data)
591
  mask_data_url = encode_array_image(mask_array)["url"]
592
- messages.append({
593
  "type": "image_url",
594
  "image_url": {"url": mask_data_url},
595
  "is_mask": True
596
  })
597
 
 
 
 
598
  config_payload = {
599
  "max_tokens": int(max_tokens),
600
  "resolution": int(resolution),
@@ -608,7 +612,7 @@ def post(
608
  }
609
 
610
  payload = {
611
- "messages": [{"role": "user", "content": messages}],
612
  "model": "unidisc",
613
  **config_payload
614
  }
 
571
  port: int | None = 8001,
572
  reward_models: str | None = "False"
573
  ):
574
+ payload_messages = []
575
  if user_input:
576
+ payload_messages.append({"role": "user", "content": [{"type": "text", "text": user_input}]})
577
 
578
+ image_message_content = []
579
  current_image = None
580
  if uploaded_file is not None and uploaded_file.filename != "No image":
581
  current_image = process(Image.open(io.BytesIO(uploaded_file.file.read())), int(resolution))
582
  img_data = encode_image(current_image)["url"]
583
 
584
+ image_message_content.append({
585
  "type": "image_url",
586
  "image_url": {"url": img_data},
587
  "is_mask": False
 
590
  if mask_data is not None and len(mask_data) > 0:
591
  mask_array = get_boolean_mask(mask_data)
592
  mask_data_url = encode_array_image(mask_array)["url"]
593
+ image_message_content.append({
594
  "type": "image_url",
595
  "image_url": {"url": mask_data_url},
596
  "is_mask": True
597
  })
598
 
599
+ if image_message_content:
600
+ payload_messages.append({"role": "assistant", "content": image_message_content})
601
+
602
  config_payload = {
603
  "max_tokens": int(max_tokens),
604
  "resolution": int(resolution),
 
612
  }
613
 
614
  payload = {
615
+ "messages": payload_messages,
616
  "model": "unidisc",
617
  **config_payload
618
  }
demo/inference.py CHANGED
@@ -386,7 +386,16 @@ def inference(
386
  disable_mask_after_eos=True
387
  )
388
 
389
- img_samples_list = torch.cat(img_samples_list, dim=0)
 
 
 
 
 
 
 
 
 
390
  reward_config = config.eval.auto_enhance_reward_config
391
  rewards, raw_rewards = model.get_rewards(reward_config, img_samples_list, text_samples_list, batch=gen_batch, return_raw_rewards=True)
392
 
 
386
  disable_mask_after_eos=True
387
  )
388
 
389
+ text_samples_list = [x.replace("You are a highly intelligent multimodal AI with the ability to analyze and generate images.", "").removeprefix(" ") for x in text_samples_list]
390
+ if isinstance(img_samples_list[0], Image.Image):
391
+ img_tensors = []
392
+ for img in img_samples_list:
393
+ img_tensor = torch.tensor(np.array(img)).permute(2, 0, 1).float() / 255.0
394
+ img_tensors.append(img_tensor.unsqueeze(0))
395
+ img_samples_list = torch.cat(img_tensors, dim=0)
396
+ else:
397
+ img_samples_list = torch.cat(img_samples_list, dim=0)
398
+
399
  reward_config = config.eval.auto_enhance_reward_config
400
  rewards, raw_rewards = model.get_rewards(reward_config, img_samples_list, text_samples_list, batch=gen_batch, return_raw_rewards=True)
401
 
model_setup.py CHANGED
@@ -12,7 +12,7 @@ from types import FrameType
12
  from contextlib import nullcontext
13
 
14
  import transformers
15
- from constants import HF_TOKEN, HF_CACHE_DIR
16
  import hydra
17
  import hydra.utils
18
  import torch
@@ -599,6 +599,12 @@ def set_accelerator(self, accelerator, ckpt_path=None):
599
 
600
  def _load(obj, path, update_fn=None, key="model"):
601
  _ckpt_path = Path(path)
 
 
 
 
 
 
602
  if _ckpt_path.is_dir() and (_ckpt_path / "model.safetensors").exists():
603
  _ckpt_path = _ckpt_path / "model.safetensors"
604
  path = str(_ckpt_path)
@@ -635,7 +641,7 @@ def set_accelerator(self, accelerator, ckpt_path=None):
635
  gprint(f"Loaded state dict from {path}")
636
  # obj.load_state_dict(state_dict[key])
637
  else:
638
- state_dict = torch.load(path)
639
 
640
  if 'model' in state_dict and len(state_dict) < 10:
641
  state_dict = state_dict['model']
 
12
  from contextlib import nullcontext
13
 
14
  import transformers
15
+ from constants import HF_TOKEN, HF_CACHE_DIR, UNIDISC_DIR
16
  import hydra
17
  import hydra.utils
18
  import torch
 
599
 
600
  def _load(obj, path, update_fn=None, key="model"):
601
  _ckpt_path = Path(path)
602
+
603
+ if not _ckpt_path.is_absolute() and not _ckpt_path.exists():
604
+ potential_path = UNIDISC_DIR / _ckpt_path
605
+ rprint(f"Relative path '{_ckpt_path}' not found. Trying path relative to script directory: '{potential_path}'")
606
+ _ckpt_path = potential_path
607
+
608
  if _ckpt_path.is_dir() and (_ckpt_path / "model.safetensors").exists():
609
  _ckpt_path = _ckpt_path / "model.safetensors"
610
  path = str(_ckpt_path)
 
641
  gprint(f"Loaded state dict from {path}")
642
  # obj.load_state_dict(state_dict[key])
643
  else:
644
+ state_dict = torch.load(_ckpt_path)
645
 
646
  if 'model' in state_dict and len(state_dict) < 10:
647
  state_dict = state_dict['model']