Vibi007 commited on
Commit
fceb8da
·
1 Parent(s): 035761e

Updated inference

Browse files
.gitattributes ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
2
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
.gitignore CHANGED
@@ -30,7 +30,7 @@ env.bak/
30
  venv.bak/
31
 
32
  # Training artifacts
33
- checkpoints/
34
  runs/
35
  logs/
36
  *.ckpt
@@ -38,7 +38,7 @@ logs/
38
  *.pth
39
  wandb/
40
  lightning_logs/
41
- final_model/
42
 
43
  # IDE
44
  .idea/
 
30
  venv.bak/
31
 
32
  # Training artifacts
33
+ # checkpoints/
34
  runs/
35
  logs/
36
  *.ckpt
 
38
  *.pth
39
  wandb/
40
  lightning_logs/
41
+ # final_model/
42
 
43
  # IDE
44
  .idea/
README.md CHANGED
@@ -15,6 +15,8 @@ tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/cosmo2-tokenizer")
15
  ```
16
  use config from https://huggingface.co/HuggingFaceTB/SmolLM2-135M/blob/main/config_smollm2_135M.yaml
17
 
 
 
18
  create model from above parameters
19
 
20
  Use it for training using pytorch lightning
 
15
  ```
16
  use config from https://huggingface.co/HuggingFaceTB/SmolLM2-135M/blob/main/config_smollm2_135M.yaml
17
 
18
+ https://github.com/huggingface/smollm/blob/main/pre-training/smollm2/config_smollm2_135M.yaml
19
+
20
  create model from above parameters
21
 
22
  Use it for training using pytorch lightning
