hamishivi commited on
Commit
feab349
·
verified ·
1 Parent(s): 887c784
Files changed (1) hide show
  1. app.py +5 -0
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import logging
2
  import gradio as gr
3
  import torch
 
4
  from transformers import (
5
  MODEL_FOR_MASKED_LM_MAPPING,
6
  )
@@ -67,6 +68,10 @@ def main():
67
  pipeline.simplex_value = simplex_value
68
  pipeline.top_p = top_p
69
  pipeline.temperature = temperature
 
 
 
 
70
  # tulu chat template
71
  inputs = "<|user|>\n" + inputs + "<|assistant|>\n"
72
 
 
1
  import logging
2
  import gradio as gr
3
  import torch
4
+ import numpy as np
5
  from transformers import (
6
  MODEL_FOR_MASKED_LM_MAPPING,
7
  )
 
68
  pipeline.simplex_value = simplex_value
69
  pipeline.top_p = top_p
70
  pipeline.temperature = temperature
71
+
72
+ # Ensure timesteps are properly set as a sequence
73
+ pipeline.scheduler.timesteps = torch.arange(0, diffusion_steps).flip(0)
74
+
75
  # tulu chat template
76
  inputs = "<|user|>\n" + inputs + "<|assistant|>\n"
77