Spaces:
Runtime error
Runtime error
Upload folder using huggingface_hub
Browse files- run_transformers_training.py +329 -62
run_transformers_training.py
CHANGED
@@ -10,6 +10,7 @@ import logging
|
|
10 |
from datetime import datetime
|
11 |
import time
|
12 |
import warnings
|
|
|
13 |
from importlib.util import find_spec
|
14 |
import multiprocessing
|
15 |
import torch
|
@@ -31,64 +32,36 @@ if CUDA_AVAILABLE:
|
|
31 |
# Method already set, which is fine
|
32 |
print("Multiprocessing start method already set")
|
33 |
|
34 |
-
#
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
level=logging.INFO,
|
40 |
-
format="%(asctime)s - %(levelname)s - %(message)s",
|
41 |
-
handlers=[logging.StreamHandler(sys.stdout)]
|
42 |
-
)
|
43 |
-
logger = logging.getLogger(__name__)
|
44 |
-
|
45 |
-
# Set other loggers to WARNING to reduce noise and ensure our logs are visible
|
46 |
-
logging.getLogger("transformers").setLevel(logging.WARNING)
|
47 |
-
logging.getLogger("datasets").setLevel(logging.WARNING)
|
48 |
-
logging.getLogger("accelerate").setLevel(logging.WARNING)
|
49 |
-
logging.getLogger("torch").setLevel(logging.WARNING)
|
50 |
-
logging.getLogger("bitsandbytes").setLevel(logging.WARNING)
|
51 |
-
|
52 |
-
# Import Unsloth first, before other ML imports
|
53 |
-
try:
|
54 |
-
from unsloth import FastLanguageModel
|
55 |
-
from unsloth.chat_templates import get_chat_template
|
56 |
-
unsloth_available = True
|
57 |
-
logger.info("Unsloth successfully imported")
|
58 |
-
except ImportError:
|
59 |
-
unsloth_available = False
|
60 |
-
logger.warning("Unsloth not available. Please install with: pip install unsloth")
|
61 |
|
62 |
-
#
|
63 |
-
|
|
|
|
|
64 |
import transformers
|
65 |
-
from transformers import
|
66 |
-
|
67 |
-
AutoTokenizer,
|
68 |
-
TrainingArguments,
|
69 |
-
Trainer,
|
70 |
-
TrainerCallback,
|
71 |
-
set_seed,
|
72 |
-
BitsAndBytesConfig
|
73 |
-
)
|
74 |
-
logger.info(f"Transformers version: {transformers.__version__}")
|
75 |
-
except ImportError:
|
76 |
-
logger.error("Transformers not available. This is a critical dependency.")
|
77 |
|
78 |
-
# Check availability of libraries
|
79 |
peft_available = find_spec("peft") is not None
|
80 |
if peft_available:
|
81 |
import peft
|
82 |
-
logger.info(f"PEFT version: {peft.__version__}")
|
83 |
-
else:
|
84 |
-
logger.warning("PEFT not available. Parameter-efficient fine-tuning will not be used.")
|
85 |
|
86 |
-
#
|
87 |
-
|
|
|
88 |
from datasets import load_dataset
|
89 |
-
|
90 |
-
|
91 |
-
|
|
|
|
|
|
|
|
|
|
|
92 |
|
93 |
# Define a clean logging function for HF Space compatibility
|
94 |
def log_info(message):
|
@@ -243,6 +216,17 @@ def load_model_and_tokenizer(config):
|
|
243 |
chat_template = get_config_value(tokenizer_config, "chat_template", None)
|
244 |
padding_side = get_config_value(tokenizer_config, "padding_side", "right")
|
245 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
246 |
log_info(f"Loading model: {model_name} (revision: {model_revision})")
|
247 |
log_info(f"Max sequence length: {max_seq_length}")
|
248 |
|
@@ -257,7 +241,7 @@ def load_model_and_tokenizer(config):
|
|
257 |
dtype=get_config_value(config, "torch_dtype", "bfloat16"),
|
258 |
revision=model_revision,
|
259 |
trust_remote_code=trust_remote_code,
|
260 |
-
use_flash_attention_2=
|
261 |
)
|
262 |
|
263 |
# Configure tokenizer settings
|
@@ -294,11 +278,23 @@ def load_model_and_tokenizer(config):
|
|
294 |
max_seq_length=max_seq_length,
|
295 |
modules_to_save=None
|
296 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
297 |
else:
|
298 |
# Standard HuggingFace loading
|
299 |
log_info("Using standard HuggingFace model loading (Unsloth not available or disabled)")
|
300 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
301 |
|
|
|
|
|
|
|
|
|
|
|
|
|
302 |
# Load tokenizer first
|
303 |
tokenizer = AutoTokenizer.from_pretrained(
|
304 |
model_name,
|
@@ -327,7 +323,8 @@ def load_model_and_tokenizer(config):
|
|
327 |
trust_remote_code=trust_remote_code,
|
328 |
revision=model_revision,
|
329 |
torch_dtype=torch.bfloat16 if get_config_value(config, "torch_dtype", "bfloat16") == "bfloat16" else torch.float16,
|
330 |
-
device_map="auto" if CUDA_AVAILABLE else None
|
|
|
331 |
)
|
332 |
|
333 |
# Apply PEFT/LoRA if enabled but using standard loading
|
@@ -760,6 +757,63 @@ class LoggingCallback(TrainerCallback):
|
|
760 |
"""Called at the beginning of a step"""
|
761 |
pass
|
762 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
763 |
def check_dependencies():
|
764 |
"""
|
765 |
Check for required and optional dependencies, ensuring proper versions and import order.
|
@@ -785,6 +839,7 @@ def check_dependencies():
|
|
785 |
missing_packages = []
|
786 |
package_versions = {}
|
787 |
order_issues = []
|
|
|
788 |
|
789 |
# Check required packages
|
790 |
log_info("Checking required dependencies...")
|
@@ -822,6 +877,7 @@ def check_dependencies():
|
|
822 |
log_info(f"✅ {package} - {feature} available")
|
823 |
except ImportError:
|
824 |
log_info(f"⚠️ {package} - {feature} not available")
|
|
|
825 |
|
826 |
# Check import order for optimal performance
|
827 |
if "transformers" in package_versions and "unsloth" in package_versions:
|
@@ -835,11 +891,19 @@ def check_dependencies():
|
|
835 |
order_issue = "⚠️ For optimal performance, import unsloth before transformers"
|
836 |
order_issues.append(order_issue)
|
837 |
log_info(order_issue)
|
|
|
838 |
else:
|
839 |
log_info("✅ Import order: unsloth before transformers (optimal)")
|
840 |
except (ValueError, IndexError) as e:
|
841 |
log_info(f"⚠️ Could not verify import order: {str(e)}")
|
842 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
843 |
# Report missing required packages
|
844 |
if missing_packages:
|
845 |
log_info("\n❌ Critical dependencies missing:")
|
@@ -990,10 +1054,22 @@ def setup_environment(args):
|
|
990 |
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = f"max_split_size_mb:128,expandable_segments:True"
|
991 |
log_info(f"Set CUDA memory allocation limit to expandable with max_split_size_mb:128")
|
992 |
|
993 |
-
# Check dependencies
|
994 |
if not check_dependencies():
|
995 |
raise RuntimeError("Critical dependencies missing")
|
996 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
997 |
return transformers_config, seed
|
998 |
|
999 |
def setup_model_and_tokenizer(config):
|
@@ -1001,21 +1077,206 @@ def setup_model_and_tokenizer(config):
|
|
1001 |
Load and configure the model and tokenizer.
|
1002 |
|
1003 |
Args:
|
1004 |
-
config: Complete configuration dictionary
|
1005 |
|
1006 |
Returns:
|
1007 |
tuple: (model, tokenizer) - The loaded model and tokenizer
|
1008 |
"""
|
1009 |
-
|
1010 |
-
|
|
|
|
|
|
|
|
|
1011 |
|
1012 |
-
if model is
|
1013 |
-
|
1014 |
-
|
1015 |
-
|
1016 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1017 |
|
1018 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1019 |
|
1020 |
def setup_dataset_and_collator(config, tokenizer):
|
1021 |
"""
|
@@ -1229,6 +1490,12 @@ def main():
|
|
1229 |
logger.info("Starting training process")
|
1230 |
|
1231 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
1232 |
# Parse command line arguments
|
1233 |
args = parse_args()
|
1234 |
|
|
|
10 |
from datetime import datetime
|
11 |
import time
|
12 |
import warnings
|
13 |
+
import traceback
|
14 |
from importlib.util import find_spec
|
15 |
import multiprocessing
|
16 |
import torch
|
|
|
32 |
# Method already set, which is fine
|
33 |
print("Multiprocessing start method already set")
|
34 |
|
35 |
+
# Import order is important: unsloth should be imported before transformers
|
36 |
+
# Check for libraries without importing them
|
37 |
+
unsloth_available = find_spec("unsloth") is not None
|
38 |
+
if unsloth_available:
|
39 |
+
import unsloth
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
|
41 |
+
# Import torch first, then transformers if available
|
42 |
+
import torch
|
43 |
+
transformers_available = find_spec("transformers") is not None
|
44 |
+
if transformers_available:
|
45 |
import transformers
|
46 |
+
from transformers import AutoTokenizer, TrainingArguments, Trainer, set_seed
|
47 |
+
from torch.utils.data import DataLoader
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
|
|
|
49 |
peft_available = find_spec("peft") is not None
|
50 |
if peft_available:
|
51 |
import peft
|
|
|
|
|
|
|
52 |
|
53 |
+
# Only import HF datasets if available
|
54 |
+
datasets_available = find_spec("datasets") is not None
|
55 |
+
if datasets_available:
|
56 |
from datasets import load_dataset
|
57 |
+
|
58 |
+
# Set up the logger
|
59 |
+
logger = logging.getLogger(__name__)
|
60 |
+
log_handler = logging.StreamHandler()
|
61 |
+
log_format = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
62 |
+
log_handler.setFormatter(log_format)
|
63 |
+
logger.addHandler(log_handler)
|
64 |
+
logger.setLevel(logging.INFO)
|
65 |
|
66 |
# Define a clean logging function for HF Space compatibility
|
67 |
def log_info(message):
|
|
|
216 |
chat_template = get_config_value(tokenizer_config, "chat_template", None)
|
217 |
padding_side = get_config_value(tokenizer_config, "padding_side", "right")
|
218 |
|
219 |
+
# Check for flash attention
|
220 |
+
use_flash_attention = get_config_value(config, "use_flash_attention", False)
|
221 |
+
flash_attention_available = False
|
222 |
+
try:
|
223 |
+
import flash_attn
|
224 |
+
flash_attention_available = True
|
225 |
+
log_info(f"Flash Attention detected (version: {flash_attn.__version__})")
|
226 |
+
except ImportError:
|
227 |
+
if use_flash_attention:
|
228 |
+
log_info("Flash Attention requested but not available")
|
229 |
+
|
230 |
log_info(f"Loading model: {model_name} (revision: {model_revision})")
|
231 |
log_info(f"Max sequence length: {max_seq_length}")
|
232 |
|
|
|
241 |
dtype=get_config_value(config, "torch_dtype", "bfloat16"),
|
242 |
revision=model_revision,
|
243 |
trust_remote_code=trust_remote_code,
|
244 |
+
use_flash_attention_2=use_flash_attention and flash_attention_available
|
245 |
)
|
246 |
|
247 |
# Configure tokenizer settings
|
|
|
278 |
max_seq_length=max_seq_length,
|
279 |
modules_to_save=None
|
280 |
)
|
281 |
+
|
282 |
+
if use_flash_attention and flash_attention_available:
|
283 |
+
log_info("🚀 Using Flash Attention for faster training")
|
284 |
+
elif use_flash_attention and not flash_attention_available:
|
285 |
+
log_info("⚠️ Flash Attention requested but not available - using standard attention")
|
286 |
+
|
287 |
else:
|
288 |
# Standard HuggingFace loading
|
289 |
log_info("Using standard HuggingFace model loading (Unsloth not available or disabled)")
|
290 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
291 |
|
292 |
+
# Check if flash attention should be enabled in config
|
293 |
+
use_attn_implementation = None
|
294 |
+
if use_flash_attention and flash_attention_available:
|
295 |
+
use_attn_implementation = "flash_attention_2"
|
296 |
+
log_info("🚀 Using Flash Attention for faster training")
|
297 |
+
|
298 |
# Load tokenizer first
|
299 |
tokenizer = AutoTokenizer.from_pretrained(
|
300 |
model_name,
|
|
|
323 |
trust_remote_code=trust_remote_code,
|
324 |
revision=model_revision,
|
325 |
torch_dtype=torch.bfloat16 if get_config_value(config, "torch_dtype", "bfloat16") == "bfloat16" else torch.float16,
|
326 |
+
device_map="auto" if CUDA_AVAILABLE else None,
|
327 |
+
attn_implementation=use_attn_implementation
|
328 |
)
|
329 |
|
330 |
# Apply PEFT/LoRA if enabled but using standard loading
|
|
|
757 |
"""Called at the beginning of a step"""
|
758 |
pass
|
759 |
|
760 |
+
def install_flash_attention():
|
761 |
+
"""
|
762 |
+
Attempt to install Flash Attention for improved performance.
|
763 |
+
Returns True if installation was successful, False otherwise.
|
764 |
+
"""
|
765 |
+
log_info("Attempting to install Flash Attention...")
|
766 |
+
|
767 |
+
# Check for CUDA before attempting installation
|
768 |
+
if not CUDA_AVAILABLE:
|
769 |
+
log_info("❌ Cannot install Flash Attention: CUDA not available")
|
770 |
+
return False
|
771 |
+
|
772 |
+
try:
|
773 |
+
# Check CUDA version to determine correct installation command
|
774 |
+
cuda_version = torch.version.cuda
|
775 |
+
if cuda_version is None:
|
776 |
+
log_info("❌ Cannot determine CUDA version for Flash Attention installation")
|
777 |
+
return False
|
778 |
+
|
779 |
+
import subprocess
|
780 |
+
|
781 |
+
# Use --no-build-isolation for better compatibility
|
782 |
+
install_cmd = [
|
783 |
+
sys.executable,
|
784 |
+
"-m",
|
785 |
+
"pip",
|
786 |
+
"install",
|
787 |
+
"flash-attn",
|
788 |
+
"--no-build-isolation"
|
789 |
+
]
|
790 |
+
|
791 |
+
log_info(f"Running: {' '.join(install_cmd)}")
|
792 |
+
result = subprocess.run(
|
793 |
+
install_cmd,
|
794 |
+
capture_output=True,
|
795 |
+
text=True,
|
796 |
+
check=False
|
797 |
+
)
|
798 |
+
|
799 |
+
if result.returncode == 0:
|
800 |
+
log_info("✅ Flash Attention installed successfully!")
|
801 |
+
# Attempt to import to verify installation
|
802 |
+
try:
|
803 |
+
import flash_attn
|
804 |
+
log_info(f"✅ Flash Attention version {flash_attn.__version__} is now available")
|
805 |
+
return True
|
806 |
+
except ImportError:
|
807 |
+
log_info("⚠️ Flash Attention installed but import failed")
|
808 |
+
return False
|
809 |
+
else:
|
810 |
+
log_info(f"❌ Flash Attention installation failed with error: {result.stderr}")
|
811 |
+
return False
|
812 |
+
|
813 |
+
except Exception as e:
|
814 |
+
log_info(f"❌ Error installing Flash Attention: {str(e)}")
|
815 |
+
return False
|
816 |
+
|
817 |
def check_dependencies():
|
818 |
"""
|
819 |
Check for required and optional dependencies, ensuring proper versions and import order.
|
|
|
839 |
missing_packages = []
|
840 |
package_versions = {}
|
841 |
order_issues = []
|
842 |
+
missing_optional = []
|
843 |
|
844 |
# Check required packages
|
845 |
log_info("Checking required dependencies...")
|
|
|
877 |
log_info(f"✅ {package} - {feature} available")
|
878 |
except ImportError:
|
879 |
log_info(f"⚠️ {package} - {feature} not available")
|
880 |
+
missing_optional.append(package)
|
881 |
|
882 |
# Check import order for optimal performance
|
883 |
if "transformers" in package_versions and "unsloth" in package_versions:
|
|
|
891 |
order_issue = "⚠️ For optimal performance, import unsloth before transformers"
|
892 |
order_issues.append(order_issue)
|
893 |
log_info(order_issue)
|
894 |
+
log_info("This might cause performance issues but won't prevent training")
|
895 |
else:
|
896 |
log_info("✅ Import order: unsloth before transformers (optimal)")
|
897 |
except (ValueError, IndexError) as e:
|
898 |
log_info(f"⚠️ Could not verify import order: {str(e)}")
|
899 |
|
900 |
+
# Try to install missing optional packages
|
901 |
+
if "flash_attn" in missing_optional and CUDA_AVAILABLE:
|
902 |
+
log_info("\nFlash Attention is missing but would improve performance.")
|
903 |
+
install_result = install_flash_attention()
|
904 |
+
if install_result:
|
905 |
+
missing_optional.remove("flash_attn")
|
906 |
+
|
907 |
# Report missing required packages
|
908 |
if missing_packages:
|
909 |
log_info("\n❌ Critical dependencies missing:")
|
|
|
1054 |
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = f"max_split_size_mb:128,expandable_segments:True"
|
1055 |
log_info(f"Set CUDA memory allocation limit to expandable with max_split_size_mb:128")
|
1056 |
|
1057 |
+
# Check dependencies and install optional ones if needed
|
1058 |
if not check_dependencies():
|
1059 |
raise RuntimeError("Critical dependencies missing")
|
1060 |
|
1061 |
+
# Check if flash attention was successfully installed
|
1062 |
+
flash_attention_available = False
|
1063 |
+
try:
|
1064 |
+
import flash_attn
|
1065 |
+
flash_attention_available = True
|
1066 |
+
log_info(f"Flash Attention will be used (version: {flash_attn.__version__})")
|
1067 |
+
# Update config to use flash attention
|
1068 |
+
if "use_flash_attention" not in transformers_config:
|
1069 |
+
transformers_config["use_flash_attention"] = True
|
1070 |
+
except ImportError:
|
1071 |
+
log_info("Flash Attention not available, will use standard attention mechanism")
|
1072 |
+
|
1073 |
return transformers_config, seed
|
1074 |
|
1075 |
def setup_model_and_tokenizer(config):
|
|
|
1077 |
Load and configure the model and tokenizer.
|
1078 |
|
1079 |
Args:
|
1080 |
+
config (dict): Complete configuration dictionary
|
1081 |
|
1082 |
Returns:
|
1083 |
tuple: (model, tokenizer) - The loaded model and tokenizer
|
1084 |
"""
|
1085 |
+
# Extract model configuration
|
1086 |
+
model_config = get_config_value(config, "model", {})
|
1087 |
+
model_name = get_config_value(model_config, "name", "unsloth/phi-4-unsloth-bnb-4bit")
|
1088 |
+
use_fast_tokenizer = get_config_value(model_config, "use_fast_tokenizer", True)
|
1089 |
+
trust_remote_code = get_config_value(model_config, "trust_remote_code", True)
|
1090 |
+
model_revision = get_config_value(config, "model_revision", "main")
|
1091 |
|
1092 |
+
# Detect if model is already pre-quantized (includes '4bit', 'bnb', or 'int4' in name)
|
1093 |
+
is_prequantized = any(q in model_name.lower() for q in ['4bit', 'bnb', 'int4', 'quant'])
|
1094 |
+
if is_prequantized:
|
1095 |
+
log_info("⚠️ Detected pre-quantized model. No additional quantization will be applied.")
|
1096 |
+
|
1097 |
+
# Unsloth configuration
|
1098 |
+
unsloth_config = get_config_value(config, "unsloth", {})
|
1099 |
+
unsloth_enabled = get_config_value(unsloth_config, "enabled", True)
|
1100 |
+
|
1101 |
+
# Tokenizer configuration
|
1102 |
+
tokenizer_config = get_config_value(config, "tokenizer", {})
|
1103 |
+
max_seq_length = min(
|
1104 |
+
get_config_value(tokenizer_config, "max_seq_length", 2048),
|
1105 |
+
4096 # Maximum supported by most models
|
1106 |
+
)
|
1107 |
+
add_eos_token = get_config_value(tokenizer_config, "add_eos_token", True)
|
1108 |
+
chat_template = get_config_value(tokenizer_config, "chat_template", None)
|
1109 |
+
padding_side = get_config_value(tokenizer_config, "padding_side", "right")
|
1110 |
+
|
1111 |
+
# Check for flash attention
|
1112 |
+
use_flash_attention = get_config_value(config, "use_flash_attention", False)
|
1113 |
+
flash_attention_available = False
|
1114 |
+
try:
|
1115 |
+
import flash_attn
|
1116 |
+
flash_attention_available = True
|
1117 |
+
log_info(f"Flash Attention detected (version: {flash_attn.__version__})")
|
1118 |
+
except ImportError:
|
1119 |
+
if use_flash_attention:
|
1120 |
+
log_info("Flash Attention requested but not available")
|
1121 |
+
|
1122 |
+
log_info(f"Loading model: {model_name} (revision: {model_revision})")
|
1123 |
+
log_info(f"Max sequence length: {max_seq_length}")
|
1124 |
|
1125 |
+
try:
|
1126 |
+
if unsloth_enabled and unsloth_available:
|
1127 |
+
log_info("Using Unsloth for LoRA fine-tuning")
|
1128 |
+
if is_prequantized:
|
1129 |
+
log_info("Using pre-quantized model - no additional quantization will be applied")
|
1130 |
+
else:
|
1131 |
+
log_info("Using 4-bit quantization for efficient training")
|
1132 |
+
|
1133 |
+
# Load using Unsloth
|
1134 |
+
from unsloth import FastLanguageModel
|
1135 |
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
1136 |
+
model_name=model_name,
|
1137 |
+
max_seq_length=max_seq_length,
|
1138 |
+
dtype=get_config_value(config, "torch_dtype", "bfloat16"),
|
1139 |
+
revision=model_revision,
|
1140 |
+
trust_remote_code=trust_remote_code,
|
1141 |
+
use_flash_attention_2=use_flash_attention and flash_attention_available
|
1142 |
+
)
|
1143 |
+
|
1144 |
+
# Configure tokenizer settings
|
1145 |
+
tokenizer.padding_side = padding_side
|
1146 |
+
if add_eos_token and tokenizer.eos_token is None:
|
1147 |
+
log_info("Setting EOS token")
|
1148 |
+
tokenizer.add_special_tokens({"eos_token": "</s>"})
|
1149 |
+
|
1150 |
+
# Set chat template if specified
|
1151 |
+
if chat_template:
|
1152 |
+
log_info(f"Setting chat template: {chat_template}")
|
1153 |
+
if hasattr(tokenizer, "chat_template"):
|
1154 |
+
tokenizer.chat_template = chat_template
|
1155 |
+
else:
|
1156 |
+
log_info("Tokenizer does not support chat templates, using default formatting")
|
1157 |
+
|
1158 |
+
# Apply LoRA
|
1159 |
+
lora_r = get_config_value(unsloth_config, "r", 16)
|
1160 |
+
lora_alpha = get_config_value(unsloth_config, "alpha", 32)
|
1161 |
+
lora_dropout = get_config_value(unsloth_config, "dropout", 0)
|
1162 |
+
target_modules = get_config_value(unsloth_config, "target_modules",
|
1163 |
+
["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"])
|
1164 |
+
|
1165 |
+
log_info(f"Applying LoRA with r={lora_r}, alpha={lora_alpha}, dropout={lora_dropout}")
|
1166 |
+
model = FastLanguageModel.get_peft_model(
|
1167 |
+
model,
|
1168 |
+
r=lora_r,
|
1169 |
+
target_modules=target_modules,
|
1170 |
+
lora_alpha=lora_alpha,
|
1171 |
+
lora_dropout=lora_dropout,
|
1172 |
+
bias="none",
|
1173 |
+
use_gradient_checkpointing=get_config_value(config, "training.gradient_checkpointing", True),
|
1174 |
+
random_state=0,
|
1175 |
+
max_seq_length=max_seq_length,
|
1176 |
+
modules_to_save=None
|
1177 |
+
)
|
1178 |
+
|
1179 |
+
if use_flash_attention and flash_attention_available:
|
1180 |
+
log_info("🚀 Using Flash Attention for faster training")
|
1181 |
+
elif use_flash_attention and not flash_attention_available:
|
1182 |
+
log_info("⚠️ Flash Attention requested but not available - using standard attention")
|
1183 |
+
|
1184 |
+
else:
|
1185 |
+
# Standard HuggingFace loading
|
1186 |
+
log_info("Using standard HuggingFace model loading (Unsloth not available or disabled)")
|
1187 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
1188 |
+
|
1189 |
+
# Check if flash attention should be enabled in config
|
1190 |
+
use_attn_implementation = None
|
1191 |
+
if use_flash_attention and flash_attention_available:
|
1192 |
+
use_attn_implementation = "flash_attention_2"
|
1193 |
+
log_info("🚀 Using Flash Attention for faster training")
|
1194 |
+
|
1195 |
+
# Load tokenizer first
|
1196 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
1197 |
+
model_name,
|
1198 |
+
trust_remote_code=trust_remote_code,
|
1199 |
+
use_fast=use_fast_tokenizer,
|
1200 |
+
revision=model_revision,
|
1201 |
+
padding_side=padding_side
|
1202 |
+
)
|
1203 |
+
|
1204 |
+
# Configure tokenizer settings
|
1205 |
+
if add_eos_token and tokenizer.eos_token is None:
|
1206 |
+
log_info("Setting EOS token")
|
1207 |
+
tokenizer.add_special_tokens({"eos_token": "</s>"})
|
1208 |
+
|
1209 |
+
# Set chat template if specified
|
1210 |
+
if chat_template:
|
1211 |
+
log_info(f"Setting chat template: {chat_template}")
|
1212 |
+
if hasattr(tokenizer, "chat_template"):
|
1213 |
+
tokenizer.chat_template = chat_template
|
1214 |
+
else:
|
1215 |
+
log_info("Tokenizer does not support chat templates, using default formatting")
|
1216 |
+
|
1217 |
+
# Only apply quantization config if model is not already pre-quantized
|
1218 |
+
quantization_config = None
|
1219 |
+
if not is_prequantized and CUDA_AVAILABLE:
|
1220 |
+
try:
|
1221 |
+
from transformers import BitsAndBytesConfig
|
1222 |
+
log_info("Using 4-bit quantization (BitsAndBytes) for efficient training")
|
1223 |
+
quantization_config = BitsAndBytesConfig(
|
1224 |
+
load_in_4bit=True,
|
1225 |
+
bnb_4bit_quant_type="nf4",
|
1226 |
+
bnb_4bit_compute_dtype=torch.float16,
|
1227 |
+
bnb_4bit_use_double_quant=True
|
1228 |
+
)
|
1229 |
+
except ImportError:
|
1230 |
+
log_info("BitsAndBytes not available - quantization disabled")
|
1231 |
+
|
1232 |
+
# Now load model with updated tokenizer
|
1233 |
+
model = AutoModelForCausalLM.from_pretrained(
|
1234 |
+
model_name,
|
1235 |
+
trust_remote_code=trust_remote_code,
|
1236 |
+
revision=model_revision,
|
1237 |
+
torch_dtype=torch.bfloat16 if get_config_value(config, "torch_dtype", "bfloat16") == "bfloat16" else torch.float16,
|
1238 |
+
device_map="auto" if CUDA_AVAILABLE else None,
|
1239 |
+
attn_implementation=use_attn_implementation,
|
1240 |
+
quantization_config=quantization_config
|
1241 |
+
)
|
1242 |
+
|
1243 |
+
# Apply PEFT/LoRA if enabled but using standard loading
|
1244 |
+
if peft_available and get_config_value(unsloth_config, "enabled", True):
|
1245 |
+
log_info("Applying standard PEFT/LoRA configuration")
|
1246 |
+
from peft import LoraConfig, get_peft_model
|
1247 |
+
|
1248 |
+
lora_r = get_config_value(unsloth_config, "r", 16)
|
1249 |
+
lora_alpha = get_config_value(unsloth_config, "alpha", 32)
|
1250 |
+
lora_dropout = get_config_value(unsloth_config, "dropout", 0)
|
1251 |
+
target_modules = get_config_value(unsloth_config, "target_modules",
|
1252 |
+
["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"])
|
1253 |
+
|
1254 |
+
log_info(f"Applying LoRA with r={lora_r}, alpha={lora_alpha}, dropout={lora_dropout}")
|
1255 |
+
lora_config = LoraConfig(
|
1256 |
+
r=lora_r,
|
1257 |
+
lora_alpha=lora_alpha,
|
1258 |
+
target_modules=target_modules,
|
1259 |
+
lora_dropout=lora_dropout,
|
1260 |
+
bias="none",
|
1261 |
+
task_type="CAUSAL_LM"
|
1262 |
+
)
|
1263 |
+
model = get_peft_model(model, lora_config)
|
1264 |
+
|
1265 |
+
# Print model summary
|
1266 |
+
log_info(f"Model loaded successfully: {model.__class__.__name__}")
|
1267 |
+
if hasattr(model, "print_trainable_parameters"):
|
1268 |
+
model.print_trainable_parameters()
|
1269 |
+
else:
|
1270 |
+
total_params = sum(p.numel() for p in model.parameters())
|
1271 |
+
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
1272 |
+
log_info(f"Model has {total_params:,} parameters, {trainable_params:,} trainable ({trainable_params/total_params:.2%})")
|
1273 |
+
|
1274 |
+
return model, tokenizer
|
1275 |
+
|
1276 |
+
except Exception as e:
|
1277 |
+
log_info(f"Error loading model: {str(e)}")
|
1278 |
+
traceback.print_exc()
|
1279 |
+
return None, None
|
1280 |
|
1281 |
def setup_dataset_and_collator(config, tokenizer):
|
1282 |
"""
|
|
|
1490 |
logger.info("Starting training process")
|
1491 |
|
1492 |
try:
|
1493 |
+
# Check for potential import order issue and warn early
|
1494 |
+
if "transformers" in sys.modules and "unsloth" in sys.modules:
|
1495 |
+
if list(sys.modules.keys()).index("transformers") < list(sys.modules.keys()).index("unsloth"):
|
1496 |
+
log_info("⚠️ Warning: transformers was imported before unsloth. This may affect performance.")
|
1497 |
+
log_info(" For optimal performance in future runs, import unsloth first.")
|
1498 |
+
|
1499 |
# Parse command line arguments
|
1500 |
args = parse_args()
|
1501 |
|