Upload 15 files
Browse files- .gitattributes +10 -0
- README.md +155 -0
- assets/data_engine.jpg +3 -0
- assets/gradio.jpg +3 -0
- assets/logo.png +3 -0
- assets/model.jpg +3 -0
- assets/teaser_example.jpg +3 -0
- demo/example_images/demo_dog.jpg +3 -0
- demo/example_images/demo_helmet.png +3 -0
- demo/example_images/demo_letter.jpg +0 -0
- demo/example_images/demo_output.jpg +3 -0
- demo/example_images/demo_person.jpg +3 -0
- demo/example_images/demo_tomato.jpg +3 -0
- demo/gradio_demo.py +319 -0
- demo/inference_single_image.py +197 -0
- requirements.txt +24 -0
.gitattributes
CHANGED
@@ -34,3 +34,13 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
37 |
+
assets/data_engine.jpg filter=lfs diff=lfs merge=lfs -text
|
38 |
+
assets/gradio.jpg filter=lfs diff=lfs merge=lfs -text
|
39 |
+
assets/logo.png filter=lfs diff=lfs merge=lfs -text
|
40 |
+
assets/model.jpg filter=lfs diff=lfs merge=lfs -text
|
41 |
+
assets/teaser_example.jpg filter=lfs diff=lfs merge=lfs -text
|
42 |
+
demo/example_images/demo_dog.jpg filter=lfs diff=lfs merge=lfs -text
|
43 |
+
demo/example_images/demo_helmet.png filter=lfs diff=lfs merge=lfs -text
|
44 |
+
demo/example_images/demo_output.jpg filter=lfs diff=lfs merge=lfs -text
|
45 |
+
demo/example_images/demo_person.jpg filter=lfs diff=lfs merge=lfs -text
|
46 |
+
demo/example_images/demo_tomato.jpg filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
<div align=center>
|
3 |
+
<img src="assets/logo.png" width=300 >
|
4 |
+
</div>
|
5 |
+
|
6 |
+
# 🦖🧠 Rex-Thinker: Grounded Object Refering via Chain-of-Thought Reasoning 🦖🧠
|
7 |
+
|
8 |
+
<div align=center>
|
9 |
+
|
10 |
+
<p align="center">
|
11 |
+
<a href="https://bagel-ai.org/">
|
12 |
+
<img
|
13 |
+
src="https://img.shields.io/badge/RexThinker-Website-Red?logo=afdian&logoColor=white&color=blue"
|
14 |
+
alt="RexThinker Website"
|
15 |
+
/>
|
16 |
+
</a>
|
17 |
+
<a href="https://arxiv.org/abs/2505.14683">
|
18 |
+
<img
|
19 |
+
src="https://img.shields.io/badge/RexThinker-Paper-Red%25red?logo=arxiv&logoColor=red&color=yellow
|
20 |
+
"
|
21 |
+
alt="RexThinker Paper on arXiv"
|
22 |
+
/>
|
23 |
+
</a>
|
24 |
+
<a href="https://huggingface.co/ByteDance-Seed/BAGEL-7B-MoT">
|
25 |
+
<img
|
26 |
+
src="https://img.shields.io/badge/RexThinker-Weight-orange?logo=huggingface&logoColor=yellow"
|
27 |
+
alt="RexThinker weight on Hugging Face"
|
28 |
+
/>
|
29 |
+
</a>
|
30 |
+
<a href="https://demo.bagel-ai.org/">
|
31 |
+
<img
|
32 |
+
src="https://img.shields.io/badge/RexThinker-Data-orange?logo=huggingface&logoColor=yellow"
|
33 |
+
alt="RexThinker data on Hugging Face"
|
34 |
+
/>
|
35 |
+
</a>
|
36 |
+
|
37 |
+
</p>
|
38 |
+
|
39 |
+
</div>
|
40 |
+
|
41 |
+
> We propose Rex-Thinker, a Chain-of-Thought (CoT) reasoning model for object referring that addresses two key challenges: lack of interpretability and inability to reject unmatched expressions. Instead of directly predicting bounding boxes, Rex-Thinker reasons step-by-step over candidate objects to determine which, if any, match a given expression. Rex-Thinker is trained in two stages: supervised fine-tuning to learn structured CoT reasoning, followed by reinforcement learning with GRPO to enhance accuracy, faithfulness, and generalization. Our approach improves both prediction precision and interpretability, while enabling the model to abstain when no suitable object is found. Below is an example of the model's reasoning process:
|
42 |
+
|
43 |
+
<p align="center"><img src="assets/teaser_example.jpg" width="95%"></p>
|
44 |
+
|
45 |
+
|
46 |
+
## Method
|
47 |
+
|
48 |
+
**Rex-Thinker** reformulates object referring as a **Chain-of-Thought (CoT)** reasoning task to improve both interpretability and reliability. The model follows a structured three-stage reasoning paradigm:
|
49 |
+
|
50 |
+
1. **Planning**: Decompose the referring expression into interpretable subgoals.
|
51 |
+
|
52 |
+
2. **Action**: Evaluate each candidate object (obtained via an open-vocabulary detector) against these subgoals using step-by-step reasoning.
|
53 |
+
|
54 |
+
3. **Summarization**: Aggregate the intermediate results to output the final prediction — or abstain when no object matches.
|
55 |
+
|
56 |
+
Each reasoning step is grounded in a specific candidate object region through **Box Hints**, making the process transparent and verifiable.
|
57 |
+
|
58 |
+
Rex-Thinker is implemented on top of **QwenVL-2.5**, and trained in two stages:
|
59 |
+
|
60 |
+
- **Supervised Fine-Tuning (SFT)**
|
61 |
+
Cold-start training using GPT-4o-generated CoT traces as supervision.
|
62 |
+
|
63 |
+
- **GRPO-based Reinforcement Learning**
|
64 |
+
Further optimizes reasoning accuracy, generalization, and rejection ability via Group Relative Policy Optimization.
|
65 |
+
|
66 |
+
This CoT-based framework enables Rex-Thinker to make faithful, interpretable predictions while generalizing well to out-of-domain referring scenarios.
|
67 |
+
|
68 |
+
|
69 |
+
<p align="center"><img src="assets/model.jpg" width="95%"></p>
|
70 |
+
|
71 |
+
|
72 |
+
|
73 |
+
## 1. Installation ⛳️
|
74 |
+
|
75 |
+
```bash
|
76 |
+
conda create -n rexthinker -m python=3.10
|
77 |
+
pip install torch==2.6.0 torchvision==0.21.0 --index-url https://download.pytorch.org/whl/cu124
|
78 |
+
pip install -v -e .
|
79 |
+
|
80 |
+
# additional packages Grounding DINO
|
81 |
+
git clone https://github.com/IDEA-Research/GroundingDINO.git
|
82 |
+
cd GroundingDINO
|
83 |
+
## To support torch2.6
|
84 |
+
git remote add quantumope https://github.com/QuantuMope/GroundingDINO.git
|
85 |
+
git fetch quantumope PR/andrew/add-torch26-support-ms-deform-attn
|
86 |
+
git merge quantumope/PR/andrew/add-torch26-support-ms-deform-attn
|
87 |
+
## Continue with installation
|
88 |
+
pip install -v -e .
|
89 |
+
mkdir weights
|
90 |
+
wget -q https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth -P weights
|
91 |
+
cd ..
|
92 |
+
```
|
93 |
+
|
94 |
+
### 1.1 Download Pre-trained Model
|
95 |
+
We provide the pre-trained model weights of Rex-Thinker-GRPO, which is trained on HumanRef-CoT through SFT and GRPO. You can download the model weights from [Hugging Face](https://huggingface.co/IDEA-Research/Rex-Thinker-GRPO-7B).
|
96 |
+
|
97 |
+
Or you can also using the following command to download the pre-trained models:
|
98 |
+
```bash
|
99 |
+
git lfs install
|
100 |
+
git clone https://huggingface.co/IDEA-Research/Rex-Thinker-GRPO-7B IDEA-Research/Rex-Thinker-GRPO-7B
|
101 |
+
```
|
102 |
+
|
103 |
+
## 2. Inference 🚀
|
104 |
+
We provide a simple inference script to test the model. In this script, we use Grouning DINO to get the candidate boxes. You can run the following command to test the model:
|
105 |
+
|
106 |
+
```bash
|
107 |
+
CUDA_VISIBLE_DEVICES=0 python demo/inference_single_image.py \
|
108 |
+
--image_path demo/example_images/demo_helmet.png \
|
109 |
+
--cate_name helmet \
|
110 |
+
--ref_exp the forth helmet from left \
|
111 |
+
--vis_path vis/example_output.jpg
|
112 |
+
```
|
113 |
+
|
114 |
+
You will get output fromt the terminal like this:
|
115 |
+
```text
|
116 |
+
<think>OK, the user needs us to detect the fourth helmet from left. To accomplish this task, I need to break it down into the following steps:
|
117 |
+
- Step 1: Sort the helmets from left to right.
|
118 |
+
- Step 2: Find the fourth helmet from the sorted list.
|
119 |
+
|
120 |
+
# Step 1: Sort the helmets from left to right
|
121 |
+
I see 6 helmets in this image, and their order from left to right is [Helmet 5, Helmet 1, Helmet 3, Helmet 2, Helmet 4, Helmet 6].
|
122 |
+
|
123 |
+
# Step 2: Find the fourth helmet from the sorted list
|
124 |
+
From the sorted list [Helmet 5, Helmet 1, Helmet 3, Helmet 2, Helmet 4, Helmet 6], the fourth helmet from the left is Helmet 2.
|
125 |
+
|
126 |
+
# Summarize and Re-Check answer
|
127 |
+
Let’s now recheck our answer and put ✅ for the target helmet and ❌ for others
|
128 |
+
- Helmet 5: It is the first helmet from left → ❌
|
129 |
+
- Helmet 1: It is the second helmet from left → ❌
|
130 |
+
- Helmet 3: It is the third helmet from left → ❌
|
131 |
+
- Helmet 2: It is the fourth helmet from left → ✅
|
132 |
+
- Helmet 4: It is the fifth helmet from left → ❌
|
133 |
+
- Helmet 6: It is the sixth helmet from left → ❌</think><answer>json
|
134 |
+
[{"bbox_2d": [578, 359, 825, 580], "label": "the forth helmet from left"}]
|
135 |
+
</answer>
|
136 |
+
```
|
137 |
+
|
138 |
+
and visulized results like this:
|
139 |
+
<p align="center"><img src="demo/example_images/demo_output.jpg" width="80%"></p>
|
140 |
+
|
141 |
+
|
142 |
+
## 3. Gradio Demo 🤗
|
143 |
+
We provide a Gradio demo for you to test the model. You can run the following command to start the Gradio demo:
|
144 |
+
```bash
|
145 |
+
CUDA_VISIBLE_DEVICES=0 python demo/gradio_demo.py \
|
146 |
+
--model_path IDEA-Research/Rex-Thinker-GRPO-7B \
|
147 |
+
--server_ip 0.0.0.0 \
|
148 |
+
--server_port 7860
|
149 |
+
```
|
150 |
+
|
151 |
+
Then you can open your browser and visit `http://localhost:7860` to see the Gradio demo. You can input the image path, category name, and referring expression to test the model.
|
152 |
+
|
153 |
+
<p align="center"><img src="assets/gradio.jpg" width="95%"></p>
|
154 |
+
|
155 |
+
## Citation 📜
|
assets/data_engine.jpg
ADDED
![]() |
Git LFS Details
|
assets/gradio.jpg
ADDED
![]() |
Git LFS Details
|
assets/logo.png
ADDED
![]() |
Git LFS Details
|
assets/model.jpg
ADDED
![]() |
Git LFS Details
|
assets/teaser_example.jpg
ADDED
![]() |
Git LFS Details
|
demo/example_images/demo_dog.jpg
ADDED
![]() |
Git LFS Details
|
demo/example_images/demo_helmet.png
ADDED
![]() |
Git LFS Details
|
demo/example_images/demo_letter.jpg
ADDED
![]() |
demo/example_images/demo_output.jpg
ADDED
![]() |
Git LFS Details
|
demo/example_images/demo_person.jpg
ADDED
![]() |
Git LFS Details
|
demo/example_images/demo_tomato.jpg
ADDED
![]() |
Git LFS Details
|
demo/gradio_demo.py
ADDED
@@ -0,0 +1,319 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
|
4 |
+
import gradio as gr
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from groundingdino.util.inference import load_model
|
8 |
+
from PIL import Image
|
9 |
+
from qwen_vl_utils import process_vision_info
|
10 |
+
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
|
11 |
+
|
12 |
+
from tools.inference_tools import (
|
13 |
+
convert_boxes_from_absolute_to_qwen25_format,
|
14 |
+
inference_gdino,
|
15 |
+
postprocess_and_vis_inference_out,
|
16 |
+
)
|
17 |
+
|
18 |
+
|
19 |
+
def parse_args():
|
20 |
+
parser = argparse.ArgumentParser()
|
21 |
+
parser.add_argument(
|
22 |
+
"--model_path", type=str, default="IDEA-Research/Rex-Thinker-GRPO-7B"
|
23 |
+
)
|
24 |
+
parser.add_argument(
|
25 |
+
"--gdino_config",
|
26 |
+
type=str,
|
27 |
+
default="GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py",
|
28 |
+
)
|
29 |
+
parser.add_argument(
|
30 |
+
"--gdino_weights",
|
31 |
+
type=str,
|
32 |
+
default="GroundingDINO/weights/groundingdino_swint_ogc.pth",
|
33 |
+
)
|
34 |
+
parser.add_argument(
|
35 |
+
"--server_ip",
|
36 |
+
type=str,
|
37 |
+
default="0.0.0.0",
|
38 |
+
help="IP address to bind the server to",
|
39 |
+
)
|
40 |
+
parser.add_argument(
|
41 |
+
"--server_port",
|
42 |
+
type=int,
|
43 |
+
default=2512,
|
44 |
+
help="Port to run the server on",
|
45 |
+
)
|
46 |
+
return parser.parse_args()
|
47 |
+
|
48 |
+
|
49 |
+
def initialize_models(args):
|
50 |
+
# Load GDINO model
|
51 |
+
gdino_model = load_model(args.gdino_config, args.gdino_weights).to("cuda")
|
52 |
+
gdino_model.eval()
|
53 |
+
|
54 |
+
# Load Rex-Thinker-GRPO
|
55 |
+
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
56 |
+
args.model_path,
|
57 |
+
torch_dtype=torch.bfloat16,
|
58 |
+
attn_implementation="flash_attention_2",
|
59 |
+
device_map="auto",
|
60 |
+
)
|
61 |
+
processor = AutoProcessor.from_pretrained(
|
62 |
+
args.model_path, min_pixels=16 * 28 * 28, max_pixels=1280 * 28 * 28
|
63 |
+
)
|
64 |
+
|
65 |
+
return (gdino_model, processor, model)
|
66 |
+
|
67 |
+
|
68 |
+
def process_image(
|
69 |
+
image,
|
70 |
+
system_prompt,
|
71 |
+
cate_name,
|
72 |
+
referring_expression,
|
73 |
+
draw_width,
|
74 |
+
font_size,
|
75 |
+
gdino_model,
|
76 |
+
rexthinker_processor,
|
77 |
+
rexthinker_model,
|
78 |
+
):
|
79 |
+
if isinstance(image, str):
|
80 |
+
image = Image.open(image)
|
81 |
+
elif isinstance(image, np.ndarray):
|
82 |
+
image = Image.fromarray(image)
|
83 |
+
|
84 |
+
# Run GDINO inference
|
85 |
+
gdino_boxes = inference_gdino(
|
86 |
+
image,
|
87 |
+
[cate_name],
|
88 |
+
gdino_model,
|
89 |
+
TEXT_TRESHOLD=0.25,
|
90 |
+
BOX_TRESHOLD=0.25,
|
91 |
+
)
|
92 |
+
proposed_box = convert_boxes_from_absolute_to_qwen25_format(
|
93 |
+
gdino_boxes, image.width, image.height
|
94 |
+
)
|
95 |
+
|
96 |
+
hint = json.dumps(
|
97 |
+
{
|
98 |
+
f"{cate_name}": proposed_box,
|
99 |
+
}
|
100 |
+
)
|
101 |
+
question = f"Hint: Object and its coordinates in this image: {hint}\nPlease detect {referring_expression} in the image."
|
102 |
+
|
103 |
+
# compose input
|
104 |
+
print(f"system_prompt: {system_prompt}")
|
105 |
+
print(f"question: {question}")
|
106 |
+
messages = [
|
107 |
+
{
|
108 |
+
"role": "system",
|
109 |
+
"content": system_prompt,
|
110 |
+
},
|
111 |
+
{
|
112 |
+
"role": "user",
|
113 |
+
"content": [
|
114 |
+
{
|
115 |
+
"type": "image",
|
116 |
+
"image": image,
|
117 |
+
},
|
118 |
+
{"type": "text", "text": question},
|
119 |
+
],
|
120 |
+
},
|
121 |
+
]
|
122 |
+
|
123 |
+
text = rexthinker_processor.apply_chat_template(
|
124 |
+
messages, tokenize=False, add_generation_prompt=True
|
125 |
+
)
|
126 |
+
image_inputs, video_inputs = process_vision_info(messages)
|
127 |
+
inputs = rexthinker_processor(
|
128 |
+
text=[text],
|
129 |
+
images=image_inputs,
|
130 |
+
videos=video_inputs,
|
131 |
+
padding=True,
|
132 |
+
return_tensors="pt",
|
133 |
+
)
|
134 |
+
inputs = inputs.to("cuda")
|
135 |
+
input_height = inputs["image_grid_thw"][0][1] * 14
|
136 |
+
input_width = inputs["image_grid_thw"][0][2] * 14
|
137 |
+
|
138 |
+
# Inference: Generation of the output
|
139 |
+
generated_ids = rexthinker_model.generate(**inputs, max_new_tokens=4096)
|
140 |
+
generated_ids_trimmed = [
|
141 |
+
out_ids[len(in_ids) :]
|
142 |
+
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
143 |
+
]
|
144 |
+
output_text = rexthinker_processor.batch_decode(
|
145 |
+
generated_ids_trimmed,
|
146 |
+
skip_special_tokens=True,
|
147 |
+
clean_up_tokenization_spaces=False,
|
148 |
+
)
|
149 |
+
output_text = output_text[0]
|
150 |
+
|
151 |
+
ref_vis_result, gdino_vis_result = postprocess_and_vis_inference_out(
|
152 |
+
image,
|
153 |
+
output_text,
|
154 |
+
proposed_box,
|
155 |
+
gdino_boxes,
|
156 |
+
font_size=font_size,
|
157 |
+
draw_width=draw_width,
|
158 |
+
input_height=input_height,
|
159 |
+
input_width=input_width,
|
160 |
+
)
|
161 |
+
|
162 |
+
return gdino_vis_result, ref_vis_result, output_text
|
163 |
+
|
164 |
+
|
165 |
+
def create_demo(models):
|
166 |
+
(
|
167 |
+
gdino_model,
|
168 |
+
rexthinker_processor,
|
169 |
+
rexthinker_model,
|
170 |
+
) = models
|
171 |
+
|
172 |
+
with gr.Blocks() as demo:
|
173 |
+
gr.Markdown("# Rex-Thinker Demo")
|
174 |
+
|
175 |
+
with gr.Row():
|
176 |
+
with gr.Column():
|
177 |
+
input_image = gr.Image(label="Input Image", type="pil")
|
178 |
+
gdino_prompt = gr.Textbox(
|
179 |
+
label="Object Category Name to get Candidate boxes",
|
180 |
+
placeholder="person",
|
181 |
+
value="person",
|
182 |
+
)
|
183 |
+
referring_prompt = gr.Textbox(
|
184 |
+
label="Referring Expression",
|
185 |
+
placeholder="person wearning red shirt and a black hat",
|
186 |
+
value="person wearning red shirt and a black hat",
|
187 |
+
)
|
188 |
+
system_prompt = gr.Textbox(
|
189 |
+
label="System Prompt",
|
190 |
+
value="A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>.",
|
191 |
+
)
|
192 |
+
with gr.Row():
|
193 |
+
draw_width = gr.Slider(
|
194 |
+
minimum=5.0,
|
195 |
+
maximum=100.0,
|
196 |
+
value=10.0,
|
197 |
+
step=1,
|
198 |
+
label="Draw Width for Visualization",
|
199 |
+
)
|
200 |
+
font_size = gr.Slider(
|
201 |
+
minimum=5.0,
|
202 |
+
maximum=100.0,
|
203 |
+
value=20.0,
|
204 |
+
step=1,
|
205 |
+
label="Font size for Visualization",
|
206 |
+
)
|
207 |
+
run_button = gr.Button("Run")
|
208 |
+
|
209 |
+
with gr.Column():
|
210 |
+
gdino_output = gr.Image(label="GroundingDINO Detection")
|
211 |
+
final_output = gr.Image(label="Rex-Thinker Visualization")
|
212 |
+
with gr.Column():
|
213 |
+
llm_output = gr.Textbox(
|
214 |
+
label="LLM Raw Output", interactive=False, lines=50, max_lines=100
|
215 |
+
)
|
216 |
+
|
217 |
+
# Add examples section
|
218 |
+
gr.Markdown("## Examples")
|
219 |
+
examples = gr.Examples(
|
220 |
+
examples=[
|
221 |
+
[
|
222 |
+
"demo/example_images/demo_tomato.jpg",
|
223 |
+
"A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>.",
|
224 |
+
"tomato",
|
225 |
+
"ripe tomato",
|
226 |
+
10,
|
227 |
+
20,
|
228 |
+
],
|
229 |
+
[
|
230 |
+
"demo/example_images/demo_helmet.png",
|
231 |
+
"A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>.",
|
232 |
+
"helmet",
|
233 |
+
"the forth helmet from left",
|
234 |
+
10,
|
235 |
+
20,
|
236 |
+
],
|
237 |
+
[
|
238 |
+
"demo/example_images/demo_person.jpg",
|
239 |
+
"A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>.",
|
240 |
+
"person",
|
241 |
+
"person in the red car but not driving",
|
242 |
+
10,
|
243 |
+
20,
|
244 |
+
],
|
245 |
+
[
|
246 |
+
"demo/example_images/demo_letter.jpg",
|
247 |
+
"A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>.",
|
248 |
+
"person",
|
249 |
+
"person wearing cloth that has two letters",
|
250 |
+
10,
|
251 |
+
20,
|
252 |
+
],
|
253 |
+
[
|
254 |
+
"demo/example_images/demo_dog.jpg",
|
255 |
+
"A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>.",
|
256 |
+
"dog",
|
257 |
+
"the dog sleep on the bed with a pot under its body",
|
258 |
+
10,
|
259 |
+
20,
|
260 |
+
],
|
261 |
+
],
|
262 |
+
inputs=[
|
263 |
+
input_image,
|
264 |
+
system_prompt,
|
265 |
+
gdino_prompt,
|
266 |
+
referring_prompt,
|
267 |
+
draw_width,
|
268 |
+
font_size,
|
269 |
+
],
|
270 |
+
outputs=[gdino_output, final_output, llm_output],
|
271 |
+
fn=lambda img, sys, p1, p2, d, f: process_image(
|
272 |
+
img,
|
273 |
+
sys,
|
274 |
+
p1,
|
275 |
+
p2,
|
276 |
+
d,
|
277 |
+
f,
|
278 |
+
gdino_model,
|
279 |
+
rexthinker_processor,
|
280 |
+
rexthinker_model,
|
281 |
+
),
|
282 |
+
cache_examples=False,
|
283 |
+
)
|
284 |
+
|
285 |
+
run_button.click(
|
286 |
+
fn=lambda img, sys, p1, p2, d, f: process_image(
|
287 |
+
img,
|
288 |
+
sys,
|
289 |
+
p1,
|
290 |
+
p2,
|
291 |
+
d,
|
292 |
+
f,
|
293 |
+
gdino_model,
|
294 |
+
rexthinker_processor,
|
295 |
+
rexthinker_model,
|
296 |
+
),
|
297 |
+
inputs=[
|
298 |
+
input_image,
|
299 |
+
system_prompt,
|
300 |
+
gdino_prompt,
|
301 |
+
referring_prompt,
|
302 |
+
draw_width,
|
303 |
+
font_size,
|
304 |
+
],
|
305 |
+
outputs=[gdino_output, final_output, llm_output],
|
306 |
+
)
|
307 |
+
|
308 |
+
return demo
|
309 |
+
|
310 |
+
|
311 |
+
def main():
|
312 |
+
args = parse_args()
|
313 |
+
models = initialize_models(args)
|
314 |
+
demo = create_demo(models)
|
315 |
+
demo.launch(server_name=args.server_ip, server_port=args.server_port, share=True)
|
316 |
+
|
317 |
+
|
318 |
+
if __name__ == "__main__":
|
319 |
+
main()
|
demo/inference_single_image.py
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from groundingdino.util.inference import load_model
|
7 |
+
from PIL import Image, ImageDraw, ImageFont
|
8 |
+
from qwen_vl_utils import process_vision_info
|
9 |
+
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
|
10 |
+
|
11 |
+
from tools.inference_tools import (
|
12 |
+
convert_boxes_from_absolute_to_qwen25_format,
|
13 |
+
inference_gdino,
|
14 |
+
postprocess_and_vis_inference_out,
|
15 |
+
)
|
16 |
+
|
17 |
+
SYSTEM_PROMPT = "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>."
|
18 |
+
|
19 |
+
|
20 |
+
def get_args():
|
21 |
+
parser = argparse.ArgumentParser(description="Inference script for Qwen-2.5-VL")
|
22 |
+
parser.add_argument(
|
23 |
+
"--image_path",
|
24 |
+
type=str,
|
25 |
+
default="demo/example_images/demo_helmet.png",
|
26 |
+
help="Path to the input image",
|
27 |
+
)
|
28 |
+
parser.add_argument(
|
29 |
+
"--cate_name",
|
30 |
+
type=str,
|
31 |
+
default="helmet",
|
32 |
+
help='text prompt for grounding dino, e.g. "cat", "dog", "car"',
|
33 |
+
)
|
34 |
+
parser.add_argument(
|
35 |
+
"--ref_exp",
|
36 |
+
type=str,
|
37 |
+
default="the forth helmet from left",
|
38 |
+
help="Reference expression for Rex-Thinker, e.g. 'the cat on the left'",
|
39 |
+
)
|
40 |
+
parser.add_argument(
|
41 |
+
"--vis_path",
|
42 |
+
type=str,
|
43 |
+
default="vis/example_output.jpg",
|
44 |
+
help="Path to save the visualization result",
|
45 |
+
)
|
46 |
+
parser.add_argument(
|
47 |
+
"--model_path",
|
48 |
+
type=str,
|
49 |
+
default="IDEA-Research/Rex-Thinker-GRPO-7B",
|
50 |
+
help="Path to the input image",
|
51 |
+
)
|
52 |
+
parser.add_argument(
|
53 |
+
"--gdino_config",
|
54 |
+
type=str,
|
55 |
+
default="GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py",
|
56 |
+
help="Path to Grounding DINO config",
|
57 |
+
)
|
58 |
+
parser.add_argument(
|
59 |
+
"--gdino_weights",
|
60 |
+
type=str,
|
61 |
+
default="GroundingDINO/weights/groundingdino_swint_ogc.pth",
|
62 |
+
help="Path to Grounding DINO weights",
|
63 |
+
)
|
64 |
+
parser.add_argument(
|
65 |
+
"--qwen_model_path",
|
66 |
+
type=str,
|
67 |
+
default="IDEA-Research/Rex-Thinker-GRPO-7B",
|
68 |
+
help="Path to Qwen-2.5-VL model or model identifier from Hugging Face Hub",
|
69 |
+
)
|
70 |
+
|
71 |
+
return parser.parse_args()
|
72 |
+
|
73 |
+
|
74 |
+
if __name__ == "__main__":
|
75 |
+
args = get_args()
|
76 |
+
image_path = args.image_path
|
77 |
+
cate_name = args.cate_name
|
78 |
+
ref_exp = args.ref_exp
|
79 |
+
gdino_config = args.gdino_config
|
80 |
+
gdino_weights = args.gdino_weights
|
81 |
+
qwen_model_path = args.qwen_model_path
|
82 |
+
|
83 |
+
# Load the Grounding DINO model
|
84 |
+
gdino_model = load_model(gdino_config, gdino_weights)
|
85 |
+
gdino_model.eval()
|
86 |
+
|
87 |
+
# Load Rex-Thinker-GRPO
|
88 |
+
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
89 |
+
args.model_path,
|
90 |
+
torch_dtype=torch.bfloat16,
|
91 |
+
attn_implementation="flash_attention_2",
|
92 |
+
device_map="auto",
|
93 |
+
)
|
94 |
+
processor = AutoProcessor.from_pretrained(
|
95 |
+
args.model_path, min_pixels=16 * 28 * 28, max_pixels=1280 * 28 * 28
|
96 |
+
)
|
97 |
+
|
98 |
+
# Load the image
|
99 |
+
image = Image.open(image_path).convert("RGB")
|
100 |
+
|
101 |
+
# Prepare the text prompts for Grounding DINO
|
102 |
+
prompts = [cate_name]
|
103 |
+
|
104 |
+
# Run inference with Grounding DINO to get box hint
|
105 |
+
gdino_boxes = inference_gdino(image, prompts, gdino_model)
|
106 |
+
|
107 |
+
proposed_box = convert_boxes_from_absolute_to_qwen25_format(
|
108 |
+
gdino_boxes, image.width, image.height
|
109 |
+
)
|
110 |
+
hint = json.dumps(
|
111 |
+
{
|
112 |
+
f"{cate_name}": proposed_box,
|
113 |
+
}
|
114 |
+
)
|
115 |
+
question = f"Hint: Object and its coordinates in this image: {hint}\nPlease detect {ref_exp} in the image."
|
116 |
+
|
117 |
+
# compose input
|
118 |
+
messages = [
|
119 |
+
{
|
120 |
+
"role": "system",
|
121 |
+
"content": SYSTEM_PROMPT,
|
122 |
+
},
|
123 |
+
{
|
124 |
+
"role": "user",
|
125 |
+
"content": [
|
126 |
+
{
|
127 |
+
"type": "image",
|
128 |
+
"image": image,
|
129 |
+
},
|
130 |
+
{"type": "text", "text": question},
|
131 |
+
],
|
132 |
+
},
|
133 |
+
]
|
134 |
+
text = processor.apply_chat_template(
|
135 |
+
messages, tokenize=False, add_generation_prompt=True
|
136 |
+
)
|
137 |
+
image_inputs, video_inputs = process_vision_info(messages)
|
138 |
+
inputs = processor(
|
139 |
+
text=[text],
|
140 |
+
images=image_inputs,
|
141 |
+
videos=video_inputs,
|
142 |
+
padding=True,
|
143 |
+
return_tensors="pt",
|
144 |
+
)
|
145 |
+
inputs = inputs.to("cuda")
|
146 |
+
input_height = inputs["image_grid_thw"][0][1] * 14
|
147 |
+
input_width = inputs["image_grid_thw"][0][2] * 14
|
148 |
+
|
149 |
+
# Inference: Generation of the output
|
150 |
+
generated_ids = model.generate(**inputs, max_new_tokens=4096)
|
151 |
+
generated_ids_trimmed = [
|
152 |
+
out_ids[len(in_ids) :]
|
153 |
+
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
154 |
+
]
|
155 |
+
output_text = processor.batch_decode(
|
156 |
+
generated_ids_trimmed,
|
157 |
+
skip_special_tokens=True,
|
158 |
+
clean_up_tokenization_spaces=False,
|
159 |
+
)
|
160 |
+
output_text = output_text[0]
|
161 |
+
print(output_text)
|
162 |
+
|
163 |
+
ref_vis_result, gdino_vis_result = postprocess_and_vis_inference_out(
|
164 |
+
image,
|
165 |
+
output_text,
|
166 |
+
proposed_box,
|
167 |
+
gdino_boxes,
|
168 |
+
font_size=20,
|
169 |
+
draw_width=10,
|
170 |
+
input_height=input_height,
|
171 |
+
input_width=input_width,
|
172 |
+
)
|
173 |
+
|
174 |
+
# Create a new image with white background for the combined result
|
175 |
+
combined_width = gdino_vis_result.width + ref_vis_result.width
|
176 |
+
combined_height = max(gdino_vis_result.height, ref_vis_result.height)
|
177 |
+
combined_image = Image.new("RGB", (combined_width, combined_height), "white")
|
178 |
+
|
179 |
+
# Paste the images side by side
|
180 |
+
combined_image.paste(gdino_vis_result, (0, 0))
|
181 |
+
combined_image.paste(ref_vis_result, (gdino_vis_result.width, 0))
|
182 |
+
|
183 |
+
# Add titles
|
184 |
+
draw = ImageDraw.Draw(combined_image)
|
185 |
+
font = ImageFont.truetype("tools/Tahoma.ttf", 30)
|
186 |
+
|
187 |
+
# Add Grounding DINO title
|
188 |
+
draw.text((10, 10), "Grounding DINO Output", fill="black", font=font)
|
189 |
+
|
190 |
+
# Add Rex-Thinker title
|
191 |
+
draw.text(
|
192 |
+
(gdino_vis_result.width + 10, 10), "Rex-Thinker Output", fill="black", font=font
|
193 |
+
)
|
194 |
+
|
195 |
+
# Save the combined visualization result
|
196 |
+
os.makedirs(os.path.dirname(args.vis_path), exist_ok=True)
|
197 |
+
combined_image.save(args.vis_path)
|
requirements.txt
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
accelerate
|
2 |
+
codetiming
|
3 |
+
datasets
|
4 |
+
flash-attn>=2.4.3
|
5 |
+
liger-kernel
|
6 |
+
mathruler
|
7 |
+
numpy
|
8 |
+
omegaconf
|
9 |
+
pandas
|
10 |
+
peft
|
11 |
+
pillow
|
12 |
+
pyarrow>=15.0.0
|
13 |
+
pylatexenc
|
14 |
+
qwen-vl-utils
|
15 |
+
ray[default]
|
16 |
+
tensordict
|
17 |
+
torchdata
|
18 |
+
transformers==4.51.3
|
19 |
+
vllm==0.8.2
|
20 |
+
wandb
|
21 |
+
tensorboard
|
22 |
+
gradio==4.44.1
|
23 |
+
pydantic==2.10.6
|
24 |
+
tabulate
|