multimodalart HF Staff commited on
Commit
3a7a77c
·
verified ·
1 Parent(s): 37d511b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -0
app.py CHANGED
@@ -1,9 +1,39 @@
 
 
 
 
 
 
1
  import gradio as gr
2
  import spaces
3
  import torch
4
  from diffusers import Cosmos2TextToImagePipeline, EDMEulerScheduler
 
5
  import random
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  model_2b_id = "nvidia/Cosmos-Predict2-2B-Text2Image"
8
 
9
  pipe_2b = Cosmos2TextToImagePipeline.from_pretrained(
 
1
+ import subprocess
2
+
3
+ subprocess.run(
4
+ "pip install flash-attn --no-build-isolation", env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"}, shell=True
5
+ )
6
+
7
  import gradio as gr
8
  import spaces
9
  import torch
10
  from diffusers import Cosmos2TextToImagePipeline, EDMEulerScheduler
11
+ from transformers import AutoModelForCausalLM, SiglipProcessor
12
  import random
13
 
14
+ #Add flash_attention_2 to the safeguard model
15
+ def patch_from_pretrained(cls):
16
+ orig_method = cls.from_pretrained
17
+
18
+ def new_from_pretrained(*args, **kwargs):
19
+ kwargs.setdefault("attn_implementation", "flash_attention_2")
20
+ kwargs.setdefault("torch_dtype", torch.bfloat16)
21
+ return orig_method(*args, **kwargs)
22
+
23
+ cls.from_pretrained = new_from_pretrained
24
+
25
+ patch_from_pretrained(AutoModelForCausalLM)
26
+
27
+ #Add a `use_fast` to the safeguard image processor
28
+ def patch_processor_fast(cls):
29
+ orig_method = cls.from_pretrained
30
+ def new_from_pretrained(*args, **kwargs):
31
+ kwargs.setdefault("use_fast", True)
32
+ return orig_method(*args, **kwargs)
33
+ cls.from_pretrained = new_from_pretrained
34
+
35
+ patch_processor_fast(SiglipProcessor)
36
+
37
  model_2b_id = "nvidia/Cosmos-Predict2-2B-Text2Image"
38
 
39
  pipe_2b = Cosmos2TextToImagePipeline.from_pretrained(