George-API commited on
Commit
dc055e5
·
verified ·
1 Parent(s): 5f730a4

Upload folder using huggingface_hub

Browse files
run_transformers_training.py CHANGED
@@ -337,10 +337,29 @@ def load_dataset_with_mapping(dataset_config):
337
  if len(dataset) == 0:
338
  raise ValueError(f"Dataset {dataset_name} (split {dataset_split}) is empty (contains 0 examples)")
339
 
340
- # Verify conversations field specifically - this is critical for training
341
  if "conversations" not in dataset.column_names:
342
  raise ValueError(f"Dataset {dataset_name} missing required 'conversations' column")
343
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
344
  # Check a sample of conversation entries to validate structure
345
  logger.info("Validating conversation structure...")
346
  for i in range(min(5, len(dataset))):
@@ -354,9 +373,6 @@ def load_dataset_with_mapping(dataset_config):
354
  else:
355
  # Look at the first conversation entry
356
  first_entry = conv[0]
357
- logger.info(f"Sample conversation: {str(first_entry)[:100]}...")
358
-
359
- # Make sure content field exists
360
  if isinstance(first_entry, dict) and "content" in first_entry:
361
  logger.info(f"Content field example: {str(first_entry['content'])[:50]}...")
362
  else:
@@ -368,71 +384,6 @@ def load_dataset_with_mapping(dataset_config):
368
  logger.error("This could be due to authentication issues with your HF_TOKEN")
369
  raise
370
 
371
- # Apply minimal processing since the dataset has already been properly structured
372
- # Just perform validation to ensure required fields exist
373
-
374
- # Check for required fields
375
- required_fields = ["prompt_number", "article_id", "conversations"]
376
- missing_fields = [field for field in required_fields if field not in dataset.column_names]
377
-
378
- if missing_fields:
379
- logger.warning(f"Dataset is missing required fields: {missing_fields}")
380
- logger.warning("This may cause issues with sequence integrity and metadata management")
381
- else:
382
- logger.info(f"Dataset has all required fields: {required_fields}")
383
-
384
- # Verify that column order matches our expectation
385
- expected_order = ["prompt_number", "article_id", "conversations"]
386
- actual_order = dataset.column_names
387
-
388
- if actual_order == expected_order:
389
- logger.info("Dataset column order matches expected order (prompt_number, article_id, conversations)")
390
- else:
391
- logger.warning(f"Dataset column order ({', '.join(actual_order)}) differs from expected order ({', '.join(expected_order)})")
392
- logger.warning("This should not affect processing but is noted for debugging purposes")
393
-
394
- # Log a few samples for verification
395
- if len(dataset) > 0:
396
- sample_indices = range(min(5, len(dataset)))
397
- sample_records = []
398
-
399
- for i in sample_indices:
400
- record = {}
401
- record["prompt_number"] = dataset[i].get("prompt_number", "N/A")
402
- record["article_id"] = dataset[i].get("article_id", "N/A")
403
- # Safely get conversations length with None check
404
- conversations = dataset[i].get("conversations")
405
- if conversations is not None and isinstance(conversations, list):
406
- record["conversations_length"] = len(conversations)
407
- else:
408
- record["conversations_length"] = 0
409
- logger.warning(f"Invalid conversations for sample {i}: {type(conversations)}")
410
- sample_records.append(record)
411
-
412
- logger.info(f"Sample records: {sample_records}")
413
-
414
- # Verify sequential integrity
415
- if "prompt_number" in dataset.column_names and len(dataset) > 1:
416
- first_prompt_numbers = [dataset[i]["prompt_number"] for i in range(min(10, len(dataset)))]
417
- is_sequential = all(first_prompt_numbers[i] == i + 1 for i in range(len(first_prompt_numbers)))
418
-
419
- if is_sequential:
420
- logger.info("Dataset prompt numbers are sequential (1-indexed) - sequence integrity preserved")
421
- else:
422
- logger.warning("Dataset prompt numbers are not sequential - sequence integrity may be compromised")
423
- logger.info(f"First few prompt numbers: {first_prompt_numbers}")
424
-
425
- logger.info(f"Dataset loaded successfully with {len(dataset)} examples")
426
- logger.info(f"Dataset columns: {dataset.column_names}")
427
-
428
- # Data loading configuration - ensure shuffle is disabled
429
- data_loading_config = dataset_config.get("data_loading", {})
430
- if data_loading_config.get("shuffle", False):
431
- logger.error("CRITICAL: shuffle is enabled in the dataset config!")
432
- logger.error("This will RANDOMIZE your dataset and break sequential order.")
433
- logger.error("Setting shuffle to False to preserve order")
434
- data_loading_config["shuffle"] = False
435
-
436
  return dataset
