Spaces:
Running
on
L40S
Running
on
L40S
Akash Garg
commited on
Commit
·
f6a2f50
1
Parent(s):
0c10674
adding variance slider for top_p
Browse files- app.py +8 -4
- cube/cube3d/generate.py +7 -7
- cube/cube3d/inference/engine.py +14 -13
- cube/cube3d/inference/logits_postprocesses.py +32 -18
app.py
CHANGED
@@ -39,9 +39,10 @@ def gen_save_folder(max_size=200):
|
|
39 |
|
40 |
return new_folder
|
41 |
|
42 |
-
def handle_text_prompt(input_prompt):
|
43 |
-
print(f"prompt: {input_prompt}")
|
44 |
-
|
|
|
45 |
# save output
|
46 |
vertices, faces = mesh_v_f[0][0], mesh_v_f[0][1]
|
47 |
save_folder = gen_save_folder()
|
@@ -57,6 +58,7 @@ def build_interface():
|
|
57 |
gr.Markdown(
|
58 |
f"""
|
59 |
# {title}
|
|
|
60 |
"""
|
61 |
)
|
62 |
|
@@ -74,11 +76,13 @@ def build_interface():
|
|
74 |
model3d = gr.Model3D(
|
75 |
label="Output", height="45em", interactive=False
|
76 |
)
|
|
|
77 |
|
78 |
submit_button.click(
|
79 |
handle_text_prompt,
|
80 |
inputs=[
|
81 |
-
input_text_box
|
|
|
82 |
],
|
83 |
outputs=[
|
84 |
model3d
|
|
|
39 |
|
40 |
return new_folder
|
41 |
|
42 |
+
def handle_text_prompt(input_prompt, variance):
|
43 |
+
print(f"prompt: {input_prompt}, variance: {variance}")
|
44 |
+
top_p = None if variance == 0 else (100 - variance) / 100.0
|
45 |
+
mesh_v_f = GLOBAL_STATE["engine_fast"].t2s([input_prompt], use_kv_cache=True, resolution_base=8.0, top_p=top_p)
|
46 |
# save output
|
47 |
vertices, faces = mesh_v_f[0][0], mesh_v_f[0][1]
|
48 |
save_folder = gen_save_folder()
|
|
|
58 |
gr.Markdown(
|
59 |
f"""
|
60 |
# {title}
|
61 |
+
# Check out our [Github](https://github.com/Roblox/cube) to try it on your own machine!
|
62 |
"""
|
63 |
)
|
64 |
|
|
|
76 |
model3d = gr.Model3D(
|
77 |
label="Output", height="45em", interactive=False
|
78 |
)
|
79 |
+
variance = gr.Slider(minimum=0, maximum=99, step=1, value=0, label="Variance")
|
80 |
|
81 |
submit_button.click(
|
82 |
handle_text_prompt,
|
83 |
inputs=[
|
84 |
+
input_text_box,
|
85 |
+
variance
|
86 |
],
|
87 |
outputs=[
|
88 |
model3d
|
cube/cube3d/generate.py
CHANGED
@@ -20,13 +20,13 @@ def generate_mesh(
|
|
20 |
output_name,
|
21 |
resolution_base=8.0,
|
22 |
disable_postprocess=False,
|
23 |
-
|
24 |
):
|
25 |
mesh_v_f = engine.t2s(
|
26 |
[prompt],
|
27 |
use_kv_cache=True,
|
28 |
resolution_base=resolution_base,
|
29 |
-
|
30 |
)
|
31 |
vertices, faces = mesh_v_f[0][0], mesh_v_f[0][1]
|
32 |
obj_path = os.path.join(output_dir, f"{output_name}.obj")
|
@@ -87,10 +87,10 @@ if __name__ == "__main__":
|
|
87 |
help="Text prompt for generating a 3D mesh",
|
88 |
)
|
89 |
parser.add_argument(
|
90 |
-
"--top-
|
91 |
-
type=
|
92 |
-
default=
|
93 |
-
help="
|
94 |
)
|
95 |
parser.add_argument(
|
96 |
"--render-gif",
|
@@ -136,7 +136,7 @@ if __name__ == "__main__":
|
|
136 |
"output",
|
137 |
args.resolution_base,
|
138 |
args.disable_postprocessing,
|
139 |
-
args.
|
140 |
)
|
141 |
if args.render_gif:
|
142 |
gif_path = renderer.render_turntable(obj_path, args.output_dir)
|
|
|
20 |
output_name,
|
21 |
resolution_base=8.0,
|
22 |
disable_postprocess=False,
|
23 |
+
top_p=None,
|
24 |
):
|
25 |
mesh_v_f = engine.t2s(
|
26 |
[prompt],
|
27 |
use_kv_cache=True,
|
28 |
resolution_base=resolution_base,
|
29 |
+
top_p=top_p,
|
30 |
)
|
31 |
vertices, faces = mesh_v_f[0][0], mesh_v_f[0][1]
|
32 |
obj_path = os.path.join(output_dir, f"{output_name}.obj")
|
|
|
87 |
help="Text prompt for generating a 3D mesh",
|
88 |
)
|
89 |
parser.add_argument(
|
90 |
+
"--top-p",
|
91 |
+
type=float,
|
92 |
+
default=None,
|
93 |
+
help="Float < 1: Keep smallest set of tokens with cumulative probability ≥ top_p. Default None: deterministic generation.",
|
94 |
)
|
95 |
parser.add_argument(
|
96 |
"--render-gif",
|
|
|
136 |
"output",
|
137 |
args.resolution_base,
|
138 |
args.disable_postprocessing,
|
139 |
+
args.top_p,
|
140 |
)
|
141 |
if args.render_gif:
|
142 |
gif_path = renderer.render_turntable(obj_path, args.output_dir)
|
cube/cube3d/inference/engine.py
CHANGED
@@ -160,7 +160,7 @@ class Engine:
|
|
160 |
prompts: list[str],
|
161 |
use_kv_cache: bool,
|
162 |
guidance_scale: float = 3.0,
|
163 |
-
|
164 |
):
|
165 |
"""
|
166 |
Generates text using a GPT model based on the provided prompts.
|
@@ -168,7 +168,8 @@ class Engine:
|
|
168 |
prompts (list[str]): A list of input prompts to generate text from.
|
169 |
use_kv_cache (bool): Whether to use key-value caching for faster generation.
|
170 |
guidance_scale (float, optional): The scale for guidance during generation. Default is 3.0.
|
171 |
-
|
|
|
172 |
Returns:
|
173 |
torch.Tensor: A tensor containing the generated token IDs.
|
174 |
"""
|
@@ -215,11 +216,10 @@ class Engine:
|
|
215 |
guidance_scale * (self.max_new_tokens - i) / self.max_new_tokens
|
216 |
)
|
217 |
logits = (1 + gamma) * logits - gamma * uncond_logits
|
218 |
-
|
219 |
logits,
|
220 |
-
|
221 |
)
|
222 |
-
next_id = torch.multinomial(probs, num_samples=1, replacement=True)
|
223 |
output_ids.append(next_id)
|
224 |
next_embed = self.gpt_model.encode_token(next_id)
|
225 |
if guidance_scale > 0.0:
|
@@ -266,7 +266,7 @@ class Engine:
|
|
266 |
guidance_scale: float = 3.0,
|
267 |
resolution_base: float = 8.0,
|
268 |
chunk_size: int = 100_000,
|
269 |
-
|
270 |
):
|
271 |
"""
|
272 |
Generates a 3D mesh from text prompts using a GPT model and shape decoder.
|
@@ -276,10 +276,12 @@ class Engine:
|
|
276 |
guidance_scale (float, optional): The scale of guidance for the GPT model. Default is 3.0.
|
277 |
resolution_base (float, optional): The base resolution for the shape decoder. Default is 8.0.
|
278 |
chunk_size (int, optional): The chunk size for processing the shape decoding. Default is 100,000.
|
|
|
|
|
279 |
Returns:
|
280 |
mesh_v_f: The generated 3D mesh vertices and faces.
|
281 |
"""
|
282 |
-
output_ids = self.run_gpt(prompts, use_kv_cache, guidance_scale,
|
283 |
with torch.autocast(self.device.type, dtype=torch.bfloat16):
|
284 |
mesh_v_f = self.run_shape_decode(output_ids, resolution_base, chunk_size)
|
285 |
return mesh_v_f
|
@@ -426,7 +428,7 @@ class EngineFast(Engine):
|
|
426 |
prompts: list[str],
|
427 |
use_kv_cache: bool,
|
428 |
guidance_scale: float = 3.0,
|
429 |
-
|
430 |
):
|
431 |
"""
|
432 |
Runs the GPT model to generate text based on the provided prompts.
|
@@ -434,6 +436,8 @@ class EngineFast(Engine):
|
|
434 |
prompts (list[str]): A list of input prompts for the GPT model. Only a single prompt is supported.
|
435 |
use_kv_cache (bool): Flag indicating whether to use key-value caching. (Currently not used)
|
436 |
guidance_scale (float, optional): The scale factor for guidance. Default is 3.0.
|
|
|
|
|
437 |
Returns:
|
438 |
torch.Tensor: A tensor containing the generated output token IDs.
|
439 |
Raises:
|
@@ -464,9 +468,7 @@ class EngineFast(Engine):
|
|
464 |
logits, uncond_logits = logits.float().chunk(2, dim=0)
|
465 |
gamma = guidance_scale
|
466 |
logits = (1 + gamma) * logits - gamma * uncond_logits
|
467 |
-
|
468 |
-
probs = process_logits(logits, top_k=top_k)
|
469 |
-
next_id = torch.multinomial(probs, num_samples=1, replacement=True)
|
470 |
|
471 |
output_ids[:, 0] = next_id.squeeze()
|
472 |
next_embed = self.gpt_model.encode_token(next_id)
|
@@ -488,8 +490,7 @@ class EngineFast(Engine):
|
|
488 |
guidance_scale * (self.max_new_tokens - i) / self.max_new_tokens
|
489 |
)
|
490 |
logits = (1 + gamma) * logits - gamma * uncond_logits
|
491 |
-
|
492 |
-
next_id = torch.multinomial(probs, num_samples=1, replacement=True)
|
493 |
|
494 |
output_ids[:, i] = next_id.squeeze()
|
495 |
next_embed = self.gpt_model.encode_token(next_id)
|
|
|
160 |
prompts: list[str],
|
161 |
use_kv_cache: bool,
|
162 |
guidance_scale: float = 3.0,
|
163 |
+
top_p: float = None,
|
164 |
):
|
165 |
"""
|
166 |
Generates text using a GPT model based on the provided prompts.
|
|
|
168 |
prompts (list[str]): A list of input prompts to generate text from.
|
169 |
use_kv_cache (bool): Whether to use key-value caching for faster generation.
|
170 |
guidance_scale (float, optional): The scale for guidance during generation. Default is 3.0.
|
171 |
+
top_p (float, optional): The cumulative probability threshold for nucleus sampling.
|
172 |
+
If None, argmax selection is performed (deterministic generation). Otherwise, smallest set of tokens with cumulative probability ≥ top_p are kept (stochastic generation).
|
173 |
Returns:
|
174 |
torch.Tensor: A tensor containing the generated token IDs.
|
175 |
"""
|
|
|
216 |
guidance_scale * (self.max_new_tokens - i) / self.max_new_tokens
|
217 |
)
|
218 |
logits = (1 + gamma) * logits - gamma * uncond_logits
|
219 |
+
next_id = process_logits(
|
220 |
logits,
|
221 |
+
top_p=top_p,
|
222 |
)
|
|
|
223 |
output_ids.append(next_id)
|
224 |
next_embed = self.gpt_model.encode_token(next_id)
|
225 |
if guidance_scale > 0.0:
|
|
|
266 |
guidance_scale: float = 3.0,
|
267 |
resolution_base: float = 8.0,
|
268 |
chunk_size: int = 100_000,
|
269 |
+
top_p: float = None,
|
270 |
):
|
271 |
"""
|
272 |
Generates a 3D mesh from text prompts using a GPT model and shape decoder.
|
|
|
276 |
guidance_scale (float, optional): The scale of guidance for the GPT model. Default is 3.0.
|
277 |
resolution_base (float, optional): The base resolution for the shape decoder. Default is 8.0.
|
278 |
chunk_size (int, optional): The chunk size for processing the shape decoding. Default is 100,000.
|
279 |
+
top_p (float, optional): The cumulative probability threshold for nucleus sampling.
|
280 |
+
If None, argmax selection is performed (deterministic generation). Otherwise, smallest set of tokens with cumulative probability ≥ top_p are kept (stochastic generation).
|
281 |
Returns:
|
282 |
mesh_v_f: The generated 3D mesh vertices and faces.
|
283 |
"""
|
284 |
+
output_ids = self.run_gpt(prompts, use_kv_cache, guidance_scale, top_p)
|
285 |
with torch.autocast(self.device.type, dtype=torch.bfloat16):
|
286 |
mesh_v_f = self.run_shape_decode(output_ids, resolution_base, chunk_size)
|
287 |
return mesh_v_f
|
|
|
428 |
prompts: list[str],
|
429 |
use_kv_cache: bool,
|
430 |
guidance_scale: float = 3.0,
|
431 |
+
top_p: float = None
|
432 |
):
|
433 |
"""
|
434 |
Runs the GPT model to generate text based on the provided prompts.
|
|
|
436 |
prompts (list[str]): A list of input prompts for the GPT model. Only a single prompt is supported.
|
437 |
use_kv_cache (bool): Flag indicating whether to use key-value caching. (Currently not used)
|
438 |
guidance_scale (float, optional): The scale factor for guidance. Default is 3.0.
|
439 |
+
top_p (float, optional): The cumulative probability threshold for nucleus sampling.
|
440 |
+
If None, argmax selection is performed. Otherwise, smallest set of tokens with cumulative probability ≥ top_p are kept.
|
441 |
Returns:
|
442 |
torch.Tensor: A tensor containing the generated output token IDs.
|
443 |
Raises:
|
|
|
468 |
logits, uncond_logits = logits.float().chunk(2, dim=0)
|
469 |
gamma = guidance_scale
|
470 |
logits = (1 + gamma) * logits - gamma * uncond_logits
|
471 |
+
next_id = process_logits(logits, top_p=top_p)
|
|
|
|
|
472 |
|
473 |
output_ids[:, 0] = next_id.squeeze()
|
474 |
next_embed = self.gpt_model.encode_token(next_id)
|
|
|
490 |
guidance_scale * (self.max_new_tokens - i) / self.max_new_tokens
|
491 |
)
|
492 |
logits = (1 + gamma) * logits - gamma * uncond_logits
|
493 |
+
next_id = process_logits(logits, top_p=top_p)
|
|
|
494 |
|
495 |
output_ids[:, i] = next_id.squeeze()
|
496 |
next_embed = self.gpt_model.encode_token(next_id)
|
cube/cube3d/inference/logits_postprocesses.py
CHANGED
@@ -2,22 +2,28 @@ import torch
|
|
2 |
import torch.nn.functional as F
|
3 |
|
4 |
|
5 |
-
def
|
6 |
"""
|
7 |
-
Filter a distribution of logits using top-
|
8 |
The input logits tensor is modified in-place.
|
9 |
|
10 |
Args:
|
11 |
-
logits: A tensor of logits to be filtered. Expected shape is [..., vocab_size].
|
12 |
-
|
|
|
|
|
13 |
|
14 |
Returns:
|
15 |
-
|
16 |
"""
|
17 |
-
if
|
18 |
-
|
19 |
-
|
20 |
-
]
|
|
|
|
|
|
|
|
|
21 |
logits.masked_fill_(idx_to_remove, -torch.inf)
|
22 |
|
23 |
return logits
|
@@ -25,19 +31,27 @@ def top_k_filtering(logits, top_k: int = 1):
|
|
25 |
|
26 |
def process_logits(
|
27 |
logits,
|
28 |
-
|
29 |
):
|
30 |
"""
|
31 |
-
Process logits by optionally applying top-
|
32 |
-
|
|
|
|
|
|
|
33 |
|
34 |
Args:
|
35 |
-
logits: A tensor of logits to process.
|
36 |
-
|
|
|
37 |
|
38 |
Returns:
|
39 |
-
|
40 |
"""
|
41 |
-
|
42 |
-
|
43 |
-
|
|
|
|
|
|
|
|
|
|
2 |
import torch.nn.functional as F
|
3 |
|
4 |
|
5 |
+
def top_p_filtering(logits, top_p: float = 1.0):
|
6 |
"""
|
7 |
+
Filter a distribution of logits using top-p filtering.
|
8 |
The input logits tensor is modified in-place.
|
9 |
|
10 |
Args:
|
11 |
+
logits (torch.Tensor): A tensor of logits to be filtered. Expected shape is [..., vocab_size].
|
12 |
+
top_p (float, optional): The cumulative probability threshold for top-p sampling.
|
13 |
+
If < 1.0, only keep the smallest set of tokens whose
|
14 |
+
cumulative probability does not exceed this threshold.
|
15 |
|
16 |
Returns:
|
17 |
+
torch.Tensor: logits where values outside the top-p threshold are set to -∞.
|
18 |
"""
|
19 |
+
if top_p < 1.0:
|
20 |
+
sorted_logits, sorted_idx = logits.sort(dim=-1, descending=True)
|
21 |
+
sorted_idx_to_remove = sorted_logits.softmax(dim=-1).cumsum(dim=-1) > top_p
|
22 |
+
sorted_idx_to_remove[..., 0] = False
|
23 |
+
|
24 |
+
idx_to_remove = sorted_idx_to_remove.scatter(
|
25 |
+
-1, sorted_idx, sorted_idx_to_remove
|
26 |
+
)
|
27 |
logits.masked_fill_(idx_to_remove, -torch.inf)
|
28 |
|
29 |
return logits
|
|
|
31 |
|
32 |
def process_logits(
|
33 |
logits,
|
34 |
+
top_p: float = None,
|
35 |
):
|
36 |
"""
|
37 |
+
Process logits by optionally applying nucleus (top-p) filtering and token selection.
|
38 |
+
|
39 |
+
If `top_p` is None, the token with the highest probability (argmax) is selected.
|
40 |
+
If `top_p` is provided, smallest set of tokens with cumulative probability ≥ top_p are kept, then softmax is applied to obtain
|
41 |
+
probabilities. A token is sampled from this filtered distribution using `torch.multinomial`.
|
42 |
|
43 |
Args:
|
44 |
+
logits (torch.Tensor): A tensor of logits to process.
|
45 |
+
top_p (float, optional): The cumulative probability threshold for nucleus sampling.
|
46 |
+
If None, argmax selection is performed (deterministic generation). Otherwise, smallest set of tokens with cumulative probability ≥ top_p are kept (stochastic generation).
|
47 |
|
48 |
Returns:
|
49 |
+
torch.Tensor: selected token index.
|
50 |
"""
|
51 |
+
if top_p is None:
|
52 |
+
next_id = torch.argmax(logits, dim=-1, keepdim=True)
|
53 |
+
else:
|
54 |
+
logits = top_p_filtering(logits, top_p=0.9)
|
55 |
+
probs = F.softmax(logits, dim=-1)
|
56 |
+
next_id = torch.multinomial(probs, num_samples=1, replacement=True)
|
57 |
+
return next_id
|