Akash Garg commited on
Commit
f6a2f50
·
1 Parent(s): 0c10674

adding variance slider for top_p

Browse files
app.py CHANGED
@@ -39,9 +39,10 @@ def gen_save_folder(max_size=200):
39
 
40
  return new_folder
41
 
42
- def handle_text_prompt(input_prompt):
43
- print(f"prompt: {input_prompt}")
44
- mesh_v_f = GLOBAL_STATE["engine_fast"].t2s([input_prompt], use_kv_cache=True, resolution_base=8.0)
 
45
  # save output
46
  vertices, faces = mesh_v_f[0][0], mesh_v_f[0][1]
47
  save_folder = gen_save_folder()
@@ -57,6 +58,7 @@ def build_interface():
57
  gr.Markdown(
58
  f"""
59
  # {title}
 
60
  """
61
  )
62
 
@@ -74,11 +76,13 @@ def build_interface():
74
  model3d = gr.Model3D(
75
  label="Output", height="45em", interactive=False
76
  )
 
77
 
78
  submit_button.click(
79
  handle_text_prompt,
80
  inputs=[
81
- input_text_box
 
82
  ],
83
  outputs=[
84
  model3d
 
39
 
40
  return new_folder
41
 
42
+ def handle_text_prompt(input_prompt, variance):
43
+ print(f"prompt: {input_prompt}, variance: {variance}")
44
+ top_p = None if variance == 0 else (100 - variance) / 100.0
45
+ mesh_v_f = GLOBAL_STATE["engine_fast"].t2s([input_prompt], use_kv_cache=True, resolution_base=8.0, top_p=top_p)
46
  # save output
47
  vertices, faces = mesh_v_f[0][0], mesh_v_f[0][1]
48
  save_folder = gen_save_folder()
 
58
  gr.Markdown(
59
  f"""
60
  # {title}
61
+ # Check out our [Github](https://github.com/Roblox/cube) to try it on your own machine!
62
  """
63
  )
64
 
 
76
  model3d = gr.Model3D(
77
  label="Output", height="45em", interactive=False
78
  )
79
+ variance = gr.Slider(minimum=0, maximum=99, step=1, value=0, label="Variance")
80
 
81
  submit_button.click(
82
  handle_text_prompt,
83
  inputs=[
84
+ input_text_box,
85
+ variance
86
  ],
87
  outputs=[
88
  model3d
cube/cube3d/generate.py CHANGED
@@ -20,13 +20,13 @@ def generate_mesh(
20
  output_name,
21
  resolution_base=8.0,
22
  disable_postprocess=False,
23
- top_k: int = 1,
24
  ):
25
  mesh_v_f = engine.t2s(
26
  [prompt],
27
  use_kv_cache=True,
28
  resolution_base=resolution_base,
29
- top_k=top_k,
30
  )
31
  vertices, faces = mesh_v_f[0][0], mesh_v_f[0][1]
32
  obj_path = os.path.join(output_dir, f"{output_name}.obj")
@@ -87,10 +87,10 @@ if __name__ == "__main__":
87
  help="Text prompt for generating a 3D mesh",
88
  )
89
  parser.add_argument(
90
- "--top-k",
91
- type=int,
92
- default=1,
93
- help="Top k filtering, 0 means no filtering, by default 1, which is determistic.",
94
  )
95
  parser.add_argument(
96
  "--render-gif",
@@ -136,7 +136,7 @@ if __name__ == "__main__":
136
  "output",
137
  args.resolution_base,
138
  args.disable_postprocessing,
139
- args.top_k,
140
  )
141
  if args.render_gif:
142
  gif_path = renderer.render_turntable(obj_path, args.output_dir)
 
20
  output_name,
21
  resolution_base=8.0,
22
  disable_postprocess=False,
23
+ top_p=None,
24
  ):
25
  mesh_v_f = engine.t2s(
26
  [prompt],
27
  use_kv_cache=True,
28
  resolution_base=resolution_base,
29
+ top_p=top_p,
30
  )
31
  vertices, faces = mesh_v_f[0][0], mesh_v_f[0][1]
32
  obj_path = os.path.join(output_dir, f"{output_name}.obj")
 
87
  help="Text prompt for generating a 3D mesh",
88
  )
89
  parser.add_argument(
90
+ "--top-p",
91
+ type=float,
92
+ default=None,
93
+ help="Float < 1: Keep smallest set of tokens with cumulative probability top_p. Default None: deterministic generation.",
94
  )
95
  parser.add_argument(
96
  "--render-gif",
 
136
  "output",
137
  args.resolution_base,
138
  args.disable_postprocessing,
139
+ args.top_p,
140
  )
141
  if args.render_gif:
142
  gif_path = renderer.render_turntable(obj_path, args.output_dir)
cube/cube3d/inference/engine.py CHANGED
@@ -160,7 +160,7 @@ class Engine:
160
  prompts: list[str],
161
  use_kv_cache: bool,
162
  guidance_scale: float = 3.0,
163
- top_k: int = 1,
164
  ):