437
 
438
  except Exception as e:
@@ -447,42 +398,35 @@ def format_phi_chat(messages, dataset_config):
447
  roles = dataset_config.get("data_formatting", {}).get("roles", {
448
  "system": "System: {content}\n\n",
449
  "human": "Human: {content}\n\n",
450
- "user": "Human: {content}\n\n",
451
  "assistant": "Assistant: {content}\n\n"
452
  })
453
 
454
- # Handle research introduction metadata first
455
- metadata = next((msg for msg in messages if isinstance(msg, dict) and
456
- "[RESEARCH INTRODUCTION]" in msg.get("content", "")), None)
457
- if metadata:
458
- system_template = roles.get("system", "System: {content}\n\n")
459
- formatted_chat = system_template.format(content=metadata['content'])
460
- messages = [msg for msg in messages if msg != metadata]
461
-
462
- # Process remaining messages
463
  for message in messages:
464
  if not isinstance(message, dict) or "content" not in message:
465
  logger.warning(f"Skipping invalid message format: {message}")
466
  continue
467
 
468
- role = message.get("role", "").lower()
469
- content = message.get("content", "")
470
-
471
- # Format based on role
472
- if role == "human" or role == "user":
473
- template = roles.get("user", roles.get("human", "Human: {content}\n\n"))
474
- formatted_chat += template.format(content=content)
475
- elif role == "assistant" or role == "bot":
476
- template = roles.get("assistant", "Assistant: {content}\n\n")
477
- formatted_chat += template.format(content=content)
478
- elif role == "system":
479
- # For system messages, prepend them
480
  template = roles.get("system", "System: {content}\n\n")
481
  formatted_chat = template.format(content=content) + formatted_chat
482
  else:
483
- # Default to system for unknown roles
484
- logger.warning(f"Unknown role '{role}' - treating as system message")
485
- template = roles.get("system", "System: {content}\n\n")
 
 
 
 
 
486
  formatted_chat += template.format(content=content)
487
 
488
  return formatted_chat.strip()
@@ -506,7 +450,7 @@ class SimpleDataCollator:
506
  paper_id = example.get("article_id", "unknown")
507
  prompt_num = example.get("prompt_number", "unknown")
508
 
509
- # Get the conversations list - should be a single item
510
  conversations = example.get("conversations", [])
511
 
512
  # Skip if no conversations
@@ -515,27 +459,17 @@ class SimpleDataCollator:
515
  self.stats["skipped"] += 1
516
  continue
517
 
518
- # Get the first conversation item (should be the only one)
519
- conv_item = conversations[0]
520
 
521
- # Skip if invalid format
522
- if not isinstance(conv_item, dict) or "content" not in conv_item:
523
- logger.warning(f"Invalid conversation format for paper_id {paper_id}, prompt {prompt_num}")
524
  self.stats["skipped"] += 1
525
  continue
526
 
527
- # Get the pre-tokenized content
528
- content = conv_item["content"]
529
-
530
- # Skip if empty content
531
- if not content:
532
- logger.warning(f"Empty content for paper_id {paper_id}, prompt {prompt_num}")
533
- self.stats["skipped"] += 1
534
- continue
535
-
536
- # Create input IDs and attention mask directly from the content
537
- # The content is already pre-tokenized and properly chunked
538
- input_ids = self.tokenizer.encode(content, add_special_tokens=False)
539
 