final_model/config.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "LlamaForCausalLM"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 0,
8
+ "eos_token_id": 0,
9
+ "head_dim": 64,
10
+ "hidden_act": "silu",
11
+ "hidden_size": 576,
12
+ "initializer_range": 0.041666666666666664,
13
+ "intermediate_size": 1536,
14
+ "max_position_embeddings": 2048,
15
+ "mlp_bias": false,
16
+ "model_type": "llama",
17
+ "num_attention_heads": 9,
18
+ "num_hidden_layers": 30,
19
+ "num_key_value_heads": 3,
20
+ "pretraining_tp": 1,
21
+ "rms_norm_eps": 1e-05,
22
+ "rope_scaling": null,
23
+ "rope_theta": 10000.0,
24
+ "tie_word_embeddings": false,
25
+ "torch_dtype": "float32",
26
+ "transformers_version": "4.47.0",
27
+ "use_cache": true,
28
+ "vocab_size": 49152
29
+ }
final_model/generation_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 0,
4
+ "eos_token_id": 0,
5
+ "transformers_version": "4.47.0"
6
+ }
final_model/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
final_model/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:af74ecbdc60bc04c52a8dc7b42a79eecadadcbca7c300cc21e56583ab1c1e0b4
3
+ size 651336704
final_model/special_tokens_map.json ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<|endoftext|>",
4
+ "<|im_start|>",
5
+ "<|im_end|>",
6
+ "<repo_name>",
7
+ "<reponame>",
8
+ "<file_sep>",
9
+ "<filename>",
10
+ "<gh_stars>",
11
+ "<issue_start>",
12
+ "<issue_comment>",
13
+ "<issue_closed>",
14
+ "<jupyter_start>",
15
+ "<jupyter_text>",
16
+ "<jupyter_code>",
17
+ "<jupyter_output>",
18
+ "<jupyter_script>",
19
+ "<empty_output>"
20
+ ],
21
+ "bos_token": {
22
+ "content": "<|endoftext|>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false
27
+ },
28
+ "eos_token": {
29
+ "content": "<|endoftext|>",
30
+ "lstrip": false,
31
+ "normalized": false,
32
+ "rstrip": false,
33
+ "single_word": false
34
+ },
35
+ "pad_token": "<|endoftext|>",
36
+ "unk_token": {
37
+ "content": "<|endoftext|>",
38
+ "lstrip": false,
39
+ "normalized": false,
40
+ "rstrip": false,
41
+ "single_word": false
42
+ }
43
+ }
final_model/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
final_model/tokenizer_config.json ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "0": {
5
+ "content": "<|endoftext|>",
6
+ "lstrip": false,
7
+ "normalized": false,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "1": {
13
+ "content": "<|im_start|>",
14
+ "lstrip": false,
15
+ "normalized": false,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": true
19
+ },
20
+ "2": {
21
+ "content": "<|im_end|>",
22
+ "lstrip": false,
23
+ "normalized": false,
24
+ "rstrip": false,
25
+ "single_word": false,
26
+ "special": true
27
+ },
28
+ "3": {
29
+ "content": "<repo_name>",
30
+ "lstrip": false,
31
+ "normalized": false,
32
+ "rstrip": false,
33
+ "single_word": false,
34
+ "special": true
35
+ },
36
+ "4": {
37
+ "content": "<reponame>",
38
+ "lstrip": false,
39
+ "normalized": false,
40
+ "rstrip": false,
41
+ "single_word": false,
42
+ "special": true
43
+ },
44
+ "5": {
45
+ "content": "<file_sep>",
46
+ "lstrip": false,
47
+ "normalized": false,
48
+ "rstrip": false,
49
+ "single_word": false,
50
+ "special": true
51
+ },
52
+ "6": {
53
+ "content": "<filename>",
54
+ "lstrip": false,
55
+ "normalized": false,
56
+ "rstrip": false,
57
+ "single_word": false,
58
+ "special": true
59
+ },
60
+ "7": {
61
+ "content": "<gh_stars>",
62
+ "lstrip": false,
63
+ "normalized": false,
64
+ "rstrip": false,
65
+ "single_word": false,
66
+ "special": true
67
+ },
68
+ "8": {
69
+ "content": "<issue_start>",
70
+ "lstrip": false,
71
+ "normalized": false,
72
+ "rstrip": false,
73
+ "single_word": false,
74
+ "special": true
75
+ },
76
+ "9": {
77
+ "content": "<issue_comment>",
78
+ "lstrip": false,
79
+ "normalized": false,
80
+ "rstrip": false,
81
+ "single_word": false,
82
+ "special": true
83
+ },
84
+ "10": {
85
+ "content": "<issue_closed>",
86
+ "lstrip": false,
87
+ "normalized": false,
88
+ "rstrip": false,
89
+ "single_word": false,
90
+ "special": true
91
+ },
92
+ "11": {
93
+ "content": "<jupyter_start>",
94
+ "lstrip": false,
95
+ "normalized": false,
96
+ "rstrip": false,
97
+ "single_word": false,
98
+ "special": true
99
+ },
100
+ "12": {
101
+ "content": "<jupyter_text>",
102
+ "lstrip": false,
103
+ "normalized": false,
104
+ "rstrip": false,
105
+ "single_word": false,
106
+ "special": true
107
+ },
108
+ "13": {
109
+ "content": "<jupyter_code>",
110
+ "lstrip": false,
111
+ "normalized": false,
112
+ "rstrip": false,
113
+ "single_word": false,
114
+ "special": true
115
+ },
116
+ "14": {
117
+ "content": "<jupyter_output>",
118
+ "lstrip": false,
119
+ "normalized": false,
120
+ "rstrip": false,
121
+ "single_word": false,
122
+ "special": true
123
+ },
124
+ "15": {
125
+ "content": "<jupyter_script>",
126
+ "lstrip": false,
127
+ "normalized": false,
128
+ "rstrip": false,
129
+ "single_word": false,
130
+ "special": true
131
+ },
132
+ "16": {
133
+ "content": "<empty_output>",
134
+ "lstrip": false,
135
+ "normalized": false,
136
+ "rstrip": false,
137
+ "single_word": false,
138
+ "special": true
139
+ }
140
+ },
141
+ "additional_special_tokens": [
142
+ "<|endoftext|>",
143
+ "<|im_start|>",
144
+ "<|im_end|>",
145
+ "<repo_name>",
146
+ "<reponame>",
147
+ "<file_sep>",
148
+ "<filename>",
149
+ "<gh_stars>",
150
+ "<issue_start>",
151
+ "<issue_comment>",
152
+ "<issue_closed>",
153
+ "<jupyter_start>",
154
+ "<jupyter_text>",
155
+ "<jupyter_code>",
156
+ "<jupyter_output>",
157
+ "<jupyter_script>",
158
+ "<empty_output>"
159
+ ],
160
+ "bos_token": "<|endoftext|>",
161
+ "chat_template": "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
162
+ "clean_up_tokenization_spaces": false,
163
+ "eos_token": "<|endoftext|>",
164
+ "extra_special_tokens": {},
165
+ "model_max_length": 1000000000000000019884624838656,
166
+ "pad_token": "<|endoftext|>",
167
+ "tokenizer_class": "GPT2Tokenizer",
168
+ "unk_token": "<|endoftext|>",
169
+ "vocab_size": 49152
170
+ }
final_model/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
inference.py CHANGED
@@ -1,8 +1,8 @@
1
  import os
