pratyyush commited on
Commit
a3e3a44
·
verified ·
1 Parent(s): adf6220

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -117
app.py CHANGED
@@ -1,158 +1,110 @@
1
- # Based on: https://github.com/jantic/DeOldify
2
- import os, re, time
3
-
4
- os.environ["TORCH_HOME"] = os.path.join(os.getcwd(), ".cache")
5
- os.environ["XDG_CACHE_HOME"] = os.path.join(os.getcwd(), ".cache")
6
-
7
- import streamlit as st
8
  import PIL
9
  import cv2
10
  import numpy as np
11
- import uuid
12
- from zipfile import ZipFile, ZIP_DEFLATED
13
  from io import BytesIO
14
- from random import randint
15
- from datetime import datetime
16
-
17
- from src.deoldify import device
18
- from src.deoldify.device_id import DeviceId
19
- from src.deoldify.visualize import *
20
  from src.app_utils import get_model_bin
 
21
 
 
 
 
22
 
23
- device.set(device=DeviceId.CPU)
24
-
25
-
26
- @st.cache(allow_output_mutation=True, show_spinner=False)
27
- def load_model(model_dir, option):
28
- if option.lower() == 'artistic':
29
- model_url = 'https://data.deepai.org/deoldify/ColorizeArtistic_gen.pth'
30
- get_model_bin(model_url, os.path.join(model_dir, "ColorizeArtistic_gen.pth"))
31
- colorizer = get_image_colorizer(artistic=True)
32
- elif option.lower() == 'stable':
33
- model_url = "https://www.dropbox.com/s/usf7uifrctqw9rl/ColorizeStable_gen.pth?dl=0"
34
- get_model_bin(model_url, os.path.join(model_dir, "ColorizeStable_gen.pth"))
35
- colorizer = get_image_colorizer(artistic=False)
36
 
37
- return colorizer
 
 
 
 
38
 
39
 
40
- def resize_img(input_img, max_size):
41
  img = input_img.copy()
42
- img_height, img_width = img.shape[0],img.shape[1]
43
 
44
  if max(img_height, img_width) > max_size:
45
  if img_height > img_width:
46
- new_width = img_width*(max_size/img_height)
47
  new_height = max_size
48
- resized_img = cv2.resize(img,(int(new_width), int(new_height)))
49
- return resized_img
 
50
 
51
- elif img_height <= img_width:
52
- new_width = img_height*(max_size/img_width)
53
- new_height = max_size
54
- resized_img = cv2.resize(img,(int(new_width), int(new_height)))
55
- return resized_img
56
 
57
  return img
58
 
59
 
60
- def colorize_image(pil_image, img_size=800) -> "PIL.Image":
61
- # Open the image
62
- pil_img = pil_image.convert("RGB")
63
- img_rgb = np.array(pil_img)
64
- resized_img_rgb = resize_img(img_rgb, img_size)
65
  resized_pil_img = PIL.Image.fromarray(resized_img_rgb)
66
 
67
- # Send the image to the model
68
  output_pil_img = colorizer.plot_transformed_pil_image(resized_pil_img, render_factor=35, compare=False)
69
-
70
- return output_pil_img
71
-
72
-
73
- def image_download_button(pil_image, filename: str, fmt: str, label="Download"):
74
- if fmt not in ["jpg", "png"]:
75
- raise Exception(f"Unknown image format (Available: {fmt} - case sensitive)")
76
-
77
- pil_format = "JPEG" if fmt == "jpg" else "PNG"
78
- file_format = "jpg" if fmt == "jpg" else "png"
79
- mime = "image/jpeg" if fmt == "jpg" else "image/png"
80
-
81
- buf = BytesIO()
82
- pil_image.save(buf, format=pil_format)
83
-
84
- return st.download_button(
85
- label=label,
86
- data=buf.getvalue(),
87
- file_name=f'{filename}.{file_format}',
88
- mime=mime,
89
- )
90
-
91
-
92
- ###########################
93
- ###### STREAMLIT CODE #####
94
- ###########################
95
-
96
 
97
- st_color_option = "Artistic"
 
 
 
 
98
 
99
- # Load models
100
- try:
101
- with st.spinner("Loading..."):
102
- print('before loading the model')
103
- colorizer = load_model('models/', st_color_option)
104
- print('after loading the model')
105
 
106
- except Exception as e:
107
- colorizer = None
108
- print('Error while loading the model. Please refresh the page')
109
- print(e)
110
- st.write("**App loading error. Please try again later.**")
 
 
111
 
 
 
112
 
113
 
114
- if colorizer is not None:
 
115
  st.title("AI Photo Colorization")
 
116
 
117
- st.image(open("assets/demo.jpg", "rb").read())
 
 
118
 
119
- st.markdown(
120
- """
121
- Colorizing black & white photo can be expensive and time consuming. We introduce AI that can colorize
122
- grayscale photo in seconds. **Just upload your grayscale image, then click colorize.**
123
- """
124
- )
125
-
126
- uploaded_file = st.file_uploader("Upload photo", accept_multiple_files=False, type=["png", "jpg", "jpeg"])
127
 
128
- if uploaded_file is not None:
129
  bytes_data = uploaded_file.getvalue()
130
  img_input = PIL.Image.open(BytesIO(bytes_data)).convert("RGB")
131
 
132
  with st.expander("Original photo", True):
133
  st.image(img_input)
134
 
135
- if st.button("Colorize!") and uploaded_file is not None:
136
-
137
  with st.spinner("AI is doing the magic!"):