540
  # Truncate if needed
541
  if len(input_ids) > self.max_seq_length:
@@ -553,6 +487,11 @@ class SimpleDataCollator:
553
  self.stats["processed"] += 1
554
  self.stats["total_tokens"] += len(input_ids)
555
 
 
 
 
 
 
556
  except Exception as e:
557
  logger.warning(f"Error processing example {paper_id}, prompt {prompt_num}: {str(e)}")
558
  self.stats["skipped"] += 1
@@ -588,31 +527,30 @@ class SimpleDataCollator:
588
  return batch
589
 
590
  class LoggingCallback(TrainerCallback):
591
- def __init__(self):
592
  super().__init__()
593
  self.training_started = time.time()
594
  self.last_log_time = time.time()
595
  self.last_step = 0
596
- self.verify_sequence = None
597
- self.sequence_samples = None
598
- self.sample_indices = None
599
 
600
  def on_train_begin(self, args, state, control, **kwargs):
601
  log_info(f"=== Training started at {time.strftime('%Y-%m-%d %H:%M:%S')} ===")
602
- log_info(f"Model parameters: {sum(p.numel() for p in model.parameters())/1e6:.2f}M")
603
-
604
- # Disable sequence verification
605
- self.verify_sequence = False
606
 
607
- log_info("=== Training is starting ===")
 
 
 
 
 
 
608
 
609
  # Log important training parameters for visibility
610
  total_batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps * NUM_GPUS
611
- total_steps = int(len(dataset) / (args.per_device_train_batch_size * NUM_GPUS * args.gradient_accumulation_steps) * args.num_train_epochs)
612
- log_info(f"Training plan: {len(dataset)} examples over {args.num_train_epochs} epochs ≈ {total_steps} steps")
613
  log_info(f"Batch size: {args.per_device_train_batch_size} × {args.gradient_accumulation_steps} steps × {NUM_GPUS} GPUs = {total_batch_size} total")
614
- log_info(f"Learning rate: {args.learning_rate}")
615
- log_info(f"Epochs: {args.num_train_epochs}")
616
 
617
  # Log memory information in compact format
618
  if CUDA_AVAILABLE:
@@ -621,85 +559,63 @@ class LoggingCallback(TrainerCallback):
621
  allocated = torch.cuda.memory_allocated(i) / 1024**2
622
  max_mem = torch.cuda.max_memory_allocated(i) / 1024**2
623
  memory_info.append(f"GPU {i}: {allocated:.1f}MB (max: {max_mem:.1f}MB)")
624
-
625
  log_info(f"Initial memory usage - {', '.join(memory_info)}")
626
-
627
- def on_step_end(self, args, state, control, **kwargs):
628
- # Log every 50 steps or every 5 minutes, whichever comes first
629
- current_time = time.time()
630
-
631
- # Sequence verification removed
632
-
633
- # Log progress at regular intervals
634
- if (state.global_step % 50 == 0) or (current_time - self.last_log_time > 300):
635
- if state.log_history:
636
- loss = state.log_history[-1].get('loss', 'N/A')
637
- # Use simple formatting for better Space log compatibility
638
- log_info(f"Step {state.global_step}: Loss {loss}")
639
- else:
640
- log_info(f"Step {state.global_step}: No loss data available")
641
- self.last_log_time = current_time
642
-
643
- def on_train_end(self, args, state, control, **kwargs):
644
- training_time = time.strftime("%H:%M:%S", time.gmtime(time.time() - self.training_started))
645
- log_info(f"=== Training completed in {training_time} ===")
646
-
647
- # Log final memory usage
648
- if CUDA_AVAILABLE:
649
- for i in range(NUM_GPUS):
650
- max_mem = torch.cuda.max_memory_allocated(i) / 1024**3 # GB
651
- log_info(f"GPU {i} max memory: {max_mem:.2f} GB")
652
-
653
- # Clear GPU memory
654
- torch.cuda.empty_cache()
655
- log_info("GPU memory cleared")
656
-
657
- log_info(f"Total steps: {state.global_step}")
658
- log_info(f"Final loss: {state.log_history[-1].get('loss', 'N/A') if state.log_history else 'N/A'}")
659
 