2
  import gradio as gr
3
  import torch
4
- from model import SmolLMModule, create_model_config
5
- from transformers import AutoTokenizer
6
  import yaml
7
  import glob
8
 
@@ -15,71 +15,94 @@ tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/cosmo2-tokenizer")
15
  tokenizer.pad_token = tokenizer.eos_token
16
 
17
 
18
- def load_model_from_checkpoint(checkpoint_path):
19
- """Load model from checkpoint"""
20
- model = SmolLMModule.load_from_checkpoint(checkpoint_path, config=config)
21
- model.eval() # Set to evaluation mode
22
- return model
23
-
24
-
25
  def get_available_checkpoints():
26
- """Get list of available checkpoints sorted by step number"""
27
- checkpoints = glob.glob("checkpoints/*.ckpt")
28
- if not checkpoints:
29
- return [], []
30
 
31
- # Sort by step number
32
- def get_step_number(filepath):
 
33
  try:
34
  # Extract step number from the filename
35
- filename = os.path.basename(filepath)
36
- # Remove .ckpt extension
37
- filename = filename.replace(".ckpt", "")
38
- # Get the step number
39
- if "step=" in filename:
40
- return int(filename.split("step=")[1])
41
- elif "-step-" in filename:
42
- return int(filename.split("-step-")[1])
43
- else:
44
- return int("".join(filter(str.isdigit, filename)))
45
- except (ValueError, IndexError):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  return 0
47
 
48
- # Sort checkpoints by step number
49
- checkpoints.sort(key=get_step_number)
 
 
 
 
 
 
 
 
 
 
50
 
51
- # Create display names
52
- display_names = [f"Step {get_step_number(x)}" for x in checkpoints]
53
- return display_names, checkpoints
54
 
55
 
56
- def generate_text(
57
- prompt, checkpoint_choice, max_length=100, temperature=0.7, top_p=0.9
58
- ):
59
- """Generate text based on prompt using selected checkpoint"""
60
- # Check if checkpoint is selected
61
- if not checkpoint_choice:
62
- return "Please select a checkpoint first!"
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
  if not prompt:
65
  return "Please enter a prompt!"
66
 
67
  try:
68
- # Get actual checkpoint path
69
- step_num = int("".join(filter(str.isdigit, checkpoint_choice)))
70
- checkpoints = glob.glob("checkpoints/*.ckpt")
71
- checkpoint_path = None
72
-
73
- for ckpt in checkpoints:
74
- if str(step_num) in ckpt:
75
- checkpoint_path = ckpt
76
- break
77
 
78
- if not checkpoint_path or not os.path.exists(checkpoint_path):
79
- return f"Checkpoint for step {step_num} not found!"
80
 
81
- # Load model from checkpoint
82
- model = load_model_from_checkpoint(checkpoint_path)
83
 
84
  # Move model to GPU if available
