Ashrafb commited on
Commit
00227e3
·
verified ·
1 Parent(s): f0c148a

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +13 -4
main.py CHANGED
@@ -11,10 +11,13 @@ from io import BytesIO
11
  app = FastAPI()
12
 
13
  # Load model and necessary components
14
- from vtoonify_model import Model
15
- model = Model(device='cuda' if torch.cuda.is_available() else 'cpu')
16
- model.load_model('cartoon1-d')
17
 
 
 
 
 
 
18
  from fastapi.middleware.cors import CORSMiddleware
19
 
20
  app.add_middleware(
@@ -27,6 +30,10 @@ app.add_middleware(
27
 
28
  @app.post("/upload/")
29
  async def process_image(file: UploadFile = File(...), top: int = Form(...), bottom: int = Form(...), left: int = Form(...), right: int = Form(...)):
 
 
 
 
30
  # Read the uploaded image file
31
  contents = await file.read()
32
 
@@ -36,7 +43,7 @@ async def process_image(file: UploadFile = File(...), top: int = Form(...), bott
36
 
37
  # Process the uploaded image
38
  aligned_face, instyle, message = model.detect_and_align_image(frame_rgb, top, bottom, left, right)
39
- processed_image, message = model.image_toonify(aligned_face, instyle, model.exstyle, style_degree=0.5, style_type='cartoon1-d')
40
 
41
  # Convert BGR to RGB
42
  processed_image_rgb = cv2.cvtColor(processed_image, cv2.COLOR_BGR2RGB)
@@ -46,3 +53,5 @@ async def process_image(file: UploadFile = File(...), top: int = Form(...), bott
46
 
47
  # Return the processed image as a streaming response
48
  return StreamingResponse(BytesIO(encoded_image.tobytes()), media_type="image/jpeg")
 
 
 
11
  app = FastAPI()
12
 
13
  # Load model and necessary components
14
+ model = None
 
 
15
 
16
+ def load_model():
17
+ global model
18
+ from vtoonify_model import Model
19
+ model = Model(device='cuda' if torch.cuda.is_available() else 'cpu')
20
+ model.load_model('cartoon4')
21
  from fastapi.middleware.cors import CORSMiddleware
22
 
23
  app.add_middleware(
 
30
 
31
  @app.post("/upload/")
32
  async def process_image(file: UploadFile = File(...), top: int = Form(...), bottom: int = Form(...), left: int = Form(...), right: int = Form(...)):
33
+ global model
34
+ if model is None:
35
+ load_model()
36
+
37
  # Read the uploaded image file
38
  contents = await file.read()
39
 
 
43
 
44
  # Process the uploaded image
45
  aligned_face, instyle, message = model.detect_and_align_image(frame_rgb, top, bottom, left, right)
46
+ processed_image, message = model.image_toonify(aligned_face, instyle, model.exstyle, style_degree=0.5, style_type='cartoon4')
47
 
48
  # Convert BGR to RGB
49
  processed_image_rgb = cv2.cvtColor(processed_image, cv2.COLOR_BGR2RGB)
 
53
 
54
  # Return the processed image as a streaming response
55
  return StreamingResponse(BytesIO(encoded_image.tobytes()), media_type="image/jpeg")
56
+
57
+