660
  def check_dependencies():
661
  """Check if all required dependencies are installed and in the correct order."""
662
  missing_packages = []
663
  order_issues = []
664
 
665
- # Check critical packages in the required order
666
-
667
- # 1. First check for unsloth as it should be imported before transformers
668
- if not unsloth_available:
669
- missing_packages.append("unsloth>=2024.3")
 
 
670
 
671
- # 2. Check transformers (imported at module level)
672
- try:
673
- import transformers
674
- logger.info(f"Using transformers version {transformers.__version__}")
675
- except ImportError:
676
- missing_packages.append("transformers>=4.38.0")
677
-
678
- # 3. Check for peft
679
- if not peft_available:
680
- missing_packages.append("peft>=0.9.0")
681
-
682
- # 4. Check for accelerate
683
- try:
684
- import accelerate
685
- logger.info(f"Using accelerate version {accelerate.__version__}")
686
- except ImportError:
687
- missing_packages.append("accelerate>=0.27.0")
688
 
689
- # Check for order-specific issues
690
  try:
691
  import sys
692
- modules = sys.modules.keys()
693
 
694
- # Unsloth should be imported before transformers for optimal performance
695
  if 'transformers' in modules and 'unsloth' in modules:
696
- if modules.index('transformers') < modules.index('unsloth'):
697
- order_issues.append("For optimal performance, unsloth should be imported before transformers")
698
- except Exception:
699
- # If we can't check order, just skip this check
700
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
701
 
702
- # If critical packages are missing, exit with instructions
703
  if missing_packages:
704
  logger.error("Critical dependencies missing:")
705
  for pkg in missing_packages:
@@ -712,35 +628,6 @@ def check_dependencies():
712
  for issue in order_issues:
713
  logger.warning(issue)
714
 
715
- # Optional packages - moved to the end
716
- if find_spec("flash_attn"):
717
- logger.info("flash-attn found. Flash attention will be used for faster training.")
718
- else:
719
- logger.warning("flash-attn not found. Training will work but may be slower.")
720
- logger.warning("Attempting to install flash-attn automatically...")
721
-
722
- try:
723
- import subprocess
724
- subprocess.check_call([sys.executable, "-m", "pip", "install", "flash-attn", "--no-build-isolation"])
725
- logger.info("Successfully installed flash-attn!")
726
-
727
- # Try to import it now that it's installed
728
- try:
729
- import flash_attn
730
- logger.info("flash-attn imported successfully after installation.")
731
- except ImportError:
732
- logger.warning("flash-attn installed but import failed - may require restart.")
733
- except Exception as e:
734
- logger.warning(f"Failed to install flash-attn: {str(e)}")
735
- logger.warning("To manually install flash attention, run: pip install flash-attn --no-build-isolation")
736
-
737
- # Additional optional packages that improve performance
738
- if find_spec("bitsandbytes"):
739
- logger.info("bitsandbytes found. Quantization will be available.")
740
- else:
741
- logger.warning("bitsandbytes not found. Quantization may not be available.")
742
- logger.warning("To use quantization, install with: pip install bitsandbytes")
743
-
744
  return True
745
 
746
  def update_huggingface_space():
@@ -981,27 +868,28 @@ def main():
981
  # Set up training arguments
982
  log_info("Setting up training arguments")
983
 
984
- # Validate FSDP config before using it
 
 
 
 
985
  fsdp_args = None
