Anuj-Panthri commited on
Commit
edb1d95
·
1 Parent(s): ffaa6bd

made some improvement

Browse files
.gitignore CHANGED
@@ -88,4 +88,5 @@ target/
88
  # Mypy cache
89
  .mypy_cache/
90
 
91
- /models
 
 
88
  # Mypy cache
89
  .mypy_cache/
90
 
91
+ /models
92
+ /artifacts
configs/experiment1.yaml CHANGED
@@ -12,5 +12,5 @@ shuffle: False
12
 
13
  # training related
14
  batch_size: 16
15
- epochs: 15
16
- # epochs: 02
 
12
 
13
  # training related
14
  batch_size: 16
15
+ # epochs: 15
16
+ epochs: 02
constants.yaml CHANGED
@@ -1,3 +1,8 @@
1
  RAW_DATASET_DIR: data/raw/
2
  INTERIM_DATASET_DIR: data/interim/
3
- PROCESSED_DATASET_DIR: data/processed/
 
 
 
 
 
 
1
  RAW_DATASET_DIR: data/raw/
2
  INTERIM_DATASET_DIR: data/interim/
3
+ PROCESSED_DATASET_DIR: data/processed/
4
+
5
+
6
+ ARTIFACT_MODEL_DIR: artifacts/model/
7
+ ARTIFACT_DATASET_VISUALIZATION_DIR: artifacts/dataset/
8
+ ARTIFACT_RESULT_VISUALIZATION_DIR: artifacts/result/
models/.gitkeep DELETED
File without changes
src/scripts/train.py CHANGED
@@ -1,7 +1,7 @@
1
  import os,shutil
2
  import argparse
3
  from comet_ml import Experiment
4
- from src.utils.config_loader import Config
5
  from src.utils import config_loader
6
  from src.utils.data_utils import print_title
7
  from src.utils.script_utils import validate_config
@@ -24,42 +24,51 @@ def train(args):
24
  Model = importlib.import_module(f"src.{config.task}.model.models.{config.model}").Model
25
 
26
 
27
- model_dir = os.path.join("models",config.task,config.model)
28
  os.makedirs(model_dir,exist_ok=True)
29
  model_save_path = os.path.join(model_dir,"model.weights.h5")
30
 
31
  # save config to exported model folder
32
  shutil.copy(config_file_path,model_dir)
 
 
33
 
34
- experiment = Experiment(
35
- api_key=os.environ["COMET_API_KEY"],
36
- project_name="image-colorization",
37
- workspace="anujpanthri",
38
- auto_histogram_activation_logging=True,
39
- auto_histogram_epoch_rate=True,
40
- auto_histogram_gradient_logging=True,
41
- auto_histogram_weight_logging=True,
42
- auto_param_logging=True,
43
- )
 
 
44
 
45
  model = Model(experiment=experiment)
 
 
46
  model.train()
47
  model.save(model_save_path)
48
 
49
  # log model to comet
50
  if "LOCAL_SYSTEM" not in os.environ:
51
- experiment.log_model(f"{config.task}_{config.dataset}_{config.model}",model_dir)
52
-
 
53
  # evaluate model
54
  print_title("\nEvaluating Model")
55
  metrics = model.evaluate()
56
  print("Model Evaluation Metrics:",metrics)
57
 
58
- experiment.end()
 
59
 
60
  def main():
61
  parser = argparse.ArgumentParser(description="train model based on config yaml file")
62
  parser.add_argument("config_file",type=str)
 
63
  args = parser.parse_args()
64
  train(args)
65
 
 
1
  import os,shutil
2
  import argparse
3
  from comet_ml import Experiment
4
+ from src.utils.config_loader import Config,constants
5
  from src.utils import config_loader
6
  from src.utils.data_utils import print_title
7
  from src.utils.script_utils import validate_config
 
24
  Model = importlib.import_module(f"src.{config.task}.model.models.{config.model}").Model
25
 
26
 
27
+ model_dir = constants.ARTIFACT_MODEL_DIR
28
  os.makedirs(model_dir,exist_ok=True)
29
  model_save_path = os.path.join(model_dir,"model.weights.h5")
30
 
31
  # save config to exported model folder
32
  shutil.copy(config_file_path,model_dir)
