Spaces:
Runtime error
Runtime error
Commit
·
edb1d95
1
Parent(s):
ffaa6bd
made some improvement
Browse files- .gitignore +2 -1
- configs/experiment1.yaml +2 -2
- constants.yaml +6 -1
- models/.gitkeep +0 -0
- src/scripts/train.py +24 -15
- src/scripts/visualize_results.py +5 -5
- src/simple_regression_colorization/data/visualize_dataset.py +8 -2
- src/simple_regression_colorization/model/base_model_interface.py +18 -3
- src/utils/data_utils.py +94 -47
.gitignore
CHANGED
@@ -88,4 +88,5 @@ target/
|
|
88 |
# Mypy cache
|
89 |
.mypy_cache/
|
90 |
|
91 |
-
/models
|
|
|
|
88 |
# Mypy cache
|
89 |
.mypy_cache/
|
90 |
|
91 |
+
/models
|
92 |
+
/artifacts
|
configs/experiment1.yaml
CHANGED
@@ -12,5 +12,5 @@ shuffle: False
|
|
12 |
|
13 |
# training related
|
14 |
batch_size: 16
|
15 |
-
epochs: 15
|
16 |
-
|
|
|
12 |
|
13 |
# training related
|
14 |
batch_size: 16
|
15 |
+
# epochs: 15
|
16 |
+
epochs: 02
|
constants.yaml
CHANGED
@@ -1,3 +1,8 @@
|
|
1 |
RAW_DATASET_DIR: data/raw/
|
2 |
INTERIM_DATASET_DIR: data/interim/
|
3 |
-
PROCESSED_DATASET_DIR: data/processed/
|
|
|
|
|
|
|
|
|
|
|
|
1 |
RAW_DATASET_DIR: data/raw/
|
2 |
INTERIM_DATASET_DIR: data/interim/
|
3 |
+
PROCESSED_DATASET_DIR: data/processed/
|
4 |
+
|
5 |
+
|
6 |
+
ARTIFACT_MODEL_DIR: artifacts/model/
|
7 |
+
ARTIFACT_DATASET_VISUALIZATION_DIR: artifacts/dataset/
|
8 |
+
ARTIFACT_RESULT_VISUALIZATION_DIR: artifacts/result/
|
models/.gitkeep
DELETED
File without changes
|
src/scripts/train.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
import os,shutil
|
2 |
import argparse
|
3 |
from comet_ml import Experiment
|
4 |
-
from src.utils.config_loader import Config
|
5 |
from src.utils import config_loader
|
6 |
from src.utils.data_utils import print_title
|
7 |
from src.utils.script_utils import validate_config
|
@@ -24,42 +24,51 @@ def train(args):
|
|
24 |
Model = importlib.import_module(f"src.{config.task}.model.models.{config.model}").Model
|
25 |
|
26 |
|
27 |
-
model_dir =
|
28 |
os.makedirs(model_dir,exist_ok=True)
|
29 |
model_save_path = os.path.join(model_dir,"model.weights.h5")
|
30 |
|
31 |
# save config to exported model folder
|
32 |
shutil.copy(config_file_path,model_dir)
|
|
|
|
|
33 |
|
34 |
-
experiment =
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
|
|
|
|
44 |
|
45 |
model = Model(experiment=experiment)
|
|
|
|
|
46 |
model.train()
|
47 |
model.save(model_save_path)
|
48 |
|
49 |
# log model to comet
|
50 |
if "LOCAL_SYSTEM" not in os.environ:
|
51 |
-
experiment
|
52 |
-
|
|
|
53 |
# evaluate model
|
54 |
print_title("\nEvaluating Model")
|
55 |
metrics = model.evaluate()
|
56 |
print("Model Evaluation Metrics:",metrics)
|
57 |
|
58 |
-
experiment
|
|
|
59 |
|
60 |
def main():
|
61 |
parser = argparse.ArgumentParser(description="train model based on config yaml file")
|
62 |
parser.add_argument("config_file",type=str)
|
|
|
63 |
args = parser.parse_args()
|
64 |
train(args)
|
65 |
|
|
|
1 |
import os,shutil
|
2 |
import argparse
|
3 |
from comet_ml import Experiment
|
4 |
+
from src.utils.config_loader import Config,constants
|
5 |
from src.utils import config_loader
|
6 |
from src.utils.data_utils import print_title
|
7 |
from src.utils.script_utils import validate_config
|
|
|
24 |
Model = importlib.import_module(f"src.{config.task}.model.models.{config.model}").Model
|
25 |
|
26 |
|
27 |
+
model_dir = constants.ARTIFACT_MODEL_DIR
|
28 |
os.makedirs(model_dir,exist_ok=True)
|
29 |
model_save_path = os.path.join(model_dir,"model.weights.h5")
|
30 |
|
31 |
# save config to exported model folder
|
32 |
shutil.copy(config_file_path,model_dir)
|
33 |
+
# rename it to config.yaml
|
34 |
+
shutil.move(os.path.join(model_dir,Path(config_file_path).name),os.path.join(model_dir,"config.yaml"))
|
35 |
|
36 |
+
experiment = None
|
37 |
+
if args.log:
|
38 |
+
experiment = Experiment(
|
39 |
+
api_key=os.environ["COMET_API_KEY"],
|
40 |
+
project_name="image-colorization",
|
41 |
+
workspace="anujpanthri",
|
42 |
+
auto_histogram_activation_logging=True,
|
43 |
+
auto_histogram_epoch_rate=True,
|
44 |
+
auto_histogram_gradient_logging=True,
|
45 |
+
auto_histogram_weight_logging=True,
|
46 |
+
auto_param_logging=True,
|
47 |
+
)
|
48 |
|
49 |
model = Model(experiment=experiment)
|
50 |
+
|
51 |
+
print_title("\nTraining Model")
|
52 |
model.train()
|
53 |
model.save(model_save_path)
|
54 |
|
55 |
# log model to comet
|
56 |
if "LOCAL_SYSTEM" not in os.environ:
|
57 |
+
if experiment:
|
58 |
+
experiment.log_model(f"{config.task}_{config.dataset}_{config.model}",model_dir)
|
59 |
+
|
60 |
# evaluate model
|
61 |
print_title("\nEvaluating Model")
|
62 |
metrics = model.evaluate()
|
63 |
print("Model Evaluation Metrics:",metrics)
|
64 |
|
65 |
+
if experiment:
|
66 |
+
experiment.end()
|
67 |
|
68 |
def main():
|
69 |
parser = argparse.ArgumentParser(description="train model based on config yaml file")
|
70 |
parser.add_argument("config_file",type=str)
|
71 |
+
parser.add_argument("--log",action="store_true",default=False)
|
72 |
args = parser.parse_args()
|
73 |
train(args)
|
74 |
|
src/scripts/visualize_results.py
CHANGED
@@ -1,12 +1,12 @@
|
|
1 |
import os
|
2 |
import argparse
|
3 |
-
from src.utils.config_loader import Config
|
4 |
from src.utils import config_loader
|
5 |
from src.utils.script_utils import validate_config
|
6 |
import importlib
|
7 |
|
8 |
|
9 |
-
def
|
10 |
config_file_path = args.config_file
|
11 |
config = Config(config_file_path)
|
12 |
|
@@ -17,7 +17,7 @@ def visualize_dataset(args):
|
|
17 |
config_loader.config = config
|
18 |
|
19 |
# now load model and visualize the results
|
20 |
-
model_dir =
|
21 |
model_save_path = os.path.join(model_dir,"model.weights.h5")
|
22 |
|
23 |
if not os.path.exists(model_save_path):
|
@@ -32,10 +32,10 @@ def visualize_dataset(args):
|
|
32 |
|
33 |
|
34 |
def main():
|
35 |
-
parser = argparse.ArgumentParser(description="
|
36 |
parser.add_argument("config_file",type=str)
|
37 |
args = parser.parse_args()
|
38 |
-
|
39 |
|
40 |
if __name__=="__main__":
|
41 |
main()
|
|
|
1 |
import os
|
2 |
import argparse
|
3 |
+
from src.utils.config_loader import Config,constants
|
4 |
from src.utils import config_loader
|
5 |
from src.utils.script_utils import validate_config
|
6 |
import importlib
|
7 |
|
8 |
|
9 |
+
def visualize_results(args):
|
10 |
config_file_path = args.config_file
|
11 |
config = Config(config_file_path)
|
12 |
|
|
|
17 |
config_loader.config = config
|
18 |
|
19 |
# now load model and visualize the results
|
20 |
+
model_dir = constants.ARTIFACT_MODEL_DIR
|
21 |
model_save_path = os.path.join(model_dir,"model.weights.h5")
|
22 |
|
23 |
if not os.path.exists(model_save_path):
|
|
|
32 |
|
33 |
|
34 |
def main():
|
35 |
+
parser = argparse.ArgumentParser(description="visualize results based on config yaml file and trained model")
|
36 |
parser.add_argument("config_file",type=str)
|
37 |
args = parser.parse_args()
|
38 |
+
visualize_results(args)
|
39 |
|
40 |
if __name__=="__main__":
|
41 |
main()
|
src/simple_regression_colorization/data/visualize_dataset.py
CHANGED
@@ -12,10 +12,16 @@ def visualize():
|
|
12 |
choosen_paths = np.random.choice(image_paths,n)
|
13 |
show_images_from_paths(choosen_paths,
|
14 |
title="sample of train_val dataset",
|
15 |
-
image_size=config.image_size
|
|
|
|
|
|
|
16 |
|
17 |
image_paths = glob(f"{constants.PROCESSED_DATASET_DIR}/test/*")
|
18 |
choosen_paths = np.random.choice(image_paths,n)
|
19 |
show_images_from_paths(choosen_paths,
|
20 |
title="sample of test dataset",
|
21 |
-
image_size=config.image_size
|
|
|
|
|
|
|
|
12 |
choosen_paths = np.random.choice(image_paths,n)
|
13 |
show_images_from_paths(choosen_paths,
|
14 |
title="sample of train_val dataset",
|
15 |
+
image_size=config.image_size,
|
16 |
+
save=True,
|
17 |
+
label="trainval",
|
18 |
+
)
|
19 |
|
20 |
image_paths = glob(f"{constants.PROCESSED_DATASET_DIR}/test/*")
|
21 |
choosen_paths = np.random.choice(image_paths,n)
|
22 |
show_images_from_paths(choosen_paths,
|
23 |
title="sample of test dataset",
|
24 |
+
image_size=config.image_size,
|
25 |
+
save=True,
|
26 |
+
label="test",
|
27 |
+
)
|
src/simple_regression_colorization/model/base_model_interface.py
CHANGED
@@ -68,17 +68,32 @@ class BaseModel(ABC):
|
|
68 |
L_batch,AB_batch = next(iter(self.train_ds))
|
69 |
L_batch = L_batch.numpy()
|
70 |
AB_pred = self.model.predict(L_batch,verbose=0)
|
71 |
-
see_batch(L_batch,
|
|
|
|
|
|
|
|
|
|
|
72 |
|
73 |
L_batch,AB_batch = next(iter(self.val_ds))
|
74 |
L_batch = L_batch.numpy()
|
75 |
AB_pred = self.model.predict(L_batch,verbose=0)
|
76 |
-
see_batch(L_batch,
|
|
|
|
|
|
|
|
|
|
|
77 |
|
78 |
L_batch,AB_batch = next(iter(self.test_ds))
|
79 |
L_batch = L_batch.numpy()
|
80 |
AB_pred = self.model.predict(L_batch,verbose=0)
|
81 |
-
see_batch(L_batch,
|
|
|
|
|
|
|
|
|
|
|
82 |
|
83 |
|
84 |
@abstractmethod
|
|
|
68 |
L_batch,AB_batch = next(iter(self.train_ds))
|
69 |
L_batch = L_batch.numpy()
|
70 |
AB_pred = self.model.predict(L_batch,verbose=0)
|
71 |
+
see_batch(L_batch,
|
72 |
+
AB_pred,
|
73 |
+
title="Train dataset Results",
|
74 |
+
save = True,
|
75 |
+
label = "train",
|
76 |
+
)
|
77 |
|
78 |
L_batch,AB_batch = next(iter(self.val_ds))
|
79 |
L_batch = L_batch.numpy()
|
80 |
AB_pred = self.model.predict(L_batch,verbose=0)
|
81 |
+
see_batch(L_batch,
|
82 |
+
AB_pred,
|
83 |
+
title="Val dataset Results",
|
84 |
+
save = True,
|
85 |
+
label = "val",
|
86 |
+
)
|
87 |
|
88 |
L_batch,AB_batch = next(iter(self.test_ds))
|
89 |
L_batch = L_batch.numpy()
|
90 |
AB_pred = self.model.predict(L_batch,verbose=0)
|
91 |
+
see_batch(L_batch,
|
92 |
+
AB_pred,
|
93 |
+
title="Test dataset Results",
|
94 |
+
save = True,
|
95 |
+
label = "test",
|
96 |
+
)
|
97 |
|
98 |
|
99 |
@abstractmethod
|
src/utils/data_utils.py
CHANGED
@@ -2,101 +2,148 @@ from src.utils.config_loader import constants
|
|
2 |
from huggingface_hub import snapshot_download
|
3 |
from zipfile import ZipFile
|
4 |
import numpy as np
|
5 |
-
import os,shutil
|
6 |
import matplotlib.pyplot as plt
|
7 |
import cv2
|
8 |
import math
|
9 |
|
10 |
|
11 |
-
def download_hf_dataset(repo_id,allow_patterns=None):
|
12 |
"""Used to download dataset from any public hugging face dataset"""
|
13 |
-
snapshot_download(
|
14 |
-
|
15 |
-
|
16 |
-
|
|
|
|
|
17 |
|
18 |
|
19 |
def download_personal_hf_dataset(name):
|
20 |
"""Used to download dataset from a specific hugging face dataset"""
|
21 |
-
download_hf_dataset(
|
22 |
-
|
|
|
23 |
|
24 |
|
25 |
-
def unzip_file(file_path,destination_dir):
|
26 |
"""unzips file to destination_dir"""
|
27 |
if os.path.exists(destination_dir):
|
28 |
shutil.rmtree(destination_dir)
|
29 |
os.makedirs(destination_dir)
|
30 |
-
with ZipFile(file_path,"r") as zip:
|
31 |
zip.extractall(destination_dir)
|
32 |
|
33 |
-
|
|
|
34 |
"""checks if RGB image is black and white"""
|
35 |
-
rg,gb,rb =
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
|
|
|
|
|
|
|
|
|
|
40 |
|
|
|
|
|
|
|
41 |
|
42 |
-
def print_title(msg:str,max_chars=105):
|
43 |
-
n = (max_chars-len(msg))//2
|
44 |
-
print("="*n,msg.upper(),"="*n,sep="")
|
45 |
|
46 |
def scale_L(L):
|
47 |
-
return L/100
|
|
|
48 |
|
49 |
def rescale_L(L):
|
50 |
-
return L*100
|
|
|
51 |
|
52 |
def scale_AB(AB):
|
53 |
-
return AB/128
|
|
|
54 |
|
55 |
def rescale_AB(AB):
|
56 |
-
return AB*128
|
57 |
-
|
58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
|
60 |
-
def show_images_from_paths(image_paths:list[str],image_size=64,cols=4,row_size=5,col_size=5,show_BW=False,title=None):
|
61 |
n = len(image_paths)
|
62 |
-
rows = math.ceil(n/cols)
|
63 |
-
fig = plt.figure(figsize=(col_size*cols,row_size*rows))
|
64 |
if title:
|
65 |
plt.title(title)
|
66 |
plt.axis("off")
|
67 |
|
68 |
for i in range(n):
|
69 |
-
fig.add_subplot(rows,cols,i+1)
|
70 |
-
|
71 |
-
img = cv2.imread(image_paths[i])[
|
72 |
-
img = cv2.resize(img,[image_size,image_size])
|
73 |
|
74 |
if show_BW:
|
75 |
-
BW = cv2.cvtColor(img,cv2.COLOR_RGB2GRAY)
|
76 |
-
BW = np.tile(BW,(1,1,3))
|
77 |
-
img = np.concatenate([BW,img],axis=1)
|
78 |
plt.imshow(img.astype("uint8"))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
plt.show()
|
80 |
|
81 |
|
82 |
-
def see_batch(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
n = L_batch.shape[0]
|
84 |
-
rows = math.ceil(n/cols)
|
85 |
-
fig = plt.figure(figsize=(col_size*cols,row_size*rows))
|
86 |
if title:
|
87 |
plt.title(title)
|
88 |
plt.axis("off")
|
89 |
-
|
90 |
for i in range(n):
|
91 |
-
fig.add_subplot(rows,cols,i+1)
|
92 |
-
L,AB = L_batch[i],AB_batch[i]
|
93 |
-
L,AB = rescale_L(L), rescale_AB(AB)
|
94 |
-
# print(L.shape,AB.shape)
|
95 |
-
img = np.concatenate([L,AB],axis=-1)
|
96 |
-
img = cv2.cvtColor(img,cv2.COLOR_LAB2RGB)*255
|
97 |
-
# print(img.min(),img.max())
|
98 |
if show_L:
|
99 |
-
L = np.tile(L,(1,1,3))/100*255
|
100 |
-
img = np.concatenate([L,img],axis=1)
|
101 |
plt.imshow(img.astype("uint8"))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
plt.show()
|
|
|
2 |
from huggingface_hub import snapshot_download
|
3 |
from zipfile import ZipFile
|
4 |
import numpy as np
|
5 |
+
import os, shutil
|
6 |
import matplotlib.pyplot as plt
|
7 |
import cv2
|
8 |
import math
|
9 |
|
10 |
|
11 |
+
def download_hf_dataset(repo_id, allow_patterns=None):
|
12 |
"""Used to download dataset from any public hugging face dataset"""
|
13 |
+
snapshot_download(
|
14 |
+
repo_id=repo_id,
|
15 |
+
repo_type="dataset",
|
16 |
+
local_dir=constants.RAW_DATASET_DIR,
|
17 |
+
allow_patterns=allow_patterns,
|
18 |
+
)
|
19 |
|
20 |
|
21 |
def download_personal_hf_dataset(name):
|
22 |
"""Used to download dataset from a specific hugging face dataset"""
|
23 |
+
download_hf_dataset(
|
24 |
+
repo_id="Anuj-Panthri/Image-Colorization-Datasets", allow_patterns=f"{name}/*"
|
25 |
+
)
|
26 |
|
27 |
|
28 |
+
def unzip_file(file_path, destination_dir):
|
29 |
"""unzips file to destination_dir"""
|
30 |
if os.path.exists(destination_dir):
|
31 |
shutil.rmtree(destination_dir)
|
32 |
os.makedirs(destination_dir)
|
33 |
+
with ZipFile(file_path, "r") as zip:
|
34 |
zip.extractall(destination_dir)
|
35 |
|
36 |
+
|
37 |
+
def is_bw(img: np.ndarray):
|
38 |
"""checks if RGB image is black and white"""
|
39 |
+
rg, gb, rb = (
|
40 |
+
img[:, :, 0] - img[:, :, 1],
|
41 |
+
img[:, :, 1] - img[:, :, 2],
|
42 |
+
img[:, :, 0] - img[:, :, 2],
|
43 |
+
)
|
44 |
+
rg, gb, rb = np.abs(rg).sum(), np.abs(gb).sum(), np.abs(rb).sum()
|
45 |
+
avg = np.mean([rg, gb, rb])
|
46 |
+
|
47 |
+
return avg < 10
|
48 |
+
|
49 |
|
50 |
+
def print_title(msg: str, max_chars=105):
|
51 |
+
n = (max_chars - len(msg)) // 2
|
52 |
+
print("=" * n, msg.upper(), "=" * n, sep="")
|
53 |
|
|
|
|
|
|
|
54 |
|
55 |
def scale_L(L):
|
56 |
+
return L / 100
|
57 |
+
|
58 |
|
59 |
def rescale_L(L):
|
60 |
+
return L * 100
|
61 |
+
|
62 |
|
63 |
def scale_AB(AB):
|
64 |
+
return AB / 128
|
65 |
+
|
66 |
|
67 |
def rescale_AB(AB):
|
68 |
+
return AB * 128
|
69 |
+
|
70 |
|
71 |
+
def show_images_from_paths(
|
72 |
+
image_paths: list[str],
|
73 |
+
image_size=64,
|
74 |
+
cols=4,
|
75 |
+
row_size=5,
|
76 |
+
col_size=5,
|
77 |
+
show_BW=False,
|
78 |
+
title=None,
|
79 |
+
save=False,
|
80 |
+
label="",
|
81 |
+
):
|
82 |
|
|
|
83 |
n = len(image_paths)
|
84 |
+
rows = math.ceil(n / cols)
|
85 |
+
fig = plt.figure(figsize=(col_size * cols, row_size * rows))
|
86 |
if title:
|
87 |
plt.title(title)
|
88 |
plt.axis("off")
|
89 |
|
90 |
for i in range(n):
|
91 |
+
fig.add_subplot(rows, cols, i + 1)
|
92 |
+
|
93 |
+
img = cv2.imread(image_paths[i])[:, :, ::-1]
|
94 |
+
img = cv2.resize(img, [image_size, image_size])
|
95 |
|
96 |
if show_BW:
|
97 |
+
BW = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
|
98 |
+
BW = np.tile(BW, (1, 1, 3))
|
99 |
+
img = np.concatenate([BW, img], axis=1)
|
100 |
plt.imshow(img.astype("uint8"))
|
101 |
+
|
102 |
+
if save:
|
103 |
+
os.makedirs(constants.ARTIFACT_DATASET_VISUALIZATION_DIR, exist_ok=True)
|
104 |
+
plt.savefig(
|
105 |
+
os.path.join(
|
106 |
+
constants.ARTIFACT_DATASET_VISUALIZATION_DIR, f"{label}_image.png"
|
107 |
+
)
|
108 |
+
)
|
109 |
plt.show()
|
110 |
|
111 |
|
112 |
+
def see_batch(
|
113 |
+
L_batch,
|
114 |
+
AB_batch,
|
115 |
+
show_L=False,
|
116 |
+
cols=4,
|
117 |
+
row_size=5,
|
118 |
+
col_size=5,
|
119 |
+
title=None,
|
120 |
+
save=False,
|
121 |
+
label="",
|
122 |
+
):
|
123 |
n = L_batch.shape[0]
|
124 |
+
rows = math.ceil(n / cols)
|
125 |
+
fig = plt.figure(figsize=(col_size * cols, row_size * rows))
|
126 |
if title:
|
127 |
plt.title(title)
|
128 |
plt.axis("off")
|
129 |
+
|
130 |
for i in range(n):
|
131 |
+
fig.add_subplot(rows, cols, i + 1)
|
132 |
+
L, AB = L_batch[i], AB_batch[i]
|
133 |
+
L, AB = rescale_L(L), rescale_AB(AB)
|
134 |
+
# print(L.shape,AB.shape)
|
135 |
+
img = np.concatenate([L, AB], axis=-1)
|
136 |
+
img = cv2.cvtColor(img, cv2.COLOR_LAB2RGB) * 255
|
137 |
+
# print(img.min(),img.max())
|
138 |
if show_L:
|
139 |
+
L = np.tile(L, (1, 1, 3)) / 100 * 255
|
140 |
+
img = np.concatenate([L, img], axis=1)
|
141 |
plt.imshow(img.astype("uint8"))
|
142 |
+
if save:
|
143 |
+
os.makedirs(constants.ARTIFACT_RESULT_VISUALIZATION_DIR, exist_ok=True)
|
144 |
+
plt.savefig(
|
145 |
+
os.path.join(
|
146 |
+
constants.ARTIFACT_RESULT_VISUALIZATION_DIR, f"{label}_image.png"
|
147 |
+
)
|
148 |
+
)
|
149 |
plt.show()
|