986
- if fsdp_config is not None and is_distributed and multi_gpu_strategy == "fsdp":
987
- try:
988
- # Convert FSDP config to proper format expected by TrainingArguments
989
- fsdp_args = {
990
- "fsdp_transformer_layer_cls_to_wrap": fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", []),
991
- "fsdp_offload_params": fsdp_config.get("fsdp_offload_params", False),
992
- "fsdp_backward_prefetch": fsdp_config.get("fsdp_backward_prefetch", "BACKWARD_PRE"),
993
- "fsdp_min_num_params": fsdp_config.get("fsdp_min_num_params", 1e6),
994
- "fsdp_sharding_strategy": fsdp_config.get("fsdp_sharding_strategy", 1),
995
- }
996
- log_info("FSDP config validated and prepared")
997
- except Exception as e:
998
- log_info(f"Error preparing FSDP config: {str(e)}, disabling FSDP")
999
- fsdp_args = None
1000
 
1001
  # Check if we're running in a Space
1002
  is_space = bool(os.environ.get("SPACE_ID"))
1003
 
1004
- # Create training arguments with validated FSDP config
1005
  training_args = TrainingArguments(
1006
  output_dir=transformers_config.get("output_dir", "./results") or transformers_config.get("checkpointing", {}).get("output_dir", "./results"),
1007
  num_train_epochs=transformers_config.get("training", {}).get("num_train_epochs", 3),
@@ -1020,7 +908,6 @@ def main():
1020
  max_grad_norm=transformers_config.get("training", {}).get("max_grad_norm", 1.0),
1021
  push_to_hub=transformers_config.get("huggingface_hub", {}).get("push_to_hub", False),
1022
  hub_model_id=transformers_config.get("huggingface_hub", {}).get("hub_model_id", None),
1023
- # Don't set hub_token when running in a Space - it will use Space secrets automatically
1024
  hub_token=None if is_space else os.environ.get("HF_TOKEN", None),
1025
  report_to="tensorboard",
1026
  remove_unused_columns=False, # Keep all columns
@@ -1031,7 +918,7 @@ def main():
1031
  dataloader_drop_last=False, # Process all examples
1032
  dataloader_num_workers=dataloader_workers,
1033
  no_cuda=False if CUDA_AVAILABLE else True, # Use CUDA if available
1034
- fsdp=fsdp_args, # Use validated FSDP config
1035
  )
1036
 
1037
  log_info("Training arguments created successfully")
@@ -1049,9 +936,9 @@ def main():
1049
  trainer = Trainer(
1050
  model=model,
1051
  args=training_args,
1052
- train_dataset=dataset, # We'll override this with our custom dataloader
1053
  data_collator=data_collator,
1054
- callbacks=[LoggingCallback()],
1055
  )
1056
 
1057
  # Then override the get_train_dataloader method
@@ -1153,7 +1040,7 @@ def main():
1153
  log_info("Cleared CUDA cache before training")
1154
 
1155
  # Display compact training info
1156
- total_steps = int(len(dataset) / (per_device_batch_size * NUM_GPUS * gradient_accumulation_steps) * training_args.num_train_epochs)
1157
  log_info(f"Training plan: {len(dataset)} examples over {training_args.num_train_epochs} epochs ≈ {total_steps} steps")
1158
 
1159
  trainer.train()
 
337
  if len(dataset) == 0:
338
  raise ValueError(f"Dataset {dataset_name} (split {dataset_split}) is empty (contains 0 examples)")
339
 
340
+ # Verify conversations field specifically
341
  if "conversations" not in dataset.column_names:
342
  raise ValueError(f"Dataset {dataset_name} missing required 'conversations' column")
343
 
344
+ # Validate conversation structure
345
+ if len(dataset) > 0:
346
+ sample = dataset[0]
347
+ conversations = sample.get("conversations", [])
348
+
349
+ if conversations:
350
+ first_conv = conversations[0]
351
+ if isinstance(first_conv, dict):
352
+ # Check actual fields
353
+ fields = list(first_conv.keys())
354
+ logger.info(f"Conversation fields: {fields}")
355
+
356
+ # Verify only 'content' field exists
357
+ if fields == ["content"]:
358
+ logger.info("Confirmed conversations have correct format with only 'content' field")
359
+ else:
360
+ logger.warning(f"Unexpected conversation fields: {fields}")
361
+ logger.warning("Expected only 'content' field")
362
+
363
  # Check a sample of conversation entries to validate structure
