tight-inversion commited on
Commit
4d0ddc3
·
1 Parent(s): 10d3d92

Align with pulid demo

Browse files
Files changed (2) hide show
  1. app.py +1 -0
  2. flux/util.py +3 -22
app.py CHANGED
@@ -431,6 +431,7 @@ if __name__ == "__main__":
431
  args.offload = True
432
 
433
  print(f"Using device: {args.device}")
 
434
  print(f"Offload: {args.offload}")
435
 
436
  demo = create_demo(args, args.name, args.device, args.offload, args.aggressive_offload)
 
431
  args.offload = True
432
 
433
  print(f"Using device: {args.device}")
434
+ print(f"fp8: {args.fp8}")
435
  print(f"Offload: {args.offload}")
436
 
437
  demo = create_demo(args, args.name, args.device, args.offload, args.aggressive_offload)
flux/util.py CHANGED
@@ -123,36 +123,17 @@ def load_flow_model(name: str, device: str = "cuda", hf_download: bool = True):
123
  ):
124
  ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow, local_dir='models')
125
 
126
- # Initialize the model on the 'meta' device, which doesn't allocate real memory
127
- with torch.device('meta'):
128
- model = Flux(configs[name].params)
129
- model = model.to_empty(device=device)
130
 
131
  if ckpt_path is not None:
132
  print("Loading checkpoint")
133
- # Load the state dictionary directly to the desired device
134
  sd = load_sft(ckpt_path, device=str(device))
135
- # Load the state dictionary into the model
136
  missing, unexpected = model.load_state_dict(sd, strict=False)
137
  print_load_warning(missing, unexpected)
138
- model.to(torch.bfloat16)
139
  return model
140
 
141
- # from XLabs-AI https://github.com/XLabs-AI/x-flux/blob/1f8ef54972105ad9062be69fe6b7f841bce02a08/src/flux/util.py#L330
142
- def load_flow_model_quintized(name: str, device: str = "cuda", hf_download: bool = True):
143
- # Loading Flux
144
- print("Init model")
145
- ckpt_path = 'models/flux-dev-fp8.safetensors'
146
- if (
147
- not os.path.exists(ckpt_path)
148
- and hf_download
149
- ):
150
- print("Downloading model")
151
- ckpt_path = hf_hub_download("XLabs-AI/flux-dev-fp8", "flux-dev-fp8.safetensors")
152
- print("Model downloaded to", ckpt_path)
153
- json_path = hf_hub_download("XLabs-AI/flux-dev-fp8", 'flux_dev_quantization_map.json')
154
-
155
- model = Flux(configs[name].params).to(torch.bfloat16)
156
  def load_flow_model_quintized(
157
  name: str,
158
  device: str = "cuda",
 
123
  ):
124
  ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow, local_dir='models')
125
 
126
+ with torch.device(device):
127
+ model = Flux(configs[name].params).to(torch.bfloat16)
 
 
128
 
129
  if ckpt_path is not None:
130
  print("Loading checkpoint")
131
+ # load_sft doesn't support torch.device
132
  sd = load_sft(ckpt_path, device=str(device))
 
133
  missing, unexpected = model.load_state_dict(sd, strict=False)
134
  print_load_warning(missing, unexpected)
 
135
  return model
136
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  def load_flow_model_quintized(
138
  name: str,
139
  device: str = "cuda",