George-API commited on
Commit
5a7635c
·
verified ·
1 Parent(s): a7d1f2a

Upload folder using huggingface_hub

Browse files
fixed_run_transformers_training.py CHANGED
@@ -201,10 +201,29 @@ class LoggingCallback(TrainerCallback):
201
  log_info("=== Training is starting ===")
202
 
203
  # Log important training parameters for visibility
204
- log_info(f"Batch size: {args.per_device_train_batch_size} × {args.gradient_accumulation_steps} steps × {max(1, torch.cuda.device_count())} GPUs")
 
 
 
 
205
  log_info(f"Learning rate: {args.learning_rate}")
206
  log_info(f"Epochs: {args.num_train_epochs}")
207
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  # Log memory information in compact format
209
  if torch.cuda.is_available():
210
  memory_info = []
@@ -227,4 +246,138 @@ class LoggingCallback(TrainerCallback):
227
  log_info(f"Final memory usage - {', '.join(memory_info)}")
228
 
229
  log_info(f"Total steps: {state.global_step}")
230
- log_info(f"Final loss: {state.log_history[-1].get('loss', 'N/A') if state.log_history else 'N/A'}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
  log_info("=== Training is starting ===")
202
 
203
  # Log important training parameters for visibility
204
+ effective_batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps * max(1, torch.cuda.device_count())
205
+ log_info(f"Per device batch size: {args.per_device_train_batch_size}")
206
+ log_info(f"Gradient accumulation steps: {args.gradient_accumulation_steps}")
207
+ log_info(f"Number of GPUs: {max(1, torch.cuda.device_count())}")
208
+ log_info(f"Total effective batch size: {effective_batch_size}")
209
  log_info(f"Learning rate: {args.learning_rate}")
210
  log_info(f"Epochs: {args.num_train_epochs}")
211
 
212
+ # Log dataset information
213
+ if hasattr(trainer, 'train_dataset') and trainer.train_dataset is not None:
214
+ log_info(f"Dataset size: {len(trainer.train_dataset)} examples")
215
+ if len(trainer.train_dataset) > 0:
216
+ try:
217
+ # Log first few prompt numbers to verify sequence
218
+ prompt_numbers = []
219
+ for i in range(min(5, len(trainer.train_dataset))):
220
+ if 'prompt_number' in trainer.train_dataset[i]:
221
+ prompt_numbers.append(trainer.train_dataset[i]['prompt_number'])
222
+ if prompt_numbers:
223
+ log_info(f"First few prompt numbers: {prompt_numbers}")
224
+ except Exception as e:
225
+ log_info(f"Error accessing dataset samples: {e}")
226
+
227
  # Log memory information in compact format
228
  if torch.cuda.is_available():
229
  memory_info = []
 
246
  log_info(f"Final memory usage - {', '.join(memory_info)}")
247
 
248
  log_info(f"Total steps: {state.global_step}")
249
+ log_info(f"Final loss: {state.log_history[-1].get('loss', 'N/A') if state.log_history else 'N/A'}")
250
+
251
+ def custom_get_train_dataloader():
252
+ """Custom dataloader that preserves original dataset order"""
253
+ log_info("Creating sequential dataloader to maintain original dataset order")
254
+
255
+ # Create a simple sequential sampler
256
+ sequential_sampler = torch.utils.data.SequentialSampler(dataset)
257
+
258
+ # Verify shuffle is disabled
259
+ data_loading_config = dataset_config.get("data_loading", {})
260
+ shuffle_enabled = data_loading_config.get("shuffle", False)
261
+
262
+ if shuffle_enabled:
263
+ log_info("CRITICAL ERROR: Shuffle is enabled! This will randomize data entry order!")
264
+ raise ValueError("Dataset shuffling is enabled but sequential processing is required. " +
265
+ "Please disable shuffling in your configuration.")
266
+
267
+ # Log our sequential processing approach
268
+ log_info("Using SequentialSampler to guarantee original dataset order is preserved")
269
+ log_info("Data order preservation is critical for proper training sequence")
270
+
271
+ # Calculate batch size based on device availability
272
+ if getattr(training_args, "no_cuda", False):
273
+ batch_size = training_args.per_device_train_batch_size
274
+ else:
275
+ batch_size = max(training_args.per_device_train_batch_size * max(1, NUM_GPUS), 1)
276
+
277
+ log_info(f"Using sequential sampler with batch size {batch_size}")
278
+
279
+ # Return DataLoader with sequential sampler
280
+ return torch.utils.data.DataLoader(
281
+ dataset,
282
+ batch_size=batch_size,
283
+ sampler=sequential_sampler,
284
+ collate_fn=data_collator,
285
+ drop_last=training_args.dataloader_drop_last,
286
+ num_workers=training_args.dataloader_num_workers,
287
+ pin_memory=training_args.dataloader_pin_memory,
288
+ )
289
+
290
+ def check_dependencies():
291
+ """Check for critical dependencies and provide useful warnings."""
292
+ # Check for flash attention without attempting import
293
+ flash_attention_available = False
294
+ try:
295
+ import importlib.util
296
+ if importlib.util.find_spec("flash_attn") is not None:
297
+ flash_attention_available = True
298
+ log_info("flash-attn found! Using Flash Attention for faster training.")
299
+ else:
300
+ log_info("flash-attn not found. Training will continue but may be slower.")
301
+ log_info("To use flash attention, install: pip install flash-attn==2.5.2 --no-build-isolation")
302
+ # Still continue as this is optional
303
+ except Exception as e:
304
+ log_info(f"Error checking for flash-attn: {e}")
305
+
306
+ # Check for torch CUDA
307
+ if not torch.cuda.is_available():
308
+ log_info("WARNING: CUDA not available. Training will be extremely slow on CPU!")
309
+ else:
310
+ log_info(f"Found {torch.cuda.device_count()} CUDA devices")
311
+
312
+ # Check for unsloth
313
+ unsloth_available = False
314
+ try:
315
+ import importlib.util
316
+ if importlib.util.find_spec("unsloth") is not None:
317
+ unsloth_available = True
318
+ log_info("Unsloth found! Using Unsloth for optimized training.")
319
+ else:
320
+ log_info("CRITICAL: Unsloth not found. This pipeline requires Unsloth.")
321
+ log_info("Install with: pip install unsloth>=2024.3")
322
+ return False
323
+ except Exception as e:
324
+ log_info(f"Error checking for unsloth: {e}")
325
+ return False
326
+
327
+ return True
328
+
329
+ def main():
330
+ """Main training function with error handling."""
331
+ try:
332
+ # Initialize logging
333
+ log_info("Starting Phi-4 training process")
334
+
335
+ # Parse arguments
336
+ args = parse_args()
337
+
338
+ # Load environment variables
339
+ load_env_variables()
340
+
341
+ # Load config from file
342
+ config = load_configs(args.config)
343
+
344
+ # Extract specific configurations
345
+ hardware_config = config.get("hardware", {})
346
+ dataset_config = config.get("dataset", {})
347
+
348
+ # Define multi_gpu_strategy early to prevent undefined errors
349
+ multi_gpu_strategy = hardware_config.get("training_optimizations", {}).get("multi_gpu_strategy", "data_parallel")
350
+ log_info(f"Multi-GPU strategy: {multi_gpu_strategy}")
351
+
352
+ # Check dependencies
353
+ if not check_dependencies():
354
+ log_info("Aborting due to missing critical dependencies")
355
+ return 1
356
+
357
+ # Log hardware info
358
+ cuda_available = torch.cuda.is_available()
359
+ num_gpus = torch.cuda.device_count() if cuda_available else 0
360
+ log_info(f"Hardware: {num_gpus} GPUs detected" if cuda_available else "Hardware: CPU only")
361
+
362
+ # Rest of training code would go here
363
+ # ...
364
+
365
+ return 0
366
+ except Exception as e:
367
+ log_info(f"Error in main training loop: {str(e)}")
368
+ # Log CUDA memory if available
369
+ if torch.cuda.is_available():
370
+ try:
371
+ memory_info = []
372
+ for i in range(torch.cuda.device_count()):
373
+ allocated = torch.cuda.memory_allocated(i) / 1024**2
374
+ reserved = torch.cuda.memory_reserved(i) / 1024**2
375
+ memory_info.append(f"GPU {i}: {allocated:.1f}MB/{reserved:.1f}MB")
376
+ log_info(f"GPU memory at failure: {', '.join(memory_info)}")
377
+ except:
378
+ pass
379
+ return 1
380
+
381
+ if __name__ == "__main__":
382
+ import sys
383
+ sys.exit(main())
requirements.txt CHANGED
@@ -1,9 +1,11 @@
 
 
1
  accelerate>=0.27.0
2
  bitsandbytes>=0.41.0
3
  datasets>=2.15.0
4
  einops>=0.7.0
5
  filelock>=3.13.1
6
- flash-attn>=2.5.1
7
  gradio>=5.17.0
8
  huggingface-hub>=0.19.0
9
  matplotlib>=3.7.0
 
1
+ # Use pre-built wheels for flash-attn instead of building from source
2
+ --find-links https://github.com/Dao-AILab/flash-attention/releases/download/v2.5.2/
3
  accelerate>=0.27.0
4
  bitsandbytes>=0.41.0
5
  datasets>=2.15.0
6
  einops>=0.7.0
7
  filelock>=3.13.1
8
+ flash-attn==2.5.2
9
  gradio>=5.17.0
10
  huggingface-hub>=0.19.0
11
  matplotlib>=3.7.0
run_transformers_training.py CHANGED
@@ -284,83 +284,61 @@ def load_dataset_with_mapping(dataset_config):
284
  if not dataset_name:
285
  raise ValueError("Dataset name not provided in configuration")
286
 
287
- logger.info(f"Loading dataset {dataset_name}, split {dataset_split}")
288
  dataset = load_dataset(dataset_name, split=dataset_split)
289
 
290
- # Map columns if specified - with checks to avoid conflicts
291
- column_mapping = dataset_config.get("dataset", {}).get("column_mapping", {})
292
- if column_mapping:
293
- logger.info(f"Checking column mapping: {column_mapping}")
294
-
295
- # Only apply mappings for columns that need renaming and don't already exist
296
- safe_mappings = {}
297
- for target, source in column_mapping.items():
298
- if source in dataset.column_names:
299
- # Skip if target already exists and is not the same as source
300
- if target in dataset.column_names and target != source:
301
- logger.warning(f"Cannot rename '{source}' to '{target}' - target column already exists")
302
- else:
303
- safe_mappings[source] = target
304
-
305
- # Apply safe renames
306
- if safe_mappings:
307
- logger.info(f"Applying safe column mapping: {safe_mappings}")
308
- for source, target in safe_mappings.items():
309
- if source != target: # Only rename if names are different
310
- dataset = dataset.rename_column(source, target)
311
-
312
- # Add prompt_number field that increments based on original order - simple approach
313
- logger.info("Adding prompt_number based on original dataset order (starting at 1)")
314
-
315
- # Simple approach 1: Add index as a column during dataset creation
316
- # Create a list of dicts with indices
317
- examples_with_idx = []
318
- for i, example in enumerate(dataset):
319
- example = dict(example) # Make a copy to avoid modifying the original
320
- example['prompt_number'] = i + 1 # 1-indexed
321
- examples_with_idx.append(example)
322
 
323
- # Recreate dataset with prompt_number included
324
- from datasets import Dataset
325
- dataset = Dataset.from_list(examples_with_idx)
326
- logger.info("Successfully added prompt_number to dataset")
327
 
328
- # If conversations is missing but text exists, attempt conversion
329
- if "conversations" not in dataset.column_names and "text" in dataset.column_names:
330
- logger.info("Converting 'text' field to 'conversations' format")
 
 
331
 
332
- def convert_text_to_conversations(example):
333
- # Check if text is already a list of conversation turns
334
- if isinstance(example.get("text"), list):
335
- example["conversations"] = example["text"]
336
- # Otherwise, create a simple conversation with the text as user message
337
- else:
338
- example["conversations"] = [
339
- {"role": "user", "content": str(example.get("text", ""))}
340
- ]
341
- return example
342
 
343
- dataset = dataset.map(convert_text_to_conversations)
344
- logger.info("Successfully converted 'text' to 'conversations'")
345
-
346
- # Verify we have the required columns
347
- if "conversations" not in dataset.column_names:
348
- logger.error("Required 'conversations' column not found in dataset!")
349
- raise ValueError("Required 'conversations' column missing from dataset")
 
 
 
 
 
 
 
 
 
 
 
 
 
350
 
351
- # Log column names and a sample
352
  logger.info(f"Dataset loaded successfully with {len(dataset)} examples")
353
  logger.info(f"Dataset columns: {dataset.column_names}")
354
 
355
- # Log a sample for inspection
356
- if len(dataset) > 0:
357
- sample = dataset[0]
358
- prompt_num = sample.get("prompt_number", "N/A")
359
- article_id = sample.get("article_id", sample.get("id", "N/A"))
360
- logger.info(f"First sample - Prompt number: {prompt_num}, ID: {article_id}")
 
361
 
362
  return dataset
363
-
364
  except Exception as e:
365
  logger.error(f"Error loading dataset: {str(e)}")
366
  raise
@@ -542,6 +520,72 @@ class LoggingCallback(TrainerCallback):
542
  self.sequence_samples = None
543
  self.sample_indices = None
544
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
545
  def on_step_end(self, args, state, control, **kwargs):
546
  # Log every 50 steps or every 5 minutes, whichever comes first
547
  current_time = time.time()
@@ -590,7 +634,7 @@ class LoggingCallback(TrainerCallback):
590
  if i < len(current_samples):
591
  current_sample = current_samples[i]
592
 
593
- # Compare prompt numbers if available
594
  if ('prompt_number' in orig_sample and
595
  'prompt_number' in current_sample and
596
  orig_sample['prompt_number'] is not None and
@@ -599,8 +643,11 @@ class LoggingCallback(TrainerCallback):
599
  if orig_sample['prompt_number'] != current_sample['prompt_number']:
600
  log_info(f"WARNING: Sequence integrity compromised! Sample {i} prompt number changed from {orig_sample['prompt_number']} to {current_sample['prompt_number']}")
601
  is_sequence_maintained = False
 
 
 
602
 
603
- # Also compare IDs as a backup check
604
  elif ('article_id' in orig_sample and
605
  'article_id' in current_sample and
606
  orig_sample['article_id'] is not None and
@@ -609,21 +656,9 @@ class LoggingCallback(TrainerCallback):
609
  if orig_sample['article_id'] != current_sample['article_id']:
610
  log_info(f"WARNING: Sequence integrity compromised! Sample {i} article_id changed from {orig_sample['article_id']} to {current_sample['article_id']}")
611
  is_sequence_maintained = False
612
-
613
- # Compare input fingerprints
614
- if ('conversations' in orig_sample and
615
- 'conversations' in current_sample and
616
- orig_sample['conversations'] is not None and
617
- current_sample['conversations'] is not None):
618
-
619
- orig_len = len(orig_sample['conversations'])
620
- curr_len = len(current_sample['conversations'])
621
- if orig_len != curr_len:
622
- log_info(f"WARNING: Sequence integrity compromised! Sample {i} conversation length changed from {orig_len} to {curr_len}")
623
- is_sequence_maintained = False
624
 
625
  if is_sequence_maintained:
626
- log_info("Data sequence integrity check: OK")
627
  else:
628
  log_info("CRITICAL WARNING: Data sequence integrity check FAILED!")
629
  else:
@@ -635,90 +670,16 @@ class LoggingCallback(TrainerCallback):
635
  except Exception as e:
636
  log_info(f"Warning: Couldn't verify sequence integrity: {e}")
637
 
638
- time_interval = current_time - self.last_log_time
639
- step_interval = state.global_step - self.last_step
640
-
641
- if step_interval >= 50 or time_interval >= 300: # 5 minutes = 300 seconds
642
- # Calculate throughput
643
- examples_per_second = step_interval * args.per_device_train_batch_size * args.gradient_accumulation_steps / max(time_interval, 1e-6)
644
-
645
- elapsed_total = time.strftime("%H:%M:%S", time.gmtime(current_time - self.training_started))
646
-
647
- # Log progress
648
- log_info(f"Step: {state.global_step}, Loss: {state.log_history[-1]['loss']:.4f}, "
649
- f"Rate: {examples_per_second:.2f} examples/sec, Elapsed: {elapsed_total}")
650
-
651
- # Report memory usage if CUDA is available
652
- if CUDA_AVAILABLE:
653
- log_info(f"GPU Memory: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB allocated, "
654
- f"{torch.cuda.max_memory_reserved() / 1024**3:.2f} GB reserved")
655
-
656
- # Reset for next interval
657
  self.last_log_time = current_time
658
- self.last_step = state.global_step
659
 
660
- def on_train_begin(self, args, state, control, **kwargs):
661
- log_info(f"=== Training started at {time.strftime('%Y-%m-%d %H:%M:%S')} ===")
662
- log_info(f"Model parameters: {sum(p.numel() for p in model.parameters())/1e6:.2f}M")
663
-
664
- # Set up sequence verification with actual sample capturing
665
- try:
666
- self.verify_sequence = dataset_config.get("validation", {}).get("verify_sequence_integrity", False)
667
- if self.verify_sequence:
668
- log_info("Sequence integrity verification enabled during training")
669
-
670
- # Save actual samples for later verification
671
- if trainer and hasattr(trainer, 'train_dataset') and trainer.train_dataset is not None:
672
- # Get some reference samples from the beginning of the dataset defensively
673
- self.sample_indices = []
674
- self.sequence_samples = []
675
-
676
- max_samples = min(5, len(trainer.train_dataset))
677
- for i in range(max_samples):
678
- try:
679
- if i < len(trainer.train_dataset):
680
- self.sample_indices.append(i)
681
- self.sequence_samples.append(trainer.train_dataset[i])
682
- except Exception as e:
683
- log_info(f"Warning: Error capturing reference sample at index {i}: {e}")
684
-
685
- if self.sequence_samples:
686
- log_info(f"Captured {len(self.sequence_samples)} reference samples for sequence integrity verification")
687
-
688
- # Log sample prompt numbers for debugging
689
- sample_prompt_numbers = []
690
- for s in self.sequence_samples:
691
- if isinstance(s, dict) and 'prompt_number' in s and s['prompt_number'] is not None:
692
- sample_prompt_numbers.append(s.get('prompt_number'))
693
-
694
- if sample_prompt_numbers:
695
- log_info(f"Reference sample prompt numbers: {sample_prompt_numbers}")
696
- else:
697
- log_info("Warning: No reference samples were captured")
698
- else:
699
- log_info("Warning: Could not capture reference samples - verification will be limited")
700
- except Exception as e:
701
- log_info(f"Warning: Could not set up sequence integrity verification: {e}")
702
- self.verify_sequence = False
703
-
704
- log_info("=== Training is starting ===")
705
-
706
- # Log important training parameters for visibility
707
- total_batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps * NUM_GPUS
708
- log_info(f"Batch size: {args.per_device_train_batch_size} × {args.gradient_accumulation_steps} steps × {NUM_GPUS} GPUs = {total_batch_size} total")
709
- log_info(f"Learning rate: {args.learning_rate}")
710
- log_info(f"Epochs: {args.num_train_epochs}")
711
-
712
- # Log memory information in compact format
713
- if CUDA_AVAILABLE:
714
- memory_info = []
715
- for i in range(NUM_GPUS):
716
- allocated = torch.cuda.memory_allocated(i) / 1024**2
717
- max_mem = torch.cuda.max_memory_allocated(i) / 1024**2
718
- memory_info.append(f"GPU {i}: {allocated:.1f}MB (max: {max_mem:.1f}MB)")
719
-
720
- log_info(f"Initial memory usage - {', '.join(memory_info)}")
721
-
722
  def on_train_end(self, args, state, control, **kwargs):
723
  training_time = time.strftime("%H:%M:%S", time.gmtime(time.time() - self.training_started))
724
  log_info(f"=== Training completed in {training_time} ===")
@@ -968,9 +929,12 @@ def main():
968
  shuffle_enabled = data_loading_config.get("shuffle", False)
969
 
970
  if shuffle_enabled:
971
- log_info("CRITICAL ERROR: Shuffle is enabled! This will randomize data entry order!")
972
- raise ValueError("Dataset shuffling is enabled but sequential processing is required. " +
973
- "Please disable shuffling in your configuration.")
 
 
 
974
 
975
  # Calculate batch size based on device availability
976
  if getattr(training_args, "no_cuda", False):
@@ -984,7 +948,7 @@ def main():
984
  return torch.utils.data.DataLoader(
985
  dataset,
986
  batch_size=batch_size,
987
- sampler=sequential_sampler,
988
  collate_fn=data_collator,
989
  drop_last=training_args.dataloader_drop_last,
990
  num_workers=training_args.dataloader_num_workers,
 
284
  if not dataset_name:
285
  raise ValueError("Dataset name not provided in configuration")
286
 
287
+ logger.info(f"Loading pre-processed dataset {dataset_name}, split {dataset_split}")
288
  dataset = load_dataset(dataset_name, split=dataset_split)
289
 
290
+ # Apply minimal processing since the dataset has already been properly structured
291
+ # Just perform validation to ensure required fields exist
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
292
 
293
+ # Check for required fields
294
+ required_fields = ["prompt_number", "article_id", "conversations"]
295
+ missing_fields = [field for field in required_fields if field not in dataset.column_names]
 
296
 
297
+ if missing_fields:
298
+ logger.warning(f"Dataset is missing required fields: {missing_fields}")
299
+ logger.warning("This may cause issues with sequence integrity and metadata management")
300
+ else:
301
+ logger.info(f"Dataset has all required fields: {required_fields}")
302
 
303
+ # Log a few samples for verification
304
+ if len(dataset) > 0:
305
+ sample_indices = range(min(5, len(dataset)))
306
+ sample_records = []
 
 
 
 
 
 
307
 
308
+ for i in sample_indices:
309
+ record = {}
310
+ record["prompt_number"] = dataset[i].get("prompt_number", "N/A")
311
+ record["article_id"] = dataset[i].get("article_id", "N/A")
312
+ if "conversations" in dataset[i]:
313
+ record["conversations_length"] = len(dataset[i]["conversations"])
314
+ sample_records.append(record)
315
+
316
+ logger.info(f"Sample records: {sample_records}")
317
+
318
+ # Verify sequential integrity
319
+ if "prompt_number" in dataset.column_names and len(dataset) > 1:
320
+ first_prompt_numbers = [dataset[i]["prompt_number"] for i in range(min(10, len(dataset)))]
321
+ is_sequential = all(first_prompt_numbers[i] == i + 1 for i in range(len(first_prompt_numbers)))
322
+
323
+ if is_sequential:
324
+ logger.info("Dataset prompt numbers are sequential (1-indexed) - sequence integrity preserved")
325
+ else:
326
+ logger.warning("Dataset prompt numbers are not sequential - sequence integrity may be compromised")
327
+ logger.info(f"First few prompt numbers: {first_prompt_numbers}")
328
 
 
329
  logger.info(f"Dataset loaded successfully with {len(dataset)} examples")
330
  logger.info(f"Dataset columns: {dataset.column_names}")
331
 
332
+ # Data loading configuration - ensure shuffle is disabled
333
+ data_loading_config = dataset_config.get("data_loading", {})
334
+ if data_loading_config.get("shuffle", False):
335
+ logger.error("CRITICAL: shuffle is enabled in the dataset config!")
336
+ logger.error("This will RANDOMIZE your dataset and break sequential order.")
337
+ logger.error("Setting shuffle to False to preserve order")
338
+ data_loading_config["shuffle"] = False
339
 
340
  return dataset
341
+
342
  except Exception as e:
343
  logger.error(f"Error loading dataset: {str(e)}")
344
  raise
 
520
  self.sequence_samples = None
521
  self.sample_indices = None
522
 
523
+ def on_train_begin(self, args, state, control, **kwargs):
524
+ log_info(f"=== Training started at {time.strftime('%Y-%m-%d %H:%M:%S')} ===")
525
+ log_info(f"Model parameters: {sum(p.numel() for p in model.parameters())/1e6:.2f}M")
526
+
527
+ # Set up sequence verification with actual sample capturing
528
+ try:
529
+ self.verify_sequence = dataset_config.get("validation", {}).get("verify_sequence_integrity", False)
530
+ if self.verify_sequence:
531
+ log_info("Sequence integrity verification enabled during training")
532
+
533
+ # Save actual samples for later verification
534
+ if trainer and hasattr(trainer, 'train_dataset') and trainer.train_dataset is not None:
535
+ # Get some reference samples from the beginning of the dataset defensively
536
+ self.sample_indices = []
537
+ self.sequence_samples = []
538
+
539
+ max_samples = min(5, len(trainer.train_dataset))
540
+ for i in range(max_samples):
541
+ try:
542
+ if i < len(trainer.train_dataset):
543
+ self.sample_indices.append(i)
544
+ self.sequence_samples.append(trainer.train_dataset[i])
545
+ except Exception as e:
546
+ log_info(f"Warning: Error capturing reference sample at index {i}: {e}")
547
+
548
+ if self.sequence_samples:
549
+ log_info(f"Captured {len(self.sequence_samples)} reference samples for sequence integrity verification")
550
+
551
+ # Log sample prompt numbers for debugging
552
+ sample_prompt_numbers = []
553
+ for s in self.sequence_samples:
554
+ if isinstance(s, dict) and 'prompt_number' in s and s['prompt_number'] is not None:
555
+ sample_prompt_numbers.append(s.get('prompt_number'))
556
+
557
+ if sample_prompt_numbers:
558
+ log_info(f"Reference sample prompt numbers: {sample_prompt_numbers}")
559
+ if sample_prompt_numbers == list(range(1, len(sample_prompt_numbers) + 1)):
560
+ log_info("Prompt numbers are sequential (1-indexed) - sequence integrity confirmed")
561
+ else:
562
+ log_info("Prompt numbers are not in expected sequence - will verify during training")
563
+ else:
564
+ log_info("Warning: No reference samples were captured")
565
+ else:
566
+ log_info("Warning: Could not capture reference samples - verification will be limited")
567
+ except Exception as e:
568
+ log_info(f"Warning: Could not set up sequence integrity verification: {e}")
569
+ self.verify_sequence = False
570
+
571
+ log_info("=== Training is starting ===")
572
+
573
+ # Log important training parameters for visibility
574
+ total_batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps * NUM_GPUS
575
+ log_info(f"Batch size: {args.per_device_train_batch_size} × {args.gradient_accumulation_steps} steps × {NUM_GPUS} GPUs = {total_batch_size} total")
576
+ log_info(f"Learning rate: {args.learning_rate}")
577
+ log_info(f"Epochs: {args.num_train_epochs}")
578
+
579
+ # Log memory information in compact format
580
+ if CUDA_AVAILABLE:
581
+ memory_info = []
582
+ for i in range(NUM_GPUS):
583
+ allocated = torch.cuda.memory_allocated(i) / 1024**2
584
+ max_mem = torch.cuda.max_memory_allocated(i) / 1024**2
585
+ memory_info.append(f"GPU {i}: {allocated:.1f}MB (max: {max_mem:.1f}MB)")
586
+
587
+ log_info(f"Initial memory usage - {', '.join(memory_info)}")
588
+
589
  def on_step_end(self, args, state, control, **kwargs):
590
  # Log every 50 steps or every 5 minutes, whichever comes first
591
  current_time = time.time()
 
634
  if i < len(current_samples):
635
  current_sample = current_samples[i]
636
 
637
+ # Compare prompt numbers if available - this is our primary check now
638
  if ('prompt_number' in orig_sample and
639
  'prompt_number' in current_sample and
640
  orig_sample['prompt_number'] is not None and
 
643
  if orig_sample['prompt_number'] != current_sample['prompt_number']:
644
  log_info(f"WARNING: Sequence integrity compromised! Sample {i} prompt number changed from {orig_sample['prompt_number']} to {current_sample['prompt_number']}")
645
  is_sequence_maintained = False
646
+ else:
647
+ # This is now our primary verification
648
+ log_info(f"Prompt number match confirmed for sample {i}: {orig_sample['prompt_number']}")
649
 
650
+ # Also compare article_id as a backup check
651
  elif ('article_id' in orig_sample and
652
  'article_id' in current_sample and
653
  orig_sample['article_id'] is not None and
 
656
  if orig_sample['article_id'] != current_sample['article_id']:
657
  log_info(f"WARNING: Sequence integrity compromised! Sample {i} article_id changed from {orig_sample['article_id']} to {current_sample['article_id']}")
658
  is_sequence_maintained = False
 
 
 
 
 
 
 
 
 
 
 
 
659
 
660
  if is_sequence_maintained:
661
+ log_info("Data sequence integrity check: OK - prompt numbers preserved")
662
  else:
663
  log_info("CRITICAL WARNING: Data sequence integrity check FAILED!")
664
  else:
 
670
  except Exception as e:
671
  log_info(f"Warning: Couldn't verify sequence integrity: {e}")
672
 
673
+ # Log progress at regular intervals
674
+ if (state.global_step % 50 == 0) or (current_time - self.last_log_time > 300):
675
+ if state.log_history:
676
+ loss = state.log_history[-1].get('loss', 'N/A')
677
+ # Use simple formatting for better Space log compatibility
678
+ log_info(f"Step {state.global_step}: Loss {loss}")
679
+ else:
680
+ log_info(f"Step {state.global_step}: No loss data available")
 
 
 
 
 
 
 
 
 
 
 
681
  self.last_log_time = current_time
 
682
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
683
  def on_train_end(self, args, state, control, **kwargs):
684
  training_time = time.strftime("%H:%M:%S", time.gmtime(time.time() - self.training_started))
685
  log_info(f"=== Training completed in {training_time} ===")
 
929
  shuffle_enabled = data_loading_config.get("shuffle", False)
930
 
931
  if shuffle_enabled:
932
+ log_info("WARNING: Shuffle is enabled in configuration! This will be overridden to preserve order.")
933
+ # We enforce sequential processing regardless of config
934
+
935
+ # Log our approach clearly
936
+ log_info("Using SequentialSampler to guarantee dataset order is preserved based on prompt_number")
937
+ log_info("Dataset is pre-processed with prompt_number field indicating the correct sequence")
938
 
939
  # Calculate batch size based on device availability
940
  if getattr(training_args, "no_cuda", False):
 
948
  return torch.utils.data.DataLoader(
949
  dataset,
950
  batch_size=batch_size,
951
+ sampler=sequential_sampler, # Always use sequential sampler
952
  collate_fn=data_collator,
953
  drop_last=training_args.dataloader_drop_last,
954
  num_workers=training_args.dataloader_num_workers,
transformers_config.json CHANGED
@@ -77,7 +77,7 @@
77
 
78
  "huggingface_hub": {
79
  "push_to_hub": true,
80
- "hub_model_id": "phi-4-research-assistant",
81
  "hub_private_repo": true
82
  },
83
 
@@ -131,18 +131,9 @@
131
 
132
  "dataset": {
133
  "dataset": {
134
- "name": "George-API/cognitive-data",
135
  "split": "train",
136
- "column_mapping": {
137
- "conversations": "text",
138
- "article_id": "id"
139
- },
140
- "processing": {
141
- "sort_by_article_id": true,
142
- "maintain_paper_order": true,
143
- "preserve_entry_sequence": true,
144
- "max_seq_length": 2048
145
- }
146
  },
147
  "data_formatting": {
148
  "chat_template": "phi",
@@ -171,7 +162,7 @@
171
  "log_samples": 3,
172
  "log_interval": 50,
173
  "verify_sequence_integrity": true,
174
- "metrics": ["processed", "skipped", "avg_tokens", "unique_papers"]
175
  }
176
  }
177
- }
 
77
 
78
  "huggingface_hub": {
79
  "push_to_hub": true,
80
+ "hub_model_id": "phi-4-cognitive-assistant",
81
  "hub_private_repo": true
82
  },
83
 
 
131
 
132
  "dataset": {
133
  "dataset": {
134
+ "name": "George-API/phi4-cognitive-dataset",
135
  "split": "train",
136
+ "column_mapping": {}
 
 
 
 
 
 
 
 
 
137
  },
138
  "data_formatting": {
139
  "chat_template": "phi",
 
162
  "log_samples": 3,
163
  "log_interval": 50,
164
  "verify_sequence_integrity": true,
165
+ "metrics": ["processed", "skipped", "avg_tokens", "unique_articles"]
166
  }
167
  }
168
+ }