364
  logger.info("Validating conversation structure...")
365
  for i in range(min(5, len(dataset))):
 
373
  else:
374
  # Look at the first conversation entry
375
  first_entry = conv[0]
 
 
 
376
  if isinstance(first_entry, dict) and "content" in first_entry:
377
  logger.info(f"Content field example: {str(first_entry['content'])[:50]}...")
378
  else:
 
384
  logger.error("This could be due to authentication issues with your HF_TOKEN")
385
  raise
386
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
387
  return dataset
388
 
389
  except Exception as e:
 
398
  roles = dataset_config.get("data_formatting", {}).get("roles", {
399
  "system": "System: {content}\n\n",
400
  "human": "Human: {content}\n\n",
 
401
  "assistant": "Assistant: {content}\n\n"
402
  })
403
 
404
+ # Handle each message in the conversation
 
 
 
 
 
 
 
 
405
  for message in messages:
406
  if not isinstance(message, dict) or "content" not in message:
407
  logger.warning(f"Skipping invalid message format: {message}")
408
  continue
409
 
410
+ content = message.get("content", "").strip()
411
+
412
+ # Skip empty content
413
+ if not content:
414
+ continue
415
+
416
+ # Infer role based on content patterns
417
+ if "[RESEARCH INTRODUCTION]" in content:
418
+ # System message
 
 
 
419
  template = roles.get("system", "System: {content}\n\n")
420
  formatted_chat = template.format(content=content) + formatted_chat
421
  else:
422
+ # Alternate between human and assistant for regular conversation turns
423
+ # In phi-4 format, human messages come first, followed by assistant responses
424
+ if len(formatted_chat.split("Human:")) == len(formatted_chat.split("Assistant:")):
425
+ # If equal numbers of Human and Assistant messages, next is Human
426
+ template = roles.get("human", "Human: {content}\n\n")
427
+ else:
428
+ # Otherwise, next is Assistant
429
+ template = roles.get("assistant", "Assistant: {content}\n\n")
430
  formatted_chat += template.format(content=content)
431
 
432
  return formatted_chat.strip()
 
450
  paper_id = example.get("article_id", "unknown")
451
  prompt_num = example.get("prompt_number", "unknown")
452
 
453
+ # Get the conversations list
454
  conversations = example.get("conversations", [])
455
 
456
  # Skip if no conversations
 
459
  self.stats["skipped"] += 1
460
  continue
461
 
462
+ # Format the conversation using phi chat template
463
+ formatted_chat = format_phi_chat(conversations, self.dataset_config)
464
 
465
+ # Skip if formatting resulted in empty content
466
+ if not formatted_chat:
467
+ logger.warning(f"Empty formatted chat for paper_id {paper_id}, prompt {prompt_num}")
468
  self.stats["skipped"] += 1
469
  continue
470
 
471
+ # Create input IDs and attention mask
472
+ input_ids = self.tokenizer.encode(formatted_chat, add_special_tokens=False)
 
 
 
 
 
 
 
 
 
 
473
 
474
  # Truncate if needed
475
  if len(input_ids) > self.max_seq_length:
 
487
  self.stats["processed"] += 1
488
  self.stats["total_tokens"] += len(input_ids)
489
 
490
+ # Log first few examples for verification
491
+ if self.stats["processed"] <= 3:
492
+ logger.info(f"Sample {self.stats['processed']} formatted chat:")
493
+ logger.info(f"{formatted_chat[:200]}...")
494
+
495
  except Exception as e:
496
  logger.warning(f"Error processing example {paper_id}, prompt {prompt_num}: {str(e)}")
497
  self.stats["skipped"] += 1
 
527
  return batch
528
 
529
  class LoggingCallback(TrainerCallback):
530
+ def __init__(self, model=None, dataset=None):
531
  super().__init__()
532
  self.training_started = time.time()
533
  self.last_log_time = time.time()
534
  self.last_step = 0
535
+ self.model = model
536
+ self.dataset = dataset
 
537
 