33
+ # rename it to config.yaml
34
+ shutil.move(os.path.join(model_dir,Path(config_file_path).name),os.path.join(model_dir,"config.yaml"))
35
 
36
+ experiment = None
37
+ if args.log:
38
+ experiment = Experiment(
39
+ api_key=os.environ["COMET_API_KEY"],
40
+ project_name="image-colorization",
41
+ workspace="anujpanthri",
42
+ auto_histogram_activation_logging=True,
43
+ auto_histogram_epoch_rate=True,
44
+ auto_histogram_gradient_logging=True,
45
+ auto_histogram_weight_logging=True,
46
+ auto_param_logging=True,
47
+ )
48
 
49
  model = Model(experiment=experiment)
50
+
51
+ print_title("\nTraining Model")
52
  model.train()
53
  model.save(model_save_path)
54
 
55
  # log model to comet
56
  if "LOCAL_SYSTEM" not in os.environ:
57
+ if experiment:
58
+ experiment.log_model(f"{config.task}_{config.dataset}_{config.model}",model_dir)
59
+
60
  # evaluate model
61
  print_title("\nEvaluating Model")
62
  metrics = model.evaluate()
63
  print("Model Evaluation Metrics:",metrics)
64
 
65
+ if experiment:
66
+ experiment.end()
67
 
68
  def main():
69
  parser = argparse.ArgumentParser(description="train model based on config yaml file")
70
  parser.add_argument("config_file",type=str)
71
+ parser.add_argument("--log",action="store_true",default=False)
72
  args = parser.parse_args()
73
  train(args)
74
 
src/scripts/visualize_results.py CHANGED
@@ -1,12 +1,12 @@
1
  import os
2
  import argparse
3
- from src.utils.config_loader import Config
4
  from src.utils import config_loader
5
  from src.utils.script_utils import validate_config
6
  import importlib
7
 
8
 
9
- def visualize_dataset(args):
10
  config_file_path = args.config_file
11
  config = Config(config_file_path)
12
 
@@ -17,7 +17,7 @@ def visualize_dataset(args):
17
  config_loader.config = config
18
 
19
  # now load model and visualize the results
20
- model_dir = os.path.join("models",config.task,config.model)
21
  model_save_path = os.path.join(model_dir,"model.weights.h5")
22
 
23
  if not os.path.exists(model_save_path):
@@ -32,10 +32,10 @@ def visualize_dataset(args):
32
 
33
 
34
  def main():
35
- parser = argparse.ArgumentParser(description="Prepare dataset based on config yaml file")
36
  parser.add_argument("config_file",type=str)
37
  args = parser.parse_args()
38
- visualize_dataset(args)
39
 
40
  if __name__=="__main__":
41
  main()
 
1
  import os
2
  import argparse
3
+ from src.utils.config_loader import Config,constants
4
  from src.utils import config_loader
5
  from src.utils.script_utils import validate_config
6
  import importlib
7
 
8
 
9
+ def visualize_results(args):
10
  config_file_path = args.config_file
11
  config = Config(config_file_path)
12
 
 
17
  config_loader.config = config
18
 
19
  # now load model and visualize the results
20
+ model_dir = constants.ARTIFACT_MODEL_DIR
21
  model_save_path = os.path.join(model_dir,"model.weights.h5")
22
 
23
  if not os.path.exists(model_save_path):
 
32
 
33
 
34
  def main():
35
+ parser = argparse.ArgumentParser(description="visualize results based on config yaml file and trained model")
36
  parser.add_argument("config_file",type=str)
37
  args = parser.parse_args()
38
+ visualize_results(args)
39
 
40
  if __name__=="__main__":
41
  main()
src/simple_regression_colorization/data/visualize_dataset.py CHANGED
@@ -12,10 +12,16 @@ def visualize():
12
  choosen_paths = np.random.choice(image_paths,n)
13
  show_images_from_paths(choosen_paths,
14
  title="sample of train_val dataset",
15
- image_size=config.image_size)
 
 
 
16
 
17
  image_paths = glob(f"{constants.PROCESSED_DATASET_DIR}/test/*")
18
  choosen_paths = np.random.choice(image_paths,n)