85
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -108,7 +131,7 @@ def generate_text(
108
  return f"Error during generation: {str(e)}"
109
 
110
 
111
- # Get available checkpoints
112
  display_names, _ = get_available_checkpoints()
113
 
114
  # Create Gradio interface
@@ -116,17 +139,18 @@ with gr.Blocks(title="SmolLM2 Inference") as demo:
116
  gr.Markdown("# SmolLM2 Text Generation")
117
 
118
  if not display_names:
119
- gr.Markdown("⚠️ No checkpoints found! Please train the model first.")
120
  else:
121
  gr.Markdown(
122
- f"Found {len(display_names)} checkpoints. Select one and enter a prompt to generate text."
123
  )
 
124
 
125
  with gr.Row():
126
  with gr.Column():
127
- checkpoint_dropdown = gr.Dropdown(
128
  choices=display_names,
129
- label="Select Checkpoint",
130
  value=display_names[-1] if display_names else None,
131
  interactive=True,
132
  )
@@ -149,7 +173,7 @@ with gr.Blocks(title="SmolLM2 Inference") as demo:
149
 
150
  generate_btn.click(
151
  fn=generate_text,
152
- inputs=[prompt, checkpoint_dropdown, max_length, temperature, top_p],
153
  outputs=output,
154
  )
155
 
 
1
  import os
2
  import gradio as gr
3
  import torch
4
+ from model import SmolLMModule
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM
6
  import yaml
7
  import glob
8
 
 
15
  tokenizer.pad_token = tokenizer.eos_token
16
 
17
 
 
 
 
 
 
 
 
18
  def get_available_checkpoints():
19
+ """Get list of available checkpoints and final model"""
20
+ models = []
21
+ model_paths = {}
 
22
 
23
+ # Get checkpoints
24
+ checkpoints = glob.glob("checkpoints/*.ckpt")
25
+ for ckpt in checkpoints:
26
  try:
27
  # Extract step number from the filename
28
+ filename = os.path.basename(ckpt)
29
+ # Handle the format 'model-step=step=X.ckpt'
30
+ if "step=step=" in filename:
31
+ step = int(filename.split("step=step=")[1].split(".")[0])
32
+ display_name = f"Checkpoint Step {step}"
33
+ models.append(display_name)
34
+ model_paths[display_name] = ckpt
35
+ except (ValueError, IndexError) as e:
36
+ print(
37
+ f"Warning: Could not parse checkpoint filename: {filename}, Error: {e}"
38
+ )
39
+ continue
40
+
41
+ # Add final model if it exists
42
+ final_model_path = "final_model"
43
+ if os.path.exists(final_model_path):
44
+ display_name = "Final Model"
45
+ models.append(display_name)
46
+ model_paths[display_name] = final_model_path
47
+
48
+ # Sort checkpoints by step number (Final model will be at the end)
49
+ def get_step_number(name):
50
+ if name == "Final Model":
51
+ return float("inf")
52
+ try:
53
+ return int(name.split("Step ")[-1])
54
+ except:
55
  return 0
56
 
57
+ models.sort(key=get_step_number)
58
+
59
+ if not models:
60
+ print(
61
+ "Warning: No checkpoints or final model found in the following locations:"
62
+ )
63
+ print("- Checkpoints directory:", os.path.abspath("checkpoints"))
64
+ print("- Final model directory:", os.path.abspath("final_model"))
65
+ else:
66
+ print(f"Found {len(models)} models:")
67
+ for model in models:
68
+ print(f"- {model}: {model_paths[model]}")
69
 
70
+ return models, model_paths
 
 
71
 
72
 
73
+ def load_model_from_checkpoint(model_path):
74
+ """Load model from checkpoint or final model directory"""
75
+ if model_path == "final_model":
76
+ # Load the final saved model
77
+ model = SmolLMModule(config)
78
+ model.model = AutoModelForCausalLM.from_pretrained(model_path)
79
+ else:
80
+ # Load from checkpoint
81
+ model = SmolLMModule.load_from_checkpoint(model_path, config=config)
82
+
83
+ model.eval() # Set to evaluation mode
84
+ return model
85
+
86
+
87
+ def generate_text(prompt, model_choice, max_length=100, temperature=0.7, top_p=0.9):
88
+ """Generate text based on prompt using selected model"""
89
+ # Check if model is selected
90
+ if not model_choice:
91
+ return "Please select a model checkpoint!"
92
 
93
  if not prompt:
94
  return "Please enter a prompt!"
95
 
96
  try:
97
+ # Get model path from the mapping
98
+ _, model_paths = get_available_checkpoints()
99
+ model_path = model_paths.get(model_choice)
 
 
 
 
 
 
100
 
101
+ if not model_path or not os.path.exists(model_path):
102
+ return f"Model {model_choice} not found!"
103
 
104
+ # Load model
105
+ model = load_model_from_checkpoint(model_path)
106
 
107
  # Move model to GPU if available
108
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
131
  return f"Error during generation: {str(e)}"
132
 
133
 
134
+ # Get available models
135
  display_names, _ = get_available_checkpoints()
136
 
137
  # Create Gradio interface
 
139
  gr.Markdown("# SmolLM2 Text Generation")
140
 
141
  if not display_names:
142
+ gr.Markdown("⚠️ No models found! Please train the model first.")
143
  else:
144
  gr.Markdown(
145
+ f"Found {len(display_names)} models/checkpoints. Select one and enter a prompt to generate text."
146
  )
147
+ gr.Markdown("Available models: " + ", ".join(display_names))
148
 
149
  with gr.Row():
150
  with gr.Column():
151
+ model_dropdown = gr.Dropdown(
152
  choices=display_names,
153
+ label="Select Model",
154
  value=display_names[-1] if display_names else None,
155
  interactive=True,
156
  )
 
173
 
174
  generate_btn.click(
175
  fn=generate_text,
176
+ inputs=[prompt, model_dropdown, max_length, temperature, top_p],
177
  outputs=output,
178
  )
179
 
model.py CHANGED
@@ -1,203 +1,228 @@
1
- # import libraries
 
 
2
  from datasets import load_dataset
3
  from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaConfig
4
- from transformers import Trainer
5
- import pytorch_lightning as pl
6
- import yaml
7
- from pytorch_lightning.callbacks import LearningRateMonitor
8
- from pytorch_lightning.callbacks import RichProgressBar
 
 
9
  from pytorch_lightning.loggers import TensorBoardLogger
10
- import torch
11
- from torch.utils.data import DataLoader
12
-
13
- # load dataset
14
- dataset = load_dataset("HuggingFaceTB/smollm-corpus", "cosmopedia-v2", streaming=True)
15
- train_dataset = dataset["train"]
16
- for sample in train_dataset:
17
- print(sample)
18
- break
19
- # load tokenizer
20
- # use tokeniser from https://huggingface.co/HuggingFaceTB/cosmo2-tokenizer
21
- tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/cosmo2-tokenizer")
22
- # Set padding token to be the same as EOS token
23
- tokenizer.pad_token = tokenizer.eos_token
24
-
25
- # load config
26
- # use config from https://huggingface.co/HuggingFaceTB/SmolLM2-135M/blob/main/config_smollm2_135M.yaml
27
- # config = AutoConfig.from_pretrained("HuggingFaceTB/SmolLM2-135M")
28
-
29
-
30
- def collate_fn(examples):
31
- # Tokenize the texts
32
- encoding = tokenizer(
33
- [example["text"] for example in examples],
34
- padding=True,
35
- truncation=True,
36
- max_length=512,
37
- return_tensors="pt",
38
- )
39
 
40
- # Create labels (same as input_ids for causal language modeling)
41
- encoding["labels"] = encoding["input_ids"].clone()
42
-
43
- return encoding
44
-
45
-
46
- def create_model_config(config):
47
- model_config = config["model"]["model_config"]
48
- return LlamaConfig(
49
- vocab_size=49152, # From the model architecture
50
- hidden_size=model_config["hidden_size"],
51
- intermediate_size=model_config["intermediate_size"],
52
- num_hidden_layers=model_config["num_hidden_layers"],
53
- num_attention_heads=model_config["num_attention_heads"],
54
- num_key_value_heads=model_config["num_key_value_heads"],
55
- hidden_act=model_config["hidden_act"],
56
- max_position_embeddings=model_config["max_position_embeddings"],
57
- initializer_range=model_config["initializer_range"],
58
- rms_norm_eps=1e-5, # From the model architecture
59
- use_cache=True,
60
- pad_token_id=model_config["pad_token_id"],
61
- bos_token_id=model_config["bos_token_id"],
62
- eos_token_id=model_config["eos_token_id"],
63
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
 
66
- # create model
67
- class SmolLMModule(pl.LightningModule):
68
  def __init__(self, config, learning_rate=1e-4):
69
  super().__init__()
70
  self.config = config
71
  self.learning_rate = learning_rate
72
- self.save_hyperparameters() # Save hyperparameters for resuming
73
-
74
- # Create model from config
75
- model_config = create_model_config(config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  self.model = AutoModelForCausalLM.from_config(model_config)
77
 
78
- def forward(self, **inputs):
79
- return self.model(**inputs)
80
-
81
  def training_step(self, batch, batch_idx):
82
- outputs = self.model(**batch)
83
  loss = outputs.loss
84
- self.log("train_loss", loss, prog_bar=True)
 
 
 
 
 
 
 
 
 
 
 
85
  return loss
86
 
87
  def configure_optimizers(self):
88
- optimizer = torch.optim.AdamW(
89
  self.model.parameters(),
90
  lr=self.learning_rate,
91
  betas=(0.9, 0.95),
92
  eps=1e-8,
93
- weight_decay=0.1,
94
  )
95
- return optimizer
96
-
97
- def on_save_checkpoint(self, checkpoint):
98
- # Save additional info if needed
99
- checkpoint["step"] = self.global_step
100
- checkpoint["model_config"] = self.config
101
-
102
- def on_load_checkpoint(self, checkpoint):
103
- # Restore additional info if needed
104
- self.global_step = checkpoint["step"]
105
- self.config = checkpoint["model_config"]
106
-
107
 
108
- # train model
109
 
110
- # save model
111
-
112
- # training script
113
  if __name__ == "__main__":
114
- import os
115
- from pytorch_lightning.callbacks import ModelCheckpoint
116
-
117
- # parameters load from config file
118
- with open("config_smollm2_135.yaml", "r") as file:
119
  config = yaml.safe_load(file)
120
- max_steps = 5000 # Total training steps
121
-
122
- # Create checkpoint directory if it doesn't exist
123
- checkpoint_dir = "checkpoints"
124
- os.makedirs(checkpoint_dir, exist_ok=True)
125
-
126
- # Checkpoint callback
127
- checkpoint_callback = ModelCheckpoint(
128
- dirpath=checkpoint_dir,
129
- filename="model-step={step}",
130
- save_top_k=-1, # Save all checkpoints
131
- every_n_train_steps=500, # Save every 500 steps
132
- save_weights_only=False, # Save the full model state
133
- )
134
 
135
- # load tokenizer
136
  tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/cosmo2-tokenizer")
137
- # Set padding token to be the same as EOS token
138
  tokenizer.pad_token = tokenizer.eos_token
139
 
140
- # load dataset
141
  dataset = load_dataset(
142
  "HuggingFaceTB/smollm-corpus", "cosmopedia-v2", streaming=True
143
  )
144
  train_dataset = dataset["train"]
145
 
146
  # Create DataLoader
 
147
  train_loader = DataLoader(
148
- train_dataset,
149
- batch_size=4, # Small batch size for testing
 
150
  collate_fn=collate_fn,
151
- num_workers=2,
152
  )
153
 
154
- # create model
155
- model = SmolLMModule(config, learning_rate=1e-4)
 
 
 
156
 
157
- # progress bar
158
- progress_bar = RichProgressBar(leave=False, refresh_rate=1, console_kwargs=None)
159
 
160
- # Find latest checkpoint if exists
161
- latest_checkpoint = None
162
- if os.path.exists(checkpoint_dir):
163
- checkpoints = [f for f in os.listdir(checkpoint_dir) if f.endswith(".ckpt")]
164
- if checkpoints:
165
- # Sort by step number and get the latest
166
- latest_checkpoint = os.path.join(
167
- checkpoint_dir,
168
- sorted(checkpoints, key=lambda x: int(x.split("-")[1].split(".")[0]))[
169
- -1
170
- ],
171
- )
172
- print(f"Resuming from checkpoint: {latest_checkpoint}")
173
 
174
- # create trainer
175
- trainer = pl.Trainer(
176
- max_steps=max_steps,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
  accelerator="gpu",
178
- devices=1,
179
- precision="bf16-mixed",
 
 
180
  callbacks=[
181
  LearningRateMonitor(logging_interval="step"),
182
  progress_bar,
183
  checkpoint_callback,
184
  ],
185
- log_every_n_steps=1,
186
  enable_progress_bar=True,
187
  enable_model_summary=True,
 
188
  )
189
 
190
- # train model
191
- if latest_checkpoint:
192
- # Resume training from checkpoint if it exists
193
- trainer.fit(model, train_loader, ckpt_path=latest_checkpoint)
194
  else:
195
- # Start training from scratch
196
- trainer.fit(model, train_loader)
 
 
 
 
 
 
 
197
 
198
- # Save final model and tokenizer
199
- if trainer.is_global_zero: # Only save on main process
200
  output_dir = "final_model"
201
  os.makedirs(output_dir, exist_ok=True)
202
- model.model.save_pretrained(os.path.join(output_dir, "model"))
203
- tokenizer.save_pretrained(os.path.join(output_dir, "tokenizer"))
 
1
+ import os
2
+ import torch
3
+ import yaml
4
  from datasets import load_dataset
5
  from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaConfig
6
+ from torch.utils.data import DataLoader, IterableDataset
7
+ from pytorch_lightning import Trainer, LightningModule
8
+ from pytorch_lightning.callbacks import (
9
+ ModelCheckpoint,
10
+ LearningRateMonitor,
11
+ RichProgressBar,
12
+ )
13
  from pytorch_lightning.loggers import TensorBoardLogger
14
+ from torch.nn.utils.rnn import pad_sequence
15
+ from lightning.pytorch.callbacks.progress.rich_progress import RichProgressBarTheme
16
+
17
+ # Set environment variable for memory management
18
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
+
21
+ # Function to log GPU memory usage
22
+ def log_memory_usage(step):
23
+ if torch.cuda.is_available():
24
+ print(
25
+ f"Step {step}: "
26
+ f"Allocated = {torch.cuda.memory_allocated() / 1e9:.2f} GB, "
27
+ f"Reserved = {torch.cuda.memory_reserved() / 1e9:.2f} GB"
28
+ )
29
+
30
+
31
+ # Custom Collate Function
32
+ def collate_fn(batch):
33
+ input_ids = [item["input_ids"] for item in batch]
34
+ labels = [item["labels"] for item in batch]
35
+ input_ids = pad_sequence(
36
+ input_ids, batch_first=True, padding_value=tokenizer.pad_token_id
37
+ )
38
+ labels = pad_sequence(
39
+ labels, batch_first=True, padding_value=tokenizer.pad_token_id
 
 
 
40
  )
41
+ return {"input_ids": input_ids, "labels": labels}
42
+
43
+
44
+ # Streaming Dataset
45
+ class StreamingDataset(IterableDataset):
46
+ def __init__(self, dataset, tokenizer, max_length=2048):
47
+ self.dataset = dataset
48
+ self.tokenizer = tokenizer
49
+ self.max_length = max_length
50
+
51
+ def __iter__(self):
52
+ for example in iter(self.dataset):
53
+ tokenized = self.tokenizer(
54
+ example["text"],
55
+ truncation=True,
56
+ max_length=self.max_length,
57
+ return_overflowing_tokens=True,
58
+ return_tensors="pt",
59
+ )
60
+ for chunk in tokenized["input_ids"]:
61
+ yield {
62
+ "input_ids": chunk.squeeze(0),
63
+ "labels": chunk.squeeze(0),
64
+ }
65
 
66
 
67
+ # Lightning Module
68
+ class SmolLMModule(LightningModule):
69
  def __init__(self, config, learning_rate=1e-4):
70
  super().__init__()
71
  self.config = config
72
  self.learning_rate = learning_rate
73
+ self.save_hyperparameters()
74
+
75
+ model_config = LlamaConfig(
76
+ vocab_size=49152,
77
+ hidden_size=config["model"]["model_config"]["hidden_size"],
78
+ intermediate_size=config["model"]["model_config"]["intermediate_size"],
79
+ num_hidden_layers=config["model"]["model_config"]["num_hidden_layers"],
80
+ num_attention_heads=config["model"]["model_config"]["num_attention_heads"],
81
+ num_key_value_heads=config["model"]["model_config"]["num_key_value_heads"],
82
+ hidden_act=config["model"]["model_config"]["hidden_act"],
83
+ max_position_embeddings=config["model"]["model_config"][
84
+ "max_position_embeddings"
85
+ ],
86
+ initializer_range=config["model"]["model_config"]["initializer_range"],
87
+ rms_norm_eps=1e-5,
88
+ use_cache=True,
89
+ pad_token_id=config["model"]["model_config"]["pad_token_id"],
90
+ bos_token_id=config["model"]["model_config"]["bos_token_id"],
91
+ eos_token_id=config["model"]["model_config"]["eos_token_id"],
92
+ )
93
  self.model = AutoModelForCausalLM.from_config(model_config)
94
 
 
 
 
95
  def training_step(self, batch, batch_idx):
96
+ outputs = self.model(input_ids=batch["input_ids"], labels=batch["labels"])
97
  loss = outputs.loss
98
+ self.log(
99
+ "train_loss", loss, prog_bar=True, on_step=True, on_epoch=True
100
+ ) # Log loss
101
+
102
+ # Log memory usage
103
+ if batch_idx % 10 == 0:
104
+ log_memory_usage(batch_idx)
105
+
106
+ # Release intermediate tensors
107
+ del outputs
108
+ torch.cuda.empty_cache()
109
+
110
  return loss
111
 
112
  def configure_optimizers(self):
113
+ return torch.optim.AdamW(
114
  self.model.parameters(),
115
  lr=self.learning_rate,
116
  betas=(0.9, 0.95),
117
  eps=1e-8,
118
+ weight_decay=self.config["optimizer"]["weight_decay"],
119
  )
 
 
 
 
 
 
 
 
 
 
 
 
120
 
 
121
 
122
+ # Main Script
 
 
123
  if __name__ == "__main__":
124
+ # Load config
125
+ with open("/kaggle/input/yaml-file/config_smollm2_135.yaml", "r") as file:
 
 
 
126
  config = yaml.safe_load(file)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
+ # Load tokenizer
129
  tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/cosmo2-tokenizer")
 
130
  tokenizer.pad_token = tokenizer.eos_token
131
 
132
+ # Load dataset
133
  dataset = load_dataset(
134
  "HuggingFaceTB/smollm-corpus", "cosmopedia-v2", streaming=True
135
  )
136
  train_dataset = dataset["train"]
137
 
138
  # Create DataLoader
139
+ streaming_dataset = StreamingDataset(train_dataset, tokenizer, max_length=2048)
140
  train_loader = DataLoader(
141
+ streaming_dataset,
142
+ batch_size=1, # Reduced batch size
143
+ num_workers=4,
144
  collate_fn=collate_fn,
145
+ pin_memory=True,
146
  )
147
 
148
+ # Create model
149
+ model = SmolLMModule(
150
+ config,
151
+ learning_rate=config["optimizer"]["learning_rate_scheduler"]["learning_rate"],
152
+ )
153
 
154
+ # Initialize logger with version based on start_step
155
+ logger = TensorBoardLogger("logs", name="smollm2")
156
 
157
+ # Checkpoint callback configuration
158
+ checkpoint_callback = ModelCheckpoint(
159
+ dirpath="checkpoints",
160
+ filename="model-{epoch:02d}-{step}-{train_loss:.2f}", # Include training loss in filename
161
+ monitor="train_loss", # Monitor training loss
162
+ mode="min", # Lower loss is better
163
+ save_top_k=3, # Save the best 3 models
164
+ save_last=True, # Additionally save the last model
165
+ every_n_train_steps=500, # Save every 500 steps
166
+ save_weights_only=False, # Save the full model state
167
+ auto_insert_metric_name=False, # Don't insert metric name in filename
168
+ )
 
169
 
170
+ # Progress bar
171
+ progress_bar = RichProgressBar(
172
+ refresh_rate=1,
173
+ leave=False,
174
+ theme=RichProgressBarTheme(
175
+ description="",
176
+ progress_bar="#6206E0",
177
+ progress_bar_finished="#6206E0",
178
+ progress_bar_pulse="#6206E0",
179
+ batch_progress="",
180
+ time="dim",
181
+ processing_speed="dim underline",
182
+ metrics="italic",
183
+ metrics_text_delimiter=" ",
184
+ metrics_format=".3f",
185
+ ),
186
+ console_kwargs=None,
187
+ )
188
+
189
+ # Create trainer
190
+ trainer = Trainer(
191
+ logger=logger,
192
+ strategy="ddp",
193
  accelerator="gpu",
194
+ devices=2,
195
+ precision="16-mixed",
196
+ max_steps=5000,
197
+ accumulate_grad_batches=1,
198
  callbacks=[
199
  LearningRateMonitor(logging_interval="step"),
200
  progress_bar,
201
  checkpoint_callback,
202
  ],
 
203
  enable_progress_bar=True,
204
  enable_model_summary=True,
205
+ log_every_n_steps=10,
206
  )
207
 
208
+ # Find latest checkpoint if exists
209
+ if os.path.exists("checkpoints/last.ckpt"):
210
+ resume_from_checkpoint = "checkpoints/last.ckpt"
211
+ print(f"Resuming from checkpoint: {resume_from_checkpoint}")
212
  else:
213
+ resume_from_checkpoint = None
214
+ print("Starting training from scratch")
215
+
216
+ # Train with automatic checkpoint resumption
217
+ trainer.fit(model, train_loader, ckpt_path=resume_from_checkpoint)
218
+
219
+ # After training, print the best model path and score
220
+ print(f"Best model path: {checkpoint_callback.best_model_path}")
221
+ print(f"Best train loss: {checkpoint_callback.best_model_score:.4f}")
222
 
223
+ # Save final model
224
+ if trainer.is_global_zero:
225
  output_dir = "final_model"
226
  os.makedirs(output_dir, exist_ok=True)
227
+ model.model.save_pretrained(output_dir)
228
+ tokenizer.save_pretrained(output_dir)