Spaces:
Running
on
A10G
Running
on
A10G
hanjiaming.0208
commited on
Commit
·
3f9caff
1
Parent(s):
e6c3189
add 512px AR
Browse files- app.py +11 -7
- t2i_inference.py +9 -22
- tok/mm_autoencoder.py +4 -3
app.py
CHANGED
@@ -3,6 +3,7 @@ import gradio as gr
|
|
3 |
from torchvision.transforms.functional import to_tensor
|
4 |
from huggingface_hub import hf_hub_download, snapshot_download, login
|
5 |
|
|
|
6 |
from t2i_inference import T2IConfig, TextToImageInference
|
7 |
|
8 |
def generate_text(self, image: str, prompt: str) -> str:
|
@@ -29,16 +30,16 @@ def generate_text(self, image: str, prompt: str) -> str:
|
|
29 |
login(token=os.getenv('HF_TOKEN'))
|
30 |
config = T2IConfig()
|
31 |
config.model = snapshot_download("csuhan/Tar-7B-v0.1")
|
32 |
-
config.ar_path =
|
|
|
|
|
|
|
33 |
config.encoder_path = hf_hub_download("csuhan/TA-Tok", "ta_tok.pth")
|
34 |
config.decoder_path = hf_hub_download("peizesun/llamagen_t2i", "vq_ds16_t2i.pt")
|
35 |
inference = TextToImageInference(config)
|
36 |
|
37 |
-
def generate_image(prompt, top_p, top_k, cfg_scale):
|
38 |
-
|
39 |
-
config.top_k = top_k
|
40 |
-
config.cfg_scale = cfg_scale
|
41 |
-
image = inference.generate_image(prompt)
|
42 |
return image
|
43 |
|
44 |
def clear_inputs_t2i():
|
@@ -68,6 +69,9 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
68 |
with gr.Column(scale=1):
|
69 |
prompt = gr.Textbox(label="Prompt", placeholder="Enter a prompt")
|
70 |
with gr.Accordion("Advanced Settings", open=False):
|
|
|
|
|
|
|
71 |
top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p")
|
72 |
top_k = gr.Slider(1, 2000, value=1200, step=10, label="Top-k")
|
73 |
cfg_scale = gr.Slider(1.0, 20.0, value=4.0, step=0.5, label="CFG Scale")
|
@@ -79,7 +83,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
79 |
|
80 |
generate_btn.click(
|
81 |
generate_image,
|
82 |
-
inputs=[prompt, top_p, top_k, cfg_scale],
|
83 |
outputs=output_image
|
84 |
)
|
85 |
clear_btn.click(
|
|
|
3 |
from torchvision.transforms.functional import to_tensor
|
4 |
from huggingface_hub import hf_hub_download, snapshot_download, login
|
5 |
|
6 |
+
from tok.ar_dtok.ar_model import ARModel
|
7 |
from t2i_inference import T2IConfig, TextToImageInference
|
8 |
|
9 |
def generate_text(self, image: str, prompt: str) -> str:
|
|
|
30 |
login(token=os.getenv('HF_TOKEN'))
|
31 |
config = T2IConfig()
|
32 |
config.model = snapshot_download("csuhan/Tar-7B-v0.1")
|
33 |
+
config.ar_path = {
|
34 |
+
"1024px": hf_hub_download("csuhan/TA-Tok", "ar_dtok_lp_1024px.pth"),
|
35 |
+
"512px": hf_hub_download("csuhan/TA-Tok", "ar_dtok_lp_512px.pth"),
|
36 |
+
}
|
37 |
config.encoder_path = hf_hub_download("csuhan/TA-Tok", "ta_tok.pth")
|
38 |
config.decoder_path = hf_hub_download("peizesun/llamagen_t2i", "vq_ds16_t2i.pt")
|
39 |
inference = TextToImageInference(config)
|
40 |
|
41 |
+
def generate_image(prompt, resolution, top_p, top_k, cfg_scale):
|
42 |
+
image = inference.generate_image(prompt, resolution, top_p, top_k, cfg_scale)
|
|
|
|
|
|
|
43 |
return image
|
44 |
|
45 |
def clear_inputs_t2i():
|
|
|
69 |
with gr.Column(scale=1):
|
70 |
prompt = gr.Textbox(label="Prompt", placeholder="Enter a prompt")
|
71 |
with gr.Accordion("Advanced Settings", open=False):
|
72 |
+
resolution = gr.Choice(
|
73 |
+
["512px", "1024px"], value="1024px", label="Resolution"
|
74 |
+
)
|
75 |
top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p")
|
76 |
top_k = gr.Slider(1, 2000, value=1200, step=10, label="Top-k")
|
77 |
cfg_scale = gr.Slider(1.0, 20.0, value=4.0, step=0.5, label="CFG Scale")
|
|
|
83 |
|
84 |
generate_btn.click(
|
85 |
generate_image,
|
86 |
+
inputs=[prompt, resolution, top_p, top_k, cfg_scale],
|
87 |
outputs=output_image
|
88 |
)
|
89 |
clear_btn.click(
|
t2i_inference.py
CHANGED
@@ -13,7 +13,7 @@ from tok.mm_autoencoder import MMAutoEncoder
|
|
13 |
class T2IConfig:
|
14 |
model_path: str = "csuhan/Tar-1.5B"
|
15 |
# visual tokenizer config
|
16 |
-
ar_path
|
17 |
encoder_path: str = 'ta_tok.pth'
|
18 |
decoder_path: str = 'vq_ds16_t2i.pt'
|
19 |
|
@@ -39,17 +39,18 @@ class TextToImageInference:
|
|
39 |
|
40 |
# Initialize visual tokenizer
|
41 |
config = dict(
|
42 |
-
|
43 |
encoder_path=self.config.encoder_path,
|
44 |
decoder_path=self.config.decoder_path,
|
45 |
encoder_args={'input_type': 'rec'},
|
46 |
decoder_args={},
|
47 |
)
|
48 |
self.visual_tokenizer = MMAutoEncoder(**config).eval().to(dtype=self.config.dtype, device=self.device)
|
49 |
-
self.visual_tokenizer.ar_model.
|
|
|
50 |
self.visual_tokenizer.encoder.pool_scale = self.config.scale + 1
|
51 |
|
52 |
-
def generate_image(self, prompt
|
53 |
# Prepare prompt
|
54 |
messages = [
|
55 |
{"role": "system", "content": "You are a helpful assistant."},
|
@@ -69,8 +70,8 @@ class TextToImageInference:
|
|
69 |
max_new_tokens=self.config.seq_len,
|
70 |
do_sample=True,
|
71 |
temperature=self.config.temperature,
|
72 |
-
top_p=
|
73 |
-
top_k=
|
74 |
|
75 |
# Process generated tokens
|
76 |
gen_text = self.tokenizer.batch_decode(gen_ids)[0]
|
@@ -80,21 +81,7 @@ class TextToImageInference:
|
|
80 |
|
81 |
gen_tensor = self.visual_tokenizer.decode_from_encoder_indices(
|
82 |
gen_code,
|
83 |
-
{'cfg_scale':
|
84 |
)
|
85 |
gen_image = Image.fromarray(gen_tensor[0].numpy())
|
86 |
-
return gen_image
|
87 |
-
|
88 |
-
def main():
|
89 |
-
config = T2IConfig()
|
90 |
-
config.ar_path = hf_hub_download("csuhan/TA-Tok", "ar_dtok_lp_1024px.pth")
|
91 |
-
config.encoder_path = hf_hub_download("csuhan/TA-Tok", "ta_tok.pth")
|
92 |
-
config.decoder_path = hf_hub_download("peizesun/llamagen_t2i", "vq_ds16_t2i.pt")
|
93 |
-
inference = TextToImageInference(config)
|
94 |
-
|
95 |
-
prompt = "A photo of a macaw"
|
96 |
-
image = inference.generate_image(prompt)
|
97 |
-
image.save("generated_image.png")
|
98 |
-
|
99 |
-
if __name__ == "__main__":
|
100 |
-
main()
|
|
|
13 |
class T2IConfig:
|
14 |
model_path: str = "csuhan/Tar-1.5B"
|
15 |
# visual tokenizer config
|
16 |
+
ar_path = None
|
17 |
encoder_path: str = 'ta_tok.pth'
|
18 |
decoder_path: str = 'vq_ds16_t2i.pt'
|
19 |
|
|
|
39 |
|
40 |
# Initialize visual tokenizer
|
41 |
config = dict(
|
42 |
+
ar_path_dict=self.config.ar_path,
|
43 |
encoder_path=self.config.encoder_path,
|
44 |
decoder_path=self.config.decoder_path,
|
45 |
encoder_args={'input_type': 'rec'},
|
46 |
decoder_args={},
|
47 |
)
|
48 |
self.visual_tokenizer = MMAutoEncoder(**config).eval().to(dtype=self.config.dtype, device=self.device)
|
49 |
+
for ar_model in self.visual_tokenizer.ar_model.values():
|
50 |
+
ar_model.cls_token_num = self.config.seq_len
|
51 |
self.visual_tokenizer.encoder.pool_scale = self.config.scale + 1
|
52 |
|
53 |
+
def generate_image(self, prompt, resolution, top_p, top_k, cfg_scale) -> Image.Image:
|
54 |
# Prepare prompt
|
55 |
messages = [
|
56 |
{"role": "system", "content": "You are a helpful assistant."},
|
|
|
70 |
max_new_tokens=self.config.seq_len,
|
71 |
do_sample=True,
|
72 |
temperature=self.config.temperature,
|
73 |
+
top_p=top_p,
|
74 |
+
top_k=top_k)
|
75 |
|
76 |
# Process generated tokens
|
77 |
gen_text = self.tokenizer.batch_decode(gen_ids)[0]
|
|
|
81 |
|
82 |
gen_tensor = self.visual_tokenizer.decode_from_encoder_indices(
|
83 |
gen_code,
|
84 |
+
{'cfg_scale': cfg_scale, 'resolution': resolution},
|
85 |
)
|
86 |
gen_image = Image.fromarray(gen_tensor[0].numpy())
|
87 |
+
return gen_image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tok/mm_autoencoder.py
CHANGED
@@ -8,17 +8,18 @@ from tok.ta_tok import TextAlignedTokenizer
|
|
8 |
|
9 |
class MMAutoEncoder(nn.Module):
|
10 |
def __init__(self,
|
11 |
-
|
12 |
encoder_path, decoder_path,
|
13 |
encoder_args={}, decoder_args={}):
|
14 |
super().__init__()
|
15 |
-
self.ar_model = ARModel.from_checkpoint(ar_path)
|
16 |
|
17 |
self.encoder = TextAlignedTokenizer.from_checkpoint(encoder_path, load_teacher=False, **encoder_args)
|
18 |
self.decoder = VQVAE.from_checkpoint(decoder_path, **decoder_args)
|
19 |
|
20 |
def ar_sample(self, x, args):
|
21 |
-
|
|
|
22 |
x,
|
23 |
cfg_scale=args.get('cfg_scale', 1.0),
|
24 |
cfg_interval=args.get('cfg_interval', -1),
|
|
|
8 |
|
9 |
class MMAutoEncoder(nn.Module):
|
10 |
def __init__(self,
|
11 |
+
ar_path_dict,
|
12 |
encoder_path, decoder_path,
|
13 |
encoder_args={}, decoder_args={}):
|
14 |
super().__init__()
|
15 |
+
self.ar_model = {resolution: ARModel.from_checkpoint(ar_path) for resolution, ar_path in ar_path_dict.items()}
|
16 |
|
17 |
self.encoder = TextAlignedTokenizer.from_checkpoint(encoder_path, load_teacher=False, **encoder_args)
|
18 |
self.decoder = VQVAE.from_checkpoint(decoder_path, **decoder_args)
|
19 |
|
20 |
def ar_sample(self, x, args):
|
21 |
+
resolution = args.get("resolution", "1024px")
|
22 |
+
x = self.ar_model[resolution].sample(
|
23 |
x,
|
24 |
cfg_scale=args.get('cfg_scale', 1.0),
|
25 |
cfg_interval=args.get('cfg_interval', -1),
|