19
  show_images_from_paths(choosen_paths,
20
  title="sample of test dataset",
21
- image_size=config.image_size)
 
 
 
 
12
  choosen_paths = np.random.choice(image_paths,n)
13
  show_images_from_paths(choosen_paths,
14
  title="sample of train_val dataset",
15
+ image_size=config.image_size,
16
+ save=True,
17
+ label="trainval",
18
+ )
19
 
20
  image_paths = glob(f"{constants.PROCESSED_DATASET_DIR}/test/*")
21
  choosen_paths = np.random.choice(image_paths,n)
22
  show_images_from_paths(choosen_paths,
23
  title="sample of test dataset",
24
+ image_size=config.image_size,
25
+ save=True,
26
+ label="test",
27
+ )
src/simple_regression_colorization/model/base_model_interface.py CHANGED
@@ -68,17 +68,32 @@ class BaseModel(ABC):
68
  L_batch,AB_batch = next(iter(self.train_ds))
69
  L_batch = L_batch.numpy()
70
  AB_pred = self.model.predict(L_batch,verbose=0)
71
- see_batch(L_batch,AB_pred,title="Train dataset Results")
 
 
 
 
 
72
 
73
  L_batch,AB_batch = next(iter(self.val_ds))
74
  L_batch = L_batch.numpy()
75
  AB_pred = self.model.predict(L_batch,verbose=0)
76
- see_batch(L_batch,AB_pred,title="Val dataset Results")
 
 
 
 
 
77
 
78
  L_batch,AB_batch = next(iter(self.test_ds))
79
  L_batch = L_batch.numpy()
80
  AB_pred = self.model.predict(L_batch,verbose=0)
81
- see_batch(L_batch,AB_pred,title="Test dataset Results")
 
 
 
 
 
82
 
83
 
84
  @abstractmethod
 
68
  L_batch,AB_batch = next(iter(self.train_ds))
69
  L_batch = L_batch.numpy()
70
  AB_pred = self.model.predict(L_batch,verbose=0)
71
+ see_batch(L_batch,
72
+ AB_pred,
73
+ title="Train dataset Results",
74
+ save = True,
75
+ label = "train",
76
+ )
77
 
78
  L_batch,AB_batch = next(iter(self.val_ds))
79
  L_batch = L_batch.numpy()
80
  AB_pred = self.model.predict(L_batch,verbose=0)
81
+ see_batch(L_batch,
82
+ AB_pred,
83
+ title="Val dataset Results",
84
+ save = True,
85
+ label = "val",
86
+ )
87
 
88
  L_batch,AB_batch = next(iter(self.test_ds))
89
  L_batch = L_batch.numpy()
90
  AB_pred = self.model.predict(L_batch,verbose=0)
91
+ see_batch(L_batch,
92
+ AB_pred,
93
+ title="Test dataset Results",
94
+ save = True,
95
+ label = "test",
96
+ )
97
 
98
 
99
  @abstractmethod
src/utils/data_utils.py CHANGED
@@ -2,101 +2,148 @@ from src.utils.config_loader import constants
2
  from huggingface_hub import snapshot_download
3
  from zipfile import ZipFile
4
  import numpy as np
5
- import os,shutil
6
  import matplotlib.pyplot as plt
7
  import cv2
8
  import math
9
 
10
 
11
- def download_hf_dataset(repo_id,allow_patterns=None):
12
  """Used to download dataset from any public hugging face dataset"""
13
- snapshot_download(repo_id=repo_id,
14
- repo_type="dataset",
15
- local_dir=constants.RAW_DATASET_DIR,
16
- allow_patterns=allow_patterns)
 
 
17
 
18
 
19
  def download_personal_hf_dataset(name):
20
  """Used to download dataset from a specific hugging face dataset"""
21
- download_hf_dataset(repo_id="Anuj-Panthri/Image-Colorization-Datasets",
22
- allow_patterns=f"{name}/*")
 
23
 
24
 
25
- def unzip_file(file_path,destination_dir):
26
  """unzips file to destination_dir"""
27
  if os.path.exists(destination_dir):
28
  shutil.rmtree(destination_dir)
29
  os.makedirs(destination_dir)
30
- with ZipFile(file_path,"r") as zip:
31
  zip.extractall(destination_dir)
32
 
33
- def is_bw(img:np.ndarray):
 
