Mountchicken commited on
Commit
56d0a80
·
verified ·
1 Parent(s): 917232f

Upload 15 files

Browse files
.gitattributes CHANGED
@@ -34,3 +34,13 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  tokenizer.json filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  tokenizer.json filter=lfs diff=lfs merge=lfs -text
37
+ assets/data_engine.jpg filter=lfs diff=lfs merge=lfs -text
38
+ assets/gradio.jpg filter=lfs diff=lfs merge=lfs -text
39
+ assets/logo.png filter=lfs diff=lfs merge=lfs -text
40
+ assets/model.jpg filter=lfs diff=lfs merge=lfs -text
41
+ assets/teaser_example.jpg filter=lfs diff=lfs merge=lfs -text
42
+ demo/example_images/demo_dog.jpg filter=lfs diff=lfs merge=lfs -text
43
+ demo/example_images/demo_helmet.png filter=lfs diff=lfs merge=lfs -text
44
+ demo/example_images/demo_output.jpg filter=lfs diff=lfs merge=lfs -text
45
+ demo/example_images/demo_person.jpg filter=lfs diff=lfs merge=lfs -text
46
+ demo/example_images/demo_tomato.jpg filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ <div align=center>
3
+ <img src="assets/logo.png" width=300 >
4
+ </div>
5
+
6
+ # 🦖🧠 Rex-Thinker: Grounded Object Refering via Chain-of-Thought Reasoning 🦖🧠
7
+
8
+ <div align=center>
9
+
10
+ <p align="center">
11
+ <a href="https://bagel-ai.org/">
12
+ <img
13
+ src="https://img.shields.io/badge/RexThinker-Website-Red?logo=afdian&logoColor=white&color=blue"
14
+ alt="RexThinker Website"
15
+ />
16
+ </a>
17
+ <a href="https://arxiv.org/abs/2505.14683">
18
+ <img
19
+ src="https://img.shields.io/badge/RexThinker-Paper-Red%25red?logo=arxiv&logoColor=red&color=yellow
20
+ "
21
+ alt="RexThinker Paper on arXiv"
22
+ />
23
+ </a>
24
+ <a href="https://huggingface.co/ByteDance-Seed/BAGEL-7B-MoT">
25
+ <img
26
+ src="https://img.shields.io/badge/RexThinker-Weight-orange?logo=huggingface&logoColor=yellow"
27
+ alt="RexThinker weight on Hugging Face"
28
+ />
29
+ </a>
30
+ <a href="https://demo.bagel-ai.org/">
31
+ <img
32
+ src="https://img.shields.io/badge/RexThinker-Data-orange?logo=huggingface&logoColor=yellow"
33
+ alt="RexThinker data on Hugging Face"
34
+ />
35
+ </a>
36
+
37
+ </p>
38
+
39
+ </div>
40
+
41
+ > We propose Rex-Thinker, a Chain-of-Thought (CoT) reasoning model for object referring that addresses two key challenges: lack of interpretability and inability to reject unmatched expressions. Instead of directly predicting bounding boxes, Rex-Thinker reasons step-by-step over candidate objects to determine which, if any, match a given expression. Rex-Thinker is trained in two stages: supervised fine-tuning to learn structured CoT reasoning, followed by reinforcement learning with GRPO to enhance accuracy, faithfulness, and generalization. Our approach improves both prediction precision and interpretability, while enabling the model to abstain when no suitable object is found. Below is an example of the model's reasoning process:
42
+
43
+ <p align="center"><img src="assets/teaser_example.jpg" width="95%"></p>
44
+
45
+
46
+ ## Method
47
+
48
+ **Rex-Thinker** reformulates object referring as a **Chain-of-Thought (CoT)** reasoning task to improve both interpretability and reliability. The model follows a structured three-stage reasoning paradigm:
49
+
50
+ 1. **Planning**: Decompose the referring expression into interpretable subgoals.
51
+
52
+ 2. **Action**: Evaluate each candidate object (obtained via an open-vocabulary detector) against these subgoals using step-by-step reasoning.
53
+
54
+ 3. **Summarization**: Aggregate the intermediate results to output the final prediction — or abstain when no object matches.
55
+
56
+ Each reasoning step is grounded in a specific candidate object region through **Box Hints**, making the process transparent and verifiable.
57
+
58
+ Rex-Thinker is implemented on top of **QwenVL-2.5**, and trained in two stages:
59
+
60
+ - **Supervised Fine-Tuning (SFT)**
61
+ Cold-start training using GPT-4o-generated CoT traces as supervision.
62
+
63
+ - **GRPO-based Reinforcement Learning**
64
+ Further optimizes reasoning accuracy, generalization, and rejection ability via Group Relative Policy Optimization.
65
+
66
+ This CoT-based framework enables Rex-Thinker to make faithful, interpretable predictions while generalizing well to out-of-domain referring scenarios.
67
+
68
+
69
+ <p align="center"><img src="assets/model.jpg" width="95%"></p>
70
+
71
+
72
+
73
+ ## 1. Installation ⛳️
74
+
75
+ ```bash
76
+ conda create -n rexthinker -m python=3.10
77
+ pip install torch==2.6.0 torchvision==0.21.0 --index-url https://download.pytorch.org/whl/cu124
78
+ pip install -v -e .
79
+
80
+ # additional packages Grounding DINO
81
+ git clone https://github.com/IDEA-Research/GroundingDINO.git
82
+ cd GroundingDINO
83
+ ## To support torch2.6
84
+ git remote add quantumope https://github.com/QuantuMope/GroundingDINO.git
85
+ git fetch quantumope PR/andrew/add-torch26-support-ms-deform-attn
86
+ git merge quantumope/PR/andrew/add-torch26-support-ms-deform-attn
87
+ ## Continue with installation
88
+ pip install -v -e .
89
+ mkdir weights
90
+ wget -q https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth -P weights
91
+ cd ..
92
+ ```
93
+
94
+ ### 1.1 Download Pre-trained Model
95
+ We provide the pre-trained model weights of Rex-Thinker-GRPO, which is trained on HumanRef-CoT through SFT and GRPO. You can download the model weights from [Hugging Face](https://huggingface.co/IDEA-Research/Rex-Thinker-GRPO-7B).
96
+
97
+ Or you can also using the following command to download the pre-trained models:
98
+ ```bash
99
+ git lfs install
100
+ git clone https://huggingface.co/IDEA-Research/Rex-Thinker-GRPO-7B IDEA-Research/Rex-Thinker-GRPO-7B
101
+ ```
102
+
103
+ ## 2. Inference 🚀
104
+ We provide a simple inference script to test the model. In this script, we use Grouning DINO to get the candidate boxes. You can run the following command to test the model:
105
+
106
+ ```bash
107
+ CUDA_VISIBLE_DEVICES=0 python demo/inference_single_image.py \
108
+ --image_path demo/example_images/demo_helmet.png \
109
+ --cate_name helmet \
110
+ --ref_exp the forth helmet from left \
111
+ --vis_path vis/example_output.jpg
112
+ ```
113
+
114
+ You will get output fromt the terminal like this:
115
+ ```text
116
+ <think>OK, the user needs us to detect the fourth helmet from left. To accomplish this task, I need to break it down into the following steps:
117
+ - Step 1: Sort the helmets from left to right.
118
+ - Step 2: Find the fourth helmet from the sorted list.
119
+
120
+ # Step 1: Sort the helmets from left to right
121
+ I see 6 helmets in this image, and their order from left to right is [Helmet 5, Helmet 1, Helmet 3, Helmet 2, Helmet 4, Helmet 6].
122
+
123
+ # Step 2: Find the fourth helmet from the sorted list
124
+ From the sorted list [Helmet 5, Helmet 1, Helmet 3, Helmet 2, Helmet 4, Helmet 6], the fourth helmet from the left is Helmet 2.
125
+
126
+ # Summarize and Re-Check answer
127
+ Let’s now recheck our answer and put ✅ for the target helmet and ❌ for others
128
+ - Helmet 5: It is the first helmet from left → ❌
129
+ - Helmet 1: It is the second helmet from left → ❌
130
+ - Helmet 3: It is the third helmet from left → ❌
131
+ - Helmet 2: It is the fourth helmet from left → ✅
132
+ - Helmet 4: It is the fifth helmet from left → ❌
133
+ - Helmet 6: It is the sixth helmet from left → ❌</think><answer>json
134
+ [{"bbox_2d": [578, 359, 825, 580], "label": "the forth helmet from left"}]
135
+ </answer>
136
+ ```
137
+
138
+ and visulized results like this:
139
+ <p align="center"><img src="demo/example_images/demo_output.jpg" width="80%"></p>
140
+
141
+
142
+ ## 3. Gradio Demo 🤗
143
+ We provide a Gradio demo for you to test the model. You can run the following command to start the Gradio demo:
144
+ ```bash
145
+ CUDA_VISIBLE_DEVICES=0 python demo/gradio_demo.py \
146
+ --model_path IDEA-Research/Rex-Thinker-GRPO-7B \
147
+ --server_ip 0.0.0.0 \
148
+ --server_port 7860
149
+ ```
150
+
151
+ Then you can open your browser and visit `http://localhost:7860` to see the Gradio demo. You can input the image path, category name, and referring expression to test the model.
152
+
153
+ <p align="center"><img src="assets/gradio.jpg" width="95%"></p>
154
+
155
+ ## Citation 📜
assets/data_engine.jpg ADDED

Git LFS Details

  • SHA256: ee0de7a98fdd735c2e9d1b48509ee9c12c38b0e0b95594dc7f9ea9e3a4209c29
  • Pointer size: 131 Bytes
  • Size of remote file: 500 kB
assets/gradio.jpg ADDED

Git LFS Details

  • SHA256: e1a4152589c2628b4535a34c1cb4abba1393c8b1d333e2e1a8f93aa35b5ff462
  • Pointer size: 131 Bytes
  • Size of remote file: 731 kB
assets/logo.png ADDED

Git LFS Details

  • SHA256: fe991cdf4e4120f85cdb2612d2e94b6838ccf57e877ca96f957b0d1f7905557e
  • Pointer size: 132 Bytes
  • Size of remote file: 1.34 MB
assets/model.jpg ADDED

Git LFS Details

  • SHA256: 142784a336982ac5eb8485a1672f4f0b02d03f8505dc430888f6bcacda3d1214
  • Pointer size: 131 Bytes
  • Size of remote file: 586 kB
assets/teaser_example.jpg ADDED

Git LFS Details

  • SHA256: 0304086e1b807243bab3f7511b3fa26e485c626c7e992b65ff26c18391259408
  • Pointer size: 131 Bytes
  • Size of remote file: 831 kB
demo/example_images/demo_dog.jpg ADDED

Git LFS Details

  • SHA256: 11b4d1efa0a566d092e3e9ec1706bbbaa38d89229d987e4b444c4151fd4208a8
  • Pointer size: 131 Bytes
  • Size of remote file: 137 kB
demo/example_images/demo_helmet.png ADDED

Git LFS Details

  • SHA256: 7a69bf695a512f85cb6dc5012387a1a7f2f26d74e36ab4026f8a1775600feea0
  • Pointer size: 131 Bytes
  • Size of remote file: 260 kB
demo/example_images/demo_letter.jpg ADDED
demo/example_images/demo_output.jpg ADDED

Git LFS Details

  • SHA256: 81fb4fb8f0ec80b9dfe2acb2fa0b2004392d18c2de2df66252596d2ada25cc44
  • Pointer size: 131 Bytes
  • Size of remote file: 120 kB
demo/example_images/demo_person.jpg ADDED

Git LFS Details

  • SHA256: f111ed549e78268977e53156ef9f836055bedcf7f047c5bc4da7a5749af194e4
  • Pointer size: 131 Bytes
  • Size of remote file: 203 kB
demo/example_images/demo_tomato.jpg ADDED

Git LFS Details

  • SHA256: 5e9735bd5a4f6cdf3a227bd8b1552130409832fe409ed346cb0dca290394741f
  • Pointer size: 131 Bytes
  • Size of remote file: 462 kB
demo/gradio_demo.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+
4
+ import gradio as gr
5
+ import numpy as np
6
+ import torch
7
+ from groundingdino.util.inference import load_model
8
+ from PIL import Image
9
+ from qwen_vl_utils import process_vision_info
10
+ from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
11
+
12
+ from tools.inference_tools import (
13
+ convert_boxes_from_absolute_to_qwen25_format,
14
+ inference_gdino,
15
+ postprocess_and_vis_inference_out,
16
+ )
17
+
18
+
19
+ def parse_args():
20
+ parser = argparse.ArgumentParser()
21
+ parser.add_argument(
22
+ "--model_path", type=str, default="IDEA-Research/Rex-Thinker-GRPO-7B"
23
+ )
24
+ parser.add_argument(
25
+ "--gdino_config",
26
+ type=str,
27
+ default="GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py",
28
+ )
29
+ parser.add_argument(
30
+ "--gdino_weights",
31
+ type=str,
32
+ default="GroundingDINO/weights/groundingdino_swint_ogc.pth",
33
+ )
34
+ parser.add_argument(
35
+ "--server_ip",
36
+ type=str,
37
+ default="0.0.0.0",
38
+ help="IP address to bind the server to",
39
+ )
40
+ parser.add_argument(
41
+ "--server_port",
42
+ type=int,
43
+ default=2512,
44
+ help="Port to run the server on",
45
+ )
46
+ return parser.parse_args()
47
+
48
+
49
+ def initialize_models(args):
50
+ # Load GDINO model
51
+ gdino_model = load_model(args.gdino_config, args.gdino_weights).to("cuda")
52
+ gdino_model.eval()
53
+
54
+ # Load Rex-Thinker-GRPO
55
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
56
+ args.model_path,
57
+ torch_dtype=torch.bfloat16,
58
+ attn_implementation="flash_attention_2",
59
+ device_map="auto",
60
+ )
61
+ processor = AutoProcessor.from_pretrained(
62
+ args.model_path, min_pixels=16 * 28 * 28, max_pixels=1280 * 28 * 28
63
+ )
64
+
65
+ return (gdino_model, processor, model)
66
+
67
+
68
+ def process_image(
69
+ image,
70
+ system_prompt,
71
+ cate_name,
72
+ referring_expression,
73
+ draw_width,
74
+ font_size,
75
+ gdino_model,
76
+ rexthinker_processor,
77
+ rexthinker_model,
78
+ ):
79
+ if isinstance(image, str):
80
+ image = Image.open(image)
81
+ elif isinstance(image, np.ndarray):
82
+ image = Image.fromarray(image)
83
+
84
+ # Run GDINO inference
85
+ gdino_boxes = inference_gdino(
86
+ image,
87
+ [cate_name],
88
+ gdino_model,
89
+ TEXT_TRESHOLD=0.25,
90
+ BOX_TRESHOLD=0.25,
91
+ )
92
+ proposed_box = convert_boxes_from_absolute_to_qwen25_format(
93
+ gdino_boxes, image.width, image.height
94
+ )
95
+
96
+ hint = json.dumps(
97
+ {
98
+ f"{cate_name}": proposed_box,
99
+ }
100
+ )
101
+ question = f"Hint: Object and its coordinates in this image: {hint}\nPlease detect {referring_expression} in the image."
102
+
103
+ # compose input
104
+ print(f"system_prompt: {system_prompt}")
105
+ print(f"question: {question}")
106
+ messages = [
107
+ {
108
+ "role": "system",
109
+ "content": system_prompt,
110
+ },
111
+ {
112
+ "role": "user",
113
+ "content": [
114
+ {
115
+ "type": "image",
116
+ "image": image,
117
+ },
118
+ {"type": "text", "text": question},
119
+ ],
120
+ },
121
+ ]
122
+
123
+ text = rexthinker_processor.apply_chat_template(
124
+ messages, tokenize=False, add_generation_prompt=True
125
+ )
126
+ image_inputs, video_inputs = process_vision_info(messages)
127
+ inputs = rexthinker_processor(
128
+ text=[text],
129
+ images=image_inputs,
130
+ videos=video_inputs,
131
+ padding=True,
132
+ return_tensors="pt",
133
+ )
134
+ inputs = inputs.to("cuda")
135
+ input_height = inputs["image_grid_thw"][0][1] * 14
136
+ input_width = inputs["image_grid_thw"][0][2] * 14
137
+
138
+ # Inference: Generation of the output
139
+ generated_ids = rexthinker_model.generate(**inputs, max_new_tokens=4096)
140
+ generated_ids_trimmed = [
141
+ out_ids[len(in_ids) :]
142
+ for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
143
+ ]
144
+ output_text = rexthinker_processor.batch_decode(
145
+ generated_ids_trimmed,
146
+ skip_special_tokens=True,
147
+ clean_up_tokenization_spaces=False,
148
+ )
149
+ output_text = output_text[0]
150
+
151
+ ref_vis_result, gdino_vis_result = postprocess_and_vis_inference_out(
152
+ image,
153
+ output_text,
154
+ proposed_box,
155
+ gdino_boxes,
156
+ font_size=font_size,
157
+ draw_width=draw_width,
158
+ input_height=input_height,
159
+ input_width=input_width,
160
+ )
161
+
162
+ return gdino_vis_result, ref_vis_result, output_text
163
+
164
+
165
+ def create_demo(models):
166
+ (
167
+ gdino_model,
168
+ rexthinker_processor,
169
+ rexthinker_model,
170
+ ) = models
171
+
172
+ with gr.Blocks() as demo:
173
+ gr.Markdown("# Rex-Thinker Demo")
174
+
175
+ with gr.Row():
176
+ with gr.Column():
177
+ input_image = gr.Image(label="Input Image", type="pil")
178
+ gdino_prompt = gr.Textbox(
179
+ label="Object Category Name to get Candidate boxes",
180
+ placeholder="person",
181
+ value="person",
182
+ )
183
+ referring_prompt = gr.Textbox(
184
+ label="Referring Expression",
185
+ placeholder="person wearning red shirt and a black hat",
186
+ value="person wearning red shirt and a black hat",
187
+ )
188
+ system_prompt = gr.Textbox(
189
+ label="System Prompt",
190
+ value="A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>.",
191
+ )
192
+ with gr.Row():
193
+ draw_width = gr.Slider(
194
+ minimum=5.0,
195
+ maximum=100.0,
196
+ value=10.0,
197
+ step=1,
198
+ label="Draw Width for Visualization",
199
+ )
200
+ font_size = gr.Slider(
201
+ minimum=5.0,
202
+ maximum=100.0,
203
+ value=20.0,
204
+ step=1,
205
+ label="Font size for Visualization",
206
+ )
207
+ run_button = gr.Button("Run")
208
+
209
+ with gr.Column():
210
+ gdino_output = gr.Image(label="GroundingDINO Detection")
211
+ final_output = gr.Image(label="Rex-Thinker Visualization")
212
+ with gr.Column():
213
+ llm_output = gr.Textbox(
214
+ label="LLM Raw Output", interactive=False, lines=50, max_lines=100
215
+ )
216
+
217
+ # Add examples section
218
+ gr.Markdown("## Examples")
219
+ examples = gr.Examples(
220
+ examples=[
221
+ [
222
+ "demo/example_images/demo_tomato.jpg",
223
+ "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>.",
224
+ "tomato",
225
+ "ripe tomato",
226
+ 10,
227
+ 20,
228
+ ],
229
+ [
230
+ "demo/example_images/demo_helmet.png",
231
+ "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>.",
232
+ "helmet",
233
+ "the forth helmet from left",
234
+ 10,
235
+ 20,
236
+ ],
237
+ [
238
+ "demo/example_images/demo_person.jpg",
239
+ "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>.",
240
+ "person",
241
+ "person in the red car but not driving",
242
+ 10,
243
+ 20,
244
+ ],
245
+ [
246
+ "demo/example_images/demo_letter.jpg",
247
+ "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>.",
248
+ "person",
249
+ "person wearing cloth that has two letters",
250
+ 10,
251
+ 20,
252
+ ],
253
+ [
254
+ "demo/example_images/demo_dog.jpg",
255
+ "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>.",
256
+ "dog",
257
+ "the dog sleep on the bed with a pot under its body",
258
+ 10,
259
+ 20,
260
+ ],
261
+ ],
262
+ inputs=[
263
+ input_image,
264
+ system_prompt,
265
+ gdino_prompt,
266
+ referring_prompt,
267
+ draw_width,
268
+ font_size,
269
+ ],
270
+ outputs=[gdino_output, final_output, llm_output],
271
+ fn=lambda img, sys, p1, p2, d, f: process_image(
272
+ img,
273
+ sys,
274
+ p1,
275
+ p2,
276
+ d,
277
+ f,
278
+ gdino_model,
279
+ rexthinker_processor,
280
+ rexthinker_model,
281
+ ),
282
+ cache_examples=False,
283
+ )
284
+
285
+ run_button.click(
286
+ fn=lambda img, sys, p1, p2, d, f: process_image(
287
+ img,
288
+ sys,
289
+ p1,
290
+ p2,
291
+ d,
292
+ f,
293
+ gdino_model,
294
+ rexthinker_processor,
295
+ rexthinker_model,
296
+ ),
297
+ inputs=[
298
+ input_image,
299
+ system_prompt,
300
+ gdino_prompt,
301
+ referring_prompt,
302
+ draw_width,
303
+ font_size,
304
+ ],
305
+ outputs=[gdino_output, final_output, llm_output],
306
+ )
307
+
308
+ return demo
309
+
310
+
311
+ def main():
312
+ args = parse_args()
313
+ models = initialize_models(args)
314
+ demo = create_demo(models)
315
+ demo.launch(server_name=args.server_ip, server_port=args.server_port, share=True)
316
+
317
+
318
+ if __name__ == "__main__":
319
+ main()
demo/inference_single_image.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+
5
+ import torch
6
+ from groundingdino.util.inference import load_model
7
+ from PIL import Image, ImageDraw, ImageFont
8
+ from qwen_vl_utils import process_vision_info
9
+ from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
10
+
11
+ from tools.inference_tools import (
12
+ convert_boxes_from_absolute_to_qwen25_format,
13
+ inference_gdino,
14
+ postprocess_and_vis_inference_out,
15
+ )
16
+
17
+ SYSTEM_PROMPT = "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>."
18
+
19
+
20
+ def get_args():
21
+ parser = argparse.ArgumentParser(description="Inference script for Qwen-2.5-VL")
22
+ parser.add_argument(
23
+ "--image_path",
24
+ type=str,
25
+ default="demo/example_images/demo_helmet.png",
26
+ help="Path to the input image",
27
+ )
28
+ parser.add_argument(
29
+ "--cate_name",
30
+ type=str,
31
+ default="helmet",
32
+ help='text prompt for grounding dino, e.g. "cat", "dog", "car"',
33
+ )
34
+ parser.add_argument(
35
+ "--ref_exp",
36
+ type=str,
37
+ default="the forth helmet from left",
38
+ help="Reference expression for Rex-Thinker, e.g. 'the cat on the left'",
39
+ )
40
+ parser.add_argument(
41
+ "--vis_path",
42
+ type=str,
43
+ default="vis/example_output.jpg",
44
+ help="Path to save the visualization result",
45
+ )
46
+ parser.add_argument(
47
+ "--model_path",
48
+ type=str,
49
+ default="IDEA-Research/Rex-Thinker-GRPO-7B",
50
+ help="Path to the input image",
51
+ )
52
+ parser.add_argument(
53
+ "--gdino_config",
54
+ type=str,
55
+ default="GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py",
56
+ help="Path to Grounding DINO config",
57
+ )
58
+ parser.add_argument(
59
+ "--gdino_weights",
60
+ type=str,
61
+ default="GroundingDINO/weights/groundingdino_swint_ogc.pth",
62
+ help="Path to Grounding DINO weights",
63
+ )
64
+ parser.add_argument(
65
+ "--qwen_model_path",
66
+ type=str,
67
+ default="IDEA-Research/Rex-Thinker-GRPO-7B",
68
+ help="Path to Qwen-2.5-VL model or model identifier from Hugging Face Hub",
69
+ )
70
+
71
+ return parser.parse_args()
72
+
73
+
74
+ if __name__ == "__main__":
75
+ args = get_args()
76
+ image_path = args.image_path
77
+ cate_name = args.cate_name
78
+ ref_exp = args.ref_exp
79
+ gdino_config = args.gdino_config
80
+ gdino_weights = args.gdino_weights
81
+ qwen_model_path = args.qwen_model_path
82
+
83
+ # Load the Grounding DINO model
84
+ gdino_model = load_model(gdino_config, gdino_weights)
85
+ gdino_model.eval()
86
+
87
+ # Load Rex-Thinker-GRPO
88
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
89
+ args.model_path,
90
+ torch_dtype=torch.bfloat16,
91
+ attn_implementation="flash_attention_2",
92
+ device_map="auto",
93
+ )
94
+ processor = AutoProcessor.from_pretrained(
95
+ args.model_path, min_pixels=16 * 28 * 28, max_pixels=1280 * 28 * 28
96
+ )
97
+
98
+ # Load the image
99
+ image = Image.open(image_path).convert("RGB")
100
+
101
+ # Prepare the text prompts for Grounding DINO
102
+ prompts = [cate_name]
103
+
104
+ # Run inference with Grounding DINO to get box hint
105
+ gdino_boxes = inference_gdino(image, prompts, gdino_model)
106
+
107
+ proposed_box = convert_boxes_from_absolute_to_qwen25_format(
108
+ gdino_boxes, image.width, image.height
109
+ )
110
+ hint = json.dumps(
111
+ {
112
+ f"{cate_name}": proposed_box,
113
+ }
114
+ )
115
+ question = f"Hint: Object and its coordinates in this image: {hint}\nPlease detect {ref_exp} in the image."
116
+
117
+ # compose input
118
+ messages = [
119
+ {
120
+ "role": "system",
121
+ "content": SYSTEM_PROMPT,
122
+ },
123
+ {
124
+ "role": "user",
125
+ "content": [
126
+ {
127
+ "type": "image",
128
+ "image": image,
129
+ },
130
+ {"type": "text", "text": question},
131
+ ],
132
+ },
133
+ ]
134
+ text = processor.apply_chat_template(
135
+ messages, tokenize=False, add_generation_prompt=True
136
+ )
137
+ image_inputs, video_inputs = process_vision_info(messages)
138
+ inputs = processor(
139
+ text=[text],
140
+ images=image_inputs,
141
+ videos=video_inputs,
142
+ padding=True,
143
+ return_tensors="pt",
144
+ )
145
+ inputs = inputs.to("cuda")
146
+ input_height = inputs["image_grid_thw"][0][1] * 14
147
+ input_width = inputs["image_grid_thw"][0][2] * 14
148
+
149
+ # Inference: Generation of the output
150
+ generated_ids = model.generate(**inputs, max_new_tokens=4096)
151
+ generated_ids_trimmed = [
152
+ out_ids[len(in_ids) :]
153
+ for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
154
+ ]
155
+ output_text = processor.batch_decode(
156
+ generated_ids_trimmed,
157
+ skip_special_tokens=True,
158
+ clean_up_tokenization_spaces=False,
159
+ )
160
+ output_text = output_text[0]
161
+ print(output_text)
162
+
163
+ ref_vis_result, gdino_vis_result = postprocess_and_vis_inference_out(
164
+ image,
165
+ output_text,
166
+ proposed_box,
167
+ gdino_boxes,
168
+ font_size=20,
169
+ draw_width=10,
170
+ input_height=input_height,
171
+ input_width=input_width,
172
+ )
173
+
174
+ # Create a new image with white background for the combined result
175
+ combined_width = gdino_vis_result.width + ref_vis_result.width
176
+ combined_height = max(gdino_vis_result.height, ref_vis_result.height)
177
+ combined_image = Image.new("RGB", (combined_width, combined_height), "white")
178
+
179
+ # Paste the images side by side
180
+ combined_image.paste(gdino_vis_result, (0, 0))
181
+ combined_image.paste(ref_vis_result, (gdino_vis_result.width, 0))
182
+
183
+ # Add titles
184
+ draw = ImageDraw.Draw(combined_image)
185
+ font = ImageFont.truetype("tools/Tahoma.ttf", 30)
186
+
187
+ # Add Grounding DINO title
188
+ draw.text((10, 10), "Grounding DINO Output", fill="black", font=font)
189
+
190
+ # Add Rex-Thinker title
191
+ draw.text(
192
+ (gdino_vis_result.width + 10, 10), "Rex-Thinker Output", fill="black", font=font
193
+ )
194
+
195
+ # Save the combined visualization result
196
+ os.makedirs(os.path.dirname(args.vis_path), exist_ok=True)
197
+ combined_image.save(args.vis_path)
requirements.txt ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate
2
+ codetiming
3
+ datasets
4
+ flash-attn>=2.4.3
5
+ liger-kernel
6
+ mathruler
7
+ numpy
8
+ omegaconf
9
+ pandas
10
+ peft
11
+ pillow
12
+ pyarrow>=15.0.0
13
+ pylatexenc
14
+ qwen-vl-utils
15
+ ray[default]
16
+ tensordict
17
+ torchdata
18
+ transformers==4.51.3
19
+ vllm==0.8.2
20
+ wandb
21
+ tensorboard
22
+ gradio==4.44.1
23
+ pydantic==2.10.6
24
+ tabulate