George-API commited on
Commit
127e6b1
·
verified ·
1 Parent(s): 75f9a64

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. run_transformers_training.py +329 -62
run_transformers_training.py CHANGED
@@ -10,6 +10,7 @@ import logging
10
  from datetime import datetime
11
  import time
12
  import warnings
 
13
  from importlib.util import find_spec
14
  import multiprocessing
15
  import torch
@@ -31,64 +32,36 @@ if CUDA_AVAILABLE:
31
  # Method already set, which is fine
32
  print("Multiprocessing start method already set")
33
 
34
- # Now import the rest of the modules
35
- import torch
36
-
37
- # Configure logging early
38
- logging.basicConfig(
39
- level=logging.INFO,
40
- format="%(asctime)s - %(levelname)s - %(message)s",
41
- handlers=[logging.StreamHandler(sys.stdout)]
42
- )
43
- logger = logging.getLogger(__name__)
44
-
45
- # Set other loggers to WARNING to reduce noise and ensure our logs are visible
46
- logging.getLogger("transformers").setLevel(logging.WARNING)
47
- logging.getLogger("datasets").setLevel(logging.WARNING)
48
- logging.getLogger("accelerate").setLevel(logging.WARNING)
49
- logging.getLogger("torch").setLevel(logging.WARNING)
50
- logging.getLogger("bitsandbytes").setLevel(logging.WARNING)
51
-
52
- # Import Unsloth first, before other ML imports
53
- try:
54
- from unsloth import FastLanguageModel
55
- from unsloth.chat_templates import get_chat_template
56
- unsloth_available = True
57
- logger.info("Unsloth successfully imported")
58
- except ImportError:
59
- unsloth_available = False
60
- logger.warning("Unsloth not available. Please install with: pip install unsloth")
61
 
62
- # Now import other ML libraries
63
- try:
 
 
64
  import transformers
65
- from transformers import (
66
- AutoModelForCausalLM,
67
- AutoTokenizer,
68
- TrainingArguments,
69
- Trainer,
70
- TrainerCallback,
71
- set_seed,
72
- BitsAndBytesConfig
73
- )
74
- logger.info(f"Transformers version: {transformers.__version__}")
75
- except ImportError:
76
- logger.error("Transformers not available. This is a critical dependency.")
77
 
78
- # Check availability of libraries
79
  peft_available = find_spec("peft") is not None
80
  if peft_available:
81
  import peft
82
- logger.info(f"PEFT version: {peft.__version__}")
83
- else:
84
- logger.warning("PEFT not available. Parameter-efficient fine-tuning will not be used.")
85
 
86
- # Import datasets library after the main ML libraries
87
- try:
 
88
  from datasets import load_dataset
89
- logger.info("Datasets library successfully imported")
90
- except ImportError:
91
- logger.error("Datasets library not available. This is required for loading training data.")
 
 
 
 
 
92
 
93
  # Define a clean logging function for HF Space compatibility
94
  def log_info(message):
@@ -243,6 +216,17 @@ def load_model_and_tokenizer(config):
243
  chat_template = get_config_value(tokenizer_config, "chat_template", None)
244
  padding_side = get_config_value(tokenizer_config, "padding_side", "right")
245
 
 
 
 
 
 
 
 
 
 
 
 
246
  log_info(f"Loading model: {model_name} (revision: {model_revision})")
247
  log_info(f"Max sequence length: {max_seq_length}")
248
 
@@ -257,7 +241,7 @@ def load_model_and_tokenizer(config):
257
  dtype=get_config_value(config, "torch_dtype", "bfloat16"),
258
  revision=model_revision,
259
  trust_remote_code=trust_remote_code,
260
- use_flash_attention_2=get_config_value(config, "use_flash_attention", True)
261
  )
262
 
263
  # Configure tokenizer settings
@@ -294,11 +278,23 @@ def load_model_and_tokenizer(config):
294
  max_seq_length=max_seq_length,
