abalakrishnaTRI commited on
Commit
83cb829
Β·
0 Parent(s):

first commit

Browse files
.gitignore ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ pip-wheel-metadata/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ target/
76
+
77
+ # Jupyter Notebook
78
+ .ipynb_checkpoints
79
+
80
+ # IPython
81
+ profile_default/
82
+ ipython_config.py
83
+
84
+ # pyenv
85
+ .python-version
86
+
87
+ # pipenv
88
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
90
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
91
+ # install all needed dependencies.
92
+ #Pipfile.lock
93
+
94
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95
+ __pypackages__/
96
+
97
+ # Celery stuff
98
+ celerybeat-schedule
99
+ celerybeat.pid
100
+
101
+ # SageMath parsed files
102
+ *.sage.py
103
+
104
+ # Logs
105
+ serve_images/
106
+
107
+ # Environments
108
+ .env
109
+ .venv
110
+ env/
111
+ venv/
112
+ ENV/
113
+ env.bak/
114
+ venv.bak/
115
+
116
+ # Spyder project settings
117
+ .spyderproject
118
+ .spyproject
119
+
120
+ # Rope project settings
121
+ .ropeproject
122
+
123
+ # mkdocs documentation
124
+ /site
125
+
126
+ # mypy
127
+ .mypy_cache/
128
+ .dmypy.json
129
+ dmypy.json
130
+
131
+ # Pyre type checker
132
+ .pyre/
133
+
134
+ # Ruff
135
+ .ruff_cache/
136
+
137
+ # IDE caches
138
+ .idea/
139
+ .vscode/
140
+
141
+ # Mac OS
142
+ .DS_Store
143
+
144
+ # Tokens
145
+ .hf_token
146
+
147
+ # Scratch & Caches
148
+ __scratch/
149
+ scratch/
150
+ cache/
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024-present, Toyota Research Institute.
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
Makefile ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .PHONY: help check autoformat
2
+ .DEFAULT: help
3
+
4
+ # Generates a useful overview/help message for various make features - add to this as necessary!
5
+ help:
6
+ @echo "make check"
7
+ @echo " Run code style and linting (black, ruff) *without* changing files!"
8
+ @echo "make autoformat"
9
+ @echo " Run code styling (black, ruff) and update in place - committing with pre-commit also does this."
10
+
11
+ check:
12
+ black --check .
13
+ ruff check --show-source .
14
+
15
+ autoformat:
16
+ black .
17
+ ruff check --fix --show-fixes .
README.md ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # VLM Demo
2
+
3
+ > *VLM Demo*: Lightweight repo for chatting with models loaded into *VLM Bench*.
4
+
5
+ ---
6
+
7
+ ## Installation
8
+
9
+ This repository
10
+
11
+ ```bash
12
+ git clone [email protected]:TRI-ML/vlm-demo.git
13
+ cd vlm-demo
14
+ pip install -e .
15
+ ```
16
+
17
+ This repository also requires that the `vlm-bench` package (`vlbench`) and
18
+ `prismatic-vlms` package (`prisma`) are installed in the current environment.
19
+ These can both be installed from source from the following git repos:
20
+
21
+ `vlm-bench`: `https://github.com/TRI-ML/vlm-bench`
22
+ `prismatic-vlms`: `https://github.com/TRI-ML/prismatic-vlms`
23
+
24
+ ## Usage
25
+
26
+ Start Gradio Controller: `serve/gradio_controller.py`
27
+ Start Gradio Web Server: `serve/gradio_web_server.py`
28
+ Run interactive demo: `interactive_demo.py`
29
+
30
+ To run the demo, run the following commands:
31
+
32
+ Start Gradio Controller: `python -m serve.controller --host 0.0.0.0 --port 10000`
33
+ Start Gradio Web Server: `python -m serve.gradio_web_server --controller http://localhost:10000 --model-list-mode reload --share`
34
+ Run interactive demo: `CUDA_VISIBLE_DEVICES=0 python -m interactive_demo --port 40000 --model_dir <PATH TO MODEL CKPT>`
35
+
36
+ ## Contributing
37
+
38
+ Before committing to the repository, *make sure to set up your dev environment!*
39
+
40
+ Here are the basic development environment setup guidelines:
41
+
42
+ + Fork/clone the repository, performing an editable installation. Make sure to install with the development dependencies
43
+ (e.g., `pip install -e ".[dev]"`); this will install `black`, `ruff`, and `pre-commit`.
44
+
45
+ + Install `pre-commit` hooks (`pre-commit install`).
46
+
47
+ + Branch for the specific feature/issue, issuing PR against the upstream repository for review.
48
+
interactive_demo.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ interactive_demo.py
3
+
4
+ Entry point for all VLM-Bench interactive demos; specify model and get a gradio UI where you can chat with it!
5
+
6
+ This file is heavily adapted from the script used to serve models in the LLaVa repo:
7
+ https://github.com/haotian-liu/LLaVA/blob/main/llava/serve/model_worker.py. It is
8
+ modified to ensure compatibility with our Prismatic models.
9
+ """
10
+ import asyncio
11
+ import json
12
+ import os
13
+ import threading
14
+ import time
15
+ import uuid
16
+ from dataclasses import dataclass
17
+ from functools import partial
18
+ from pathlib import Path
19
+ from typing import Union
20
+
21
+ import draccus
22
+ import requests
23
+ import torch
24
+ import uvicorn
25
+ from accelerate.utils import set_seed
26
+ from fastapi import BackgroundTasks, FastAPI, Request
27
+ from fastapi.responses import StreamingResponse
28
+ from llava.constants import WORKER_HEART_BEAT_INTERVAL
29
+ from llava.mm_utils import load_image_from_base64
30
+ from llava.utils import server_error_msg
31
+ from torchvision.transforms import Compose
32
+
33
+ from vlbench.models import load_vlm
34
+ from vlbench.overwatch import initialize_overwatch
35
+ from serve import INTERACTION_MODES_MAP, MODEL_ID_TO_NAME
36
+
37
+ GB = 1 << 30
38
+ worker_id = str(uuid.uuid4())[:6]
39
+ global_counter = 0
40
+ model_semaphore = None
41
+
42
+
43
+ def heart_beat_worker(controller):
44
+ while True:
45
+ time.sleep(WORKER_HEART_BEAT_INTERVAL)
46
+ controller.send_heart_beat()
47
+
48
+
49
+ class ModelWorker:
50
+ def __init__(self, controller_addr, worker_addr, worker_id, no_register, vlm, model_base, model_name):
51
+ self.controller_addr = controller_addr
52
+ self.worker_addr = worker_addr
53
+ self.worker_id = worker_id
54
+ self.model_name = model_name
55
+
56
+ # logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...")
57
+ self.vlm = vlm
58
+ self.tokenizer, self.model, self.image_processor, self.context_len = (
59
+ vlm.tokenizer,
60
+ vlm.model,
61
+ vlm.image_processor,
62
+ vlm.max_length,
63
+ )
64
+
65
+ if not no_register:
66
+ self.register_to_controller()
67
+ self.heart_beat_thread = threading.Thread(target=heart_beat_worker, args=(self,))
68
+ self.heart_beat_thread.start()
69
+
70
+ def register_to_controller(self):
71
+ # logger.info("Register to controller")
72
+
73
+ url = self.controller_addr + "/register_worker"
74
+ data = {"worker_name": self.worker_addr, "check_heart_beat": True, "worker_status": self.get_status()}
75
+ r = requests.post(url, json=data)
76
+ assert r.status_code == 200
77
+
78
+ def send_heart_beat(self):
79
+ # logger.info(f"Send heart beat. Models: {[self.model_name]}. "
80
+ # f"Semaphore: {pretty_print_semaphore(model_semaphore)}. "
81
+ # f"global_counter: {global_counter}")
82
+
83
+ url = self.controller_addr + "/receive_heart_beat"
84
+
85
+ while True:
86
+ try:
87
+ ret = requests.post(
88
+ url, json={"worker_name": self.worker_addr, "queue_length": self.get_queue_length()}, timeout=5
89
+ )
90
+ exist = ret.json()["exist"]
91
+ break
92
+ except requests.exceptions.RequestException:
93
+ pass
94
+ # logger.error(f"heart beat error: {e}")
95
+ time.sleep(5)
96
+
97
+ if not exist:
98
+ self.register_to_controller()
99
+
100
+ def get_queue_length(self):
101
+ if model_semaphore is None:
102
+ return 0
103
+ else:
104
+ return (
105
+ limit_model_concurrency
106
+ - model_semaphore._value
107
+ + (len(model_semaphore._waiters) if model_semaphore._waiters is not None else 0)
108
+ )
109
+
110
+ def get_status(self):
111
+ return {
112
+ "model_names": [self.model_name],
113
+ "speed": 1,
114
+ "queue_length": self.get_queue_length(),
115
+ }
116
+
117
+ @torch.inference_mode()
118
+ def generate_stream(self, params):
119
+ prompt = params["prompt"]
120
+ ori_prompt = prompt
121
+ images = params.get("images", None)
122
+
123
+ temperature = params.get("temperature", 0.2)
124
+ max_new_tokens = params.get("max_new_tokens", 2048)
125
+ interaction_mode = INTERACTION_MODES_MAP[params.get("interaction_mode", "Chat")]
126
+
127
+ if temperature != 0:
128
+ self.vlm.set_generate_kwargs(
129
+ {"do_sample": True, "max_new_tokens": max_new_tokens, "temperature": temperature}
130
+ )
131
+ else:
132
+ self.vlm.set_generate_kwargs({"do_sample": False, "max_new_tokens": max_new_tokens})
133
+
134
+ if images is not None and len(images) == 1:
135
+ images = [load_image_from_base64(image) for image in images]
136
+ else:
137
+ raise NotImplementedError("Only supports queries with one image for now")
138
+
139
+ if interaction_mode == "chat":
140
+ question_prompt = [prompt]
141
+ else:
142
+ prompt_fn = self.vlm.get_prompt_fn(interaction_mode)
143
+ if interaction_mode != "captioning":
144
+ question_prompt = [prompt_fn(prompt)]
145
+ else:
146
+ question_prompt = [prompt_fn()]
147
+
148
+ if isinstance(self.image_processor, Compose) or hasattr(self.image_processor, "is_prismatic"):
149
+ # This is a standard `torchvision.transforms` object or custom PrismaticVLM wrapper
150
+ pixel_values = self.image_processor(images[0].convert("RGB"))
151
+ else:
152
+ # Assume `image_transform` is a HF ImageProcessor...
153
+ pixel_values = self.image_processor(images[0].convert("RGB"), return_tensors="pt")["pixel_values"][0]
154
+
155
+ generated_text = self.vlm.generate_answer(torch.unsqueeze(pixel_values.cuda(), 0), question_prompt)[0]
156
+ generated_text = generated_text.split("USER")[0].split("ASSISTANT")[0]
157
+ yield json.dumps({"text": ori_prompt + generated_text, "error_code": 0}).encode() + b"\0"
158
+
159
+ def generate_stream_gate(self, params):
160
+ try:
161
+ for x in self.generate_stream(params):
162
+ yield x
163
+ except ValueError as e:
164
+ print("Caught ValueError:", e)
165
+ ret = {
166
+ "text": server_error_msg,
167
+ "error_code": 1,
168
+ }
169
+ yield json.dumps(ret).encode() + b"\0"
170
+ except torch.cuda.CudaError as e:
171
+ print("Caught torch.cuda.CudaError:", e)
172
+ ret = {
173
+ "text": server_error_msg,
174
+ "error_code": 1,
175
+ }
176
+ yield json.dumps(ret).encode() + b"\0"
177
+ except Exception as e:
178
+ print("Caught Unknown Error", e)
179
+ ret = {
180
+ "text": server_error_msg,
181
+ "error_code": 1,
182
+ }
183
+ yield json.dumps(ret).encode() + b"\0"
184
+
185
+
186
+ app = FastAPI()
187
+
188
+
189
+ def release_model_semaphore(fn=None):
190
+ model_semaphore.release()
191
+ if fn is not None:
192
+ fn()
193
+
194
+
195
+ @app.post("/worker_generate_stream")
196
+ async def generate_stream(request: Request):
197
+ global model_semaphore, global_counter
198
+ global_counter += 1
199
+ params = await request.json()
200
+
201
+ if model_semaphore is None:
202
+ model_semaphore = asyncio.Semaphore(limit_model_concurrency)
203
+ await model_semaphore.acquire()
204
+ worker.send_heart_beat()
205
+ generator = worker.generate_stream_gate(params)
206
+ background_tasks = BackgroundTasks()
207
+ background_tasks.add_task(partial(release_model_semaphore, fn=worker.send_heart_beat))
208
+ return StreamingResponse(generator, background=background_tasks)
209
+
210
+
211
+ @app.post("/worker_get_status")
212
+ async def get_status(request: Request):
213
+ return worker.get_status()
214
+
215
+
216
+ # Initialize Overwatch =>> Wraps `logging.Logger` and `accelerate.PartialState`
217
+ overwatch = initialize_overwatch(__name__)
218
+
219
+
220
+ @dataclass
221
+ class DemoConfig:
222
+ # fmt: off
223
+
224
+ # === Model Parameters =>> Quartz ===
225
+ model_family: str = "quartz" # Model family to load from in < `quartz` | `llava-v15` | ... >
226
+ model_id: str = "llava-v1.5-7b" # Model ID to load and run (instance of `model_family`)
227
+ model_dir: Path = ( # Path to model checkpoint to load --> should be self-contained
228
+ "resize-naive-siglip-vit-l-16-384px-no-align-2-epochs+13b+stage-finetune+x7"
229
+ )
230
+
231
+ # === Model Parameters =>> Official LLaVa ===
232
+ # model_family: str = "llava-v15"
233
+ # model_id: str = "llava-v1.5-13b"
234
+ # model_dir: Path = "liuhaotian/llava-v1.5-13b"
235
+
236
+ # Model Worker Parameters
237
+ host: str = "0.0.0.0"
238
+ port: int = 40000
239
+ controller_address: str = "http://localhost:10000"
240
+ model_base: str = "llava-v15"
241
+ limit_model_concurrency: int = 5
242
+ stream_interval: int = 1
243
+ no_register: bool = False
244
+
245
+ # Inference Parameters
246
+ device_batch_size: int = 1 # Device Batch Size set to 1 until LLaVa/HF LLaMa fixes bugs!
247
+ num_workers: int = 2 # Number of Dataloader Workers (on each process)
248
+
249
+ # HF Hub Credentials (for LLaMa-2)
250
+ hf_token: Union[str, Path] = Path(".hf_token") # Environment variable or Path to HF Token
251
+
252
+ # Randomness
253
+ seed: int = 21 # Random Seed (for reproducibility)
254
+
255
+ def __post_init__(self) -> None:
256
+ if self.model_family == "quartz":
257
+ self.model_name = MODEL_ID_TO_NAME[str(self.model_dir)]
258
+ self.run_dir = Path("/mnt/fsx/x-onyx-vlms/runs") / self.model_dir
259
+ elif self.model_family in {"instruct-blip", "llava", "llava-v15"}:
260
+ self.model_name = MODEL_ID_TO_NAME[self.model_id]
261
+ self.run_dir = self.model_dir
262
+ else:
263
+ raise ValueError(f"Run Directory for `{self.model_family = }` does not exist!")
264
+ self.worker_address = f"http://localhost:{self.port}"
265
+
266
+ # fmt: on
267
+
268
+
269
+ @draccus.wrap()
270
+ def interactive_demo(cfg: DemoConfig):
271
+ # overwatch.info(f"Starting Evaluation for Dataset `{cfg.dataset.dataset_id}` w/ Model `{cfg.model_id}`")
272
+ set_seed(cfg.seed)
273
+
274
+ # Build the VLM --> Download/Load Pretrained Model from Checkpoint
275
+ overwatch.info("Initializing VLM =>> Bundling Models, Image Processors, and Tokenizer")
276
+ hf_token = cfg.hf_token.read_text().strip() if isinstance(cfg.hf_token, Path) else os.environ[cfg.hf_token]
277
+ vlm = load_vlm(cfg.model_family, cfg.model_id, cfg.run_dir, hf_token=hf_token)
278
+
279
+ global worker
280
+ global limit_model_concurrency
281
+ limit_model_concurrency = cfg.limit_model_concurrency
282
+ worker = ModelWorker(
283
+ cfg.controller_address, cfg.worker_address, worker_id, cfg.no_register, vlm, cfg.model_base, cfg.model_name
284
+ )
285
+ uvicorn.run(app, host=cfg.host, port=cfg.port, log_level="info")
286
+
287
+
288
+ if __name__ == "__main__":
289
+ interactive_demo()
pyproject.toml ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "vldemo"
7
+ authors = [
8
+ {name = "Siddharth Karamcheti", email="[email protected]"}
9
+ ]
10
+ description = "VLM Demo: Interactive Demo for VLMs"
11
+ version = "0.0.1"
12
+ readme = "README.md"
13
+ requires-python = ">=3.8"
14
+ keywords = ["machine learning"]
15
+ license = {file = "LICENSE"}
16
+ classifiers = [
17
+ "Development Status :: 3 - Alpha",
18
+ "Intended Audience :: Developers",
19
+ "Intended Audience :: Education",
20
+ "Intended Audience :: Science/Research",
21
+ "License :: OSI Approved :: MIT License",
22
+ "Operating System :: OS Independent",
23
+ "Programming Language :: Python :: 3",
24
+ "Programming Language :: Python :: 3.8",
25
+ "Programming Language :: Python :: 3.9",
26
+ "Programming Language :: Python :: 3.10",
27
+ "Programming Language :: Python :: 3 :: Only",
28
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
29
+ ]
30
+ dependencies = [
31
+
32
+ ]
33
+
34
+ [project.optional-dependencies]
35
+ dev = [
36
+ "black",
37
+ "gpustat",
38
+ "ipython",
39
+ "pre-commit",
40
+ "ruff",
41
+ ]
42
+
43
+ [project.urls]
44
+ homepage = "https://github.com/TRI-ML/vlm-demo"
45
+ repository = "https://github.com/TRI-ML/vlm-demo"
46
+ documentation = "https://github.com/TRI-ML/vlm-demo"
47
+
48
+ [tool.setuptools.packages.find]
49
+ where = ["."]
50
+ exclude = ["cache"]
51
+
52
+ [tool.black]
53
+ line-length = 121
54
+ target-version = ["py38", "py39", "py310"]
55
+ preview = true
56
+
57
+ [tool.ruff]
58
+ line-length = 121
59
+ target-version = "py38"
60
+ select = ["A", "B", "C90", "E", "F", "I", "RUF", "W"]
61
+ ignore = ["B008", "F722"]
62
+
63
+ [tool.ruff.per-file-ignores]
64
+ "__init__.py" = ["E402", "F401"]
serve/__init__.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+
3
+
4
+ # Arrange keys in display priority order (high --> low)
5
+ MODEL_ID_TO_NAME = OrderedDict(
6
+ [
7
+ (
8
+ "llava-lvis4v-lrv+lvis4v-lrv-resize-naive-clip-vit-l-14-336px-no-align-2-epochs-llama2pure+13b+stage-finetune+x7",
9
+ "Prism-CLIP 13B",
10
+ ),
11
+ (
12
+ "llava-lvis4v-lrv+lvis4v-lrv-resize-naive-clip-vit-l-14-336px-no-align-2-epochs-llama2pure+7b+stage-finetune+x7",
13
+ "Prism-CLIP 7B",
14
+ ),
15
+ (
16
+ "resize-naive-clip-vit-l-14-336px-no-align-llama2pure+13b+stage-finetune+x7",
17
+ "Prism-CLIP 13B (Controlled)",
18
+ ),
19
+ (
20
+ "resize-naive-clip-vit-l-14-336px-no-align-llama2pure+7b+stage-finetune+x7",
21
+ "Prism-CLIP 7B (Controlled)",
22
+ ),
23
+ (
24
+ "resize-naive-clip-vit-l-14-336px-no-align+13b+stage-finetune+x7",
25
+ "Prism-CLIP 13B (Controlled) - Chat",
26
+ ),
27
+ (
28
+ "resize-naive-clip-vit-l-14-336px-no-align+7b+stage-finetune+x7",
29
+ "Prism-CLIP 7B (Controlled) - Chat",
30
+ ),
31
+ ("llava-v1.5-7b", "LLaVA 1.5: 7B"),
32
+ ("llava-v1.5-13b", "LLaVA 1.5: 13B"),
33
+ ]
34
+ )
35
+
36
+ INTERACTION_MODES_MAP = OrderedDict(
37
+ [
38
+ ("Chat", "chat"),
39
+ ("Captioning", "captioning"),
40
+ ("Bounding Box Prediction", "bbox_pred"),
41
+ ("Visual Question Answering", "vqa"),
42
+ ("True/False Visual Question Answering", "true_false"),
43
+ ]
44
+ )
serve/controller.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ controller.py
3
+ A controller manages distributed workers.
4
+ It sends worker addresses to clients.
5
+
6
+ This file is exactly copied from
7
+ https://github.com/haotian-liu/LLaVA/blob/main/llava/serve/controller.py.
8
+ """
9
+ import argparse
10
+ import dataclasses
11
+ import json
12
+ import threading
13
+ import time
14
+ from enum import Enum, auto
15
+ from typing import List
16
+
17
+ import numpy as np
18
+ import requests
19
+ import uvicorn
20
+ from fastapi import FastAPI, Request
21
+ from fastapi.responses import StreamingResponse
22
+ from llava.constants import CONTROLLER_HEART_BEAT_EXPIRATION
23
+ from llava.utils import build_logger, server_error_msg
24
+
25
+ logger = build_logger("controller", "controller.log")
26
+
27
+
28
+ class DispatchMethod(Enum):
29
+ LOTTERY = auto()
30
+ SHORTEST_QUEUE = auto()
31
+
32
+ @classmethod
33
+ def from_str(cls, name):
34
+ if name == "lottery":
35
+ return cls.LOTTERY
36
+ elif name == "shortest_queue":
37
+ return cls.SHORTEST_QUEUE
38
+ else:
39
+ raise ValueError("Invalid dispatch method")
40
+
41
+
42
+ @dataclasses.dataclass
43
+ class WorkerInfo:
44
+ model_names: List[str]
45
+ speed: int
46
+ queue_length: int
47
+ check_heart_beat: bool
48
+ last_heart_beat: str
49
+
50
+
51
+ def heart_beat_controller(controller):
52
+ while True:
53
+ time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION)
54
+ controller.remove_stable_workers_by_expiration()
55
+
56
+
57
+ class Controller:
58
+ def __init__(self, dispatch_method: str):
59
+ # Dict[str -> WorkerInfo]
60
+ self.worker_info = {}
61
+ self.dispatch_method = DispatchMethod.from_str(dispatch_method)
62
+
63
+ self.heart_beat_thread = threading.Thread(target=heart_beat_controller, args=(self,))
64
+ self.heart_beat_thread.start()
65
+
66
+ logger.info("Init controller")
67
+
68
+ def register_worker(self, worker_name: str, check_heart_beat: bool, worker_status: dict):
69
+ if worker_name not in self.worker_info:
70
+ logger.info(f"Register a new worker: {worker_name}")
71
+ else:
72
+ logger.info(f"Register an existing worker: {worker_name}")
73
+
74
+ if not worker_status:
75
+ worker_status = self.get_worker_status(worker_name)
76
+ if not worker_status:
77
+ return False
78
+
79
+ self.worker_info[worker_name] = WorkerInfo(
80
+ worker_status["model_names"],
81
+ worker_status["speed"],
82
+ worker_status["queue_length"],
83
+ check_heart_beat,
84
+ time.time(),
85
+ )
86
+
87
+ logger.info(f"Register done: {worker_name}, {worker_status}")
88
+ return True
89
+
90
+ def get_worker_status(self, worker_name: str):
91
+ try:
92
+ r = requests.post(worker_name + "/worker_get_status", timeout=5)
93
+ except requests.exceptions.RequestException as e:
94
+ logger.error(f"Get status fails: {worker_name}, {e}")
95
+ return None
96
+
97
+ if r.status_code != 200:
98
+ logger.error(f"Get status fails: {worker_name}, {r}")
99
+ return None
100
+
101
+ return r.json()
102
+
103
+ def remove_worker(self, worker_name: str):
104
+ del self.worker_info[worker_name]
105
+
106
+ def refresh_all_workers(self):
107
+ old_info = dict(self.worker_info)
108
+ self.worker_info = {}
109
+
110
+ for w_name, w_info in old_info.items():
111
+ if not self.register_worker(w_name, w_info.check_heart_beat, None):
112
+ logger.info(f"Remove stale worker: {w_name}")
113
+
114
+ def list_models(self):
115
+ model_names = set()
116
+
117
+ for _w_name, w_info in self.worker_info.items():
118
+ model_names.update(w_info.model_names)
119
+
120
+ return list(model_names)
121
+
122
+ def get_worker_address_lottery(self, model_name: str):
123
+ worker_names = []
124
+ worker_speeds = []
125
+ for w_name, w_info in self.worker_info.items():
126
+ if model_name in w_info.model_names:
127
+ worker_names.append(w_name)
128
+ worker_speeds.append(w_info.speed)
129
+ worker_speeds = np.array(worker_speeds, dtype=np.float32)
130
+ norm = np.sum(worker_speeds)
131
+ if norm < 1e-4:
132
+ return ""
133
+ worker_speeds = worker_speeds / norm
134
+ if True: # Directly return address
135
+ pt = np.random.choice(np.arange(len(worker_names)), p=worker_speeds)
136
+ worker_name = worker_names[pt]
137
+ return worker_name
138
+
139
+ # Check status before returning
140
+ while True:
141
+ pt = np.random.choice(np.arange(len(worker_names)), p=worker_speeds)
142
+ worker_name = worker_names[pt]
143
+
144
+ if self.get_worker_status(worker_name):
145
+ break
146
+ else:
147
+ self.remove_worker(worker_name)
148
+ worker_speeds[pt] = 0
149
+ norm = np.sum(worker_speeds)
150
+ if norm < 1e-4:
151
+ return ""
152
+ worker_speeds = worker_speeds / norm
153
+ continue
154
+ return worker_name
155
+
156
+ def get_worker_address_shortest_queue(self, model_name: str):
157
+ worker_names = []
158
+ worker_qlen = []
159
+ for w_name, w_info in self.worker_info.items():
160
+ if model_name in w_info.model_names:
161
+ worker_names.append(w_name)
162
+ worker_qlen.append(w_info.queue_length / w_info.speed)
163
+ if len(worker_names) == 0:
164
+ return ""
165
+ min_index = np.argmin(worker_qlen)
166
+ w_name = worker_names[min_index]
167
+ self.worker_info[w_name].queue_length += 1
168
+ logger.info(f"names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}")
169
+ return w_name
170
+
171
+ def get_worker_address(self, model_name: str):
172
+ if self.dispatch_method == DispatchMethod.LOTTERY:
173
+ return self.get_worker_address_lottery(model_name)
174
+ elif self.dispatch_method == DispatchMethod.SHORTEST_QUEUE:
175
+ return self.get_worker_address_shortest_queue(model_name)
176
+ else:
177
+ raise ValueError(f"Invalid dispatch method: {self.dispatch_method}")
178
+
179
+ def receive_heart_beat(self, worker_name: str, queue_length: int):
180
+ if worker_name not in self.worker_info:
181
+ logger.info(f"Receive unknown heart beat. {worker_name}")
182
+ return False
183
+
184
+ self.worker_info[worker_name].queue_length = queue_length
185
+ self.worker_info[worker_name].last_heart_beat = time.time()
186
+ logger.info(f"Receive heart beat. {worker_name}")
187
+ return True
188
+
189
+ def remove_stable_workers_by_expiration(self):
190
+ expire = time.time() - CONTROLLER_HEART_BEAT_EXPIRATION
191
+ to_delete = []
192
+ for worker_name, w_info in self.worker_info.items():
193
+ if w_info.check_heart_beat and w_info.last_heart_beat < expire:
194
+ to_delete.append(worker_name)
195
+
196
+ for worker_name in to_delete:
197
+ self.remove_worker(worker_name)
198
+
199
+ def worker_api_generate_stream(self, params):
200
+ worker_addr = self.get_worker_address(params["model"])
201
+ if not worker_addr:
202
+ logger.info(f"no worker: {params['model']}")
203
+ ret = {
204
+ "text": server_error_msg,
205
+ "error_code": 2,
206
+ }
207
+ yield json.dumps(ret).encode() + b"\0"
208
+
209
+ try:
210
+ response = requests.post(worker_addr + "/worker_generate_stream", json=params, stream=True, timeout=5)
211
+ for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
212
+ if chunk:
213
+ yield chunk + b"\0"
214
+ except requests.exceptions.RequestException:
215
+ logger.info(f"worker timeout: {worker_addr}")
216
+ ret = {
217
+ "text": server_error_msg,
218
+ "error_code": 3,
219
+ }
220
+ yield json.dumps(ret).encode() + b"\0"
221
+
222
+ # Let the controller act as a worker to achieve hierarchical
223
+ # management. This can be used to connect isolated sub networks.
224
+ def worker_api_get_status(self):
225
+ model_names = set()
226
+ speed = 0
227
+ queue_length = 0
228
+
229
+ for w_name in self.worker_info:
230
+ worker_status = self.get_worker_status(w_name)
231
+ if worker_status is not None:
232
+ model_names.update(worker_status["model_names"])
233
+ speed += worker_status["speed"]
234
+ queue_length += worker_status["queue_length"]
235
+
236
+ return {
237
+ "model_names": list(model_names),
238
+ "speed": speed,
239
+ "queue_length": queue_length,
240
+ }
241
+
242
+
243
+ app = FastAPI()
244
+
245
+
246
+ @app.post("/register_worker")
247
+ async def register_worker(request: Request):
248
+ data = await request.json()
249
+ controller.register_worker(data["worker_name"], data["check_heart_beat"], data.get("worker_status", None))
250
+
251
+
252
+ @app.post("/refresh_all_workers")
253
+ async def refresh_all_workers():
254
+ controller.refresh_all_workers()
255
+
256
+
257
+ @app.post("/list_models")
258
+ async def list_models():
259
+ models = controller.list_models()
260
+ return {"models": models}
261
+
262
+
263
+ @app.post("/get_worker_address")
264
+ async def get_worker_address(request: Request):
265
+ data = await request.json()
266
+ addr = controller.get_worker_address(data["model"])
267
+ return {"address": addr}
268
+
269
+
270
+ @app.post("/receive_heart_beat")
271
+ async def receive_heart_beat(request: Request):
272
+ data = await request.json()
273
+ exist = controller.receive_heart_beat(data["worker_name"], data["queue_length"])
274
+ return {"exist": exist}
275
+
276
+
277
+ @app.post("/worker_generate_stream")
278
+ async def worker_api_generate_stream(request: Request):
279
+ params = await request.json()
280
+ generator = controller.worker_api_generate_stream(params)
281
+ return StreamingResponse(generator)
282
+
283
+
284
+ @app.post("/worker_get_status")
285
+ async def worker_api_get_status(request: Request):
286
+ return controller.worker_api_get_status()
287
+
288
+
289
+ if __name__ == "__main__":
290
+ parser = argparse.ArgumentParser()
291
+ parser.add_argument("--host", type=str, default="localhost")
292
+ parser.add_argument("--port", type=int, default=21001)
293
+ parser.add_argument("--dispatch-method", type=str, choices=["lottery", "shortest_queue"], default="shortest_queue")
294
+ args = parser.parse_args()
295
+ logger.info(f"args: {args}")
296
+
297
+ controller = Controller(args.dispatch_method)
298
+ uvicorn.run(app, host=args.host, port=args.port, log_level="info")
serve/examples/cows_in_pasture.png ADDED
serve/examples/monkey_knives.png ADDED
serve/gradio_web_server.py ADDED
@@ -0,0 +1,462 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ gradio_web_server.py
3
+
4
+ Entry point for all VLM-Bench interactive demos; specify model and get a gradio UI where you can chat with it!
5
+
6
+ This file is copied from the script used to define the gradio web server in the LLaVa codebase:
7
+ https://github.com/haotian-liu/LLaVA/blob/main/llava/serve/gradio_web_server.py with only very minor
8
+ modifications.
9
+ """
10
+
11
+ import argparse
12
+ import datetime
13
+ import hashlib
14
+ import json
15
+ import os
16
+ import time
17
+
18
+ import gradio as gr
19
+ import requests
20
+ from llava.constants import LOGDIR
21
+ from llava.conversation import conv_templates, default_conversation
22
+ from llava.utils import build_logger, moderation_msg, server_error_msg, violates_moderation
23
+
24
+ from serve import INTERACTION_MODES_MAP, MODEL_ID_TO_NAME
25
+
26
+ logger = build_logger("gradio_web_server", "gradio_web_server.log")
27
+
28
+ headers = {"User-Agent": "PrismaticVLMs Client"}
29
+
30
+ no_change_btn = gr.Button.update()
31
+ enable_btn = gr.Button.update(interactive=True)
32
+ disable_btn = gr.Button.update(interactive=False)
33
+
34
+
35
+ def get_conv_log_filename():
36
+ t = datetime.datetime.now()
37
+ name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
38
+ return name
39
+
40
+
41
+ def get_model_list():
42
+ ret = requests.post(args.controller_url + "/refresh_all_workers")
43
+ assert ret.status_code == 200
44
+ ret = requests.post(args.controller_url + "/list_models")
45
+ models = ret.json()["models"]
46
+ models = sorted(
47
+ models, key=lambda x: list(MODEL_ID_TO_NAME.values()).index(x) if x in MODEL_ID_TO_NAME.values() else len(models)
48
+ )
49
+ logger.info(f"Models: {models}")
50
+ return models
51
+
52
+
53
+ get_window_url_params = """
54
+ function() {
55
+ const params = new URLSearchParams(window.location.search);
56
+ url_params = Object.fromEntries(params);
57
+ console.log(url_params);
58
+ return url_params;
59
+ }
60
+ """
61
+
62
+
63
+ def load_demo(url_params, request: gr.Request):
64
+ logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
65
+
66
+ dropdown_update = gr.Dropdown.update(visible=True)
67
+ if "model" in url_params:
68
+ model = url_params["model"]
69
+ if model in models:
70
+ dropdown_update = gr.Dropdown.update(value=model, visible=True)
71
+
72
+ state = default_conversation.copy()
73
+ return state, dropdown_update
74
+
75
+
76
+ def load_demo_refresh_model_list(request: gr.Request):
77
+ logger.info(f"load_demo. ip: {request.client.host}")
78
+ models = get_model_list()
79
+ state = default_conversation.copy()
80
+ dropdown_update = gr.Dropdown.update(choices=models, value=models[0] if len(models) > 0 else "")
81
+ return state, dropdown_update
82
+
83
+
84
+ def vote_last_response(state, vote_type, model_selector, request: gr.Request):
85
+ with open(get_conv_log_filename(), "a") as fout:
86
+ data = {
87
+ "tstamp": round(time.time(), 4),
88
+ "type": vote_type,
89
+ "model": model_selector,
90
+ "state": state.dict(),
91
+ "ip": request.client.host,
92
+ }
93
+ fout.write(json.dumps(data) + "\n")
94
+
95
+
96
+ # def upvote_last_response(state, model_selector, request: gr.Request):
97
+ # logger.info(f"upvote. ip: {request.client.host}")
98
+ # vote_last_response(state, "upvote", model_selector, request)
99
+ # return ("",) + (disable_btn,) * 3
100
+
101
+
102
+ # def downvote_last_response(state, model_selector, request: gr.Request):
103
+ # logger.info(f"downvote. ip: {request.client.host}")
104
+ # vote_last_response(state, "downvote", model_selector, request)
105
+ # return ("",) + (disable_btn,) * 3
106
+
107
+
108
+ # def flag_last_response(state, model_selector, request: gr.Request):
109
+ # logger.info(f"flag. ip: {request.client.host}")
110
+ # vote_last_response(state, "flag", model_selector, request)
111
+ # return ("",) + (disable_btn,) * 3
112
+
113
+
114
+ def regenerate(state, image_process_mode, request: gr.Request):
115
+ logger.info(f"regenerate. ip: {request.client.host}")
116
+ state.messages[-1][-1] = None
117
+ prev_human_msg = state.messages[-2]
118
+ if type(prev_human_msg[1]) in (tuple, list):
119
+ prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
120
+ state.skip_next = False
121
+ return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
122
+
123
+
124
+ def clear_history(request: gr.Request):
125
+ logger.info(f"clear_history. ip: {request.client.host}")
126
+ state = default_conversation.copy()
127
+ return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
128
+
129
+
130
+ def add_text(state, text, image, image_process_mode, request: gr.Request):
131
+ logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
132
+ if len(text) <= 0 and image is None:
133
+ state.skip_next = True
134
+ return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5
135
+ if args.moderate:
136
+ flagged = violates_moderation(text)
137
+ if flagged:
138
+ state.skip_next = True
139
+ return (state, state.to_gradio_chatbot(), moderation_msg, None) + (no_change_btn,) * 5
140
+
141
+ text = text[:1536] # Hard cut-off
142
+ if image is not None:
143
+ text = text[:1200] # Hard cut-off for images
144
+ if "<image>" not in text:
145
+ # text = '<Image><image></Image>' + text
146
+ text = text + "\n<image>"
147
+ text = (text, image, image_process_mode)
148
+ if len(state.get_images(return_pil=True)) > 0:
149
+ state = default_conversation.copy()
150
+ state.append_message(state.roles[0], text)
151
+ state.append_message(state.roles[1], None)
152
+ state.skip_next = False
153
+ return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
154
+
155
+
156
+ def http_bot(state, model_selector, interaction_mode, temperature, max_new_tokens, request: gr.Request):
157
+ logger.info(f"http_bot. ip: {request.client.host}")
158
+ start_tstamp = time.time()
159
+ model_name = model_selector
160
+
161
+ if state.skip_next:
162
+ # This generate call is skipped due to invalid inputs
163
+ yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
164
+ return
165
+
166
+ if len(state.messages) == state.offset + 2:
167
+ # First round of conversation
168
+ # (Note): Hardcoding llava_v1 conv template for now
169
+ new_state = conv_templates["llava_v1"].copy()
170
+ new_state.append_message(new_state.roles[0], state.messages[-2][1])
171
+ new_state.append_message(new_state.roles[1], None)
172
+ state = new_state
173
+
174
+ # Query worker address
175
+ controller_url = args.controller_url
176
+ ret = requests.post(controller_url + "/get_worker_address", json={"model": model_name})
177
+ worker_addr = ret.json()["address"]
178
+ logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
179
+
180
+ # No available worker
181
+ if worker_addr == "":
182
+ state.messages[-1][-1] = server_error_msg
183
+ yield (state, state.to_gradio_chatbot(), disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
184
+ return
185
+
186
+ # Construct prompt
187
+ prompt = state.get_prompt()
188
+
189
+ all_images = state.get_images(return_pil=True)
190
+ all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images]
191
+ for image, im_hash in zip(all_images, all_image_hash):
192
+ t = datetime.datetime.now()
193
+ filename = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{im_hash}.jpg")
194
+ if not os.path.isfile(filename):
195
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
196
+ image.save(filename)
197
+
198
+ # Make requests
199
+ pload = {
200
+ "model": model_name,
201
+ "prompt": prompt,
202
+ "interaction_mode": interaction_mode,
203
+ "temperature": float(temperature),
204
+ "max_new_tokens": int(max_new_tokens),
205
+ "images": f"List of {len(state.get_images())} images: {all_image_hash}",
206
+ }
207
+ logger.info(f"==== request ====\n{pload}")
208
+
209
+ pload["images"] = state.get_images()
210
+
211
+ state.messages[-1][-1] = "β–Œ"
212
+ yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
213
+
214
+ try:
215
+ # Stream output
216
+ response = requests.post(
217
+ worker_addr + "/worker_generate_stream", headers=headers, json=pload, stream=True, timeout=10
218
+ )
219
+ for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
220
+ if chunk:
221
+ data = json.loads(chunk.decode())
222
+ if data["error_code"] == 0:
223
+ output = data["text"][len(prompt) :].strip()
224
+ state.messages[-1][-1] = output + "β–Œ"
225
+ yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
226
+ else:
227
+ output = data["text"] + f" (error_code: {data['error_code']})"
228
+ state.messages[-1][-1] = output
229
+ yield (state, state.to_gradio_chatbot()) + (
230
+ disable_btn,
231
+ disable_btn,
232
+ disable_btn,
233
+ enable_btn,
234
+ enable_btn,
235
+ )
236
+ return
237
+ time.sleep(0.03)
238
+ except requests.exceptions.RequestException:
239
+ state.messages[-1][-1] = server_error_msg
240
+ yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
241
+ return
242
+
243
+ state.messages[-1][-1] = state.messages[-1][-1][:-1]
244
+ yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
245
+
246
+ finish_tstamp = time.time()
247
+ logger.info(f"{output}")
248
+
249
+ with open(get_conv_log_filename(), "a") as fout:
250
+ data = {
251
+ "tstamp": round(finish_tstamp, 4),
252
+ "type": "chat",
253
+ "model": model_name,
254
+ "start": round(start_tstamp, 4),
255
+ "finish": round(finish_tstamp, 4),
256
+ "state": state.dict(),
257
+ "images": all_image_hash,
258
+ "ip": request.client.host,
259
+ }
260
+ fout.write(json.dumps(data) + "\n")
261
+
262
+
263
+ title_markdown = """
264
+ # Prismatic VLMs: Investigating the Design Space of Visually-Conditioned Language Models
265
+ [[Project Page](TODO)] [[Code](TODO)]
266
+ [[Models](TODO)]
267
+ | πŸ“š [[Paper](TODO)]
268
+ """
269
+
270
+ tos_markdown = """
271
+ ### Terms of use
272
+ By using this service, users are required to agree to the following terms:
273
+ The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may
274
+ generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The
275
+ service may collect user dialogue data for future research. Please click the "Flag" button if you get any
276
+ inappropriate answer! We will collect those to keep improving our moderator. For an optimal experience,
277
+ please use desktop computers for this demo, as mobile devices may compromise its quality. This website
278
+ is heavily inspired by the website released by [LLaVA](https://github.com/haotian-liu/LLaVA).
279
+ """
280
+
281
+
282
+ learn_more_markdown = """
283
+ ### License
284
+ The service is a research preview intended for non-commercial use only, subject to the model
285
+ [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA,
286
+ [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI,
287
+ and [Privacy Practices]
288
+ (https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb)
289
+ of ShareGPT. Please contact us if you find any potential violation.
290
+ """
291
+
292
+ block_css = """
293
+
294
+ #buttons button {
295
+ min-width: min(120px,100%);
296
+ }
297
+
298
+ """
299
+
300
+
301
+ def build_demo(embed_mode):
302
+ textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
303
+
304
+ with gr.Blocks(theme=gr.themes.Default(primary_hue="red", secondary_hue="stone")) as demo:
305
+ state = gr.State()
306
+
307
+ if not embed_mode:
308
+ gr.Markdown(title_markdown)
309
+
310
+ with gr.Row():
311
+ with gr.Column(scale=3):
312
+ with gr.Row(elem_id="model_selector_row"):
313
+ model_selector = gr.Dropdown(
314
+ choices=models,
315
+ value=models[0] if len(models) > 0 else "",
316
+ interactive=True,
317
+ show_label=False,
318
+ container=False,
319
+ )
320
+
321
+ imagebox = gr.Image(type="pil")
322
+ image_process_mode = gr.Radio(
323
+ ["Crop", "Resize", "Pad", "Default"],
324
+ value="Default",
325
+ label="Preprocess for non-square image",
326
+ visible=False,
327
+ )
328
+
329
+ cur_dir = os.path.dirname(os.path.abspath(__file__))
330
+ gr.Examples(
331
+ examples=[
332
+ [f"{cur_dir}/examples/cows_in_pasture.png", "How many cows are in this image?"],
333
+ [
334
+ f"{cur_dir}/examples/monkey_knives.png",
335
+ "What is happening in this image?",
336
+ ],
337
+ ],
338
+ inputs=[imagebox, textbox],
339
+ )
340
+
341
+ with gr.Accordion("Parameters", open=False):
342
+ temperature = gr.Slider(
343
+ minimum=0.0,
344
+ maximum=1.0,
345
+ value=0.2,
346
+ step=0.1,
347
+ interactive=True,
348
+ label="Temperature",
349
+ )
350
+ max_output_tokens = gr.Slider(
351
+ minimum=0,
352
+ maximum=4096,
353
+ value=2048,
354
+ step=64,
355
+ interactive=True,
356
+ label="Max output tokens",
357
+ )
358
+
359
+ with gr.Accordion("Interaction Mode", open=False):
360
+ interaction_modes = list(INTERACTION_MODES_MAP.keys())
361
+ interaction_mode = gr.Dropdown(
362
+ choices=interaction_modes,
363
+ value=interaction_modes[0] if len(interaction_modes) > 0 else "Chat",
364
+ interactive=True,
365
+ show_label=False,
366
+ container=False,
367
+ )
368
+
369
+ with gr.Column(scale=8):
370
+ chatbot = gr.Chatbot(elem_id="chatbot", label="PrismaticVLMs Chatbot", height=550)
371
+ with gr.Row():
372
+ with gr.Column(scale=8):
373
+ textbox.render()
374
+ with gr.Column(scale=1, min_width=50):
375
+ submit_btn = gr.Button(value="Generate", variant="primary")
376
+ with gr.Row(elem_id="buttons"):
377
+ # upvote_btn = gr.Button(value="πŸ‘ Upvote", interactive=False)
378
+ # downvote_btn = gr.Button(value="πŸ‘Ž Downvote", interactive=False)
379
+ # flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
380
+ # stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
381
+ regenerate_btn = gr.Button(value="πŸ”„ Regenerate", interactive=False)
382
+ clear_btn = gr.Button(value="πŸ—‘οΈ Clear", interactive=False)
383
+
384
+ if not embed_mode:
385
+ gr.Markdown(tos_markdown)
386
+ gr.Markdown(learn_more_markdown)
387
+ url_params = gr.JSON(visible=False)
388
+
389
+ # Register listeners
390
+ btn_list = [regenerate_btn, clear_btn]
391
+ # upvote_btn.click(
392
+ # upvote_last_response, [state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn], queue=False
393
+ # )
394
+ # downvote_btn.click(
395
+ # downvote_last_response, [state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn], queue=False
396
+ # )
397
+ # flag_btn.click(
398
+ # flag_last_response, [state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn], queue=False
399
+ # )
400
+
401
+ regenerate_btn.click(
402
+ regenerate, [state, image_process_mode], [state, chatbot, textbox, imagebox, *btn_list], queue=False
403
+ ).then(
404
+ http_bot,
405
+ [state, model_selector, interaction_mode, temperature, max_output_tokens],
406
+ [state, chatbot, *btn_list],
407
+ )
408
+
409
+ clear_btn.click(clear_history, None, [state, chatbot, textbox, imagebox, *btn_list], queue=False)
410
+
411
+ textbox.submit(
412
+ add_text,
413
+ [state, textbox, imagebox, image_process_mode],
414
+ [state, chatbot, textbox, imagebox, *btn_list],
415
+ queue=False,
416
+ ).then(
417
+ http_bot,
418
+ [state, model_selector, interaction_mode, temperature, max_output_tokens],
419
+ [state, chatbot, *btn_list],
420
+ )
421
+
422
+ submit_btn.click(
423
+ add_text,
424
+ [state, textbox, imagebox, image_process_mode],
425
+ [state, chatbot, textbox, imagebox, *btn_list],
426
+ queue=False,
427
+ ).then(
428
+ http_bot,
429
+ [state, model_selector, interaction_mode, temperature, max_output_tokens],
430
+ [state, chatbot, *btn_list],
431
+ )
432
+
433
+ if args.model_list_mode == "once":
434
+ demo.load(load_demo, [url_params], [state, model_selector], _js=get_window_url_params, queue=False)
435
+ elif args.model_list_mode == "reload":
436
+ demo.load(load_demo_refresh_model_list, None, [state, model_selector], queue=False)
437
+ else:
438
+ raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
439
+
440
+ return demo
441
+
442
+
443
+ if __name__ == "__main__":
444
+ parser = argparse.ArgumentParser()
445
+ parser.add_argument("--host", type=str, default="0.0.0.0")
446
+ parser.add_argument("--port", type=int)
447
+ parser.add_argument("--controller-url", type=str, default="http://localhost:21001")
448
+ parser.add_argument("--concurrency-count", type=int, default=10)
449
+ parser.add_argument("--model-list-mode", type=str, default="once", choices=["once", "reload"])
450
+ parser.add_argument("--share", action="store_true")
451
+ parser.add_argument("--moderate", action="store_true")
452
+ parser.add_argument("--embed", action="store_true")
453
+ args = parser.parse_args()
454
+ logger.info(f"args: {args}")
455
+
456
+ models = get_model_list()
457
+
458
+ logger.info(args)
459
+ demo = build_demo(args.embed)
460
+ demo.queue(concurrency_count=args.concurrency_count, api_open=False).launch(
461
+ server_name=args.host, server_port=args.port, share=args.share
462
+ )