165
  """
166
  Generates text using a GPT model based on the provided prompts.
@@ -168,7 +168,8 @@ class Engine:
168
  prompts (list[str]): A list of input prompts to generate text from.
169
  use_kv_cache (bool): Whether to use key-value caching for faster generation.
170
  guidance_scale (float, optional): The scale for guidance during generation. Default is 3.0.
171
- top_k : (int, optional): Top k filtering, 0 means no filtering, by default 1.
 
172
  Returns:
173
  torch.Tensor: A tensor containing the generated token IDs.
174
  """
@@ -215,11 +216,10 @@ class Engine:
215
  guidance_scale * (self.max_new_tokens - i) / self.max_new_tokens
216
  )
217
  logits = (1 + gamma) * logits - gamma * uncond_logits
218
- probs = process_logits(
219
  logits,
220
- top_k=top_k,
221
  )
222
- next_id = torch.multinomial(probs, num_samples=1, replacement=True)
223
  output_ids.append(next_id)
224
  next_embed = self.gpt_model.encode_token(next_id)
225
  if guidance_scale > 0.0:
@@ -266,7 +266,7 @@ class Engine:
266
  guidance_scale: float = 3.0,
267
  resolution_base: float = 8.0,
268
  chunk_size: int = 100_000,
269
- top_k: int = 1,
270
  ):
271
  """
272
  Generates a 3D mesh from text prompts using a GPT model and shape decoder.
@@ -276,10 +276,12 @@ class Engine:
276
  guidance_scale (float, optional): The scale of guidance for the GPT model. Default is 3.0.
277
  resolution_base (float, optional): The base resolution for the shape decoder. Default is 8.0.
278
  chunk_size (int, optional): The chunk size for processing the shape decoding. Default is 100,000.
 
 
279
  Returns:
280
  mesh_v_f: The generated 3D mesh vertices and faces.
281
  """
282
- output_ids = self.run_gpt(prompts, use_kv_cache, guidance_scale, top_k)
283
  with torch.autocast(self.device.type, dtype=torch.bfloat16):
284
  mesh_v_f = self.run_shape_decode(output_ids, resolution_base, chunk_size)
285
  return mesh_v_f
@@ -426,7 +428,7 @@ class EngineFast(Engine):
426
  prompts: list[str],
427
  use_kv_cache: bool,
428
  guidance_scale: float = 3.0,
429
- top_k: int = 1,
430
  ):
431
  """
432
  Runs the GPT model to generate text based on the provided prompts.
@@ -434,6 +436,8 @@ class EngineFast(Engine):
434
  prompts (list[str]): A list of input prompts for the GPT model. Only a single prompt is supported.
435
  use_kv_cache (bool): Flag indicating whether to use key-value caching. (Currently not used)
436
  guidance_scale (float, optional): The scale factor for guidance. Default is 3.0.
 
 
437
  Returns:
438
  torch.Tensor: A tensor containing the generated output token IDs.
439
  Raises:
@@ -464,9 +468,7 @@ class EngineFast(Engine):
464
  logits, uncond_logits = logits.float().chunk(2, dim=0)
465
  gamma = guidance_scale
466
  logits = (1 + gamma) * logits - gamma * uncond_logits
467
-
468
- probs = process_logits(logits, top_k=top_k)
469
- next_id = torch.multinomial(probs, num_samples=1, replacement=True)
470
 
471
  output_ids[:, 0] = next_id.squeeze()
472
  next_embed = self.gpt_model.encode_token(next_id)
@@ -488,8 +490,7 @@ class EngineFast(Engine):
488
  guidance_scale * (self.max_new_tokens - i) / self.max_new_tokens
489
  )
490
  logits = (1 + gamma) * logits - gamma * uncond_logits
491
- probs = process_logits(logits, top_k=top_k)
492
- next_id = torch.multinomial(probs, num_samples=1, replacement=True)
493
 
494
  output_ids[:, i] = next_id.squeeze()
495
  next_embed = self.gpt_model.encode_token(next_id)
 
160
  prompts: list[str],
161
  use_kv_cache: bool,
162
  guidance_scale: float = 3.0,
163
+ top_p: float = None,
164
  ):
165
  """