295
  modules_to_save=None
296
  )
 
 
 
 
 
 
297
  else:
298
  # Standard HuggingFace loading
299
  log_info("Using standard HuggingFace model loading (Unsloth not available or disabled)")
300
  from transformers import AutoModelForCausalLM, AutoTokenizer
301
 
 
 
 
 
 
 
302
  # Load tokenizer first
303
  tokenizer = AutoTokenizer.from_pretrained(
304
  model_name,
@@ -327,7 +323,8 @@ def load_model_and_tokenizer(config):
327
  trust_remote_code=trust_remote_code,
328
  revision=model_revision,
329
  torch_dtype=torch.bfloat16 if get_config_value(config, "torch_dtype", "bfloat16") == "bfloat16" else torch.float16,
330
- device_map="auto" if CUDA_AVAILABLE else None
 
331
  )
332
 
333
  # Apply PEFT/LoRA if enabled but using standard loading
@@ -760,6 +757,63 @@ class LoggingCallback(TrainerCallback):
760
  """Called at the beginning of a step"""
761
  pass
762
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
763
  def check_dependencies():
764
  """
765
  Check for required and optional dependencies, ensuring proper versions and import order.
@@ -785,6 +839,7 @@ def check_dependencies():
785
  missing_packages = []
786
  package_versions = {}
787
  order_issues = []
 
788
 
789
  # Check required packages
790
  log_info("Checking required dependencies...")
@@ -822,6 +877,7 @@ def check_dependencies():
822
  log_info(f"✅ {package} - {feature} available")
823
  except ImportError:
824
  log_info(f"⚠️ {package} - {feature} not available")
 
825
 
826
  # Check import order for optimal performance
827
  if "transformers" in package_versions and "unsloth" in package_versions:
@@ -835,11 +891,19 @@ def check_dependencies():
835
  order_issue = "⚠️ For optimal performance, import unsloth before transformers"
836
  order_issues.append(order_issue)
837
  log_info(order_issue)
 
838
  else:
839
  log_info("✅ Import order: unsloth before transformers (optimal)")
840
  except (ValueError, IndexError) as e:
841
  log_info(f"⚠️ Could not verify import order: {str(e)}")
842
 
 
 
 
 
 
 
 
843
  # Report missing required packages
844
  if missing_packages:
845
  log_info("\n❌ Critical dependencies missing:")
@@ -990,10 +1054,22 @@ def setup_environment(args):
990
  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = f"max_split_size_mb:128,expandable_segments:True"
991
  log_info(f"Set CUDA memory allocation limit to expandable with max_split_size_mb:128")
992
 
993
- # Check dependencies before proceeding
994
  if not check_dependencies():
995
  raise RuntimeError("Critical dependencies missing")
996
 
 
 
 
 
 
 
 
 
 
 
 
 
997
  return transformers_config, seed
998
 
999
  def setup_model_and_tokenizer(config):
@@ -1001,21 +1077,206 @@ def setup_model_and_tokenizer(config):
1001
  Load and configure the model and tokenizer.
1002
 
1003
  Args:
1004
- config: Complete configuration dictionary
1005
 
1006
  Returns:
1007
  tuple: (model, tokenizer) - The loaded model and tokenizer
1008
  """
1009
- log_info("Loading model and tokenizer...")
1010
- model, tokenizer = load_model_and_tokenizer(config)
 
 
 
 
1011
 
1012
- if model is None or tokenizer is None:
1013
- raise ValueError("Failed to load model or tokenizer")
1014
-
1015
- log_info(f"Model loaded successfully: {model.__class__.__name__}")
1016
- log_info(f"Tokenizer loaded: {tokenizer.__class__.__name__} (vocab size: {tokenizer.vocab_size})")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1017
 
1018
- return model, tokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1019
 
1020
  def setup_dataset_and_collator(config, tokenizer):
