hanjiaming.0208 commited on
Commit
3f9caff
·
1 Parent(s): e6c3189

add 512px AR

Browse files
Files changed (3) hide show
  1. app.py +11 -7
  2. t2i_inference.py +9 -22
  3. tok/mm_autoencoder.py +4 -3
app.py CHANGED
@@ -3,6 +3,7 @@ import gradio as gr
3
  from torchvision.transforms.functional import to_tensor
4
  from huggingface_hub import hf_hub_download, snapshot_download, login
5
 
 
6
  from t2i_inference import T2IConfig, TextToImageInference
7
 
8
  def generate_text(self, image: str, prompt: str) -> str:
@@ -29,16 +30,16 @@ def generate_text(self, image: str, prompt: str) -> str:
29
  login(token=os.getenv('HF_TOKEN'))
30
  config = T2IConfig()
31
  config.model = snapshot_download("csuhan/Tar-7B-v0.1")
32
- config.ar_path = hf_hub_download("csuhan/TA-Tok", "ar_dtok_lp_1024px.pth")
 
 
 
33
  config.encoder_path = hf_hub_download("csuhan/TA-Tok", "ta_tok.pth")
34
  config.decoder_path = hf_hub_download("peizesun/llamagen_t2i", "vq_ds16_t2i.pt")
35
  inference = TextToImageInference(config)
36
 
37
- def generate_image(prompt, top_p, top_k, cfg_scale):
38
- config.top_p = top_p
39
- config.top_k = top_k
40
- config.cfg_scale = cfg_scale
41
- image = inference.generate_image(prompt)
42
  return image
43
 
44
  def clear_inputs_t2i():
@@ -68,6 +69,9 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
68
  with gr.Column(scale=1):
69
  prompt = gr.Textbox(label="Prompt", placeholder="Enter a prompt")
70
  with gr.Accordion("Advanced Settings", open=False):
 
 
 
71
  top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p")
72
  top_k = gr.Slider(1, 2000, value=1200, step=10, label="Top-k")
73
  cfg_scale = gr.Slider(1.0, 20.0, value=4.0, step=0.5, label="CFG Scale")
@@ -79,7 +83,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
79
 
80
  generate_btn.click(
81
  generate_image,
82
- inputs=[prompt, top_p, top_k, cfg_scale],
83
  outputs=output_image
84
  )
