Spaces:
Sleeping
Sleeping
Updated inference
Browse files- .gitattributes +2 -0
- .gitignore +2 -2
- README.md +2 -0
- final_model/config.json +29 -0
- final_model/generation_config.json +6 -0
- final_model/merges.txt +0 -0
- final_model/model.safetensors +3 -0
- final_model/special_tokens_map.json +43 -0
- final_model/tokenizer.json +0 -0
- final_model/tokenizer_config.json +170 -0
- final_model/vocab.json +0 -0
- inference.py +81 -57
- model.py +170 -145
.gitattributes
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
.gitignore
CHANGED
|
@@ -30,7 +30,7 @@ env.bak/
|
|
| 30 |
venv.bak/
|
| 31 |
|
| 32 |
# Training artifacts
|
| 33 |
-
checkpoints/
|
| 34 |
runs/
|
| 35 |
logs/
|
| 36 |
*.ckpt
|
|
@@ -38,7 +38,7 @@ logs/
|
|
| 38 |
*.pth
|
| 39 |
wandb/
|
| 40 |
lightning_logs/
|
| 41 |
-
final_model/
|
| 42 |
|
| 43 |
# IDE
|
| 44 |
.idea/
|
|
|
|
| 30 |
venv.bak/
|
| 31 |
|
| 32 |
# Training artifacts
|
| 33 |
+
# checkpoints/
|
| 34 |
runs/
|
| 35 |
logs/
|
| 36 |
*.ckpt
|
|
|
|
| 38 |
*.pth
|
| 39 |
wandb/
|
| 40 |
lightning_logs/
|
| 41 |
+
# final_model/
|
| 42 |
|
| 43 |
# IDE
|
| 44 |
.idea/
|
README.md
CHANGED
|
@@ -15,6 +15,8 @@ tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/cosmo2-tokenizer")
|
|
| 15 |
```
|
| 16 |
use config from https://huggingface.co/HuggingFaceTB/SmolLM2-135M/blob/main/config_smollm2_135M.yaml
|
| 17 |
|
|
|
|
|
|
|
| 18 |
create model from above parameters
|
| 19 |
|
| 20 |
Use it for training using pytorch lightning
|
|
|
|
| 15 |
```
|
| 16 |
use config from https://huggingface.co/HuggingFaceTB/SmolLM2-135M/blob/main/config_smollm2_135M.yaml
|
| 17 |
|
| 18 |
+
https://github.com/huggingface/smollm/blob/main/pre-training/smollm2/config_smollm2_135M.yaml
|
| 19 |
+
|
| 20 |
create model from above parameters
|
| 21 |
|
| 22 |
Use it for training using pytorch lightning
|
final_model/config.json
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"LlamaForCausalLM"
|
| 4 |
+
],
|
| 5 |
+
"attention_bias": false,
|
| 6 |
+
"attention_dropout": 0.0,
|
| 7 |
+
"bos_token_id": 0,
|
| 8 |
+
"eos_token_id": 0,
|
| 9 |
+
"head_dim": 64,
|
| 10 |
+
"hidden_act": "silu",
|
| 11 |
+
"hidden_size": 576,
|
| 12 |
+
"initializer_range": 0.041666666666666664,
|
| 13 |
+
"intermediate_size": 1536,
|
| 14 |
+
"max_position_embeddings": 2048,
|
| 15 |
+
"mlp_bias": false,
|
| 16 |
+
"model_type": "llama",
|
| 17 |
+
"num_attention_heads": 9,
|
| 18 |
+
"num_hidden_layers": 30,
|
| 19 |
+
"num_key_value_heads": 3,
|
| 20 |
+
"pretraining_tp": 1,
|
| 21 |
+
"rms_norm_eps": 1e-05,
|
| 22 |
+
"rope_scaling": null,
|
| 23 |
+
"rope_theta": 10000.0,
|
| 24 |
+
"tie_word_embeddings": false,
|
| 25 |
+
"torch_dtype": "float32",
|
| 26 |
+
"transformers_version": "4.47.0",
|
| 27 |
+
"use_cache": true,
|
| 28 |
+
"vocab_size": 49152
|
| 29 |
+
}
|
final_model/generation_config.json
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_from_model_config": true,
|
| 3 |
+
"bos_token_id": 0,
|
| 4 |
+
"eos_token_id": 0,
|
| 5 |
+
"transformers_version": "4.47.0"
|
| 6 |
+
}
|
final_model/merges.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
final_model/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:af74ecbdc60bc04c52a8dc7b42a79eecadadcbca7c300cc21e56583ab1c1e0b4
|
| 3 |
+
size 651336704
|
final_model/special_tokens_map.json
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"additional_special_tokens": [
|
| 3 |
+
"<|endoftext|>",
|
| 4 |
+
"<|im_start|>",
|
| 5 |
+
"<|im_end|>",
|
| 6 |
+
"<repo_name>",
|
| 7 |
+
"<reponame>",
|
| 8 |
+
"<file_sep>",
|
| 9 |
+
"<filename>",
|
| 10 |
+
"<gh_stars>",
|
| 11 |
+
"<issue_start>",
|
| 12 |
+
"<issue_comment>",
|
| 13 |
+
"<issue_closed>",
|
| 14 |
+
"<jupyter_start>",
|
| 15 |
+
"<jupyter_text>",
|
| 16 |
+
"<jupyter_code>",
|
| 17 |
+
"<jupyter_output>",
|
| 18 |
+
"<jupyter_script>",
|
| 19 |
+
"<empty_output>"
|
| 20 |
+
],
|
| 21 |
+
"bos_token": {
|
| 22 |
+
"content": "<|endoftext|>",
|
| 23 |
+
"lstrip": false,
|
| 24 |
+
"normalized": false,
|
| 25 |
+
"rstrip": false,
|
| 26 |
+
"single_word": false
|
| 27 |
+
},
|
| 28 |
+
"eos_token": {
|
| 29 |
+
"content": "<|endoftext|>",
|
| 30 |
+
"lstrip": false,
|
| 31 |
+
"normalized": false,
|
| 32 |
+
"rstrip": false,
|
| 33 |
+
"single_word": false
|
| 34 |
+
},
|
| 35 |
+
"pad_token": "<|endoftext|>",
|
| 36 |
+
"unk_token": {
|
| 37 |
+
"content": "<|endoftext|>",
|
| 38 |
+
"lstrip": false,
|
| 39 |
+
"normalized": false,
|
| 40 |
+
"rstrip": false,
|
| 41 |
+
"single_word": false
|
| 42 |
+
}
|
| 43 |
+
}
|
final_model/tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
final_model/tokenizer_config.json
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_prefix_space": false,
|
| 3 |
+
"added_tokens_decoder": {
|
| 4 |
+
"0": {
|
| 5 |
+
"content": "<|endoftext|>",
|
| 6 |
+
"lstrip": false,
|
| 7 |
+
"normalized": false,
|
| 8 |
+
"rstrip": false,
|
| 9 |
+
"single_word": false,
|
| 10 |
+
"special": true
|
| 11 |
+
},
|
| 12 |
+
"1": {
|
| 13 |
+
"content": "<|im_start|>",
|
| 14 |
+
"lstrip": false,
|
| 15 |
+
"normalized": false,
|
| 16 |
+
"rstrip": false,
|
| 17 |
+
"single_word": false,
|
| 18 |
+
"special": true
|
| 19 |
+
},
|
| 20 |
+
"2": {
|
| 21 |
+
"content": "<|im_end|>",
|
| 22 |
+
"lstrip": false,
|
| 23 |
+
"normalized": false,
|
| 24 |
+
"rstrip": false,
|
| 25 |
+
"single_word": false,
|
| 26 |
+
"special": true
|
| 27 |
+
},
|
| 28 |
+
"3": {
|
| 29 |
+
"content": "<repo_name>",
|
| 30 |
+
"lstrip": false,
|
| 31 |
+
"normalized": false,
|
| 32 |
+
"rstrip": false,
|
| 33 |
+
"single_word": false,
|
| 34 |
+
"special": true
|
| 35 |
+
},
|
| 36 |
+
"4": {
|
| 37 |
+
"content": "<reponame>",
|
| 38 |
+
"lstrip": false,
|
| 39 |
+
"normalized": false,
|
| 40 |
+
"rstrip": false,
|
| 41 |
+
"single_word": false,
|
| 42 |
+
"special": true
|
| 43 |
+
},
|
| 44 |
+
"5": {
|
| 45 |
+
"content": "<file_sep>",
|
| 46 |
+
"lstrip": false,
|
| 47 |
+
"normalized": false,
|
| 48 |
+
"rstrip": false,
|
| 49 |
+
"single_word": false,
|
| 50 |
+
"special": true
|
| 51 |
+
},
|
| 52 |
+
"6": {
|
| 53 |
+
"content": "<filename>",
|
| 54 |
+
"lstrip": false,
|
| 55 |
+
"normalized": false,
|
| 56 |
+
"rstrip": false,
|
| 57 |
+
"single_word": false,
|
| 58 |
+
"special": true
|
| 59 |
+
},
|
| 60 |
+
"7": {
|
| 61 |
+
"content": "<gh_stars>",
|
| 62 |
+
"lstrip": false,
|
| 63 |
+
"normalized": false,
|
| 64 |
+
"rstrip": false,
|
| 65 |
+
"single_word": false,
|
| 66 |
+
"special": true
|
| 67 |
+
},
|
| 68 |
+
"8": {
|
| 69 |
+
"content": "<issue_start>",
|
| 70 |
+
"lstrip": false,
|
| 71 |
+
"normalized": false,
|
| 72 |
+
"rstrip": false,
|
| 73 |
+
"single_word": false,
|
| 74 |
+
"special": true
|
| 75 |
+
},
|
| 76 |
+
"9": {
|
| 77 |
+
"content": "<issue_comment>",
|
| 78 |
+
"lstrip": false,
|
| 79 |
+
"normalized": false,
|
| 80 |
+
"rstrip": false,
|
| 81 |
+
"single_word": false,
|
| 82 |
+
"special": true
|
| 83 |
+
},
|
| 84 |
+
"10": {
|
| 85 |
+
"content": "<issue_closed>",
|
| 86 |
+
"lstrip": false,
|
| 87 |
+
"normalized": false,
|
| 88 |
+
"rstrip": false,
|
| 89 |
+
"single_word": false,
|
| 90 |
+
"special": true
|
| 91 |
+
},
|
| 92 |
+
"11": {
|
| 93 |
+
"content": "<jupyter_start>",
|
| 94 |
+
"lstrip": false,
|
| 95 |
+
"normalized": false,
|
| 96 |
+
"rstrip": false,
|
| 97 |
+
"single_word": false,
|
| 98 |
+
"special": true
|
| 99 |
+
},
|
| 100 |
+
"12": {
|
| 101 |
+
"content": "<jupyter_text>",
|
| 102 |
+
"lstrip": false,
|
| 103 |
+
"normalized": false,
|
| 104 |
+
"rstrip": false,
|
| 105 |
+
"single_word": false,
|
| 106 |
+
"special": true
|
| 107 |
+
},
|
| 108 |
+
"13": {
|
| 109 |
+
"content": "<jupyter_code>",
|
| 110 |
+
"lstrip": false,
|
| 111 |
+
"normalized": false,
|
| 112 |
+
"rstrip": false,
|
| 113 |
+
"single_word": false,
|
| 114 |
+
"special": true
|
| 115 |
+
},
|
| 116 |
+
"14": {
|
| 117 |
+
"content": "<jupyter_output>",
|
| 118 |
+
"lstrip": false,
|
| 119 |
+
"normalized": false,
|
| 120 |
+
"rstrip": false,
|
| 121 |
+
"single_word": false,
|
| 122 |
+
"special": true
|
| 123 |
+
},
|
| 124 |
+
"15": {
|
| 125 |
+
"content": "<jupyter_script>",
|
| 126 |
+
"lstrip": false,
|
| 127 |
+
"normalized": false,
|
| 128 |
+
"rstrip": false,
|
| 129 |
+
"single_word": false,
|
| 130 |
+
"special": true
|
| 131 |
+
},
|
| 132 |
+
"16": {
|
| 133 |
+
"content": "<empty_output>",
|
| 134 |
+
"lstrip": false,
|
| 135 |
+
"normalized": false,
|
| 136 |
+
"rstrip": false,
|
| 137 |
+
"single_word": false,
|
| 138 |
+
"special": true
|
| 139 |
+
}
|
| 140 |
+
},
|
| 141 |
+
"additional_special_tokens": [
|
| 142 |
+
"<|endoftext|>",
|
| 143 |
+
"<|im_start|>",
|
| 144 |
+
"<|im_end|>",
|
| 145 |
+
"<repo_name>",
|
| 146 |
+
"<reponame>",
|
| 147 |
+
"<file_sep>",
|
| 148 |
+
"<filename>",
|
| 149 |
+
"<gh_stars>",
|
| 150 |
+
"<issue_start>",
|
| 151 |
+
"<issue_comment>",
|
| 152 |
+
"<issue_closed>",
|
| 153 |
+
"<jupyter_start>",
|
| 154 |
+
"<jupyter_text>",
|
| 155 |
+
"<jupyter_code>",
|
| 156 |
+
"<jupyter_output>",
|
| 157 |
+
"<jupyter_script>",
|
| 158 |
+
"<empty_output>"
|
| 159 |
+
],
|
| 160 |
+
"bos_token": "<|endoftext|>",
|
| 161 |
+
"chat_template": "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
| 162 |
+
"clean_up_tokenization_spaces": false,
|
| 163 |
+
"eos_token": "<|endoftext|>",
|
| 164 |
+
"extra_special_tokens": {},
|
| 165 |
+
"model_max_length": 1000000000000000019884624838656,
|
| 166 |
+
"pad_token": "<|endoftext|>",
|
| 167 |
+
"tokenizer_class": "GPT2Tokenizer",
|
| 168 |
+
"unk_token": "<|endoftext|>",
|
| 169 |
+
"vocab_size": 49152
|
| 170 |
+
}
|
final_model/vocab.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
inference.py
CHANGED
|
@@ -1,8 +1,8 @@
|
|
| 1 |
import os
|
| 2 |
import gradio as gr
|
| 3 |
import torch
|
| 4 |
-
from model import SmolLMModule
|
| 5 |
-
from transformers import AutoTokenizer
|
| 6 |
import yaml
|
| 7 |
import glob
|
| 8 |
|
|
@@ -15,71 +15,94 @@ tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/cosmo2-tokenizer")
|
|
| 15 |
tokenizer.pad_token = tokenizer.eos_token
|
| 16 |
|
| 17 |
|
| 18 |
-
def load_model_from_checkpoint(checkpoint_path):
|
| 19 |
-
"""Load model from checkpoint"""
|
| 20 |
-
model = SmolLMModule.load_from_checkpoint(checkpoint_path, config=config)
|
| 21 |
-
model.eval() # Set to evaluation mode
|
| 22 |
-
return model
|
| 23 |
-
|
| 24 |
-
|
| 25 |
def get_available_checkpoints():
|
| 26 |
-
"""Get list of available checkpoints
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
return [], []
|
| 30 |
|
| 31 |
-
#
|
| 32 |
-
|
|
|
|
| 33 |
try:
|
| 34 |
# Extract step number from the filename
|
| 35 |
-
filename = os.path.basename(
|
| 36 |
-
#
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
return 0
|
| 47 |
|
| 48 |
-
|
| 49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
-
|
| 52 |
-
display_names = [f"Step {get_step_number(x)}" for x in checkpoints]
|
| 53 |
-
return display_names, checkpoints
|
| 54 |
|
| 55 |
|
| 56 |
-
def
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
|
| 64 |
if not prompt:
|
| 65 |
return "Please enter a prompt!"
|
| 66 |
|
| 67 |
try:
|
| 68 |
-
# Get
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
checkpoint_path = None
|
| 72 |
-
|
| 73 |
-
for ckpt in checkpoints:
|
| 74 |
-
if str(step_num) in ckpt:
|
| 75 |
-
checkpoint_path = ckpt
|
| 76 |
-
break
|
| 77 |
|
| 78 |
-
if not
|
| 79 |
-
return f"
|
| 80 |
|
| 81 |
-
# Load model
|
| 82 |
-
model = load_model_from_checkpoint(
|
| 83 |
|
| 84 |
# Move model to GPU if available
|
| 85 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
@@ -108,7 +131,7 @@ def generate_text(
|
|
| 108 |
return f"Error during generation: {str(e)}"
|
| 109 |
|
| 110 |
|
| 111 |
-
# Get available
|
| 112 |
display_names, _ = get_available_checkpoints()
|
| 113 |
|
| 114 |
# Create Gradio interface
|
|
@@ -116,17 +139,18 @@ with gr.Blocks(title="SmolLM2 Inference") as demo:
|
|
| 116 |
gr.Markdown("# SmolLM2 Text Generation")
|
| 117 |
|
| 118 |
if not display_names:
|
| 119 |
-
gr.Markdown("⚠️ No
|
| 120 |
else:
|
| 121 |
gr.Markdown(
|
| 122 |
-
f"Found {len(display_names)} checkpoints. Select one and enter a prompt to generate text."
|
| 123 |
)
|
|
|
|
| 124 |
|
| 125 |
with gr.Row():
|
| 126 |
with gr.Column():
|
| 127 |
-
|
| 128 |
choices=display_names,
|
| 129 |
-
label="Select
|
| 130 |
value=display_names[-1] if display_names else None,
|
| 131 |
interactive=True,
|
| 132 |
)
|
|
@@ -149,7 +173,7 @@ with gr.Blocks(title="SmolLM2 Inference") as demo:
|
|
| 149 |
|
| 150 |
generate_btn.click(
|
| 151 |
fn=generate_text,
|
| 152 |
-
inputs=[prompt,
|
| 153 |
outputs=output,
|
| 154 |
)
|
| 155 |
|
|
|
|
| 1 |
import os
|
| 2 |
import gradio as gr
|
| 3 |
import torch
|
| 4 |
+
from model import SmolLMModule
|
| 5 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 6 |
import yaml
|
| 7 |
import glob
|
| 8 |
|
|
|
|
| 15 |
tokenizer.pad_token = tokenizer.eos_token
|
| 16 |
|
| 17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
def get_available_checkpoints():
|
| 19 |
+
"""Get list of available checkpoints and final model"""
|
| 20 |
+
models = []
|
| 21 |
+
model_paths = {}
|
|
|
|
| 22 |
|
| 23 |
+
# Get checkpoints
|
| 24 |
+
checkpoints = glob.glob("checkpoints/*.ckpt")
|
| 25 |
+
for ckpt in checkpoints:
|
| 26 |
try:
|
| 27 |
# Extract step number from the filename
|
| 28 |
+
filename = os.path.basename(ckpt)
|
| 29 |
+
# Handle the format 'model-step=step=X.ckpt'
|
| 30 |
+
if "step=step=" in filename:
|
| 31 |
+
step = int(filename.split("step=step=")[1].split(".")[0])
|
| 32 |
+
display_name = f"Checkpoint Step {step}"
|
| 33 |
+
models.append(display_name)
|
| 34 |
+
model_paths[display_name] = ckpt
|
| 35 |
+
except (ValueError, IndexError) as e:
|
| 36 |
+
print(
|
| 37 |
+
f"Warning: Could not parse checkpoint filename: {filename}, Error: {e}"
|
| 38 |
+
)
|
| 39 |
+
continue
|
| 40 |
+
|
| 41 |
+
# Add final model if it exists
|
| 42 |
+
final_model_path = "final_model"
|
| 43 |
+
if os.path.exists(final_model_path):
|
| 44 |
+
display_name = "Final Model"
|
| 45 |
+
models.append(display_name)
|
| 46 |
+
model_paths[display_name] = final_model_path
|
| 47 |
+
|
| 48 |
+
# Sort checkpoints by step number (Final model will be at the end)
|
| 49 |
+
def get_step_number(name):
|
| 50 |
+
if name == "Final Model":
|
| 51 |
+
return float("inf")
|
| 52 |
+
try:
|
| 53 |
+
return int(name.split("Step ")[-1])
|
| 54 |
+
except:
|
| 55 |
return 0
|
| 56 |
|
| 57 |
+
models.sort(key=get_step_number)
|
| 58 |
+
|
| 59 |
+
if not models:
|
| 60 |
+
print(
|
| 61 |
+
"Warning: No checkpoints or final model found in the following locations:"
|
| 62 |
+
)
|
| 63 |
+
print("- Checkpoints directory:", os.path.abspath("checkpoints"))
|
| 64 |
+
print("- Final model directory:", os.path.abspath("final_model"))
|
| 65 |
+
else:
|
| 66 |
+
print(f"Found {len(models)} models:")
|
| 67 |
+
for model in models:
|
| 68 |
+
print(f"- {model}: {model_paths[model]}")
|
| 69 |
|
| 70 |
+
return models, model_paths
|
|
|
|
|
|
|
| 71 |
|
| 72 |
|
| 73 |
+
def load_model_from_checkpoint(model_path):
|
| 74 |
+
"""Load model from checkpoint or final model directory"""
|
| 75 |
+
if model_path == "final_model":
|
| 76 |
+
# Load the final saved model
|
| 77 |
+
model = SmolLMModule(config)
|
| 78 |
+
model.model = AutoModelForCausalLM.from_pretrained(model_path)
|
| 79 |
+
else:
|
| 80 |
+
# Load from checkpoint
|
| 81 |
+
model = SmolLMModule.load_from_checkpoint(model_path, config=config)
|
| 82 |
+
|
| 83 |
+
model.eval() # Set to evaluation mode
|
| 84 |
+
return model
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def generate_text(prompt, model_choice, max_length=100, temperature=0.7, top_p=0.9):
|
| 88 |
+
"""Generate text based on prompt using selected model"""
|
| 89 |
+
# Check if model is selected
|
| 90 |
+
if not model_choice:
|
| 91 |
+
return "Please select a model checkpoint!"
|
| 92 |
|
| 93 |
if not prompt:
|
| 94 |
return "Please enter a prompt!"
|
| 95 |
|
| 96 |
try:
|
| 97 |
+
# Get model path from the mapping
|
| 98 |
+
_, model_paths = get_available_checkpoints()
|
| 99 |
+
model_path = model_paths.get(model_choice)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
|
| 101 |
+
if not model_path or not os.path.exists(model_path):
|
| 102 |
+
return f"Model {model_choice} not found!"
|
| 103 |
|
| 104 |
+
# Load model
|
| 105 |
+
model = load_model_from_checkpoint(model_path)
|
| 106 |
|
| 107 |
# Move model to GPU if available
|
| 108 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
| 131 |
return f"Error during generation: {str(e)}"
|
| 132 |
|
| 133 |
|
| 134 |
+
# Get available models
|
| 135 |
display_names, _ = get_available_checkpoints()
|
| 136 |
|
| 137 |
# Create Gradio interface
|
|
|
|
| 139 |
gr.Markdown("# SmolLM2 Text Generation")
|
| 140 |
|
| 141 |
if not display_names:
|
| 142 |
+
gr.Markdown("⚠️ No models found! Please train the model first.")
|
| 143 |
else:
|
| 144 |
gr.Markdown(
|
| 145 |
+
f"Found {len(display_names)} models/checkpoints. Select one and enter a prompt to generate text."
|
| 146 |
)
|
| 147 |
+
gr.Markdown("Available models: " + ", ".join(display_names))
|
| 148 |
|
| 149 |
with gr.Row():
|
| 150 |
with gr.Column():
|
| 151 |
+
model_dropdown = gr.Dropdown(
|
| 152 |
choices=display_names,
|
| 153 |
+
label="Select Model",
|
| 154 |
value=display_names[-1] if display_names else None,
|
| 155 |
interactive=True,
|
| 156 |
)
|
|
|
|
| 173 |
|
| 174 |
generate_btn.click(
|
| 175 |
fn=generate_text,
|
| 176 |
+
inputs=[prompt, model_dropdown, max_length, temperature, top_p],
|
| 177 |
outputs=output,
|
| 178 |
)
|
| 179 |
|
model.py
CHANGED
|
@@ -1,203 +1,228 @@
|
|
| 1 |
-
|
|
|
|
|
|
|
| 2 |
from datasets import load_dataset
|
| 3 |
from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaConfig
|
| 4 |
-
from
|
| 5 |
-
|
| 6 |
-
import
|
| 7 |
-
|
| 8 |
-
|
|
|
|
|
|
|
| 9 |
from pytorch_lightning.loggers import TensorBoardLogger
|
| 10 |
-
import
|
| 11 |
-
from
|
| 12 |
-
|
| 13 |
-
#
|
| 14 |
-
|
| 15 |
-
train_dataset = dataset["train"]
|
| 16 |
-
for sample in train_dataset:
|
| 17 |
-
print(sample)
|
| 18 |
-
break
|
| 19 |
-
# load tokenizer
|
| 20 |
-
# use tokeniser from https://huggingface.co/HuggingFaceTB/cosmo2-tokenizer
|
| 21 |
-
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/cosmo2-tokenizer")
|
| 22 |
-
# Set padding token to be the same as EOS token
|
| 23 |
-
tokenizer.pad_token = tokenizer.eos_token
|
| 24 |
-
|
| 25 |
-
# load config
|
| 26 |
-
# use config from https://huggingface.co/HuggingFaceTB/SmolLM2-135M/blob/main/config_smollm2_135M.yaml
|
| 27 |
-
# config = AutoConfig.from_pretrained("HuggingFaceTB/SmolLM2-135M")
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
def collate_fn(examples):
|
| 31 |
-
# Tokenize the texts
|
| 32 |
-
encoding = tokenizer(
|
| 33 |
-
[example["text"] for example in examples],
|
| 34 |
-
padding=True,
|
| 35 |
-
truncation=True,
|
| 36 |
-
max_length=512,
|
| 37 |
-
return_tensors="pt",
|
| 38 |
-
)
|
| 39 |
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
pad_token_id=model_config["pad_token_id"],
|
| 61 |
-
bos_token_id=model_config["bos_token_id"],
|
| 62 |
-
eos_token_id=model_config["eos_token_id"],
|
| 63 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
|
| 65 |
|
| 66 |
-
#
|
| 67 |
-
class SmolLMModule(
|
| 68 |
def __init__(self, config, learning_rate=1e-4):
|
| 69 |
super().__init__()
|
| 70 |
self.config = config
|
| 71 |
self.learning_rate = learning_rate
|
| 72 |
-
self.save_hyperparameters()
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
self.model = AutoModelForCausalLM.from_config(model_config)
|
| 77 |
|
| 78 |
-
def forward(self, **inputs):
|
| 79 |
-
return self.model(**inputs)
|
| 80 |
-
|
| 81 |
def training_step(self, batch, batch_idx):
|
| 82 |
-
outputs = self.model(
|
| 83 |
loss = outputs.loss
|
| 84 |
-
self.log(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
return loss
|
| 86 |
|
| 87 |
def configure_optimizers(self):
|
| 88 |
-
|
| 89 |
self.model.parameters(),
|
| 90 |
lr=self.learning_rate,
|
| 91 |
betas=(0.9, 0.95),
|
| 92 |
eps=1e-8,
|
| 93 |
-
weight_decay=
|
| 94 |
)
|
| 95 |
-
return optimizer
|
| 96 |
-
|
| 97 |
-
def on_save_checkpoint(self, checkpoint):
|
| 98 |
-
# Save additional info if needed
|
| 99 |
-
checkpoint["step"] = self.global_step
|
| 100 |
-
checkpoint["model_config"] = self.config
|
| 101 |
-
|
| 102 |
-
def on_load_checkpoint(self, checkpoint):
|
| 103 |
-
# Restore additional info if needed
|
| 104 |
-
self.global_step = checkpoint["step"]
|
| 105 |
-
self.config = checkpoint["model_config"]
|
| 106 |
-
|
| 107 |
|
| 108 |
-
# train model
|
| 109 |
|
| 110 |
-
#
|
| 111 |
-
|
| 112 |
-
# training script
|
| 113 |
if __name__ == "__main__":
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
# parameters load from config file
|
| 118 |
-
with open("config_smollm2_135.yaml", "r") as file:
|
| 119 |
config = yaml.safe_load(file)
|
| 120 |
-
max_steps = 5000 # Total training steps
|
| 121 |
-
|
| 122 |
-
# Create checkpoint directory if it doesn't exist
|
| 123 |
-
checkpoint_dir = "checkpoints"
|
| 124 |
-
os.makedirs(checkpoint_dir, exist_ok=True)
|
| 125 |
-
|
| 126 |
-
# Checkpoint callback
|
| 127 |
-
checkpoint_callback = ModelCheckpoint(
|
| 128 |
-
dirpath=checkpoint_dir,
|
| 129 |
-
filename="model-step={step}",
|
| 130 |
-
save_top_k=-1, # Save all checkpoints
|
| 131 |
-
every_n_train_steps=500, # Save every 500 steps
|
| 132 |
-
save_weights_only=False, # Save the full model state
|
| 133 |
-
)
|
| 134 |
|
| 135 |
-
#
|
| 136 |
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/cosmo2-tokenizer")
|
| 137 |
-
# Set padding token to be the same as EOS token
|
| 138 |
tokenizer.pad_token = tokenizer.eos_token
|
| 139 |
|
| 140 |
-
#
|
| 141 |
dataset = load_dataset(
|
| 142 |
"HuggingFaceTB/smollm-corpus", "cosmopedia-v2", streaming=True
|
| 143 |
)
|
| 144 |
train_dataset = dataset["train"]
|
| 145 |
|
| 146 |
# Create DataLoader
|
|
|
|
| 147 |
train_loader = DataLoader(
|
| 148 |
-
|
| 149 |
-
batch_size=
|
|
|
|
| 150 |
collate_fn=collate_fn,
|
| 151 |
-
|
| 152 |
)
|
| 153 |
|
| 154 |
-
#
|
| 155 |
-
model = SmolLMModule(
|
|
|
|
|
|
|
|
|
|
| 156 |
|
| 157 |
-
#
|
| 158 |
-
|
| 159 |
|
| 160 |
-
#
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
print(f"Resuming from checkpoint: {latest_checkpoint}")
|
| 173 |
|
| 174 |
-
#
|
| 175 |
-
|
| 176 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 177 |
accelerator="gpu",
|
| 178 |
-
devices=
|
| 179 |
-
precision="
|
|
|
|
|
|
|
| 180 |
callbacks=[
|
| 181 |
LearningRateMonitor(logging_interval="step"),
|
| 182 |
progress_bar,
|
| 183 |
checkpoint_callback,
|
| 184 |
],
|
| 185 |
-
log_every_n_steps=1,
|
| 186 |
enable_progress_bar=True,
|
| 187 |
enable_model_summary=True,
|
|
|
|
| 188 |
)
|
| 189 |
|
| 190 |
-
#
|
| 191 |
-
if
|
| 192 |
-
|
| 193 |
-
|
| 194 |
else:
|
| 195 |
-
|
| 196 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 197 |
|
| 198 |
-
# Save final model
|
| 199 |
-
if trainer.is_global_zero:
|
| 200 |
output_dir = "final_model"
|
| 201 |
os.makedirs(output_dir, exist_ok=True)
|
| 202 |
-
model.model.save_pretrained(
|
| 203 |
-
tokenizer.save_pretrained(
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import yaml
|
| 4 |
from datasets import load_dataset
|
| 5 |
from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaConfig
|
| 6 |
+
from torch.utils.data import DataLoader, IterableDataset
|
| 7 |
+
from pytorch_lightning import Trainer, LightningModule
|
| 8 |
+
from pytorch_lightning.callbacks import (
|
| 9 |
+
ModelCheckpoint,
|
| 10 |
+
LearningRateMonitor,
|
| 11 |
+
RichProgressBar,
|
| 12 |
+
)
|
| 13 |
from pytorch_lightning.loggers import TensorBoardLogger
|
| 14 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 15 |
+
from lightning.pytorch.callbacks.progress.rich_progress import RichProgressBarTheme
|
| 16 |
+
|
| 17 |
+
# Set environment variable for memory management
|
| 18 |
+
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
+
|
| 21 |
+
# Function to log GPU memory usage
|
| 22 |
+
def log_memory_usage(step):
|
| 23 |
+
if torch.cuda.is_available():
|
| 24 |
+
print(
|
| 25 |
+
f"Step {step}: "
|
| 26 |
+
f"Allocated = {torch.cuda.memory_allocated() / 1e9:.2f} GB, "
|
| 27 |
+
f"Reserved = {torch.cuda.memory_reserved() / 1e9:.2f} GB"
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# Custom Collate Function
|
| 32 |
+
def collate_fn(batch):
|
| 33 |
+
input_ids = [item["input_ids"] for item in batch]
|
| 34 |
+
labels = [item["labels"] for item in batch]
|
| 35 |
+
input_ids = pad_sequence(
|
| 36 |
+
input_ids, batch_first=True, padding_value=tokenizer.pad_token_id
|
| 37 |
+
)
|
| 38 |
+
labels = pad_sequence(
|
| 39 |
+
labels, batch_first=True, padding_value=tokenizer.pad_token_id
|
|
|
|
|
|
|
|
|
|
| 40 |
)
|
| 41 |
+
return {"input_ids": input_ids, "labels": labels}
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# Streaming Dataset
|
| 45 |
+
class StreamingDataset(IterableDataset):
|
| 46 |
+
def __init__(self, dataset, tokenizer, max_length=2048):
|
| 47 |
+
self.dataset = dataset
|
| 48 |
+
self.tokenizer = tokenizer
|
| 49 |
+
self.max_length = max_length
|
| 50 |
+
|
| 51 |
+
def __iter__(self):
|
| 52 |
+
for example in iter(self.dataset):
|
| 53 |
+
tokenized = self.tokenizer(
|
| 54 |
+
example["text"],
|
| 55 |
+
truncation=True,
|
| 56 |
+
max_length=self.max_length,
|
| 57 |
+
return_overflowing_tokens=True,
|
| 58 |
+
return_tensors="pt",
|
| 59 |
+
)
|
| 60 |
+
for chunk in tokenized["input_ids"]:
|
| 61 |
+
yield {
|
| 62 |
+
"input_ids": chunk.squeeze(0),
|
| 63 |
+
"labels": chunk.squeeze(0),
|
| 64 |
+
}
|
| 65 |
|
| 66 |
|
| 67 |
+
# Lightning Module
|
| 68 |
+
class SmolLMModule(LightningModule):
|
| 69 |
def __init__(self, config, learning_rate=1e-4):
|
| 70 |
super().__init__()
|
| 71 |
self.config = config
|
| 72 |
self.learning_rate = learning_rate
|
| 73 |
+
self.save_hyperparameters()
|
| 74 |
+
|
| 75 |
+
model_config = LlamaConfig(
|
| 76 |
+
vocab_size=49152,
|
| 77 |
+
hidden_size=config["model"]["model_config"]["hidden_size"],
|
| 78 |
+
intermediate_size=config["model"]["model_config"]["intermediate_size"],
|
| 79 |
+
num_hidden_layers=config["model"]["model_config"]["num_hidden_layers"],
|
| 80 |
+
num_attention_heads=config["model"]["model_config"]["num_attention_heads"],
|
| 81 |
+
num_key_value_heads=config["model"]["model_config"]["num_key_value_heads"],
|
| 82 |
+
hidden_act=config["model"]["model_config"]["hidden_act"],
|
| 83 |
+
max_position_embeddings=config["model"]["model_config"][
|
| 84 |
+
"max_position_embeddings"
|
| 85 |
+
],
|
| 86 |
+
initializer_range=config["model"]["model_config"]["initializer_range"],
|
| 87 |
+
rms_norm_eps=1e-5,
|
| 88 |
+
use_cache=True,
|
| 89 |
+
pad_token_id=config["model"]["model_config"]["pad_token_id"],
|
| 90 |
+
bos_token_id=config["model"]["model_config"]["bos_token_id"],
|
| 91 |
+
eos_token_id=config["model"]["model_config"]["eos_token_id"],
|
| 92 |
+
)
|
| 93 |
self.model = AutoModelForCausalLM.from_config(model_config)
|
| 94 |
|
|
|
|
|
|
|
|
|
|
| 95 |
def training_step(self, batch, batch_idx):
|
| 96 |
+
outputs = self.model(input_ids=batch["input_ids"], labels=batch["labels"])
|
| 97 |
loss = outputs.loss
|
| 98 |
+
self.log(
|
| 99 |
+
"train_loss", loss, prog_bar=True, on_step=True, on_epoch=True
|
| 100 |
+
) # Log loss
|
| 101 |
+
|
| 102 |
+
# Log memory usage
|
| 103 |
+
if batch_idx % 10 == 0:
|
| 104 |
+
log_memory_usage(batch_idx)
|
| 105 |
+
|
| 106 |
+
# Release intermediate tensors
|
| 107 |
+
del outputs
|
| 108 |
+
torch.cuda.empty_cache()
|
| 109 |
+
|
| 110 |
return loss
|
| 111 |
|
| 112 |
def configure_optimizers(self):
|
| 113 |
+
return torch.optim.AdamW(
|
| 114 |
self.model.parameters(),
|
| 115 |
lr=self.learning_rate,
|
| 116 |
betas=(0.9, 0.95),
|
| 117 |
eps=1e-8,
|
| 118 |
+
weight_decay=self.config["optimizer"]["weight_decay"],
|
| 119 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
|
|
|
|
| 121 |
|
| 122 |
+
# Main Script
|
|
|
|
|
|
|
| 123 |
if __name__ == "__main__":
|
| 124 |
+
# Load config
|
| 125 |
+
with open("/kaggle/input/yaml-file/config_smollm2_135.yaml", "r") as file:
|
|
|
|
|
|
|
|
|
|
| 126 |
config = yaml.safe_load(file)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
|
| 128 |
+
# Load tokenizer
|
| 129 |
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/cosmo2-tokenizer")
|
|
|
|
| 130 |
tokenizer.pad_token = tokenizer.eos_token
|
| 131 |
|
| 132 |
+
# Load dataset
|
| 133 |
dataset = load_dataset(
|
| 134 |
"HuggingFaceTB/smollm-corpus", "cosmopedia-v2", streaming=True
|
| 135 |
)
|
| 136 |
train_dataset = dataset["train"]
|
| 137 |
|
| 138 |
# Create DataLoader
|
| 139 |
+
streaming_dataset = StreamingDataset(train_dataset, tokenizer, max_length=2048)
|
| 140 |
train_loader = DataLoader(
|
| 141 |
+
streaming_dataset,
|
| 142 |
+
batch_size=1, # Reduced batch size
|
| 143 |
+
num_workers=4,
|
| 144 |
collate_fn=collate_fn,
|
| 145 |
+
pin_memory=True,
|
| 146 |
)
|
| 147 |
|
| 148 |
+
# Create model
|
| 149 |
+
model = SmolLMModule(
|
| 150 |
+
config,
|
| 151 |
+
learning_rate=config["optimizer"]["learning_rate_scheduler"]["learning_rate"],
|
| 152 |
+
)
|
| 153 |
|
| 154 |
+
# Initialize logger with version based on start_step
|
| 155 |
+
logger = TensorBoardLogger("logs", name="smollm2")
|
| 156 |
|
| 157 |
+
# Checkpoint callback configuration
|
| 158 |
+
checkpoint_callback = ModelCheckpoint(
|
| 159 |
+
dirpath="checkpoints",
|
| 160 |
+
filename="model-{epoch:02d}-{step}-{train_loss:.2f}", # Include training loss in filename
|
| 161 |
+
monitor="train_loss", # Monitor training loss
|
| 162 |
+
mode="min", # Lower loss is better
|
| 163 |
+
save_top_k=3, # Save the best 3 models
|
| 164 |
+
save_last=True, # Additionally save the last model
|
| 165 |
+
every_n_train_steps=500, # Save every 500 steps
|
| 166 |
+
save_weights_only=False, # Save the full model state
|
| 167 |
+
auto_insert_metric_name=False, # Don't insert metric name in filename
|
| 168 |
+
)
|
|
|
|
| 169 |
|
| 170 |
+
# Progress bar
|
| 171 |
+
progress_bar = RichProgressBar(
|
| 172 |
+
refresh_rate=1,
|
| 173 |
+
leave=False,
|
| 174 |
+
theme=RichProgressBarTheme(
|
| 175 |
+
description="",
|
| 176 |
+
progress_bar="#6206E0",
|
| 177 |
+
progress_bar_finished="#6206E0",
|
| 178 |
+
progress_bar_pulse="#6206E0",
|
| 179 |
+
batch_progress="",
|
| 180 |
+
time="dim",
|
| 181 |
+
processing_speed="dim underline",
|
| 182 |
+
metrics="italic",
|
| 183 |
+
metrics_text_delimiter=" ",
|
| 184 |
+
metrics_format=".3f",
|
| 185 |
+
),
|
| 186 |
+
console_kwargs=None,
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
# Create trainer
|
| 190 |
+
trainer = Trainer(
|
| 191 |
+
logger=logger,
|
| 192 |
+
strategy="ddp",
|
| 193 |
accelerator="gpu",
|
| 194 |
+
devices=2,
|
| 195 |
+
precision="16-mixed",
|
| 196 |
+
max_steps=5000,
|
| 197 |
+
accumulate_grad_batches=1,
|
| 198 |
callbacks=[
|
| 199 |
LearningRateMonitor(logging_interval="step"),
|
| 200 |
progress_bar,
|
| 201 |
checkpoint_callback,
|
| 202 |
],
|
|
|
|
| 203 |
enable_progress_bar=True,
|
| 204 |
enable_model_summary=True,
|
| 205 |
+
log_every_n_steps=10,
|
| 206 |
)
|
| 207 |
|
| 208 |
+
# Find latest checkpoint if exists
|
| 209 |
+
if os.path.exists("checkpoints/last.ckpt"):
|
| 210 |
+
resume_from_checkpoint = "checkpoints/last.ckpt"
|
| 211 |
+
print(f"Resuming from checkpoint: {resume_from_checkpoint}")
|
| 212 |
else:
|
| 213 |
+
resume_from_checkpoint = None
|
| 214 |
+
print("Starting training from scratch")
|
| 215 |
+
|
| 216 |
+
# Train with automatic checkpoint resumption
|
| 217 |
+
trainer.fit(model, train_loader, ckpt_path=resume_from_checkpoint)
|
| 218 |
+
|
| 219 |
+
# After training, print the best model path and score
|
| 220 |
+
print(f"Best model path: {checkpoint_callback.best_model_path}")
|
| 221 |
+
print(f"Best train loss: {checkpoint_callback.best_model_score:.4f}")
|
| 222 |
|
| 223 |
+
# Save final model
|
| 224 |
+
if trainer.is_global_zero:
|
| 225 |
output_dir = "final_model"
|
| 226 |
os.makedirs(output_dir, exist_ok=True)
|
| 227 |
+
model.model.save_pretrained(output_dir)
|
| 228 |
+
tokenizer.save_pretrained(output_dir)
|