1021
  """
@@ -1229,6 +1490,12 @@ def main():
1229
  logger.info("Starting training process")
1230
 
1231
  try:
 
 
 
 
 
 
1232
  # Parse command line arguments
1233
  args = parse_args()
1234
 
 
10
  from datetime import datetime
11
  import time
12
  import warnings
13
+ import traceback
14
  from importlib.util import find_spec
15
  import multiprocessing
16
  import torch
 
32
  # Method already set, which is fine
33
  print("Multiprocessing start method already set")
34
 
35
+ # Import order is important: unsloth should be imported before transformers
36
+ # Check for libraries without importing them
37
+ unsloth_available = find_spec("unsloth") is not None
38
+ if unsloth_available:
39
+ import unsloth
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
+ # Import torch first, then transformers if available
42
+ import torch
43
+ transformers_available = find_spec("transformers") is not None
44
+ if transformers_available:
45
  import transformers
46
+ from transformers import AutoTokenizer, TrainingArguments, Trainer, set_seed
47
+ from torch.utils.data import DataLoader
 
 
 
 
 
 
 
 
 
 
48
 
 
49
  peft_available = find_spec("peft") is not None
50
  if peft_available:
51
  import peft
 
 
 
52
 
53
+ # Only import HF datasets if available
54
+ datasets_available = find_spec("datasets") is not None
55
+ if datasets_available:
56
  from datasets import load_dataset
57
+
58
+ # Set up the logger
59
+ logger = logging.getLogger(__name__)
60
+ log_handler = logging.StreamHandler()
61
+ log_format = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
62
+ log_handler.setFormatter(log_format)
63
+ logger.addHandler(log_handler)
64
+ logger.setLevel(logging.INFO)
65
 
66
  # Define a clean logging function for HF Space compatibility
67
  def log_info(message):
 
216
  chat_template = get_config_value(tokenizer_config, "chat_template", None)
217
  padding_side = get_config_value(tokenizer_config, "padding_side", "right")
218
 
219
+ # Check for flash attention
220
+ use_flash_attention = get_config_value(config, "use_flash_attention", False)
221
+ flash_attention_available = False
222
+ try:
223
+ import flash_attn
224
+ flash_attention_available = True
225
+ log_info(f"Flash Attention detected (version: {flash_attn.__version__})")
226
+ except ImportError:
227
+ if use_flash_attention:
228
+ log_info("Flash Attention requested but not available")
229
+
230
  log_info(f"Loading model: {model_name} (revision: {model_revision})")
231
  log_info(f"Max sequence length: {max_seq_length}")
232
 
 
241
  dtype=get_config_value(config, "torch_dtype", "bfloat16"),
242
  revision=model_revision,
243
  trust_remote_code=trust_remote_code,
244
+ use_flash_attention_2=use_flash_attention and flash_attention_available
245
  )
246
 
247
  # Configure tokenizer settings
 
278
  max_seq_length=max_seq_length,
279
  modules_to_save=None
280
  )
281
+
282
+ if use_flash_attention and flash_attention_available:
283
+ log_info("🚀 Using Flash Attention for faster training")
284
+ elif use_flash_attention and not flash_attention_available:
285
+ log_info("⚠️ Flash Attention requested but not available - using standard attention")
286
+
287
  else:
288
  # Standard HuggingFace loading
289
  log_info("Using standard HuggingFace model loading (Unsloth not available or disabled)")
290
  from transformers import AutoModelForCausalLM, AutoTokenizer
291
 
292
+ # Check if flash attention should be enabled in config
293
+ use_attn_implementation = None
294
+ if use_flash_attention and flash_attention_available:
295
+ use_attn_implementation = "flash_attention_2"
296
+ log_info("🚀 Using Flash Attention for faster training")
297
+
298
  # Load tokenizer first
299
  tokenizer = AutoTokenizer.from_pretrained(
300
  model_name,
 
323
  trust_remote_code=trust_remote_code,
324
  revision=model_revision,
325
  torch_dtype=torch.bfloat16 if get_config_value(config, "torch_dtype", "bfloat16") == "bfloat16" else torch.float16,
326
+ device_map="auto" if CUDA_AVAILABLE else None,
327
+ attn_implementation=use_attn_implementation
328
  )
329
 
330
  # Apply PEFT/LoRA if enabled but using standard loading
 
757
  """Called at the beginning of a step"""
758
  pass
759
 
760
+ def install_flash_attention():
761
+ """
762
+ Attempt to install Flash Attention for improved performance.
763
+ Returns True if installation was successful, False otherwise.
764
+ """
765
+ log_info("Attempting to install Flash Attention...")
766
+
767
+ # Check for CUDA before attempting installation
768
+ if not CUDA_AVAILABLE:
769
+ log_info("❌ Cannot install Flash Attention: CUDA not available")
770
+ return False
771
+
772
+ try:
773
+ # Check CUDA version to determine correct installation command
774
+ cuda_version = torch.version.cuda
775
+ if cuda_version is None:
776
+ log_info("❌ Cannot determine CUDA version for Flash Attention installation")
777
+ return False
778
+
779
+ import subprocess
780
+
781
+ # Use --no-build-isolation for better compatibility
782
+ install_cmd = [
783
+ sys.executable,
784
+ "-m",
785
+ "pip",
786
+ "install",
787
+ "flash-attn",
788
+ "--no-build-isolation"
789
+ ]
790
+
791
+ log_info(f"Running: {' '.join(install_cmd)}")
792
+ result = subprocess.run(
793
+ install_cmd,
794
+ capture_output=True,
795
+ text=True,
796
+ check=False
797
+ )
798
+
799
+ if result.returncode == 0:
800
+ log_info("✅ Flash Attention installed successfully!")
801
+ # Attempt to import to verify installation
802
+ try:
803
+ import flash_attn
804
+ log_info(f"✅ Flash Attention version {flash_attn.__version__} is now available")
805
+ return True
806
+ except ImportError:
807
+ log_info("⚠️ Flash Attention installed but import failed")
808
+ return False
809
+ else:
810
+ log_info(f"❌ Flash Attention installation failed with error: {result.stderr}")
811
+ return False
812
+
813
+ except Exception as e:
814
+ log_info(f"❌ Error installing Flash Attention: {str(e)}")
815
+ return False
816
+
817
  def check_dependencies():
818
  """