34
  """checks if RGB image is black and white"""
35
- rg,gb,rb = img[:,:,0]-img[:,:,1] , img[:,:,1]-img[:,:,2] , img[:,:,0]-img[:,:,2]
36
- rg,gb,rb = np.abs(rg).sum(),np.abs(gb).sum(),np.abs(rb).sum()
37
- avg = np.mean([rg,gb,rb])
38
-
39
- return avg<10
 
 
 
 
 
40
 
 
 
 
41
 
42
- def print_title(msg:str,max_chars=105):
43
- n = (max_chars-len(msg))//2
44
- print("="*n,msg.upper(),"="*n,sep="")
45
 
46
  def scale_L(L):
47
- return L/100
 
48
 
49
  def rescale_L(L):
50
- return L*100
 
51
 
52
  def scale_AB(AB):
53
- return AB/128
 
54
 
55
  def rescale_AB(AB):
56
- return AB*128
57
-
58
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
- def show_images_from_paths(image_paths:list[str],image_size=64,cols=4,row_size=5,col_size=5,show_BW=False,title=None):
61
  n = len(image_paths)
62
- rows = math.ceil(n/cols)
63
- fig = plt.figure(figsize=(col_size*cols,row_size*rows))
64
  if title:
65
  plt.title(title)
66
  plt.axis("off")
67
 
68
  for i in range(n):
69
- fig.add_subplot(rows,cols,i+1)
70
-
71
- img = cv2.imread(image_paths[i])[:,:,::-1]
72
- img = cv2.resize(img,[image_size,image_size])
73
 
74
  if show_BW:
75
- BW = cv2.cvtColor(img,cv2.COLOR_RGB2GRAY)
76
- BW = np.tile(BW,(1,1,3))
77
- img = np.concatenate([BW,img],axis=1)
78
  plt.imshow(img.astype("uint8"))
 
 
 
 
 
 
 
 
79
  plt.show()
80
 
81
 
82
- def see_batch(L_batch,AB_batch,show_L=False,cols=4,row_size=5,col_size=5,title=None):
 
 
 
 
 
 
 
 
 
 
83
  n = L_batch.shape[0]
84
- rows = math.ceil(n/cols)
85
- fig = plt.figure(figsize=(col_size*cols,row_size*rows))
86
  if title:
87
  plt.title(title)
88
  plt.axis("off")
89
-
90
  for i in range(n):
91
- fig.add_subplot(rows,cols,i+1)
92
- L,AB = L_batch[i],AB_batch[i]
93
- L,AB = rescale_L(L), rescale_AB(AB)
94
- # print(L.shape,AB.shape)
95
- img = np.concatenate([L,AB],axis=-1)
96
- img = cv2.cvtColor(img,cv2.COLOR_LAB2RGB)*255
97
- # print(img.min(),img.max())
98
  if show_L:
99
- L = np.tile(L,(1,1,3))/100*255
100
- img = np.concatenate([L,img],axis=1)
101
  plt.imshow(img.astype("uint8"))
 
 
 
 
 
 
 
102
  plt.show()
 
2
  from huggingface_hub import snapshot_download
3
  from zipfile import ZipFile
4
  import numpy as np
5
+ import os, shutil
6
  import matplotlib.pyplot as plt
7
  import cv2
8
  import math
9
 
10
 
11
+ def download_hf_dataset(repo_id, allow_patterns=None):
12
  """Used to download dataset from any public hugging face dataset"""
13
+ snapshot_download(
14
+ repo_id=repo_id,
15
+ repo_type="dataset",
16
+ local_dir=constants.RAW_DATASET_DIR,
17
+ allow_patterns=allow_patterns,
18
+ )
19
 
20
 
21
  def download_personal_hf_dataset(name):
22
  """Used to download dataset from a specific hugging face dataset"""
23
+ download_hf_dataset(
24
+ repo_id="Anuj-Panthri/Image-Colorization-Datasets", allow_patterns=f"{name}/*"
25
+ )
26
 
27
 
28
+ def unzip_file(file_path, destination_dir):
29
  """unzips file to destination_dir"""
30
  if os.path.exists(destination_dir):
31
  shutil.rmtree(destination_dir)
32
  os.makedirs(destination_dir)
