Ashrafb commited on
Commit
b09de3d
·
verified ·
1 Parent(s): e530dd3

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +9 -23
main.py CHANGED
@@ -2,9 +2,10 @@ from fastapi import FastAPI, File, UploadFile, Form
2
  from fastapi.responses import StreamingResponse
3
  from fastapi.staticfiles import StaticFiles
4
  import torch
 
5
  import cv2
6
  import numpy as np
7
- import logging
8
  from io import BytesIO
9
 
10
  app = FastAPI()
@@ -18,9 +19,6 @@ def load_model():
18
  model = Model(device='cuda' if torch.cuda.is_available() else 'cpu')
19
  model.load_model('cartoon4')
20
 
21
- # Configure logging
22
- logging.basicConfig(level=logging.INFO)
23
-
24
  @app.post("/upload/")
25
  async def process_image(file: UploadFile = File(...), top: int = Form(...), bottom: int = Form(...), left: int = Form(...), right: int = Form(...)):
26
  global model
@@ -32,26 +30,13 @@ async def process_image(file: UploadFile = File(...), top: int = Form(...), bott
32
 
33
  # Convert the uploaded image to numpy array
34
  nparr = np.frombuffer(contents, np.uint8)
35
- frame_bgr = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
36
-
37
- if frame_bgr is None:
38
- logging.error("Failed to decode the image.")
39
- return {"error": "Failed to decode the image. Please ensure the file is a valid image format."}
40
-
41
- logging.info(f"Uploaded image shape: {frame_bgr.shape}")
42
 
43
  # Process the uploaded image
44
- aligned_face, instyle, message = model.detect_and_align_image(frame_bgr, top, bottom, left, right)
45
- if aligned_face is None or instyle is None:
46
- logging.error("Failed to process the image: No face detected or alignment failed.")
47
- return {"error": message}
48
-
49
- processed_image, message = model.image_toonify(aligned_face, instyle, model.exstyle, style_degree=0.5, style_type='cartoon1')
50
- if processed_image is None:
51
- logging.error("Failed to toonify the image.")
52
- return {"error": message}
53
 
54
- # Convert BGR to RGB for display
55
  processed_image_rgb = cv2.cvtColor(processed_image, cv2.COLOR_BGR2RGB)
56
 
57
  # Convert processed image to bytes
@@ -60,11 +45,12 @@ async def process_image(file: UploadFile = File(...), top: int = Form(...), bott
60
  # Return the processed image as a streaming response
61
  return StreamingResponse(BytesIO(encoded_image.tobytes()), media_type="image/jpeg")
62
 
 
63
  # Mount static files directory
64
  app.mount("/", StaticFiles(directory="AB", html=True), name="static")
65
 
66
  # Define index route
67
  @app.get("/")
68
  def index():
69
- from fastapi.responses import FileResponse
70
- return FileResponse(path="/app/AB/index.html", media_type="text/html")
 
2
  from fastapi.responses import StreamingResponse
3
  from fastapi.staticfiles import StaticFiles
4
  import torch
5
+ import shutil
6
  import cv2
7
  import numpy as np
8
+ import io
9
  from io import BytesIO
10
 
11
  app = FastAPI()
 
19
  model = Model(device='cuda' if torch.cuda.is_available() else 'cpu')
20
  model.load_model('cartoon4')
21
 
 
 
 
22
  @app.post("/upload/")
23
  async def process_image(file: UploadFile = File(...), top: int = Form(...), bottom: int = Form(...), left: int = Form(...), right: int = Form(...)):
24
  global model
 
30
 
31
  # Convert the uploaded image to numpy array
32
  nparr = np.frombuffer(contents, np.uint8)
33
+ frame_rgb = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
 
 
 
 
 
 
34
 
35
  # Process the uploaded image
36
+ aligned_face, instyle, message = model.detect_and_align_image(frame_rgb, top, bottom, left, right)
37
+ processed_image, message = model.image_toonify(aligned_face, instyle, model.exstyle, style_degree=0.5, style_type='cartoon4')
 
 
 
 
 
 
 
38
 
39
+ # Convert BGR to RGB
40
  processed_image_rgb = cv2.cvtColor(processed_image, cv2.COLOR_BGR2RGB)
41
 
42
  # Convert processed image to bytes
 
45
  # Return the processed image as a streaming response
46
  return StreamingResponse(BytesIO(encoded_image.tobytes()), media_type="image/jpeg")
47
 
48
+
49
  # Mount static files directory
50
  app.mount("/", StaticFiles(directory="AB", html=True), name="static")
51
 
52
  # Define index route
53
  @app.get("/")
54
  def index():
55
+ return FileResponse(path="/app/AB/index.html", media_type="text/html")
56
+