Fixed demo instructions & yaml config
Browse files- README.md +11 -16
- configs/config.yaml +3 -97
- configs/config_empty.yaml +0 -8
- configs/slurm_example.yaml +94 -0
- demo/assets/boat.jpg +3 -0
- demo/assets/building.jpg +3 -0
- demo/assets/dog.jpg +3 -0
- demo/assets/dog_grass.jpg +3 -0
- demo/assets/mountain.jpg +3 -0
- demo/assets/pickup.jpg +3 -0
- demo/assets/tajmahal.jpg +3 -0
- demo/assets/venice.jpg +3 -0
- demo/client.py +9 -5
- demo/inference.py +10 -1
- model_setup.py +8 -2
README.md
CHANGED
@@ -45,19 +45,13 @@ See [TRAIN.md](docs/TRAIN.md) for training commands.
|
|
45 |
|
46 |
## Inference
|
47 |
|
48 |
-
<!-- Inference demo for **TODO**.
|
49 |
-
```
|
50 |
-
TODO
|
51 |
-
``` -->
|
52 |
-
<!-- <img src="docs/todo.png" width="1000"> -->
|
53 |
-
|
54 |
-
|
55 |
Interactive demo:
|
|
|
|
|
|
|
|
|
|
|
56 |
```
|
57 |
-
python demo/server.py
|
58 |
-
python demo/client_simple_fasthtml.py
|
59 |
-
```
|
60 |
-
|
61 |
|
62 |
## Training
|
63 |
|
@@ -71,11 +65,12 @@ See [EVAL.md](docs/EVAL.md) for details.
|
|
71 |
### Citation
|
72 |
To cite our work, please use the following:
|
73 |
```
|
74 |
-
@article{
|
75 |
-
title={
|
76 |
-
author={
|
77 |
-
journal={arXiv preprint arXiv:
|
78 |
-
year={
|
|
|
79 |
}
|
80 |
```
|
81 |
|
|
|
45 |
|
46 |
## Inference
|
47 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
Interactive demo:
|
49 |
+
```bash
|
50 |
+
mkdir -p ./ckpts/unidisc_interleaved
|
51 |
+
huggingface-cli download aswerdlow/unidisc_interleaved --local-dir ./ckpts/unidisc_interleaved
|
52 |
+
uv run demo/server.py experiments='[large_scale_train,large_scale_train_high_res_interleaved,eval_unified,large_scale_high_res_interleaved_inference]' trainer.load_from_state_dict="./ckpts/unidisc_interleaved/unidisc_interleaved.pt"
|
53 |
+
uv run demo/client.py
|
54 |
```
|
|
|
|
|
|
|
|
|
55 |
|
56 |
## Training
|
57 |
|
|
|
65 |
### Citation
|
66 |
To cite our work, please use the following:
|
67 |
```
|
68 |
+
@article{swerdlow2025unidisc,
|
69 |
+
title = {Unified Multimodal Discrete Diffusion},
|
70 |
+
author = {Swerdlow, Alexander and Prabhudesai, Mihir and Gandhi, Siddharth and Pathak, Deepak and Fragkiadaki, Katerina},
|
71 |
+
journal = {arXiv preprint arXiv:2503.20853},
|
72 |
+
year = {2025},
|
73 |
+
doi = {10.48550/arXiv.2503.20853},
|
74 |
}
|
75 |
```
|
76 |
|
configs/config.yaml
CHANGED
@@ -293,103 +293,6 @@ hydra:
|
|
293 |
subdir: ${hydra.job.id}
|
294 |
job:
|
295 |
chdir: true
|
296 |
-
# launcher:
|
297 |
-
# name: ${get_slurm_name:}
|
298 |
-
# # See https://hydra.cc/docs/configure_hydra/workdir/
|
299 |
-
# submitit_folder: ${hydra.sweep.dir}/%j
|
300 |
-
# nodes: ${nodes} # Number of nodes. This value is *per* node
|
301 |
-
# mem_gb: ${eval:'${mem_per_gpu} * ${trainer.devices}'} # 40GB per gpu. This value is *per* node
|
302 |
-
# gpus_per_node: ${trainer.devices}
|
303 |
-
# partition: ${partition}
|
304 |
-
# constraint: ${constraint}
|
305 |
-
# exclude: ${exclude_nodes:}
|
306 |
-
|
307 |
-
# timeout_min: ${timeout_min}
|
308 |
-
# max_num_timeout: 12 # Num requeue exlcuding pre-emptions
|
309 |
-
# comment: aswerdlo
|
310 |
-
# stderr_to_stdout: true
|
311 |
-
|
312 |
-
# # Be careful with changing anything below.
|
313 |
-
# # see: https://github.com/stas00/ml-engineering/tree/master/training/fault-tolerance#approach-b2-choosing-which-process-to-send-the-signal-to
|
314 |
-
# # see: https://github.com/huggingface/accelerate/issues/1918
|
315 |
-
|
316 |
-
# # The accelerate launcher w/1 initial process and then spawn 1 per GPU
|
317 |
-
# tasks_per_node: 1
|
318 |
-
# cpus_per_task: ${eval:'${cpus_per_gpu} * ${trainer.devices}'}
|
319 |
-
# python: |
|
320 |
-
# bash -c "torchrun --nnodes $SLURM_NNODES --nproc_per_node $SLURM_GPUS_PER_NODE --role \$(hostname -s|tr -dc '0-9'): --node_rank \$SLURM_PROCID --max-restarts=2 --rdzv_id $RANDOM --rdzv_backend c10d --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \
|
321 |
-
|
322 |
-
# # python: "${getpythoncmd:}"
|
323 |
-
# # tasks_per_node: ${devices}
|
324 |
-
# # cpus_per_task: 8
|
325 |
-
# # python: 'python'
|
326 |
-
|
327 |
-
# python_suffix: ' --dummy-arg $SLURM_JOB_ID" &'
|
328 |
-
# signal: 'B:USR2@360'
|
329 |
-
# post_srun_commands:
|
330 |
-
# - ''
|
331 |
-
# - wait
|
332 |
-
|
333 |
-
# srun_args:
|
334 |
-
# - '--jobid $SLURM_JOB_ID'
|
335 |
-
|
336 |
-
# setup:
|
337 |
-
# - |
|
338 |
-
# export MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
|
339 |
-
# export MASTER_PORT=$(( ($SLURM_JOB_ID % 20001) + 30000 ))
|
340 |
-
# export NUM_PROCESSES=$((SLURM_NNODES * SLURM_GPUS_PER_NODE))
|
341 |
-
# export NCCL_DEBUG=INFO
|
342 |
-
# export NCCL_NSOCKS_PERTHREAD=4
|
343 |
-
# export NCCL_SOCKET_NTHREADS=2
|
344 |
-
# export OMP_NUM_THREADS=2
|
345 |
-
# export PYTHONUNBUFFERED=1
|
346 |
-
# export STDOUT_PATH=$(scontrol show job $SLURM_JOB_ID | grep -oP "StdOut=\K[^ ]+")
|
347 |
-
# export LOCAL_JOB_FOLDER=$(dirname $STDOUT_PATH)
|
348 |
-
# export NCCL_TOPO_DUMP_FILE="$LOCAL_JOB_FOLDER/nccl_topo.xml"
|
349 |
-
# if [ -n "$SLURM_RESTART_COUNT" ]; then
|
350 |
-
# export RESTART_COUNT=$SLURM_RESTART_COUNT
|
351 |
-
# else
|
352 |
-
# export RESTART_COUNT=0
|
353 |
-
# fi
|
354 |
-
# export MAIN_LOG_PATH="$LOCAL_JOB_FOLDER/log_$RESTART_COUNT.txt"
|
355 |
-
|
356 |
-
# mkdir -p $LOCAL_JOB_FOLDER
|
357 |
-
# printenv > "$LOCAL_JOB_FOLDER"/env_"$SLURM_LOCALID_$RESTART_COUNT.txt"
|
358 |
-
|
359 |
-
# echo "ibstatus: $(ibstatus)"
|
360 |
-
# echo "ibdev2netdev: $(ibdev2netdev)"
|
361 |
-
# echo "rdma device: $(rdma link)"
|
362 |
-
# echo "environment: $(env | grep NCCL)"
|
363 |
-
# echo "NUM_PROCESSES: $NUM_PROCESSES, SLURM_NNODES: $SLURM_NNODES SLURM_GPUS_PER_NODE: $SLURM_GPUS_PER_NODE"
|
364 |
-
# echo "NODE_ID: $SLURM_NODEID, SLURM_PROCID: $SLURM_PROCID, MASTER_ADDR: $MASTER_ADDR, MASTER_PORT: $MASTER_PORT"
|
365 |
-
# echo "PWD: $PWD, LOCAL_JOB_FOLDER: $LOCAL_JOB_FOLDER, MAIN_LOG_PATH: $MAIN_LOG_PATH"
|
366 |
-
|
367 |
-
# trap 'echo "SIGUSR2 received for $SLURM_JOB_ID"; \
|
368 |
-
# if [ -n "$SLURM_ARRAY_JOB_ID" ]; then echo "SLURM_ARRAY_JOB_ID: $SLURM_ARRAY_JOB_ID"; fi; \
|
369 |
-
# if [ -n "$SLURM_ARRAY_TASK_ID" ]; then echo "SLURM_ARRAY_TASK_ID: $SLURM_ARRAY_TASK_ID"; fi; \
|
370 |
-
# # ps auxww | grep $USER; \
|
371 |
-
# pid=$(pgrep -u $USER -f "python.*(accelerate|torchrun|deepspeed|distributed\.run).*dummy-arg $SLURM_JOB_ID"); \
|
372 |
-
# echo "Found parent PIDs: $pid"; \
|
373 |
-
# for p in $pid; do \
|
374 |
-
# echo "Parent PID has cmd: $(ps -p $p -o cmd=)"; \
|
375 |
-
# children=$(pgrep -P $p); \
|
376 |
-
# echo "Children: $children"; \
|
377 |
-
# if [ -n "$children" ]; then \
|
378 |
-
# for child in $children; do \
|
379 |
-
# ppid=$(ps -o ppid= -p $child | tr -d " ")
|
380 |
-
# if [ "$ppid" -eq "$p" ]; then
|
381 |
-
# echo "Killing direct child process: PID $child with cmd: $(ps -p $child -o cmd=)"
|
382 |
-
# kill -USR2 $child &
|
383 |
-
# else
|
384 |
-
# echo "Skipping non-direct child process: PID $child with PPID $ppid"
|
385 |
-
# fi
|
386 |
-
# done; \
|
387 |
-
# echo "Sent kill signals to children of $p"; \
|
388 |
-
# else \
|
389 |
-
# echo "No children found for $p"; \
|
390 |
-
# fi; \
|
391 |
-
# done; \
|
392 |
-
# wait;' SIGUSR2
|
393 |
|
394 |
checkpointing:
|
395 |
# Use custom `save_dir` if, e.g., saving to S3 bucket, otherwise leave this parameter as is
|
@@ -447,5 +350,8 @@ data:
|
|
447 |
add_image_gen_tokens: false
|
448 |
use_slow_tokenizer: false
|
449 |
add_image_token: false
|
|
|
|
|
|
|
450 |
|
451 |
dummyarg: null
|
|
|
293 |
subdir: ${hydra.job.id}
|
294 |
job:
|
295 |
chdir: true
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
296 |
|
297 |
checkpointing:
|
298 |
# Use custom `save_dir` if, e.g., saving to S3 bucket, otherwise leave this parameter as is
|
|
|
350 |
add_image_gen_tokens: false
|
351 |
use_slow_tokenizer: false
|
352 |
add_image_token: false
|
353 |
+
train: "unset_dataset"
|
354 |
+
val: "unset_dataset"
|
355 |
+
tokenizer_name_or_path: "NousResearch/Llama-2-7b-hf"
|
356 |
|
357 |
dummyarg: null
|
configs/config_empty.yaml
DELETED
@@ -1,8 +0,0 @@
|
|
1 |
-
defaults:
|
2 |
-
- _self_
|
3 |
-
- /model: small
|
4 |
-
- /experiments: []
|
5 |
-
|
6 |
-
# from omegaconf import OmegaConf
|
7 |
-
# with open("config.yaml", "w") as fp:
|
8 |
-
# OmegaConf.save(config=config, f=fp.name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configs/slurm_example.yaml
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This is an example slurm launcher config that should be added to the main config.yaml file under the hydra section. This cannot be run directly.
|
2 |
+
hydra:
|
3 |
+
launcher:
|
4 |
+
name: ${get_slurm_name:}
|
5 |
+
# See https://hydra.cc/docs/configure_hydra/workdir/
|
6 |
+
submitit_folder: ${hydra.sweep.dir}/%j
|
7 |
+
nodes: ${nodes} # Number of nodes. This value is *per* node
|
8 |
+
mem_gb: ${eval:'${mem_per_gpu} * ${trainer.devices}'} # 40GB per gpu. This value is *per* node
|
9 |
+
gpus_per_node: ${trainer.devices}
|
10 |
+
partition: ${partition}
|
11 |
+
constraint: ${constraint}
|
12 |
+
exclude: ${exclude_nodes:}
|
13 |
+
|
14 |
+
timeout_min: ${timeout_min}
|
15 |
+
max_num_timeout: 12 # Num requeue exlcuding pre-emptions
|
16 |
+
comment: aswerdlo
|
17 |
+
stderr_to_stdout: true
|
18 |
+
|
19 |
+
# Be careful with changing anything below.
|
20 |
+
# see: https://github.com/stas00/ml-engineering/tree/master/training/fault-tolerance#approach-b2-choosing-which-process-to-send-the-signal-to
|
21 |
+
# see: https://github.com/huggingface/accelerate/issues/1918
|
22 |
+
|
23 |
+
# The accelerate launcher w/1 initial process and then spawn 1 per GPU
|
24 |
+
tasks_per_node: 1
|
25 |
+
cpus_per_task: ${eval:'${cpus_per_gpu} * ${trainer.devices}'}
|
26 |
+
python: |
|
27 |
+
bash -c "torchrun --nnodes $SLURM_NNODES --nproc_per_node $SLURM_GPUS_PER_NODE --role \$(hostname -s|tr -dc '0-9'): --node_rank \$SLURM_PROCID --max-restarts=2 --rdzv_id $RANDOM --rdzv_backend c10d --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \
|
28 |
+
|
29 |
+
python_suffix: ' --dummy-arg $SLURM_JOB_ID" &'
|
30 |
+
signal: 'B:USR2@360'
|
31 |
+
post_srun_commands:
|
32 |
+
- ''
|
33 |
+
- wait
|
34 |
+
|
35 |
+
srun_args:
|
36 |
+
- '--jobid $SLURM_JOB_ID'
|
37 |
+
|
38 |
+
setup:
|
39 |
+
- |
|
40 |
+
export MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
|
41 |
+
export MASTER_PORT=$(( ($SLURM_JOB_ID % 20001) + 30000 ))
|
42 |
+
export NUM_PROCESSES=$((SLURM_NNODES * SLURM_GPUS_PER_NODE))
|
43 |
+
export NCCL_DEBUG=INFO
|
44 |
+
export NCCL_NSOCKS_PERTHREAD=4
|
45 |
+
export NCCL_SOCKET_NTHREADS=2
|
46 |
+
export OMP_NUM_THREADS=2
|
47 |
+
export PYTHONUNBUFFERED=1
|
48 |
+
export STDOUT_PATH=$(scontrol show job $SLURM_JOB_ID | grep -oP "StdOut=\K[^ ]+")
|
49 |
+
export LOCAL_JOB_FOLDER=$(dirname $STDOUT_PATH)
|
50 |
+
export NCCL_TOPO_DUMP_FILE="$LOCAL_JOB_FOLDER/nccl_topo.xml"
|
51 |
+
if [ -n "$SLURM_RESTART_COUNT" ]; then
|
52 |
+
export RESTART_COUNT=$SLURM_RESTART_COUNT
|
53 |
+
else
|
54 |
+
export RESTART_COUNT=0
|
55 |
+
fi
|
56 |
+
export MAIN_LOG_PATH="$LOCAL_JOB_FOLDER/log_$RESTART_COUNT.txt"
|
57 |
+
|
58 |
+
mkdir -p $LOCAL_JOB_FOLDER
|
59 |
+
printenv > "$LOCAL_JOB_FOLDER"/env_"$SLURM_LOCALID_$RESTART_COUNT.txt"
|
60 |
+
|
61 |
+
echo "ibstatus: $(ibstatus)"
|
62 |
+
echo "ibdev2netdev: $(ibdev2netdev)"
|
63 |
+
echo "rdma device: $(rdma link)"
|
64 |
+
echo "environment: $(env | grep NCCL)"
|
65 |
+
echo "NUM_PROCESSES: $NUM_PROCESSES, SLURM_NNODES: $SLURM_NNODES SLURM_GPUS_PER_NODE: $SLURM_GPUS_PER_NODE"
|
66 |
+
echo "NODE_ID: $SLURM_NODEID, SLURM_PROCID: $SLURM_PROCID, MASTER_ADDR: $MASTER_ADDR, MASTER_PORT: $MASTER_PORT"
|
67 |
+
echo "PWD: $PWD, LOCAL_JOB_FOLDER: $LOCAL_JOB_FOLDER, MAIN_LOG_PATH: $MAIN_LOG_PATH"
|
68 |
+
|
69 |
+
trap 'echo "SIGUSR2 received for $SLURM_JOB_ID"; \
|
70 |
+
if [ -n "$SLURM_ARRAY_JOB_ID" ]; then echo "SLURM_ARRAY_JOB_ID: $SLURM_ARRAY_JOB_ID"; fi; \
|
71 |
+
if [ -n "$SLURM_ARRAY_TASK_ID" ]; then echo "SLURM_ARRAY_TASK_ID: $SLURM_ARRAY_TASK_ID"; fi; \
|
72 |
+
# ps auxww | grep $USER; \
|
73 |
+
pid=$(pgrep -u $USER -f "python.*(accelerate|torchrun|deepspeed|distributed\.run).*dummy-arg $SLURM_JOB_ID"); \
|
74 |
+
echo "Found parent PIDs: $pid"; \
|
75 |
+
for p in $pid; do \
|
76 |
+
echo "Parent PID has cmd: $(ps -p $p -o cmd=)"; \
|
77 |
+
children=$(pgrep -P $p); \
|
78 |
+
echo "Children: $children"; \
|
79 |
+
if [ -n "$children" ]; then \
|
80 |
+
for child in $children; do \
|
81 |
+
ppid=$(ps -o ppid= -p $child | tr -d " ")
|
82 |
+
if [ "$ppid" -eq "$p" ]; then
|
83 |
+
echo "Killing direct child process: PID $child with cmd: $(ps -p $child -o cmd=)"
|
84 |
+
kill -USR2 $child &
|
85 |
+
else
|
86 |
+
echo "Skipping non-direct child process: PID $child with PPID $ppid"
|
87 |
+
fi
|
88 |
+
done; \
|
89 |
+
echo "Sent kill signals to children of $p"; \
|
90 |
+
else \
|
91 |
+
echo "No children found for $p"; \
|
92 |
+
fi; \
|
93 |
+
done; \
|
94 |
+
wait;' SIGUSR2
|
demo/assets/boat.jpg
ADDED
![]() |
Git LFS Details
|
demo/assets/building.jpg
ADDED
![]() |
Git LFS Details
|
demo/assets/dog.jpg
ADDED
![]() |
Git LFS Details
|
demo/assets/dog_grass.jpg
ADDED
![]() |
Git LFS Details
|
demo/assets/mountain.jpg
ADDED
![]() |
Git LFS Details
|
demo/assets/pickup.jpg
ADDED
![]() |
Git LFS Details
|
demo/assets/tajmahal.jpg
ADDED
![]() |
Git LFS Details
|
demo/assets/venice.jpg
ADDED
![]() |
Git LFS Details
|
demo/client.py
CHANGED
@@ -571,16 +571,17 @@ def post(
|
|
571 |
port: int | None = 8001,
|
572 |
reward_models: str | None = "False"
|
573 |
):
|
574 |
-
|
575 |
if user_input:
|
576 |
-
|
577 |
|
|
|
578 |
current_image = None
|
579 |
if uploaded_file is not None and uploaded_file.filename != "No image":
|
580 |
current_image = process(Image.open(io.BytesIO(uploaded_file.file.read())), int(resolution))
|
581 |
img_data = encode_image(current_image)["url"]
|
582 |
|
583 |
-
|
584 |
"type": "image_url",
|
585 |
"image_url": {"url": img_data},
|
586 |
"is_mask": False
|
@@ -589,12 +590,15 @@ def post(
|
|
589 |
if mask_data is not None and len(mask_data) > 0:
|
590 |
mask_array = get_boolean_mask(mask_data)
|
591 |
mask_data_url = encode_array_image(mask_array)["url"]
|
592 |
-
|
593 |
"type": "image_url",
|
594 |
"image_url": {"url": mask_data_url},
|
595 |
"is_mask": True
|
596 |
})
|
597 |
|
|
|
|
|
|
|
598 |
config_payload = {
|
599 |
"max_tokens": int(max_tokens),
|
600 |
"resolution": int(resolution),
|
@@ -608,7 +612,7 @@ def post(
|
|
608 |
}
|
609 |
|
610 |
payload = {
|
611 |
-
"messages":
|
612 |
"model": "unidisc",
|
613 |
**config_payload
|
614 |
}
|
|
|
571 |
port: int | None = 8001,
|
572 |
reward_models: str | None = "False"
|
573 |
):
|
574 |
+
payload_messages = []
|
575 |
if user_input:
|
576 |
+
payload_messages.append({"role": "user", "content": [{"type": "text", "text": user_input}]})
|
577 |
|
578 |
+
image_message_content = []
|
579 |
current_image = None
|
580 |
if uploaded_file is not None and uploaded_file.filename != "No image":
|
581 |
current_image = process(Image.open(io.BytesIO(uploaded_file.file.read())), int(resolution))
|
582 |
img_data = encode_image(current_image)["url"]
|
583 |
|
584 |
+
image_message_content.append({
|
585 |
"type": "image_url",
|
586 |
"image_url": {"url": img_data},
|
587 |
"is_mask": False
|
|
|
590 |
if mask_data is not None and len(mask_data) > 0:
|
591 |
mask_array = get_boolean_mask(mask_data)
|
592 |
mask_data_url = encode_array_image(mask_array)["url"]
|
593 |
+
image_message_content.append({
|
594 |
"type": "image_url",
|
595 |
"image_url": {"url": mask_data_url},
|
596 |
"is_mask": True
|
597 |
})
|
598 |
|
599 |
+
if image_message_content:
|
600 |
+
payload_messages.append({"role": "assistant", "content": image_message_content})
|
601 |
+
|
602 |
config_payload = {
|
603 |
"max_tokens": int(max_tokens),
|
604 |
"resolution": int(resolution),
|
|
|
612 |
}
|
613 |
|
614 |
payload = {
|
615 |
+
"messages": payload_messages,
|
616 |
"model": "unidisc",
|
617 |
**config_payload
|
618 |
}
|
demo/inference.py
CHANGED
@@ -386,7 +386,16 @@ def inference(
|
|
386 |
disable_mask_after_eos=True
|
387 |
)
|
388 |
|
389 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
390 |
reward_config = config.eval.auto_enhance_reward_config
|
391 |
rewards, raw_rewards = model.get_rewards(reward_config, img_samples_list, text_samples_list, batch=gen_batch, return_raw_rewards=True)
|
392 |
|
|
|
386 |
disable_mask_after_eos=True
|
387 |
)
|
388 |
|
389 |
+
text_samples_list = [x.replace("You are a highly intelligent multimodal AI with the ability to analyze and generate images.", "").removeprefix(" ") for x in text_samples_list]
|
390 |
+
if isinstance(img_samples_list[0], Image.Image):
|
391 |
+
img_tensors = []
|
392 |
+
for img in img_samples_list:
|
393 |
+
img_tensor = torch.tensor(np.array(img)).permute(2, 0, 1).float() / 255.0
|
394 |
+
img_tensors.append(img_tensor.unsqueeze(0))
|
395 |
+
img_samples_list = torch.cat(img_tensors, dim=0)
|
396 |
+
else:
|
397 |
+
img_samples_list = torch.cat(img_samples_list, dim=0)
|
398 |
+
|
399 |
reward_config = config.eval.auto_enhance_reward_config
|
400 |
rewards, raw_rewards = model.get_rewards(reward_config, img_samples_list, text_samples_list, batch=gen_batch, return_raw_rewards=True)
|
401 |
|
model_setup.py
CHANGED
@@ -12,7 +12,7 @@ from types import FrameType
|
|
12 |
from contextlib import nullcontext
|
13 |
|
14 |
import transformers
|
15 |
-
from constants import HF_TOKEN, HF_CACHE_DIR
|
16 |
import hydra
|
17 |
import hydra.utils
|
18 |
import torch
|
@@ -599,6 +599,12 @@ def set_accelerator(self, accelerator, ckpt_path=None):
|
|
599 |
|
600 |
def _load(obj, path, update_fn=None, key="model"):
|
601 |
_ckpt_path = Path(path)
|
|
|
|
|
|
|
|
|
|
|
|
|
602 |
if _ckpt_path.is_dir() and (_ckpt_path / "model.safetensors").exists():
|
603 |
_ckpt_path = _ckpt_path / "model.safetensors"
|
604 |
path = str(_ckpt_path)
|
@@ -635,7 +641,7 @@ def set_accelerator(self, accelerator, ckpt_path=None):
|
|
635 |
gprint(f"Loaded state dict from {path}")
|
636 |
# obj.load_state_dict(state_dict[key])
|
637 |
else:
|
638 |
-
state_dict = torch.load(
|
639 |
|
640 |
if 'model' in state_dict and len(state_dict) < 10:
|
641 |
state_dict = state_dict['model']
|
|
|
12 |
from contextlib import nullcontext
|
13 |
|
14 |
import transformers
|
15 |
+
from constants import HF_TOKEN, HF_CACHE_DIR, UNIDISC_DIR
|
16 |
import hydra
|
17 |
import hydra.utils
|
18 |
import torch
|
|
|
599 |
|
600 |
def _load(obj, path, update_fn=None, key="model"):
|
601 |
_ckpt_path = Path(path)
|
602 |
+
|
603 |
+
if not _ckpt_path.is_absolute() and not _ckpt_path.exists():
|
604 |
+
potential_path = UNIDISC_DIR / _ckpt_path
|
605 |
+
rprint(f"Relative path '{_ckpt_path}' not found. Trying path relative to script directory: '{potential_path}'")
|
606 |
+
_ckpt_path = potential_path
|
607 |
+
|
608 |
if _ckpt_path.is_dir() and (_ckpt_path / "model.safetensors").exists():
|
609 |
_ckpt_path = _ckpt_path / "model.safetensors"
|
610 |
path = str(_ckpt_path)
|
|
|
641 |
gprint(f"Loaded state dict from {path}")
|
642 |
# obj.load_state_dict(state_dict[key])
|
643 |
else:
|
644 |
+
state_dict = torch.load(_ckpt_path)
|
645 |
|
646 |
if 'model' in state_dict and len(state_dict) < 10:
|
647 |
state_dict = state_dict['model']
|