538
  def on_train_begin(self, args, state, control, **kwargs):
539
  log_info(f"=== Training started at {time.strftime('%Y-%m-%d %H:%M:%S')} ===")
 
 
 
 
540
 
541
+ # Log model info if available
542
+ if self.model is not None:
543
+ log_info(f"Model parameters: {sum(p.numel() for p in self.model.parameters())/1e6:.2f}M")
544
+
545
+ # Log dataset info if available
546
+ if self.dataset is not None:
547
+ log_info(f"Dataset size: {len(self.dataset)} examples")
548
 
549
  # Log important training parameters for visibility
550
  total_batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps * NUM_GPUS
551
+ total_steps = int(len(self.dataset or []) / (args.per_device_train_batch_size * NUM_GPUS * args.gradient_accumulation_steps) * args.num_train_epochs)
552
+ log_info(f"Training plan: {len(self.dataset or [])} examples over {args.num_train_epochs} epochs ≈ {total_steps} steps")
553
  log_info(f"Batch size: {args.per_device_train_batch_size} × {args.gradient_accumulation_steps} steps × {NUM_GPUS} GPUs = {total_batch_size} total")
 
 
554
 
555
  # Log memory information in compact format
556
  if CUDA_AVAILABLE:
 
559
  allocated = torch.cuda.memory_allocated(i) / 1024**2
560
  max_mem = torch.cuda.max_memory_allocated(i) / 1024**2
561
  memory_info.append(f"GPU {i}: {allocated:.1f}MB (max: {max_mem:.1f}MB)")
 