33
+ with ZipFile(file_path, "r") as zip:
34
  zip.extractall(destination_dir)
35
 
36
+
37
+ def is_bw(img: np.ndarray):
38
  """checks if RGB image is black and white"""
39
+ rg, gb, rb = (
40
+ img[:, :, 0] - img[:, :, 1],
41
+ img[:, :, 1] - img[:, :, 2],
42
+ img[:, :, 0] - img[:, :, 2],
43
+ )
44
+ rg, gb, rb = np.abs(rg).sum(), np.abs(gb).sum(), np.abs(rb).sum()
45
+ avg = np.mean([rg, gb, rb])
46
+
47
+ return avg < 10
48
+
49
 
50
+ def print_title(msg: str, max_chars=105):
51
+ n = (max_chars - len(msg)) // 2
52
+ print("=" * n, msg.upper(), "=" * n, sep="")
53
 
 
 
 
54
 
55
  def scale_L(L):
56
+ return L / 100
57
+
58
 
59
  def rescale_L(L):
60
+ return L * 100
61
+
62
 
63
  def scale_AB(AB):
64
+ return AB / 128
65
+
66
 
67
  def rescale_AB(AB):
68
+ return AB * 128
69
+
70
 
71
+ def show_images_from_paths(
72
+ image_paths: list[str],
73
+ image_size=64,
74
+ cols=4,
75
+ row_size=5,
76
+ col_size=5,
77
+ show_BW=False,
78
+ title=None,
79
+ save=False,
80
+ label="",
81
+ ):
82
 
 
83
  n = len(image_paths)
84
+ rows = math.ceil(n / cols)
85
+ fig = plt.figure(figsize=(col_size * cols, row_size * rows))
86
  if title:
87
  plt.title(title)
88
  plt.axis("off")
89
 
90
  for i in range(n):
91
+ fig.add_subplot(rows, cols, i + 1)
92
+
93
+ img = cv2.imread(image_paths[i])[:, :, ::-1]
94
+ img = cv2.resize(img, [image_size, image_size])
95
 
96
  if show_BW:
97
+ BW = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
98
+ BW = np.tile(BW, (1, 1, 3))
99
+ img = np.concatenate([BW, img], axis=1)
100
  plt.imshow(img.astype("uint8"))
101
+
102
+ if save:
103
+ os.makedirs(constants.ARTIFACT_DATASET_VISUALIZATION_DIR, exist_ok=True)
104
+ plt.savefig(
105
+ os.path.join(
106
+ constants.ARTIFACT_DATASET_VISUALIZATION_DIR, f"{label}_image.png"
107
+ )
108
+ )
109
  plt.show()
110
 
111
 
112
+ def see_batch(
113
+ L_batch,
114
+ AB_batch,
115
+ show_L=False,
116
+ cols=4,
117
+ row_size=5,
118
+ col_size=5,
119
+ title=None,
120
+ save=False,
121
+ label="",
122
+ ):
123
  n = L_batch.shape[0]
124
+ rows = math.ceil(n / cols)
125
+ fig = plt.figure(figsize=(col_size * cols, row_size * rows))
126
  if title:
127
  plt.title(title)
128
  plt.axis("off")
129
+
130
  for i in range(n):
131
+ fig.add_subplot(rows, cols, i + 1)
132
+ L, AB = L_batch[i], AB_batch[i]
133
+ L, AB = rescale_L(L), rescale_AB(AB)
134
+ # print(L.shape,AB.shape)
135
+ img = np.concatenate([L, AB], axis=-1)
136
+ img = cv2.cvtColor(img, cv2.COLOR_LAB2RGB) * 255
137
+ # print(img.min(),img.max())
138
  if show_L:
139
+ L = np.tile(L, (1, 1, 3)) / 100 * 255
140
+ img = np.concatenate([L, img], axis=1)
141
  plt.imshow(img.astype("uint8"))
142
+ if save:
143
+ os.makedirs(constants.ARTIFACT_RESULT_VISUALIZATION_DIR, exist_ok=True)
144
+ plt.savefig(
145
+ os.path.join(
146
+ constants.ARTIFACT_RESULT_VISUALIZATION_DIR, f"{label}_image.png"
147
+ )
148
+ )
149
  plt.show()