WeichenFan commited on
Commit
7f96c09
·
1 Parent(s): 2244ce5

Add application file

Browse files
Files changed (1) hide show
  1. app.py +2 -1
app.py CHANGED
@@ -49,7 +49,8 @@ def load_model(model_name):
49
  else:
50
  current_model = StableDiffusion3Pipeline.from_pretrained(model_paths[model_name], torch_dtype=torch.bfloat16).to("cuda")
51
 
52
- return current_model
 
53
 
54
  def generate_content(prompt, model_name, guidance_scale=7.5, num_inference_steps=50, use_cfg_zero_star=True, use_zero_init=True, zero_steps=0, seed=None, compare_mode=False):
55
  model = load_model(model_name)
 
49
  else:
50
  current_model = StableDiffusion3Pipeline.from_pretrained(model_paths[model_name], torch_dtype=torch.bfloat16).to("cuda")
51
 
52
+ return current_model.to('cuda')
53
+
54
 
55
  def generate_content(prompt, model_name, guidance_scale=7.5, num_inference_steps=50, use_cfg_zero_star=True, use_zero_init=True, zero_steps=0, seed=None, compare_mode=False):
56
  model = load_model(model_name)