marta-0 commited on
Commit
11ae501
1 Parent(s): 6da6215

initate cfg and model outside inf

Browse files
Files changed (1) hide show
  1. app.py +6 -5
app.py CHANGED
@@ -13,23 +13,25 @@ def download_file(http_address, file_name):
13
  r = requests.get(http_address, allow_redirects=True)
14
  open(file_name, 'wb').write(r.content)
15
 
16
-
17
- cfgs = ['configs/modnet/modnet_mobilenetv2.yml', 'configs/modnet/modnet_resnet50_vd.yml', 'configs/modnet/modnet_hrnet_w18.yml']
18
 
19
  download_file('https://paddleseg.bj.bcebos.com/matting/models/modnet-mobilenetv2.pdparams', 'modnet-mobilenetv2.pdparams')
20
  download_file('https://paddleseg.bj.bcebos.com/matting/models/modnet-resnet50_vd.pdparams', 'modnet-resnet50_vd.pdparams')
21
  download_file('https://paddleseg.bj.bcebos.com/matting/models/modnet-hrnet_w18.pdparams', 'modnet-hrnet_w18.pdparams')
22
  models_paths = ['modnet-mobilenetv2.pdparams', 'modnet-resnet50_vd.pdparams', 'modnet-hrnet_w18.pdparams']
 
23
 
24
 
25
  def inference(image, chosen_model):
26
  paddle.set_device('cpu')
27
- cfg = Config(cfgs[chosen_model])
28
 
 
29
  val_dataset = cfg.val_dataset
30
- model = cfg.model
31
  img_transforms = val_dataset.transforms
32
 
 
 
33
  alpha_pred = predict(model,
34
  model_path=models_paths[chosen_model],
35
  transforms=img_transforms,
@@ -41,7 +43,6 @@ def inference(image, chosen_model):
41
  inputs = [gr.inputs.Image(label='Input Image'),
42
  gr.inputs.Radio(['MobileNetV2', 'ResNet50_vd', 'HRNet_W18'], label='Model', type='index')]
43
 
44
-
45
  gr.Interface(
46
  inference,
47
  inputs,
 
13
  r = requests.get(http_address, allow_redirects=True)
14
  open(file_name, 'wb').write(r.content)
15
 
16
+ cfg_paths = ['configs/modnet/modnet_mobilenetv2.yml', 'configs/modnet/modnet_resnet50_vd.yml', 'configs/modnet/modnet_hrnet_w18.yml']
17
+ cfgs = [Config(cfg) for cfg in cfg_paths]
18
 
19
  download_file('https://paddleseg.bj.bcebos.com/matting/models/modnet-mobilenetv2.pdparams', 'modnet-mobilenetv2.pdparams')
20
  download_file('https://paddleseg.bj.bcebos.com/matting/models/modnet-resnet50_vd.pdparams', 'modnet-resnet50_vd.pdparams')
21
  download_file('https://paddleseg.bj.bcebos.com/matting/models/modnet-hrnet_w18.pdparams', 'modnet-hrnet_w18.pdparams')
22
  models_paths = ['modnet-mobilenetv2.pdparams', 'modnet-resnet50_vd.pdparams', 'modnet-hrnet_w18.pdparams']
23
+ models = [cfg.model for cfg in cfgs]
24
 
25
 
26
  def inference(image, chosen_model):
27
  paddle.set_device('cpu')
 
28
 
29
+ cfg = cfgs[chosen_model]
30
  val_dataset = cfg.val_dataset
 
31
  img_transforms = val_dataset.transforms
32
 
33
+ model = models[chosen_model]
34
+
35
  alpha_pred = predict(model,
36
  model_path=models_paths[chosen_model],
37
  transforms=img_transforms,
 
43
  inputs = [gr.inputs.Image(label='Input Image'),
44
  gr.inputs.Radio(['MobileNetV2', 'ResNet50_vd', 'HRNet_W18'], label='Model', type='index')]
45
 
 
46
  gr.Interface(
47
  inference,
48
  inputs,