562
  log_info(f"Initial memory usage - {', '.join(memory_info)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
563
 
564
  def check_dependencies():
565
  """Check if all required dependencies are installed and in the correct order."""
566
  missing_packages = []
567
  order_issues = []
568
 
569
+ # Define required packages with versions
570
+ required_packages = {
571
+ "unsloth": ">=2024.3",
572
+ "transformers": ">=4.38.0",
573
+ "peft": ">=0.9.0",
574
+ "accelerate": ">=0.27.0"
575
+ }
576
 
577
+ # Check for required packages
578
+ for package, version in required_packages.items():
579
+ try:
580
+ if package == "unsloth" and not unsloth_available:
581
+ missing_packages.append(f"{package}{version}")
582
+ elif package == "peft" and not peft_available:
583
+ missing_packages.append(f"{package}{version}")
584
+ else:
585
+ module = __import__(package)
586
+ logger.info(f"Using {package} version {getattr(module, '__version__', 'unknown')}")
587
+ except ImportError:
588
+ missing_packages.append(f"{package}{version}")
 
 
 
 
 
589
 
590
+ # Check import order
591
  try:
592
  import sys
593
+ modules = list(sys.modules.keys())
594
 
 
595
  if 'transformers' in modules and 'unsloth' in modules:
596
+ try:
597
+ transformers_idx = modules.index('transformers')
598
+ unsloth_idx = modules.index('unsloth')
599
+ if transformers_idx < unsloth_idx:
600
+ order_issues.append("For optimal performance, unsloth should be imported before transformers")
601
+ except ValueError:
602
+ pass
603
+ except Exception as e:
604
+ logger.warning(f"Could not check module import order: {str(e)}")
605
+
606
+ # Check optional dependencies
607
+ optional_packages = {
608
+ "flash_attn": "Flash attention support",
609
+ "bitsandbytes": "4-bit quantization support"
610
+ }
611
+
612
+ for package, feature in optional_packages.items():
613
+ if find_spec(package):
614
+ logger.info(f"Found {package} - {feature} enabled")
615
+ else:
616
+ logger.warning(f"{package} not found - {feature} will not be available")
617
 
618
+ # Report missing required packages
619
  if missing_packages:
620
  logger.error("Critical dependencies missing:")
621
  for pkg in missing_packages:
 
628
  for issue in order_issues:
629
  logger.warning(issue)
630
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
631
  return True
632
 
633
  def update_huggingface_space():
 
868
  # Set up training arguments
869
  log_info("Setting up training arguments")
870
 
871
+ # Handle FSDP configuration
872
+ fsdp_config = transformers_config.get("distributed_training", {}).get("fsdp_config", {})
873
+ fsdp_enabled = fsdp_config.get("enabled", False)
874
+
875
+ # Only set FSDP args if explicitly enabled
876
  fsdp_args = None
877
+ if fsdp_enabled and is_distributed and NUM_GPUS > 1:
878
+ fsdp_args = {
879
+ "fsdp": ["full_shard", "auto_wrap"],
880
+ "fsdp_transformer_layer_cls_to_wrap": "LlamaDecoderLayer",
881
+ "fsdp_offload_params": fsdp_config.get("offload_params", False),
882
+ "fsdp_state_dict_type": "FULL_STATE_DICT",
883
+ "fsdp_sharding_strategy": 1, # FULL_SHARD
884
+ }
885
+ log_info("FSDP configuration enabled")
886
+ else:
887
+ log_info("FSDP disabled, using standard data parallel")
 
 
 
888
 
889
  # Check if we're running in a Space
890
  is_space = bool(os.environ.get("SPACE_ID"))
891
 
892
+ # Create training arguments
893
  training_args = TrainingArguments(
894
  output_dir=transformers_config.get("output_dir", "./results") or transformers_config.get("checkpointing", {}).get("output_dir", "./results"),
895
  num_train_epochs=transformers_config.get("training", {}).get("num_train_epochs", 3),
 
908
  max_grad_norm=transformers_config.get("training", {}).get("max_grad_norm", 1.0),
909
  push_to_hub=transformers_config.get("huggingface_hub", {}).get("push_to_hub", False),
910
  hub_model_id=transformers_config.get("huggingface_hub", {}).get("hub_model_id", None),
 
911
  hub_token=None if is_space else os.environ.get("HF_TOKEN", None),
912
  report_to="tensorboard",
913
  remove_unused_columns=False, # Keep all columns
 
918
  dataloader_drop_last=False, # Process all examples
919
  dataloader_num_workers=dataloader_workers,
920
  no_cuda=False if CUDA_AVAILABLE else True, # Use CUDA if available
921
+ **({} if fsdp_args is None else fsdp_args) # Only include FSDP args if configured
922
  )
923
 
924
  log_info("Training arguments created successfully")
 
936
  trainer = Trainer(
937
  model=model,
938
  args=training_args,
939
+ train_dataset=dataset,
940
  data_collator=data_collator,
941
+ callbacks=[LoggingCallback(model=model, dataset=dataset)],
942
  )
943
 
944
  # Then override the get_train_dataloader method
 
1040
  log_info("Cleared CUDA cache before training")
1041
 
1042
  # Display compact training info
1043
+ total_steps = int(len(dataset) / (per_device_batch_size * NUM_GPUS * gradient_accumulation_steps) * training_args.num_train_epochs
1044
  log_info(f"Training plan: {len(dataset)} examples over {training_args.num_train_epochs} epochs ≈ {total_steps} steps")
1045
 
1046
  trainer.train()
transformers_config.json CHANGED
@@ -136,11 +136,14 @@
136
  },
137
  "data_formatting": {
138
  "chat_template": "phi",
 
 
 
 
139
  "roles": {
140
  "system": "System: {content}\n\n",
141
  "human": "Human: {content}\n\n",
142
- "assistant": "Assistant: {content}\n\n",
143
- "user": "Human: {content}\n\n"
144
  }
145
  },
146
  "data_loading": {
 
136
  },
137
  "data_formatting": {
138
  "chat_template": "phi",
139
+ "conversation_structure": {
140
+ "system_identifier": "[RESEARCH INTRODUCTION]",
141
+ "turn_order": ["human", "assistant"]
142
+ },
143
  "roles": {
144
  "system": "System: {content}\n\n",
145
  "human": "Human: {content}\n\n",
146
+ "assistant": "Assistant: {content}\n\n"
 
147
  }
148
  },
149
  "data_loading": {