Ashrafb commited on
Commit
f7edd62
·
verified ·
1 Parent(s): 3fc2316

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +2 -73
main.py CHANGED
@@ -1,75 +1,4 @@
1
- from fastapi import FastAPI, File, UploadFile
2
- from fastapi import FastAPI, File, UploadFile, Form, Request
3
- from fastapi.responses import HTMLResponse, FileResponse
4
- from fastapi.staticfiles import StaticFiles
5
- from fastapi.templating import Jinja2Templates
6
- from fastapi import FastAPI, File, UploadFile, HTTPException
7
- from fastapi.responses import JSONResponse
8
- from fastapi.responses import StreamingResponse
9
- import numpy as np
10
- import torch
11
- import torch.nn.functional as F
12
- from torchvision.transforms.functional import normalize
13
- from huggingface_hub import hf_hub_download
14
- from briarmbg import BriaRMBG
15
- import PIL
16
- from PIL import Image
17
- import io
18
-
19
- app = FastAPI()
20
-
21
- net = BriaRMBG()
22
- model_path = hf_hub_download("briaai/RMBG-1.4", 'model.pth')
23
-
24
- if torch.cuda.is_available():
25
- net.load_state_dict(torch.load(model_path))
26
- net = net.cuda()
27
- else:
28
- net.load_state_dict(torch.load(model_path, map_location="cpu"))
29
- net.eval()
30
-
31
- def resize_image(image):
32
- image = image.convert('RGB')
33
- model_input_size = (1024, 1024)
34
- image = image.resize(model_input_size, Image.BILINEAR)
35
- return image
36
-
37
- def process_image(image):
38
- orig_image = image
39
- w, h = orig_image.size
40
- image = resize_image(orig_image)
41
- im_np = np.array(image)
42
- im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1)
43
- im_tensor = torch.unsqueeze(im_tensor, 0)
44
- im_tensor = torch.divide(im_tensor, 255.0)
45
- im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
46
-
47
- if torch.cuda.is_available():
48
- im_tensor = im_tensor.cuda()
49
-
50
- result = net(im_tensor)
51
- result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode='bilinear'), 0)
52
- ma = torch.max(result)
53
- mi = torch.min(result)
54
- result = (result - mi) / (ma - mi)
55
- im_array = (result * 255).cpu().data.numpy().astype(np.uint8)
56
- pil_im = Image.fromarray(np.squeeze(im_array))
57
- new_im = Image.new("RGBA", pil_im.size, (0, 0, 0, 0))
58
- new_im.paste(orig_image, mask=pil_im)
59
-
60
- return new_im
61
-
62
- @app.post("/process-image/")
63
- async def process_image_endpoint(file: UploadFile = File(...)):
64
- contents = await file.read()
65
- pil_image = Image.open(io.BytesIO(contents))
66
- processed_image = process_image(pil_image)
67
-
68
- # Save the processed image temporarily
69
- temp_file_path = "processed_image.png"
70
- processed_image.save(temp_file_path)
71
-
72
- # Return the processed image
73
- return FileResponse(temp_file_path, media_type="image/png")
74
 
75
 
 
 
1
+ import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
 
4
+ exec(os.environ.get('CODE'))