166
  Generates text using a GPT model based on the provided prompts.
 
168
  prompts (list[str]): A list of input prompts to generate text from.
169
  use_kv_cache (bool): Whether to use key-value caching for faster generation.
170
  guidance_scale (float, optional): The scale for guidance during generation. Default is 3.0.
171
+ top_p (float, optional): The cumulative probability threshold for nucleus sampling.
172
+ If None, argmax selection is performed (deterministic generation). Otherwise, smallest set of tokens with cumulative probability ≥ top_p are kept (stochastic generation).
173
  Returns:
174
  torch.Tensor: A tensor containing the generated token IDs.
175
  """
 
216
  guidance_scale * (self.max_new_tokens - i) / self.max_new_tokens
217
  )
218
  logits = (1 + gamma) * logits - gamma * uncond_logits
219
+ next_id = process_logits(
220
  logits,
221
+ top_p=top_p,
222
  )
 
223
  output_ids.append(next_id)
224
  next_embed = self.gpt_model.encode_token(next_id)
225
  if guidance_scale > 0.0:
 
266
  guidance_scale: float = 3.0,
267
  resolution_base: float = 8.0,
268
  chunk_size: int = 100_000,
269
+ top_p: float = None,
270
  ):
271
  """
272
  Generates a 3D mesh from text prompts using a GPT model and shape decoder.
 
276
  guidance_scale (float, optional): The scale of guidance for the GPT model. Default is 3.0.
277
  resolution_base (float, optional): The base resolution for the shape decoder. Default is 8.0.
278
  chunk_size (int, optional): The chunk size for processing the shape decoding. Default is 100,000.
279
+ top_p (float, optional): The cumulative probability threshold for nucleus sampling.
280
+ If None, argmax selection is performed (deterministic generation). Otherwise, smallest set of tokens with cumulative probability ≥ top_p are kept (stochastic generation).
281
  Returns:
282
  mesh_v_f: The generated 3D mesh vertices and faces.
283
  """
284
+ output_ids = self.run_gpt(prompts, use_kv_cache, guidance_scale, top_p)
285
  with torch.autocast(self.device.type, dtype=torch.bfloat16):
286
  mesh_v_f = self.run_shape_decode(output_ids, resolution_base, chunk_size)
287
  return mesh_v_f
 
428
  prompts: list[str],
429
  use_kv_cache: bool,
430
  guidance_scale: float = 3.0,
431
+ top_p: float = None
432
  ):
433
  """
434
  Runs the GPT model to generate text based on the provided prompts.
 
436
  prompts (list[str]): A list of input prompts for the GPT model. Only a single prompt is supported.
437
  use_kv_cache (bool): Flag indicating whether to use key-value caching. (Currently not used)
438
  guidance_scale (float, optional): The scale factor for guidance. Default is 3.0.
439
+ top_p (float, optional): The cumulative probability threshold for nucleus sampling.
440
+ If None, argmax selection is performed. Otherwise, smallest set of tokens with cumulative probability ≥ top_p are kept.
441
  Returns:
442
  torch.Tensor: A tensor containing the generated output token IDs.
443
  Raises:
 
468
  logits, uncond_logits = logits.float().chunk(2, dim=0)
469
  gamma = guidance_scale
470
  logits = (1 + gamma) * logits - gamma * uncond_logits
471
+ next_id = process_logits(logits, top_p=top_p)
 
 
472
 
473
  output_ids[:, 0] = next_id.squeeze()
474
  next_embed = self.gpt_model.encode_token(next_id)
 
490
  guidance_scale * (self.max_new_tokens - i) / self.max_new_tokens
491
  )
492
  logits = (1 + gamma) * logits - gamma * uncond_logits
493
+ next_id = process_logits(logits, top_p=top_p)
 
494
 
495
  output_ids[:, i] = next_id.squeeze()
496
  next_embed = self.gpt_model.encode_token(next_id)
