Saad0KH commited on
Commit
954e47d
Β·
verified Β·
1 Parent(s): c6be1d4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -15
app.py CHANGED
@@ -25,24 +25,22 @@ executor = ThreadPoolExecutor(max_workers=4)
25
  # GPU model setup
26
  birefnet = None
27
  transform_image = None
 
28
 
29
  def load_model():
30
  global birefnet, transform_image
31
- birefnet = AutoModelForImageSegmentation.from_pretrained(
32
- "ZhengPeng7/BiRefNet", trust_remote_code=True
33
- )
34
- birefnet.to("cuda")
35
- birefnet.eval()
36
- transform_image = transforms.Compose([
37
- transforms.Resize((1024, 1024)),
38
- transforms.ToTensor(),
39
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
40
- ])
41
-
42
- # Lazy load the model on the first request
43
- @app.before_first_request
44
- def initialize():
45
- threading.Thread(target=load_model).start()
46
 
47
  # Helper functions
48
  def decode_image_from_base64(image_data):
@@ -68,6 +66,9 @@ async def process_image(image):
68
  """Process the image asynchronously, including background removal."""
69
  global birefnet, transform_image
70
 
 
 
 
71
  # Convert image to tensor
72
  input_images = transform_image(image).unsqueeze(0).to("cuda")
73
 
 
25
  # GPU model setup
26
  birefnet = None
27
  transform_image = None
28
+ model_loaded = threading.Event()
29
 
30
  def load_model():
31
  global birefnet, transform_image
32
+ if not model_loaded.is_set():
33
+ birefnet = AutoModelForImageSegmentation.from_pretrained(
34
+ "ZhengPeng7/BiRefNet", trust_remote_code=True
35
+ )
36
+ birefnet.to("cuda")
37
+ birefnet.eval()
38
+ transform_image = transforms.Compose([
39
+ transforms.Resize((1024, 1024)),
40
+ transforms.ToTensor(),
41
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
42
+ ])
43
+ model_loaded.set()
 
 
 
44
 
45
  # Helper functions
46
  def decode_image_from_base64(image_data):
 
66
  """Process the image asynchronously, including background removal."""
67
  global birefnet, transform_image
68
 
69
+ # Ensure the model is loaded
70
+ load_model()
71
+
72
  # Convert image to tensor
73
  input_images = transform_image(image).unsqueeze(0).to("cuda")
74