819
  Check for required and optional dependencies, ensuring proper versions and import order.
 
839
  missing_packages = []
840
  package_versions = {}
841
  order_issues = []
842
+ missing_optional = []
843
 
844
  # Check required packages
845
  log_info("Checking required dependencies...")
 
877
  log_info(f"✅ {package} - {feature} available")
878
  except ImportError:
879
  log_info(f"⚠️ {package} - {feature} not available")
880
+ missing_optional.append(package)
881
 
882
  # Check import order for optimal performance
883
  if "transformers" in package_versions and "unsloth" in package_versions:
 
891
  order_issue = "⚠️ For optimal performance, import unsloth before transformers"
892
  order_issues.append(order_issue)
893
  log_info(order_issue)
894
+ log_info("This might cause performance issues but won't prevent training")
895
  else:
896
  log_info("✅ Import order: unsloth before transformers (optimal)")
897
  except (ValueError, IndexError) as e:
898
  log_info(f"⚠️ Could not verify import order: {str(e)}")
899
 
900
+ # Try to install missing optional packages
901
+ if "flash_attn" in missing_optional and CUDA_AVAILABLE:
902
+ log_info("\nFlash Attention is missing but would improve performance.")
903
+ install_result = install_flash_attention()
904
+ if install_result:
905
+ missing_optional.remove("flash_attn")
906
+
907
  # Report missing required packages
908
  if missing_packages:
909
  log_info("\n❌ Critical dependencies missing:")
 
1054
  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = f"max_split_size_mb:128,expandable_segments:True"