138
- img_output = colorize_image(img_input)
139
- img_output = img_output.resize(img_input.size)
140
-
141
- # NOTE: Calm! I'm not logging the input and outputs.
142
- # It is impossible to access the filesystem in spaces environment.
143
- now = datetime.now().strftime("%Y%m%d-%H%M%S-%f")
144
- img_input.convert("RGB").save(f"./output/{now}-input.jpg")
145
- img_output.convert("RGB").save(f"./output/{now}-output.jpg")
146
-
147
- st.write("AI has finished the job!")
148
- st.image(img_output)
149
- # reuse = st.button('Edit again (Re-use this image)', on_click=set_image, args=(inpainted_img, ))
150
-
151
- uploaded_name = os.path.splitext(uploaded_file.name)[0]
152
- image_download_button(
153
- pil_image=img_output,
154
- filename=uploaded_name,
155
- fmt="jpg",
156
- label="Download Image"
157
- )
158
 
 
1
+ import os
2
+ import uvicorn
 
 
 
 
 
3
  import PIL
4
  import cv2
5
  import numpy as np
6
+ from fastapi import FastAPI, File, UploadFile
7
+ from fastapi.responses import JSONResponse
8
  from io import BytesIO
9
+ from src.deoldify.visualize import get_image_colorizer
 
 
 
 
 
10
  from src.app_utils import get_model_bin
11
+ import streamlit as st
12
 
13
+ # Set model cache directory
14
+ os.environ["TORCH_HOME"] = os.path.join(os.getcwd(), ".cache")
15
+ os.environ["XDG_CACHE_HOME"] = os.path.join(os.getcwd(), ".cache")
16
 
17
+ # Initialize FastAPI app
18
+ app = FastAPI()
 
 
 
 
 
 
 
 
 
 
 
19
 
20
+ # Load the DeOldify model
21
+ MODEL_DIR = "models/"
22
+ MODEL_URL = "https://data.deepai.org/deoldify/ColorizeArtistic_gen.pth"
23
+ get_model_bin(MODEL_URL, os.path.join(MODEL_DIR, "ColorizeArtistic_gen.pth"))
24
+ colorizer = get_image_colorizer(artistic=True)
25
 
26
 
27
+ def resize_img(input_img, max_size=800):
28
  img = input_img.copy()
29
+ img_height, img_width = img.shape[0], img.shape[1]
30
 
31
  if max(img_height, img_width) > max_size:
32
  if img_height > img_width:
33
+ new_width = int(img_width * (max_size / img_height))
34
  new_height = max_size
35
+ else:
36
+ new_width = max_size
37
+ new_height = int(img_height * (max_size / img_width))
38
 
39
+ resized_img = cv2.resize(img, (new_width, new_height))
40
+ return resized_img
 
 
 
41
 
42
  return img
43
 
44
 
45
+ def colorize_image(image_bytes) -> bytes:
46
+ # Convert uploaded image to PIL format
47
+ pil_image = PIL.Image.open(BytesIO(image_bytes)).convert("RGB")
48
+ img_rgb = np.array(pil_image)
49
+ resized_img_rgb = resize_img(img_rgb, 800)
50
  resized_pil_img = PIL.Image.fromarray(resized_img_rgb)
51
 
52
+ # Colorize image
53
  output_pil_img = colorizer.plot_transformed_pil_image(resized_pil_img, render_factor=35, compare=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
+ # Convert back to bytes
56
+ img_io = BytesIO()
57
+ output_pil_img.save(img_io, format="JPEG")
58
+ img_io.seek(0)
59
+ return img_io.getvalue()
60
 
 
 
 
 
 
 
61
 
62
+ @app.post("/colorize")
63
+ async def colorize(file: UploadFile = File(...)):
64
+ try:
65
+ image_bytes = await file.read()
66
+ colorized_image = colorize_image(image_bytes)
67
+
68
+ return JSONResponse(content={"status": "success", "image": colorized_image}, media_type="image/jpeg")
69
 
70
+ except Exception as e:
71
+ return JSONResponse(content={"error": str(e)}, status_code=500)
72
 
73
 
74
+ # Start Streamlit UI
75
+ def start_streamlit():
76
  st.title("AI Photo Colorization")
77
+ st.image("assets/demo.jpg")
78
 
79
+ st.markdown("""
80
+ Upload a black-and-white image, and AI will colorize it in seconds.
81
+ """)
82
 
83
+ uploaded_file = st.file_uploader("Upload photo", type=["png", "jpg", "jpeg"])
 
 
 
 
 
 
 
84
 
85
+ if uploaded_file:
86
  bytes_data = uploaded_file.getvalue()
87
  img_input = PIL.Image.open(BytesIO(bytes_data)).convert("RGB")
88
 
89
  with st.expander("Original photo", True):
90
  st.image(img_input)
91
 
92
+ if st.button("Colorize!"):
 
93
  with st.spinner("AI is doing the magic!"):
94
+ img_output = colorize_image(bytes_data)
95
+ st.image(img_output)
96
+
97
+ st.download_button(
98
+ label="Download Image",
99
+ data=img_output,
100
+ file_name="colorized.jpg",
101
+ mime="image/jpeg"
102
+ )
103
+
104
+
105
+ if __name__ == "__main__":
106
+ import threading
107
+ threading.Thread(target=start_streamlit, daemon=True).start()
108
+ uvicorn.run(app, host="0.0.0.0", port=7860)
109
+
 
 
 
 
110