85
  clear_btn.click(
 
3
  from torchvision.transforms.functional import to_tensor
4
  from huggingface_hub import hf_hub_download, snapshot_download, login
5
 
6
+ from tok.ar_dtok.ar_model import ARModel
7
  from t2i_inference import T2IConfig, TextToImageInference
8
 
9
  def generate_text(self, image: str, prompt: str) -> str:
 
30
  login(token=os.getenv('HF_TOKEN'))
31
  config = T2IConfig()
32
  config.model = snapshot_download("csuhan/Tar-7B-v0.1")
33
+ config.ar_path = {
34
+ "1024px": hf_hub_download("csuhan/TA-Tok", "ar_dtok_lp_1024px.pth"),
35
+ "512px": hf_hub_download("csuhan/TA-Tok", "ar_dtok_lp_512px.pth"),
36
+ }
37
  config.encoder_path = hf_hub_download("csuhan/TA-Tok", "ta_tok.pth")
38
  config.decoder_path = hf_hub_download("peizesun/llamagen_t2i", "vq_ds16_t2i.pt")
39
  inference = TextToImageInference(config)
40
 
41
+ def generate_image(prompt, resolution, top_p, top_k, cfg_scale):
42
+ image = inference.generate_image(prompt, resolution, top_p, top_k, cfg_scale)
 
 
 
43
  return image
44
 
45
  def clear_inputs_t2i():
 
69
  with gr.Column(scale=1):
70
  prompt = gr.Textbox(label="Prompt", placeholder="Enter a prompt")
71
  with gr.Accordion("Advanced Settings", open=False):
72
+ resolution = gr.Choice(
73
+ ["512px", "1024px"], value="1024px", label="Resolution"
74
+ )
75
  top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p")
76
  top_k = gr.Slider(1, 2000, value=1200, step=10, label="Top-k")
77
  cfg_scale = gr.Slider(1.0, 20.0, value=4.0, step=0.5, label="CFG Scale")
 
83
 
84
  generate_btn.click(
85
  generate_image,
86
+ inputs=[prompt, resolution, top_p, top_k, cfg_scale],
87
  outputs=output_image
88
  )
89
  clear_btn.click(
t2i_inference.py CHANGED
@@ -13,7 +13,7 @@ from tok.mm_autoencoder import MMAutoEncoder
13
  class T2IConfig:
14
  model_path: str = "csuhan/Tar-1.5B"
15
  # visual tokenizer config
16
- ar_path: str = 'ar_dtok_lp_256px.pth'
17
  encoder_path: str = 'ta_tok.pth'
18
  decoder_path: str = 'vq_ds16_t2i.pt'
19
 
@@ -39,17 +39,18 @@ class TextToImageInference:
39
 
40
  # Initialize visual tokenizer
41
  config = dict(
42
- ar_path=self.config.ar_path,
43
  encoder_path=self.config.encoder_path,
44
  decoder_path=self.config.decoder_path,
45
  encoder_args={'input_type': 'rec'},
46
  decoder_args={},
47
  )
48
  self.visual_tokenizer = MMAutoEncoder(**config).eval().to(dtype=self.config.dtype, device=self.device)
49
- self.visual_tokenizer.ar_model.cls_token_num = self.config.seq_len
 
50
  self.visual_tokenizer.encoder.pool_scale = self.config.scale + 1
51
 
52
- def generate_image(self, prompt: str) -> Image.Image:
53
  # Prepare prompt
54
  messages = [
55
  {"role": "system", "content": "You are a helpful assistant."},
@@ -69,8 +70,8 @@ class TextToImageInference:
69
  max_new_tokens=self.config.seq_len,
70
  do_sample=True,
71
  temperature=self.config.temperature,
72
- top_p=self.config.top_p,
73
- top_k=self.config.top_k)
74
 
75
  # Process generated tokens
76
  gen_text = self.tokenizer.batch_decode(gen_ids)[0]
@@ -80,21 +81,7 @@ class TextToImageInference:
80
 
81
  gen_tensor = self.visual_tokenizer.decode_from_encoder_indices(
82
  gen_code,
83
- {'cfg_scale': self.config.cfg_scale}
84
  )
85
  gen_image = Image.fromarray(gen_tensor[0].numpy())
86
- return gen_image
87
-
88
- def main():
89
- config = T2IConfig()
90
- config.ar_path = hf_hub_download("csuhan/TA-Tok", "ar_dtok_lp_1024px.pth")
91
- config.encoder_path = hf_hub_download("csuhan/TA-Tok", "ta_tok.pth")
92
- config.decoder_path = hf_hub_download("peizesun/llamagen_t2i", "vq_ds16_t2i.pt")
93
- inference = TextToImageInference(config)
94
-
95
- prompt = "A photo of a macaw"
96
- image = inference.generate_image(prompt)
97
- image.save("generated_image.png")
98
-
99
- if __name__ == "__main__":
100
- main()
 
13
  class T2IConfig:
14
  model_path: str = "csuhan/Tar-1.5B"
15
  # visual tokenizer config
16
+ ar_path = None
17
  encoder_path: str = 'ta_tok.pth'
18
  decoder_path: str = 'vq_ds16_t2i.pt'
19
 
 
39
 
40
  # Initialize visual tokenizer
41
  config = dict(
42
+ ar_path_dict=self.config.ar_path,
43
  encoder_path=self.config.encoder_path,
44
  decoder_path=self.config.decoder_path,
45
  encoder_args={'input_type': 'rec'},
46
  decoder_args={},
47
  )
48
  self.visual_tokenizer = MMAutoEncoder(**config).eval().to(dtype=self.config.dtype, device=self.device)
49
+ for ar_model in self.visual_tokenizer.ar_model.values():
50
+ ar_model.cls_token_num = self.config.seq_len
51
  self.visual_tokenizer.encoder.pool_scale = self.config.scale + 1
52
 
53
+ def generate_image(self, prompt, resolution, top_p, top_k, cfg_scale) -> Image.Image:
54
  # Prepare prompt
55
  messages = [
56
  {"role": "system", "content": "You are a helpful assistant."},
 
70
  max_new_tokens=self.config.seq_len,
71
  do_sample=True,
72
  temperature=self.config.temperature,
73
+ top_p=top_p,
74
+ top_k=top_k)
75
 
76
  # Process generated tokens
77
  gen_text = self.tokenizer.batch_decode(gen_ids)[0]
 
81
 
82
  gen_tensor = self.visual_tokenizer.decode_from_encoder_indices(
83
  gen_code,
84
+ {'cfg_scale': cfg_scale, 'resolution': resolution},
85
  )
86
  gen_image = Image.fromarray(gen_tensor[0].numpy())
87
+ return gen_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tok/mm_autoencoder.py CHANGED
@@ -8,17 +8,18 @@ from tok.ta_tok import TextAlignedTokenizer
8
 
9
  class MMAutoEncoder(nn.Module):
10
  def __init__(self,
11
- ar_path,
12
  encoder_path, decoder_path,
13
  encoder_args={}, decoder_args={}):
14
  super().__init__()
15
- self.ar_model = ARModel.from_checkpoint(ar_path)
16
 
17
  self.encoder = TextAlignedTokenizer.from_checkpoint(encoder_path, load_teacher=False, **encoder_args)
18
  self.decoder = VQVAE.from_checkpoint(decoder_path, **decoder_args)
19
 
20
  def ar_sample(self, x, args):
21
- x = self.ar_model.sample(
 
22
  x,
23
  cfg_scale=args.get('cfg_scale', 1.0),
24
  cfg_interval=args.get('cfg_interval', -1),
 
8
 
9
  class MMAutoEncoder(nn.Module):
10
  def __init__(self,
11
+ ar_path_dict,
12
  encoder_path, decoder_path,
13
  encoder_args={}, decoder_args={}):
14
  super().__init__()
15
+ self.ar_model = {resolution: ARModel.from_checkpoint(ar_path) for resolution, ar_path in ar_path_dict.items()}
16
 
17
  self.encoder = TextAlignedTokenizer.from_checkpoint(encoder_path, load_teacher=False, **encoder_args)
18
  self.decoder = VQVAE.from_checkpoint(decoder_path, **decoder_args)
19
 
20
  def ar_sample(self, x, args):
21
+ resolution = args.get("resolution", "1024px")
22
+ x = self.ar_model[resolution].sample(
23
  x,
24
  cfg_scale=args.get('cfg_scale', 1.0),
25
  cfg_interval=args.get('cfg_interval', -1),