Spaces:
Paused
Paused
| import getpass | |
| import math | |
| import os | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| import requests | |
| import torch | |
| from einops import rearrange | |
| from huggingface_hub import hf_hub_download, login | |
| from PIL import ExifTags, Image | |
| from safetensors.torch import load_file as load_sft | |
| from flux.model import Flux, FluxLoraWrapper, FluxParams | |
| from flux.modules.autoencoder import AutoEncoder, AutoEncoderParams | |
| from flux.modules.conditioner import HFEmbedder | |
| CHECKPOINTS_DIR = Path("checkpoints") | |
| BFL_API_KEY = os.getenv("BFL_API_KEY") | |
| def ensure_hf_auth(): | |
| hf_token = os.environ.get("HF_TOKEN") | |
| if hf_token: | |
| print("Trying to authenticate to HuggingFace with the HF_TOKEN environment variable.") | |
| try: | |
| login(token=hf_token) | |
| print("Successfully authenticated with HuggingFace using HF_TOKEN") | |
| return True | |
| except Exception as e: | |
| print(f"Warning: Failed to authenticate with HF_TOKEN: {e}") | |
| if os.path.exists(os.path.expanduser("~/.cache/huggingface/token")): | |
| print("Already authenticated with HuggingFace") | |
| return True | |
| return False | |
| def prompt_for_hf_auth(): | |
| try: | |
| token = getpass.getpass("HF Token (hidden input): ").strip() | |
| if not token: | |
| print("No token provided. Aborting.") | |
| return False | |
| login(token=token) | |
| print("Successfully authenticated!") | |
| return True | |
| except KeyboardInterrupt: | |
| print("\nAuthentication cancelled by user.") | |
| return False | |
| except Exception as auth_e: | |
| print(f"Authentication failed: {auth_e}") | |
| print("Tip: You can also run 'huggingface-cli login' or set HF_TOKEN environment variable") | |
| return False | |
| def get_checkpoint_path(repo_id: str, filename: str, env_var: str) -> Path: | |
| """Get the local path for a checkpoint file, downloading if necessary.""" | |
| # if os.environ.get(env_var) is not None: | |
| # local_path = os.environ[env_var] | |
| # if os.path.exists(local_path): | |
| # return Path(local_path) | |
| # print( | |
| # f"Trying to load model {repo_id}, {filename} from environment " | |
| # f"variable {env_var}. But file {local_path} does not exist. " | |
| # "Falling back to default location." | |
| # ) | |
| # # Create a safe directory name from repo_id | |
| # safe_repo_name = repo_id.replace("/", "_") | |
| # checkpoint_dir = CHECKPOINTS_DIR / safe_repo_name | |
| # checkpoint_dir.mkdir(exist_ok=True) | |
| # local_path = checkpoint_dir / filename | |
| local_path = filename | |
| from mmgp import offload | |
| if False: | |
| print(f"Downloading {filename} from {repo_id} to {local_path}") | |
| try: | |
| ensure_hf_auth() | |
| hf_hub_download(repo_id=repo_id, filename=filename, local_dir=checkpoint_dir) | |
| except Exception as e: | |
| if "gated repo" in str(e).lower() or "restricted" in str(e).lower(): | |
| print(f"\nError: Cannot access {repo_id} -- this is a gated repository.") | |
| # Try one more time to authenticate | |
| if prompt_for_hf_auth(): | |
| # Retry the download after authentication | |
| print("Retrying download...") | |
| hf_hub_download(repo_id=repo_id, filename=filename, local_dir=checkpoint_dir) | |
| else: | |
| print("Authentication failed or cancelled.") | |
| print("You can also run 'huggingface-cli login' or set HF_TOKEN environment variable") | |
| raise RuntimeError(f"Authentication required for {repo_id}") | |
| else: | |
| raise e | |
| return local_path | |
| def download_onnx_models_for_trt(model_name: str, trt_transformer_precision: str = "bf16") -> str | None: | |
| """Download ONNX models for TRT to our checkpoints directory""" | |
| onnx_repo_map = { | |
| "flux-dev": "black-forest-labs/FLUX.1-dev-onnx", | |
| "flux-schnell": "black-forest-labs/FLUX.1-schnell-onnx", | |
| "flux-dev-canny": "black-forest-labs/FLUX.1-Canny-dev-onnx", | |
| "flux-dev-depth": "black-forest-labs/FLUX.1-Depth-dev-onnx", | |
| "flux-dev-redux": "black-forest-labs/FLUX.1-Redux-dev-onnx", | |
| "flux-dev-fill": "black-forest-labs/FLUX.1-Fill-dev-onnx", | |
| "flux-dev-kontext": "black-forest-labs/FLUX.1-Kontext-dev-onnx", | |
| } | |
| if model_name not in onnx_repo_map: | |
| return None # No ONNX repository required for this model | |
| repo_id = onnx_repo_map[model_name] | |
| safe_repo_name = repo_id.replace("/", "_") | |
| onnx_dir = CHECKPOINTS_DIR / safe_repo_name | |
| # Map of module names to their ONNX file paths (using specified precision) | |
| onnx_file_map = { | |
| "clip": "clip.opt/model.onnx", | |
| "transformer": f"transformer.opt/{trt_transformer_precision}/model.onnx", | |
| "transformer_data": f"transformer.opt/{trt_transformer_precision}/backbone.onnx_data", | |
| "t5": "t5.opt/model.onnx", | |
| "t5_data": "t5.opt/backbone.onnx_data", | |
| "vae": "vae.opt/model.onnx", | |
| } | |
| # If all files exist locally, return the custom_onnx_paths format | |
| if onnx_dir.exists(): | |
| all_files_exist = True | |
| custom_paths = [] | |
| for module, onnx_file in onnx_file_map.items(): | |
| if module.endswith("_data"): | |
| continue # Skip data files | |
| local_path = onnx_dir / onnx_file | |
| if not local_path.exists(): | |
| all_files_exist = False | |
| break | |
| custom_paths.append(f"{module}:{local_path}") | |
| if all_files_exist: | |
| print(f"ONNX models ready in {onnx_dir}") | |
| return ",".join(custom_paths) | |
| # If not all files exist, download them | |
| print(f"Downloading ONNX models from {repo_id} to {onnx_dir}") | |
| print(f"Using transformer precision: {trt_transformer_precision}") | |
| onnx_dir.mkdir(exist_ok=True) | |
| # Download all ONNX files | |
| for module, onnx_file in onnx_file_map.items(): | |
| local_path = onnx_dir / onnx_file | |
| if local_path.exists(): | |
| continue # Already downloaded | |
| # Create parent directories | |
| local_path.parent.mkdir(parents=True, exist_ok=True) | |
| try: | |
| print(f"Downloading {onnx_file}") | |
| hf_hub_download(repo_id=repo_id, filename=onnx_file, local_dir=onnx_dir) | |
| except Exception as e: | |
| if "does not exist" in str(e).lower() or "not found" in str(e).lower(): | |
| continue | |
| elif "gated repo" in str(e).lower() or "restricted" in str(e).lower(): | |
| print(f"Cannot access {repo_id} - requires license acceptance") | |
| print("Please follow these steps:") | |
| print(f" 1. Visit: https://huggingface.co/{repo_id}") | |
| print(" 2. Log in to your HuggingFace account") | |
| print(" 3. Accept the license terms and conditions") | |
| print(" 4. Then retry this command") | |
| raise RuntimeError(f"License acceptance required for {model_name}") | |
| else: | |
| # Re-raise other errors | |
| raise | |
| print(f"ONNX models ready in {onnx_dir}") | |
| # Return the custom_onnx_paths format that TRT expects: "module1:path1,module2:path2" | |
| # Note: Only return the actual module paths, not the data file | |
| custom_paths = [] | |
| for module, onnx_file in onnx_file_map.items(): | |
| if module.endswith("_data"): | |
| continue # Skip the data file in the return paths | |
| full_path = onnx_dir / onnx_file | |
| if full_path.exists(): | |
| custom_paths.append(f"{module}:{full_path}") | |
| return ",".join(custom_paths) | |
| def check_onnx_access_for_trt(model_name: str, trt_transformer_precision: str = "bf16") -> str | None: | |
| """Check ONNX access and download models for TRT - returns ONNX directory path""" | |
| return download_onnx_models_for_trt(model_name, trt_transformer_precision) | |
| def track_usage_via_api(name: str, n=1) -> None: | |
| """ | |
| Track usage of licensed models via the BFL API for commercial licensing compliance. | |
| For more information on licensing BFL's models for commercial use and usage reporting, | |
| see the README.md or visit: https://dashboard.bfl.ai/licensing/subscriptions?showInstructions=true | |
| """ | |
| assert BFL_API_KEY is not None, "BFL_API_KEY is not set" | |
| model_slug_map = { | |
| "flux-dev": "flux-1-dev", | |
| "flux-dev-kontext": "flux-1-kontext-dev", | |
| "flux-dev-fill": "flux-tools", | |
| "flux-dev-depth": "flux-tools", | |
| "flux-dev-canny": "flux-tools", | |
| "flux-dev-canny-lora": "flux-tools", | |
| "flux-dev-depth-lora": "flux-tools", | |
| "flux-dev-redux": "flux-tools", | |
| } | |
| if name not in model_slug_map: | |
| print(f"Skipping tracking usage for {name}, as it cannot be tracked. Please check the model name.") | |
| return | |
| model_slug = model_slug_map[name] | |
| url = f"https://api.bfl.ai/v1/licenses/models/{model_slug}/usage" | |
| headers = {"x-key": BFL_API_KEY, "Content-Type": "application/json"} | |
| payload = {"number_of_generations": n} | |
| response = requests.post(url, headers=headers, json=payload) | |
| if response.status_code != 200: | |
| raise Exception(f"Failed to track usage: {response.status_code} {response.text}") | |
| else: | |
| print(f"Successfully tracked usage for {name} with {n} generations") | |
| def save_image( | |
| nsfw_classifier, | |
| name: str, | |
| output_name: str, | |
| idx: int, | |
| x: torch.Tensor, | |
| add_sampling_metadata: bool, | |
| prompt: str, | |
| nsfw_threshold: float = 0.85, | |
| track_usage: bool = False, | |
| ) -> int: | |
| fn = output_name.format(idx=idx) | |
| print(f"Saving {fn}") | |
| # bring into PIL format and save | |
| x = x.clamp(-1, 1) | |
| x = rearrange(x[0], "c h w -> h w c") | |
| img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy()) | |
| if nsfw_classifier is not None: | |
| nsfw_score = [x["score"] for x in nsfw_classifier(img) if x["label"] == "nsfw"][0] | |
| else: | |
| nsfw_score = nsfw_threshold - 1.0 | |
| if nsfw_score < nsfw_threshold: | |
| exif_data = Image.Exif() | |
| if name in ["flux-dev", "flux-schnell"]: | |
| exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux" | |
| else: | |
| exif_data[ExifTags.Base.Software] = "AI generated;img2img;flux" | |
| exif_data[ExifTags.Base.Make] = "Black Forest Labs" | |
| exif_data[ExifTags.Base.Model] = name | |
| if add_sampling_metadata: | |
| exif_data[ExifTags.Base.ImageDescription] = prompt | |
| img.save(fn, exif=exif_data, quality=95, subsampling=0) | |
| if track_usage: | |
| track_usage_via_api(name, 1) | |
| idx += 1 | |
| else: | |
| print("Your generated image may contain NSFW content.") | |
| return idx | |
| class ModelSpec: | |
| params: FluxParams | |
| ae_params: AutoEncoderParams | |
| repo_id: str | |
| repo_flow: str | |
| repo_ae: str | |
| lora_repo_id: str | None = None | |
| lora_filename: str | None = None | |
| configs = { | |
| "flux-dev": ModelSpec( | |
| repo_id="", | |
| repo_flow="", | |
| repo_ae="ckpts/flux_vae.safetensors", | |
| params=FluxParams( | |
| in_channels=64, | |
| out_channels=64, | |
| vec_in_dim=768, | |
| context_in_dim=4096, | |
| hidden_size=3072, | |
| mlp_ratio=4.0, | |
| num_heads=24, | |
| depth=19, | |
| depth_single_blocks=38, | |
| axes_dim=[16, 56, 56], | |
| theta=10_000, | |
| qkv_bias=True, | |
| guidance_embed=True, | |
| ), | |
| ae_params=AutoEncoderParams( | |
| resolution=256, | |
| in_channels=3, | |
| ch=128, | |
| out_ch=3, | |
| ch_mult=[1, 2, 4, 4], | |
| num_res_blocks=2, | |
| z_channels=16, | |
| scale_factor=0.3611, | |
| shift_factor=0.1159, | |
| ), | |
| ), | |
| "flux-schnell": ModelSpec( | |
| repo_id="black-forest-labs/FLUX.1-schnell", | |
| repo_flow="", | |
| repo_ae="ckpts/flux_vae.safetensors", | |
| params=FluxParams( | |
| in_channels=64, | |
| out_channels=64, | |
| vec_in_dim=768, | |
| context_in_dim=4096, | |
| hidden_size=3072, | |
| mlp_ratio=4.0, | |
| num_heads=24, | |
| depth=19, | |
| depth_single_blocks=38, | |
| axes_dim=[16, 56, 56], | |
| theta=10_000, | |
| qkv_bias=True, | |
| guidance_embed=False, | |
| ), | |
| ae_params=AutoEncoderParams( | |
| resolution=256, | |
| in_channels=3, | |
| ch=128, | |
| out_ch=3, | |
| ch_mult=[1, 2, 4, 4], | |
| num_res_blocks=2, | |
| z_channels=16, | |
| scale_factor=0.3611, | |
| shift_factor=0.1159, | |
| ), | |
| ), | |
| "flux-dev-canny": ModelSpec( | |
| repo_id="black-forest-labs/FLUX.1-Canny-dev", | |
| repo_flow="", | |
| repo_ae="ckpts/flux_vae.safetensors", | |
| params=FluxParams( | |
| in_channels=128, | |
| out_channels=64, | |
| vec_in_dim=768, | |
| context_in_dim=4096, | |
| hidden_size=3072, | |
| mlp_ratio=4.0, | |
| num_heads=24, | |
| depth=19, | |
| depth_single_blocks=38, | |
| axes_dim=[16, 56, 56], | |
| theta=10_000, | |
| qkv_bias=True, | |
| guidance_embed=True, | |
| ), | |
| ae_params=AutoEncoderParams( | |
| resolution=256, | |
| in_channels=3, | |
| ch=128, | |
| out_ch=3, | |
| ch_mult=[1, 2, 4, 4], | |
| num_res_blocks=2, | |
| z_channels=16, | |
| scale_factor=0.3611, | |
| shift_factor=0.1159, | |
| ), | |
| ), | |
| "flux-dev-canny-lora": ModelSpec( | |
| repo_id="black-forest-labs/FLUX.1-dev", | |
| repo_flow="", | |
| repo_ae="ckpts/flux_vae.safetensors", | |
| lora_repo_id="black-forest-labs/FLUX.1-Canny-dev-lora", | |
| lora_filename="flux1-canny-dev-lora.safetensors", | |
| params=FluxParams( | |
| in_channels=128, | |
| out_channels=64, | |
| vec_in_dim=768, | |
| context_in_dim=4096, | |
| hidden_size=3072, | |
| mlp_ratio=4.0, | |
| num_heads=24, | |
| depth=19, | |
| depth_single_blocks=38, | |
| axes_dim=[16, 56, 56], | |
| theta=10_000, | |
| qkv_bias=True, | |
| guidance_embed=True, | |
| ), | |
| ae_params=AutoEncoderParams( | |
| resolution=256, | |
| in_channels=3, | |
| ch=128, | |
| out_ch=3, | |
| ch_mult=[1, 2, 4, 4], | |
| num_res_blocks=2, | |
| z_channels=16, | |
| scale_factor=0.3611, | |
| shift_factor=0.1159, | |
| ), | |
| ), | |
| "flux-dev-depth": ModelSpec( | |
| repo_id="black-forest-labs/FLUX.1-Depth-dev", | |
| repo_flow="", | |
| repo_ae="ckpts/flux_vae.safetensors", | |
| params=FluxParams( | |
| in_channels=128, | |
| out_channels=64, | |
| vec_in_dim=768, | |
| context_in_dim=4096, | |
| hidden_size=3072, | |
| mlp_ratio=4.0, | |
| num_heads=24, | |
| depth=19, | |
| depth_single_blocks=38, | |
| axes_dim=[16, 56, 56], | |
| theta=10_000, | |
| qkv_bias=True, | |
| guidance_embed=True, | |
| ), | |
| ae_params=AutoEncoderParams( | |
| resolution=256, | |
| in_channels=3, | |
| ch=128, | |
| out_ch=3, | |
| ch_mult=[1, 2, 4, 4], | |
| num_res_blocks=2, | |
| z_channels=16, | |
| scale_factor=0.3611, | |
| shift_factor=0.1159, | |
| ), | |
| ), | |
| "flux-dev-depth-lora": ModelSpec( | |
| repo_id="black-forest-labs/FLUX.1-dev", | |
| repo_flow="", | |
| repo_ae="ckpts/flux_vae.safetensors", | |
| lora_repo_id="black-forest-labs/FLUX.1-Depth-dev-lora", | |
| lora_filename="flux1-depth-dev-lora.safetensors", | |
| params=FluxParams( | |
| in_channels=128, | |
| out_channels=64, | |
| vec_in_dim=768, | |
| context_in_dim=4096, | |
| hidden_size=3072, | |
| mlp_ratio=4.0, | |
| num_heads=24, | |
| depth=19, | |
| depth_single_blocks=38, | |
| axes_dim=[16, 56, 56], | |
| theta=10_000, | |
| qkv_bias=True, | |
| guidance_embed=True, | |
| ), | |
| ae_params=AutoEncoderParams( | |
| resolution=256, | |
| in_channels=3, | |
| ch=128, | |
| out_ch=3, | |
| ch_mult=[1, 2, 4, 4], | |
| num_res_blocks=2, | |
| z_channels=16, | |
| scale_factor=0.3611, | |
| shift_factor=0.1159, | |
| ), | |
| ), | |
| "flux-dev-redux": ModelSpec( | |
| repo_id="black-forest-labs/FLUX.1-Redux-dev", | |
| repo_flow="", | |
| repo_ae="ckpts/flux_vae.safetensors", | |
| params=FluxParams( | |
| in_channels=64, | |
| out_channels=64, | |
| vec_in_dim=768, | |
| context_in_dim=4096, | |
| hidden_size=3072, | |
| mlp_ratio=4.0, | |
| num_heads=24, | |
| depth=19, | |
| depth_single_blocks=38, | |
| axes_dim=[16, 56, 56], | |
| theta=10_000, | |
| qkv_bias=True, | |
| guidance_embed=True, | |
| ), | |
| ae_params=AutoEncoderParams( | |
| resolution=256, | |
| in_channels=3, | |
| ch=128, | |
| out_ch=3, | |
| ch_mult=[1, 2, 4, 4], | |
| num_res_blocks=2, | |
| z_channels=16, | |
| scale_factor=0.3611, | |
| shift_factor=0.1159, | |
| ), | |
| ), | |
| "flux-dev-fill": ModelSpec( | |
| repo_id="black-forest-labs/FLUX.1-Fill-dev", | |
| repo_flow="", | |
| repo_ae="ckpts/flux_vae.safetensors", | |
| params=FluxParams( | |
| in_channels=384, | |
| out_channels=64, | |
| vec_in_dim=768, | |
| context_in_dim=4096, | |
| hidden_size=3072, | |
| mlp_ratio=4.0, | |
| num_heads=24, | |
| depth=19, | |
| depth_single_blocks=38, | |
| axes_dim=[16, 56, 56], | |
| theta=10_000, | |
| qkv_bias=True, | |
| guidance_embed=True, | |
| ), | |
| ae_params=AutoEncoderParams( | |
| resolution=256, | |
| in_channels=3, | |
| ch=128, | |
| out_ch=3, | |
| ch_mult=[1, 2, 4, 4], | |
| num_res_blocks=2, | |
| z_channels=16, | |
| scale_factor=0.3611, | |
| shift_factor=0.1159, | |
| ), | |
| ), | |
| "flux-dev-kontext": ModelSpec( | |
| repo_id="black-forest-labs/FLUX.1-Kontext-dev", | |
| repo_flow="", | |
| repo_ae="ckpts/flux_vae.safetensors", | |
| params=FluxParams( | |
| in_channels=64, | |
| out_channels=64, | |
| vec_in_dim=768, | |
| context_in_dim=4096, | |
| hidden_size=3072, | |
| mlp_ratio=4.0, | |
| num_heads=24, | |
| depth=19, | |
| depth_single_blocks=38, | |
| axes_dim=[16, 56, 56], | |
| theta=10_000, | |
| qkv_bias=True, | |
| guidance_embed=True, | |
| ), | |
| ae_params=AutoEncoderParams( | |
| resolution=256, | |
| in_channels=3, | |
| ch=128, | |
| out_ch=3, | |
| ch_mult=[1, 2, 4, 4], | |
| num_res_blocks=2, | |
| z_channels=16, | |
| scale_factor=0.3611, | |
| shift_factor=0.1159, | |
| ), | |
| ), | |
| } | |
| PREFERED_KONTEXT_RESOLUTIONS = [ | |
| (672, 1568), | |
| (688, 1504), | |
| (720, 1456), | |
| (752, 1392), | |
| (800, 1328), | |
| (832, 1248), | |
| (880, 1184), | |
| (944, 1104), | |
| (1024, 1024), | |
| (1104, 944), | |
| (1184, 880), | |
| (1248, 832), | |
| (1328, 800), | |
| (1392, 752), | |
| (1456, 720), | |
| (1504, 688), | |
| (1568, 672), | |
| ] | |
| def aspect_ratio_to_height_width(aspect_ratio: str, area: int = 1024**2) -> tuple[int, int]: | |
| width = float(aspect_ratio.split(":")[0]) | |
| height = float(aspect_ratio.split(":")[1]) | |
| ratio = width / height | |
| width = round(math.sqrt(area * ratio)) | |
| height = round(math.sqrt(area / ratio)) | |
| return 16 * (width // 16), 16 * (height // 16) | |
| def print_load_warning(missing: list[str], unexpected: list[str]) -> None: | |
| if len(missing) > 0 and len(unexpected) > 0: | |
| print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) | |
| print("\n" + "-" * 79 + "\n") | |
| print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) | |
| elif len(missing) > 0: | |
| print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) | |
| elif len(unexpected) > 0: | |
| print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) | |
| def load_flow_model(name: str, model_filename, device: str | torch.device = "cuda", verbose: bool = True) -> Flux: | |
| # Loading Flux | |
| config = configs[name] | |
| ckpt_path = model_filename #config.repo_flow | |
| with torch.device("meta"): | |
| if config.lora_repo_id is not None and config.lora_filename is not None: | |
| model = FluxLoraWrapper(params=config.params).to(torch.bfloat16) | |
| else: | |
| model = Flux(config.params).to(torch.bfloat16) | |
| # print(f"Loading checkpoint: {ckpt_path}") | |
| from mmgp import offload | |
| offload.load_model_data(model, model_filename ) | |
| # # load_sft doesn't support torch.device | |
| # sd = load_sft(ckpt_path, device=str(device)) | |
| # sd = optionally_expand_state_dict(model, sd) | |
| # missing, unexpected = model.load_state_dict(sd, strict=False, assign=True) | |
| # if verbose: | |
| # print_load_warning(missing, unexpected) | |
| # if config.lora_repo_id is not None and config.lora_filename is not None: | |
| # print("Loading LoRA") | |
| # lora_path = str(get_checkpoint_path(config.lora_repo_id, config.lora_filename, "FLUX_LORA")) | |
| # lora_sd = load_sft(lora_path, device=str(device)) | |
| # # loading the lora params + overwriting scale values in the norms | |
| # missing, unexpected = model.load_state_dict(lora_sd, strict=False, assign=True) | |
| # if verbose: | |
| # print_load_warning(missing, unexpected) | |
| return model | |
| def load_t5(device: str | torch.device = "cuda", text_encoder_filename = None, max_length: int = 512) -> HFEmbedder: | |
| # max length 64, 128, 256 and 512 should work (if your sequence is short enough) | |
| return HFEmbedder("",text_encoder_filename, max_length=max_length, torch_dtype=torch.bfloat16).to(device) | |
| def load_clip(device: str | torch.device = "cuda") -> HFEmbedder: | |
| return HFEmbedder("ckpts/clip_vit_large_patch14", "", max_length=77, torch_dtype=torch.bfloat16, is_clip =True).to(device) | |
| def load_ae(name: str, device: str | torch.device = "cuda") -> AutoEncoder: | |
| config = configs[name] | |
| ckpt_path = str(get_checkpoint_path(config.repo_id, config.repo_ae, "FLUX_AE")) | |
| # Loading the autoencoder | |
| with torch.device("meta"): | |
| ae = AutoEncoder(config.ae_params) | |
| # print(f"Loading AE checkpoint: {ckpt_path}") | |
| sd = load_sft(ckpt_path, device=str(device)) | |
| missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True) | |
| print_load_warning(missing, unexpected) | |
| return ae | |
| def optionally_expand_state_dict(model: torch.nn.Module, state_dict: dict) -> dict: | |
| """ | |
| Optionally expand the state dict to match the model's parameters shapes. | |
| """ | |
| for name, param in model.named_parameters(): | |
| if name in state_dict: | |
| if state_dict[name].shape != param.shape: | |
| print( | |
| f"Expanding '{name}' with shape {state_dict[name].shape} to model parameter with shape {param.shape}." | |
| ) | |
| # expand with zeros: | |
| expanded_state_dict_weight = torch.zeros_like(param, device=state_dict[name].device) | |
| slices = tuple(slice(0, dim) for dim in state_dict[name].shape) | |
| expanded_state_dict_weight[slices] = state_dict[name] | |
| state_dict[name] = expanded_state_dict_weight | |
| return state_dict | |