cube/cube3d/inference/logits_postprocesses.py CHANGED
@@ -2,22 +2,28 @@ import torch
2
  import torch.nn.functional as F
3
 
4
 
5
- def top_k_filtering(logits, top_k: int = 1):
6
  """
7
- Filter a distribution of logits using top-k and/or top-p (nucleus) filtering.
8
  The input logits tensor is modified in-place.
9
 
10
  Args:
11
- logits: A tensor of logits to be filtered. Expected shape is [..., vocab_size].
12
- top_k: If > 0, only keep the top k tokens with highest probability.
 
 
13
 
14
  Returns:
15
- A tensor of logits where values outside the top-k/top-p threshold are set to -∞.
16
  """
17
- if top_k > 0:
18
- idx_to_remove = logits < logits.topk(top_k, largest=True, sorted=False, dim=-1)[
19
- 0
20
- ].amin(dim=-1, keepdim=True)
 
 
 
 
21
  logits.masked_fill_(idx_to_remove, -torch.inf)
22
 
23
  return logits
@@ -25,19 +31,27 @@ def top_k_filtering(logits, top_k: int = 1):
25
 
26
  def process_logits(
27
  logits,
28
- top_k: int = 1,
29
  ):
30
  """
31
- Process logits by optionally applying top-k filtering.
32
- The final probabilities are returned after applying softmax on the filtered logits.
 
 
 
33
 
34
  Args:
35
- logits: A tensor of logits to process. Expected shape is [..., vocab_size].
36
- top_k: If > 0, only keep the top k tokens with highest probability.
 
37
 
38
  Returns:
39
- A tensor of probabilities after filtering, with the same shape as the input logits.
40
  """
41
- logits = top_k_filtering(logits, top_k=top_k)
42
- probs = F.softmax(logits, dim=-1)
43
- return probs
 
 
 
 
 
2
  import torch.nn.functional as F
3
 
4
 
5
+ def top_p_filtering(logits, top_p: float = 1.0):
6
  """
7
+ Filter a distribution of logits using top-p filtering.
8
  The input logits tensor is modified in-place.
9
 
10
  Args:
11
+ logits (torch.Tensor): A tensor of logits to be filtered. Expected shape is [..., vocab_size].
12
+ top_p (float, optional): The cumulative probability threshold for top-p sampling.
13
+ If < 1.0, only keep the smallest set of tokens whose
14
+ cumulative probability does not exceed this threshold.
15
 
16
  Returns:
17
+ torch.Tensor: logits where values outside the top-p threshold are set to -∞.
18
  """
19
+ if top_p < 1.0:
20
+ sorted_logits, sorted_idx = logits.sort(dim=-1, descending=True)
21
+ sorted_idx_to_remove = sorted_logits.softmax(dim=-1).cumsum(dim=-1) > top_p
22
+ sorted_idx_to_remove[..., 0] = False
23
+
24
+ idx_to_remove = sorted_idx_to_remove.scatter(
25
+ -1, sorted_idx, sorted_idx_to_remove
26
+ )
27
  logits.masked_fill_(idx_to_remove, -torch.inf)
28
 
29
  return logits
 
31
 
32
  def process_logits(
33
  logits,
34
+ top_p: float = None,
35
  ):
36
  """
37
+ Process logits by optionally applying nucleus (top-p) filtering and token selection.
38
+
39
+ If `top_p` is None, the token with the highest probability (argmax) is selected.
40
+ If `top_p` is provided, smallest set of tokens with cumulative probability ≥ top_p are kept, then softmax is applied to obtain
41
+ probabilities. A token is sampled from this filtered distribution using `torch.multinomial`.
42
 
43
  Args:
44
+ logits (torch.Tensor): A tensor of logits to process.
45
+ top_p (float, optional): The cumulative probability threshold for nucleus sampling.
46
+ If None, argmax selection is performed (deterministic generation). Otherwise, smallest set of tokens with cumulative probability ≥ top_p are kept (stochastic generation).
47
 
48
  Returns:
49
+ torch.Tensor: selected token index.
50
  """
51
+ if top_p is None:
52
+ next_id = torch.argmax(logits, dim=-1, keepdim=True)
53
+ else:
54
+ logits = top_p_filtering(logits, top_p=0.9)
55
+ probs = F.softmax(logits, dim=-1)
56
+ next_id = torch.multinomial(probs, num_samples=1, replacement=True)
57
+ return next_id