1055
  log_info(f"Set CUDA memory allocation limit to expandable with max_split_size_mb:128")
1056
 
1057
+ # Check dependencies and install optional ones if needed
1058
  if not check_dependencies():
1059
  raise RuntimeError("Critical dependencies missing")
1060
 
1061
+ # Check if flash attention was successfully installed
1062
+ flash_attention_available = False
1063
+ try:
1064
+ import flash_attn
1065
+ flash_attention_available = True
1066
+ log_info(f"Flash Attention will be used (version: {flash_attn.__version__})")
1067
+ # Update config to use flash attention
1068
+ if "use_flash_attention" not in transformers_config:
1069
+ transformers_config["use_flash_attention"] = True
1070
+ except ImportError:
1071
+ log_info("Flash Attention not available, will use standard attention mechanism")
1072
+
1073
  return transformers_config, seed
1074
 
1075
  def setup_model_and_tokenizer(config):
 
1077
  Load and configure the model and tokenizer.
1078
 
1079
  Args:
1080
+ config (dict): Complete configuration dictionary
1081
 
1082
  Returns:
1083
  tuple: (model, tokenizer) - The loaded model and tokenizer
1084
  """
1085
+ # Extract model configuration
1086
+ model_config = get_config_value(config, "model", {})
1087
+ model_name = get_config_value(model_config, "name", "unsloth/phi-4-unsloth-bnb-4bit")
1088
+ use_fast_tokenizer = get_config_value(model_config, "use_fast_tokenizer", True)
1089
+ trust_remote_code = get_config_value(model_config, "trust_remote_code", True)
1090
+ model_revision = get_config_value(config, "model_revision", "main")
1091
 
1092
+ # Detect if model is already pre-quantized (includes '4bit', 'bnb', or 'int4' in name)
1093
+ is_prequantized = any(q in model_name.lower() for q in ['4bit', 'bnb', 'int4', 'quant'])
1094
+ if is_prequantized:
1095
+ log_info("⚠️ Detected pre-quantized model. No additional quantization will be applied.")
1096
+
1097
+ # Unsloth configuration
1098
+ unsloth_config = get_config_value(config, "unsloth", {})
1099
+ unsloth_enabled = get_config_value(unsloth_config, "enabled", True)
1100
+
1101
+ # Tokenizer configuration
1102
+ tokenizer_config = get_config_value(config, "tokenizer", {})
1103
+ max_seq_length = min(
1104
+ get_config_value(tokenizer_config, "max_seq_length", 2048),
1105
+ 4096 # Maximum supported by most models
1106
+ )
1107
+ add_eos_token = get_config_value(tokenizer_config, "add_eos_token", True)
1108
+ chat_template = get_config_value(tokenizer_config, "chat_template", None)
1109
+ padding_side = get_config_value(tokenizer_config, "padding_side", "right")
1110
+
1111
+ # Check for flash attention
1112
+ use_flash_attention = get_config_value(config, "use_flash_attention", False)
1113
+ flash_attention_available = False
1114
+ try:
1115
+ import flash_attn
1116
+ flash_attention_available = True
1117
+ log_info(f"Flash Attention detected (version: {flash_attn.__version__})")
1118
+ except ImportError:
1119
+ if use_flash_attention:
1120
+ log_info("Flash Attention requested but not available")
1121
+
1122
+ log_info(f"Loading model: {model_name} (revision: {model_revision})")
1123
+ log_info(f"Max sequence length: {max_seq_length}")
1124
 
1125
+ try:
1126
+ if unsloth_enabled and unsloth_available:
1127
+ log_info("Using Unsloth for LoRA fine-tuning")
1128
+ if is_prequantized:
1129
+ log_info("Using pre-quantized model - no additional quantization will be applied")
1130
+ else:
1131
+ log_info("Using 4-bit quantization for efficient training")
1132
+
1133
+ # Load using Unsloth
1134
+ from unsloth import FastLanguageModel
1135
+ model, tokenizer = FastLanguageModel.from_pretrained(
1136
+ model_name=model_name,
1137
+ max_seq_length=max_seq_length,
1138
+ dtype=get_config_value(config, "torch_dtype", "bfloat16"),
1139
+ revision=model_revision,
1140
+ trust_remote_code=trust_remote_code,
1141
+ use_flash_attention_2=use_flash_attention and flash_attention_available
1142
+ )
1143
+
1144
+ # Configure tokenizer settings
1145
+ tokenizer.padding_side = padding_side
1146
+ if add_eos_token and tokenizer.eos_token is None:
1147
+ log_info("Setting EOS token")
1148
+ tokenizer.add_special_tokens({"eos_token": "</s>"})
1149
+
1150
+ # Set chat template if specified
1151
+ if chat_template:
1152
+ log_info(f"Setting chat template: {chat_template}")
1153
+ if hasattr(tokenizer, "chat_template"):
1154
+ tokenizer.chat_template = chat_template
1155
+ else:
1156
+ log_info("Tokenizer does not support chat templates, using default formatting")
1157
+
1158
+ # Apply LoRA
1159
+ lora_r = get_config_value(unsloth_config, "r", 16)
1160
+ lora_alpha = get_config_value(unsloth_config, "alpha", 32)
1161
+ lora_dropout = get_config_value(unsloth_config, "dropout", 0)
1162
+ target_modules = get_config_value(unsloth_config, "target_modules",
1163
+ ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"])
1164
+
1165
+ log_info(f"Applying LoRA with r={lora_r}, alpha={lora_alpha}, dropout={lora_dropout}")
1166
+ model = FastLanguageModel.get_peft_model(
1167
+ model,
1168
+ r=lora_r,
1169
+ target_modules=target_modules,
1170
+ lora_alpha=lora_alpha,
1171
+ lora_dropout=lora_dropout,
1172
+ bias="none",
1173
+ use_gradient_checkpointing=get_config_value(config, "training.gradient_checkpointing", True),
1174
+ random_state=0,
1175
+ max_seq_length=max_seq_length,
1176
+ modules_to_save=None
1177
+ )
1178
+
1179
+ if use_flash_attention and flash_attention_available:
1180
+ log_info("🚀 Using Flash Attention for faster training")
1181
+ elif use_flash_attention and not flash_attention_available:
1182
+ log_info("⚠️ Flash Attention requested but not available - using standard attention")
1183
+
1184
+ else:
1185
+ # Standard HuggingFace loading
1186
+ log_info("Using standard HuggingFace model loading (Unsloth not available or disabled)")
1187
+ from transformers import AutoModelForCausalLM, AutoTokenizer
1188
+
1189
+ # Check if flash attention should be enabled in config
1190
+ use_attn_implementation = None
1191
+ if use_flash_attention and flash_attention_available:
1192
+ use_attn_implementation = "flash_attention_2"
1193
+ log_info("🚀 Using Flash Attention for faster training")
1194
+
1195
+ # Load tokenizer first
1196
+ tokenizer = AutoTokenizer.from_pretrained(
1197
+ model_name,
1198
+ trust_remote_code=trust_remote_code,
1199
+ use_fast=use_fast_tokenizer,
1200
+ revision=model_revision,
1201
+ padding_side=padding_side
1202
+ )
1203
+
1204
+ # Configure tokenizer settings
1205
+ if add_eos_token and tokenizer.eos_token is None:
1206
+ log_info("Setting EOS token")
1207
+ tokenizer.add_special_tokens({"eos_token": "</s>"})
1208
+
1209
+ # Set chat template if specified
1210
+ if chat_template:
1211
+ log_info(f"Setting chat template: {chat_template}")
1212
+ if hasattr(tokenizer, "chat_template"):
1213
+ tokenizer.chat_template = chat_template
1214
+ else:
1215
+ log_info("Tokenizer does not support chat templates, using default formatting")
1216
+
1217
+ # Only apply quantization config if model is not already pre-quantized
1218
+ quantization_config = None
1219
+ if not is_prequantized and CUDA_AVAILABLE:
1220
+ try:
1221
+ from transformers import BitsAndBytesConfig
1222
+ log_info("Using 4-bit quantization (BitsAndBytes) for efficient training")
1223
+ quantization_config = BitsAndBytesConfig(
1224
+ load_in_4bit=True,
1225
+ bnb_4bit_quant_type="nf4",
1226
+ bnb_4bit_compute_dtype=torch.float16,
1227
+ bnb_4bit_use_double_quant=True
1228
+ )
1229
+ except ImportError:
1230
+ log_info("BitsAndBytes not available - quantization disabled")
1231
+
1232
+ # Now load model with updated tokenizer
1233
+ model = AutoModelForCausalLM.from_pretrained(
1234
+ model_name,
1235
+ trust_remote_code=trust_remote_code,
1236
+ revision=model_revision,
1237
+ torch_dtype=torch.bfloat16 if get_config_value(config, "torch_dtype", "bfloat16") == "bfloat16" else torch.float16,
1238
+ device_map="auto" if CUDA_AVAILABLE else None,
1239
+ attn_implementation=use_attn_implementation,
1240
+ quantization_config=quantization_config
1241
+ )
1242
+
1243
+ # Apply PEFT/LoRA if enabled but using standard loading
1244
+ if peft_available and get_config_value(unsloth_config, "enabled", True):
1245
+ log_info("Applying standard PEFT/LoRA configuration")
1246
+ from peft import LoraConfig, get_peft_model
1247
+
1248
+ lora_r = get_config_value(unsloth_config, "r", 16)
1249
+ lora_alpha = get_config_value(unsloth_config, "alpha", 32)
1250
+ lora_dropout = get_config_value(unsloth_config, "dropout", 0)
1251
+ target_modules = get_config_value(unsloth_config, "target_modules",
1252
+ ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"])
1253
+
1254
+ log_info(f"Applying LoRA with r={lora_r}, alpha={lora_alpha}, dropout={lora_dropout}")
1255
+ lora_config = LoraConfig(
1256
+ r=lora_r,
1257
+ lora_alpha=lora_alpha,
1258
+ target_modules=target_modules,
1259
+ lora_dropout=lora_dropout,
1260
+ bias="none",
1261
+ task_type="CAUSAL_LM"
1262
+ )
1263
+ model = get_peft_model(model, lora_config)
1264
+
1265
+ # Print model summary
1266
+ log_info(f"Model loaded successfully: {model.__class__.__name__}")
1267
+ if hasattr(model, "print_trainable_parameters"):
1268
+ model.print_trainable_parameters()
1269
+ else:
1270
+ total_params = sum(p.numel() for p in model.parameters())
1271
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
1272
+ log_info(f"Model has {total_params:,} parameters, {trainable_params:,} trainable ({trainable_params/total_params:.2%})")
1273
+
1274
+ return model, tokenizer
1275
+
1276
+ except Exception as e:
1277
+ log_info(f"Error loading model: {str(e)}")
1278
+ traceback.print_exc()
1279
+ return None, None
1280
 
1281
  def setup_dataset_and_collator(config, tokenizer):
1282
  """
 
1490
  logger.info("Starting training process")
1491
 
1492
  try:
1493
+ # Check for potential import order issue and warn early
1494
+ if "transformers" in sys.modules and "unsloth" in sys.modules:
1495
+ if list(sys.modules.keys()).index("transformers") < list(sys.modules.keys()).index("unsloth"):
1496
+ log_info("⚠️ Warning: transformers was imported before unsloth. This may affect performance.")
1497
+ log_info(" For optimal performance in future runs, import unsloth first.")
1498
+
1499
  # Parse command line arguments
1500
  args = parse_args()
1501