Image-to-Image
Diffusers
Safetensors
English
model_hub_mixin
pytorch_model_hub_mixin
SherryXTChen commited on
Commit
27fead3
·
verified ·
1 Parent(s): 7317c13

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +71 -20
README.md CHANGED
@@ -25,7 +25,6 @@ The model is based on the paper [Instruct-CLIP: Improving Instruction-Guided Ima
25
  ## Capabilities
26
 
27
  <p align="center">
28
- <img src="https://raw.githubusercontent.com/SherryXTChen/Instruct-CLIP/refs/heads/main/assets/teaser_1.png" alt="Figure 1" width="43%">
29
  <img src="https://raw.githubusercontent.com/SherryXTChen/Instruct-CLIP/refs/heads/main/assets/teaser_2.png" alt="Figure 2" width="50%">
30
  </p>
31
 
@@ -34,31 +33,83 @@ The model is based on the paper [Instruct-CLIP: Improving Instruction-Guided Ima
34
  pip install -r requirements.txt
35
  ```
36
 
37
- ## Inference
38
 
39
  ```python
40
- import PIL
41
- import requests
42
  import torch
43
- from diffusers import StableDiffusionInstructPix2PixPipeline, EulerAncestralDiscreteScheduler
44
 
45
- model_id = "timbrooks/instruct-pix2pix"
46
- pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
47
- pipe.load_lora_weights("SherryXTChen/InstructCLIP-InstructPix2Pix")
48
- pipe.to("cuda")
49
- pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
50
 
51
- url = "https://raw.githubusercontent.com/SherryXTChen/Instruct-CLIP/refs/heads/main/assets/1_input.jpg"
52
- def download_image(url):
53
- image = PIL.Image.open(requests.get(url, stream=True).raw)
54
- image = PIL.ImageOps.exif_transpose(image)
55
- image = image.convert("RGB")
56
- return image
57
- image = download_image(url)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
- prompt = "as a 3 d sculpture"
60
- images = pipe(prompt, image=image, num_inference_steps=20).images
61
- images[0].save("output.jpg")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  ```
63
 
64
  ## Citation
 
25
  ## Capabilities
26
 
27
  <p align="center">
 
28
  <img src="https://raw.githubusercontent.com/SherryXTChen/Instruct-CLIP/refs/heads/main/assets/teaser_2.png" alt="Figure 2" width="50%">
29
  </p>
30
 
 
33
  pip install -r requirements.txt
34
  ```
35
 
36
+ ## Edit Instruction Refinement Inference
37
 
38
  ```python
39
+ from PIL import Image
 
40
  import torch
41
+ from torchvision import transforms
42
 
43
+ from model import InstructCLIP
44
+ from utils import get_sd_components, normalize
 
 
 
45
 
46
+ parser = argparse.ArgumentParser(description="Simple example of estimating edit instruction from image pair")
47
+ parser.add_argument(
48
+ "--pretrained_instructclip_name_or_path",
49
+ type=str,
50
+ default="SherryXTChen/Instruct-CLIP",
51
+ help=(
52
+ "instructclip pretrained checkpoints"
53
+ ),
54
+ )
55
+ parser.add_argument(
56
+ "--pretrained_model_name_or_path",
57
+ type=str,
58
+ default="runwayml/stable-diffusion-v1-5",
59
+ help=(
60
+ "sd pretrained checkpoints"
61
+ ),
62
+ )
63
+ parser.add_argument(
64
+ "--input_path",
65
+ type=str,
66
+ default="assets/1_input.jpg",
67
+ help=(
68
+ "Input image path"
69
+ )
70
+ )
71
+ parser.add_argument(
72
+ "--output_path",
73
+ type=str,
74
+ default="assets/1_output.jpg",
75
+ help=(
76
+ "Output image path"
77
+ )
78
+ )
79
+ args = parser.parse_args()
80
+ device = "cuda"
81
+
82
+ # load model for edit instruction estimation
83
+ model = InstructCLIP.from_pretrained("SherryXTChen/Instruct-CLIP")
84
+ model = model.to(device).eval()
85
 
86
+ # load model to preprocess/encode image to latent space
87
+ tokenizer, _, vae, _, _ = get_sd_components(args, device, torch.float32)
88
+
89
+ # prepare image input
90
+ transform = transforms.Compose([
91
+ transforms.ToTensor(),
92
+ transforms.Normalize(mean=[0.5], std=[0.5]),
93
+ ])
94
+ image_list = [args.input_path, args.output_path]
95
+ image_list = [
96
+ transform(Image.open(f).resize((512, 512))).unsqueeze(0).to(device)
97
+ for f in image_list
98
+ ]
99
+
100
+ with torch.no_grad():
101
+ image_list = [vae.encode(x).latent_dist.sample() * vae.config.scaling_factor for x in image_list]
102
+
103
+ # get image feature
104
+ zero_timesteps = torch.zeros_like(torch.tensor([0])).to(device)
105
+ img_feat = model.get_image_features(
106
+ inp=image_list[0], out=image_list[1], inp_t=zero_timesteps, out_t=zero_timesteps)
107
+ img_feat = normalize(img_feat)
108
+
109
+ # get edit instruction
110
+ pred_instruct_input_ids = model.text_decoder.infer(img_feat[:1])[0]
111
+ pred_instruct = tokenizer.decode(pred_instruct_input_ids, skip_special_tokens=True)
112
+ print(pred_instruct) # as a 3 d sculpture
113
  ```
114
 
115
  ## Citation