sentencebird commited on
Commit
f3e20a5
·
1 Parent(s): 85a1bed

add: 背景マスキングのモデルをpickleでload

Browse files
.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ deeplabv3_resnet101.pkl filter=lfs diff=lfs merge=lfs -text
deeplabv3_resnet101.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9d3cc589cf318a3e3b010dd2b701d38d0fbe779da680cead45e63a16a3016e08
3
+ size 244712524
torchvision_funcs.py CHANGED
@@ -1,11 +1,13 @@
1
  import numpy as np
2
  import cv2
3
  from PIL import Image
 
4
 
5
  import torch
6
  import torchvision
7
  from torchvision import transforms
8
 
 
9
  def deeplabv3_remove_bg(img):
10
  img = np.array(img, dtype=np.uint8)
11
  # img = cv2.imread(image_path)
@@ -14,10 +16,8 @@ def deeplabv3_remove_bg(img):
14
  # img = cv2.resize(img,(1000,1000))
15
 
16
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
17
-
18
- model = torchvision.models.segmentation.deeplabv3_resnet101(pretrained=True)
19
- model = model.to(device)
20
- model.eval();
21
 
22
  preprocess = transforms.Compose([
23
  transforms.ToTensor(),
 
1
  import numpy as np
2
  import cv2
3
  from PIL import Image
4
+ import pickle
5
 
6
  import torch
7
  import torchvision
8
  from torchvision import transforms
9
 
10
+
11
  def deeplabv3_remove_bg(img):
12
  img = np.array(img, dtype=np.uint8)
13
  # img = cv2.imread(image_path)
 
16
  # img = cv2.resize(img,(1000,1000))
17
 
18
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
19
+ with open('deeplabv3_resnet101.pkl', 'rb') as f:
20
+ model = pickle.load(f)
 
 
21
 
22
  preprocess = transforms.Compose([
23
  transforms.ToTensor(),