Spaces:
Running
on
Zero
Running
on
Zero
| import random | |
| import gradio as gr | |
| import numpy as np | |
| import spaces | |
| import torch | |
| import os | |
| from huggingface_hub import hf_hub_download | |
| from diffusers import StableDiffusionXLPipeline, AutoencoderKL, EulerAncestralDiscreteScheduler | |
| from compel import Compel, ReturnedEmbeddingsType | |
| from PIL import Image, PngImagePlugin | |
| import json | |
| import io | |
| # ===================================== | |
| # Prompt weights | |
| # ===================================== | |
| import re | |
| def parse_prompt_attention(text): | |
| re_attention = re.compile(r""" | |
| \\\(| | |
| \\\)| | |
| \\\[| | |
| \\]| | |
| \\\\| | |
| \\| | |
| \(| | |
| \[| | |
| :([+-]?[.\d]+)\)| | |
| \)| | |
| ]| | |
| [^\\()\[\]:]+| | |
| : | |
| """, re.X) | |
| res = [] | |
| round_brackets = [] | |
| square_brackets = [] | |
| round_bracket_multiplier = 1.1 | |
| square_bracket_multiplier = 1 / 1.1 | |
| def multiply_range(start_position, multiplier): | |
| for p in range(start_position, len(res)): | |
| res[p][1] *= multiplier | |
| for m in re_attention.finditer(text): | |
| text = m.group(0) | |
| weight = m.group(1) | |
| if text.startswith('\\'): | |
| res.append([text[1:], 1.0]) | |
| elif text == '(': | |
| round_brackets.append(len(res)) | |
| elif text == '[': | |
| square_brackets.append(len(res)) | |
| elif weight is not None and len(round_brackets) > 0: | |
| multiply_range(round_brackets.pop(), float(weight)) | |
| elif text == ')' and len(round_brackets) > 0: | |
| multiply_range(round_brackets.pop(), round_bracket_multiplier) | |
| elif text == ']' and len(square_brackets) > 0: | |
| multiply_range(square_brackets.pop(), square_bracket_multiplier) | |
| else: | |
| parts = re.split(re.compile(r"\s*\bBREAK\b\s*", re.S), text) | |
| for i, part in enumerate(parts): | |
| if i > 0: | |
| res.append(["BREAK", -1]) | |
| res.append([part, 1.0]) | |
| for pos in round_brackets: | |
| multiply_range(pos, round_bracket_multiplier) | |
| for pos in square_brackets: | |
| multiply_range(pos, square_bracket_multiplier) | |
| if len(res) == 0: | |
| res = [["", 1.0]] | |
| # merge runs of identical weights | |
| i = 0 | |
| while i + 1 < len(res): | |
| if res[i][1] == res[i + 1][1]: | |
| res[i][0] += res[i + 1][0] | |
| res.pop(i + 1) | |
| else: | |
| i += 1 | |
| return res | |
| def prompt_attention_to_invoke_prompt(attention): | |
| tokens = [] | |
| for text, weight in attention: | |
| # Round weight to 2 decimal places | |
| weight = round(weight, 2) | |
| if weight == 1.0: | |
| tokens.append(text) | |
| elif weight < 1.0: | |
| if weight < 0.8: | |
| tokens.append(f"({text}){weight}") | |
| else: | |
| tokens.append(f"({text})-" + "-" * int((1.0 - weight) * 10)) | |
| else: | |
| if weight < 1.3: | |
| tokens.append(f"({text})" + "+" * int((weight - 1.0) * 10)) | |
| else: | |
| tokens.append(f"({text}){weight}") | |
| return "".join(tokens) | |
| def concat_tensor(t): | |
| t_list = torch.split(t, 1, dim=0) | |
| t = torch.cat(t_list, dim=1) | |
| return t | |
| def merge_embeds(prompt_chanks, compel): | |
| num_chanks = len(prompt_chanks) | |
| if num_chanks != 0: | |
| power_prompt = 1/(num_chanks*(num_chanks+1)//2) | |
| prompt_embs = compel(prompt_chanks) | |
| t_list = list(torch.split(prompt_embs, 1, dim=0)) | |
| for i in range(num_chanks): | |
| t_list[-(i+1)] = t_list[-(i+1)] * ((i+1)*power_prompt) | |
| prompt_emb = torch.stack(t_list, dim=0).sum(dim=0) | |
| else: | |
| prompt_emb = compel('') | |
| return prompt_emb | |
| def detokenize(chunk, actual_prompt): | |
| chunk[-1] = chunk[-1].replace('</w>', '') | |
| chanked_prompt = ''.join(chunk).strip() | |
| while '</w>' in chanked_prompt: | |
| if actual_prompt[chanked_prompt.find('</w>')] == ' ': | |
| chanked_prompt = chanked_prompt.replace('</w>', ' ', 1) | |
| else: | |
| chanked_prompt = chanked_prompt.replace('</w>', '', 1) | |
| actual_prompt = actual_prompt.replace(chanked_prompt,'') | |
| return chanked_prompt.strip(), actual_prompt.strip() | |
| def tokenize_line(line, tokenizer): # split into chunks | |
| actual_prompt = line.lower().strip() | |
| actual_tokens = tokenizer.tokenize(actual_prompt) | |
| max_tokens = tokenizer.model_max_length - 2 | |
| comma_token = tokenizer.tokenize(',')[0] | |
| chunks = [] | |
| chunk = [] | |
| for item in actual_tokens: | |
| chunk.append(item) | |
| if len(chunk) == max_tokens: | |
| if chunk[-1] != comma_token: | |
| for i in range(max_tokens-1, -1, -1): | |
| if chunk[i] == comma_token: | |
| actual_chunk, actual_prompt = detokenize(chunk[:i+1], actual_prompt) | |
| chunks.append(actual_chunk) | |
| chunk = chunk[i+1:] | |
| break | |
| else: | |
| actual_chunk, actual_prompt = detokenize(chunk, actual_prompt) | |
| chunks.append(actual_chunk) | |
| chunk = [] | |
| else: | |
| actual_chunk, actual_prompt = detokenize(chunk, actual_prompt) | |
| chunks.append(actual_chunk) | |
| chunk = [] | |
| if chunk: | |
| actual_chunk, _ = detokenize(chunk, actual_prompt) | |
| chunks.append(actual_chunk) | |
| return chunks | |
| def get_embed_new(prompt, pipeline, compel, only_convert_string=False, compel_process_sd=False): | |
| if compel_process_sd: | |
| return merge_embeds(tokenize_line(prompt, pipeline.tokenizer), compel) | |
| else: | |
| # fix bug weights conversion excessive emphasis | |
| prompt = prompt.replace("((", "(").replace("))", ")").replace("\\", "\\\\\\") | |
| # Convert to Compel | |
| attention = parse_prompt_attention(prompt) | |
| global_attention_chanks = [] | |
| for att in attention: | |
| for chank in att[0].split(','): | |
| temp_prompt_chanks = tokenize_line(chank, pipeline.tokenizer) | |
| for small_chank in temp_prompt_chanks: | |
| temp_dict = { | |
| "weight": round(att[1], 2), | |
| "lenght": len(pipeline.tokenizer.tokenize(f'{small_chank},')), | |
| "prompt": f'{small_chank},' | |
| } | |
| global_attention_chanks.append(temp_dict) | |
| max_tokens = pipeline.tokenizer.model_max_length - 2 | |
| global_prompt_chanks = [] | |
| current_list = [] | |
| current_length = 0 | |
| for item in global_attention_chanks: | |
| if current_length + item['lenght'] > max_tokens: | |
| global_prompt_chanks.append(current_list) | |
| current_list = [[item['prompt'], item['weight']]] | |
| current_length = item['lenght'] | |
| else: | |
| if not current_list: | |
| current_list.append([item['prompt'], item['weight']]) | |
| else: | |
| if item['weight'] != current_list[-1][1]: | |
| current_list.append([item['prompt'], item['weight']]) | |
| else: | |
| current_list[-1][0] += f" {item['prompt']}" | |
| current_length += item['lenght'] | |
| if current_list: | |
| global_prompt_chanks.append(current_list) | |
| if only_convert_string: | |
| return ' '.join([prompt_attention_to_invoke_prompt(i) for i in global_prompt_chanks]) | |
| return merge_embeds([prompt_attention_to_invoke_prompt(i) for i in global_prompt_chanks], compel) | |
| # Add metadata to the image | |
| def add_metadata_to_image(image, metadata): | |
| metadata_str = json.dumps(metadata) | |
| # Convert PIL Image to PNG with metadata | |
| img_with_metadata = image.copy() | |
| # Create a PngInfo object and add metadata | |
| png_info = PngImagePlugin.PngInfo() | |
| png_info.add_text("parameters", metadata_str) | |
| # Save to a byte buffer with metadata | |
| buffer = io.BytesIO() | |
| img_with_metadata.save(buffer, format="PNG", pnginfo=png_info) | |
| # Reopen from buffer to get the image with metadata | |
| buffer.seek(0) | |
| return Image.open(buffer) | |
| def add_comma_after_pattern_ti(text): | |
| pattern = re.compile(r'\b\w+_\d+\b') | |
| modified_text = pattern.sub(lambda x: x.group() + ',', text) | |
| return modified_text | |
| if not torch.cuda.is_available(): | |
| DESCRIPTION += "\n<p>你现在运行在CPU上 但是此项目只支持GPU.</p>" | |
| MAX_SEED = np.iinfo(np.int32).max | |
| MAX_IMAGE_SIZE = 2048 | |
| if torch.cuda.is_available(): | |
| # vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) | |
| token = os.environ.get("HF_TOKEN") # 从环境变量读取令牌 | |
| model_path = hf_hub_download( | |
| repo_id="Menyu/miaomiaoHarem_vPredDogma10", # 模型仓库名称(非完整URL) | |
| filename="MiaoMiao Harem_V-Pred_dogma_1.1_FP16.safetensors", | |
| use_auth_token=token | |
| ) | |
| pipe = StableDiffusionXLPipeline.from_single_file( | |
| model_path, | |
| #vae=vae, 内置VAE | |
| use_safetensors=True, | |
| torch_dtype=torch.float16, | |
| ) | |
| pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config) | |
| pipe.scheduler.register_to_config( | |
| prediction_type="v_prediction", | |
| rescale_betas_zero_snr=True, | |
| ) | |
| pipe.to("cuda") | |
| def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: | |
| if randomize_seed: | |
| seed = random.randint(0, MAX_SEED) | |
| return seed | |
| def infer( | |
| prompt: str, | |
| negative_prompt: str = "lowres, {bad}, error, fewer, extra, missing, worst quality, jpeg artifacts, bad quality, watermark, unfinished, displeasing, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]", | |
| use_negative_prompt: bool = True, | |
| seed: int = 7, | |
| width: int = 1024, | |
| height: int = 1536, | |
| guidance_scale: float = 3, | |
| num_inference_steps: int = 30, | |
| randomize_seed: bool = True, | |
| use_resolution_binning: bool = True, | |
| progress=gr.Progress(track_tqdm=True), | |
| ): | |
| seed = int(randomize_seed_fn(seed, randomize_seed)) | |
| generator = torch.Generator().manual_seed(seed) | |
| # 初始化 Compel 实例 | |
| compel = Compel( | |
| tokenizer=[pipe.tokenizer, pipe.tokenizer_2], | |
| text_encoder=[pipe.text_encoder, pipe.text_encoder_2], | |
| returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, | |
| requires_pooled=[False, True], | |
| truncate_long_prompts=False | |
| ) | |
| # 在 infer 函数中调用 get_embed_new | |
| if not use_negative_prompt: | |
| negative_prompt = "" | |
| original_prompt = prompt # Store original prompt for metadata | |
| prompt = get_embed_new(prompt, pipe, compel, only_convert_string=True) | |
| negative_prompt = get_embed_new(negative_prompt, pipe, compel, only_convert_string=True) | |
| conditioning, pooled = compel([prompt, negative_prompt]) # 必须同时处理来保证长度相等 | |
| # 在调用 pipe 时,使用新的参数名称(确保参数名称正确) | |
| image = pipe( | |
| prompt_embeds=conditioning[0:1], | |
| pooled_prompt_embeds=pooled[0:1], | |
| negative_prompt_embeds=conditioning[1:2], | |
| negative_pooled_prompt_embeds=pooled[1:2], | |
| width=width, | |
| height=height, | |
| guidance_scale=guidance_scale, | |
| num_inference_steps=num_inference_steps, | |
| generator=generator, | |
| use_resolution_binning=use_resolution_binning, | |
| ).images[0] | |
| # Create metadata dictionary | |
| metadata = { | |
| "prompt": original_prompt, | |
| "processed_prompt": prompt, | |
| "negative_prompt": negative_prompt, | |
| "seed": seed, | |
| "width": width, | |
| "height": height, | |
| "guidance_scale": guidance_scale, | |
| "num_inference_steps": num_inference_steps, | |
| "model": "miaomiaoHarem_vPredDogma11", | |
| "use_resolution_binning": use_resolution_binning, | |
| "PreUrl": "https://huggingface.co/spaces/Menyu/miaomiaoHaremDogma11" | |
| } | |
| # Add metadata to the image | |
| image_with_metadata = add_metadata_to_image(image, metadata) | |
| return image_with_metadata, seed | |
| examples = [ | |
| "nahida (genshin impact)", | |
| "klee (genshin impact)", | |
| ] | |
| css = ''' | |
| .gradio-container { | |
| max-width: 560px !important; | |
| margin-left: auto !important; | |
| margin-right: auto !important; | |
| } | |
| h1{text-align:center} | |
| ''' | |
| with gr.Blocks(css=css) as demo: | |
| gr.Markdown("""# 梦羽的模型生成器 | |
| ### 快速生成 MiaomiaoHarem vPred Dogma 1.1 模型的图片""") | |
| with gr.Group(): | |
| with gr.Row(): | |
| prompt = gr.Text( | |
| label="关键词", | |
| show_label=True, | |
| max_lines=5, | |
| placeholder="输入你要的图片关键词", | |
| container=False, | |
| ) | |
| run_button = gr.Button("生成", scale=0, variant="primary") | |
| result = gr.Image(label="Result", show_label=False, format="png") | |
| with gr.Accordion("高级选项", open=False): | |
| with gr.Row(): | |
| use_negative_prompt = gr.Checkbox(label="使用反向词条", value=True) | |
| negative_prompt = gr.Text( | |
| label="反向词条", | |
| max_lines=5, | |
| lines=4, | |
| placeholder="输入你要排除的图片关键词", | |
| value="lowres, {bad}, error, fewer, extra, missing, worst quality, jpeg artifacts, bad quality, watermark, unfinished, displeasing, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]", | |
| visible=True, | |
| ) | |
| seed = gr.Slider( | |
| label="种子", | |
| minimum=0, | |
| maximum=MAX_SEED, | |
| step=1, | |
| value=0, | |
| ) | |
| randomize_seed = gr.Checkbox(label="随机种子", value=True) | |
| with gr.Row(visible=True): | |
| width = gr.Slider( | |
| label="宽度", | |
| minimum=512, | |
| maximum=MAX_IMAGE_SIZE, | |
| step=64, | |
| value=832, | |
| ) | |
| height = gr.Slider( | |
| label="高度", | |
| minimum=512, | |
| maximum=MAX_IMAGE_SIZE, | |
| step=64, | |
| value=1216, | |
| ) | |
| with gr.Row(): | |
| guidance_scale = gr.Slider( | |
| label="Guidance Scale", | |
| minimum=0.1, | |
| maximum=10, | |
| step=0.1, | |
| value=7.0, | |
| ) | |
| num_inference_steps = gr.Slider( | |
| label="生成步数", | |
| minimum=1, | |
| maximum=50, | |
| step=1, | |
| value=28, | |
| ) | |
| gr.Examples( | |
| examples=examples, | |
| inputs=prompt, | |
| outputs=[result, seed], | |
| fn=infer | |
| ) | |
| use_negative_prompt.change( | |
| fn=lambda x: gr.update(visible=x), | |
| inputs=use_negative_prompt, | |
| outputs=negative_prompt, | |
| ) | |
| gr.on( | |
| triggers=[prompt.submit, run_button.click], | |
| fn=infer, | |
| inputs=[ | |
| prompt, | |
| negative_prompt, | |
| use_negative_prompt, | |
| seed, | |
| width, | |
| height, | |
| guidance_scale, | |
| num_inference_steps, | |
| randomize_seed, | |
| ], | |
| outputs=[result, seed], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(share=True) |