George-API commited on
Commit
a7d1f2a
·
verified ·
1 Parent(s): 5b6d8f0

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. app.py +7 -4
  2. requirements.txt +1 -0
  3. run_transformers_training.py +54 -311
app.py CHANGED
@@ -84,15 +84,17 @@ def display_config():
84
  <ul>
85
  <li><b>Model:</b> {model_name}</li>
86
  <li><b>Learning Rate:</b> {training_config.get('learning_rate', '2e-5')}</li>
87
- <li><b>Batch Size:</b> {training_config.get('per_device_train_batch_size', 4)} × {training_config.get('gradient_accumulation_steps', 4)} = {training_config.get('per_device_train_batch_size', 4) * training_config.get('gradient_accumulation_steps', 4)}</li>
88
- <li><b>Epochs:</b> {training_config.get('num_train_epochs', 3)}</li>
 
 
89
  <li><b>Precision:</b> {'BF16' if transformers_config.get('bf16', True) else 'FP16' if transformers_config.get('fp16', False) else 'FP32'}</li>
90
  <li><b>Max Sequence Length:</b> {transformers_config.get('tokenizer', {}).get('max_seq_length', 2048)}</li>
91
  </ul>
92
 
93
  <h3>Hardware</h3>
94
  <ul>
95
- <li><b>GPU:</b> {gpu_count}× {gpu_type} ({vram} GB)</li>
96
  <li><b>Multi-GPU Strategy:</b> {hardware_config.get('training_optimizations', {}).get('multi_gpu_strategy', 'data_parallel')}</li>
97
  <li><b>Memory Optimizations:</b> {'Gradient Checkpointing' if hardware_config.get('training_optimizations', {}).get('memory_optimizations', {}).get('use_gradient_checkpointing', True) else 'None'}</li>
98
  </ul>
@@ -154,9 +156,10 @@ def create_interface():
154
  gr.Markdown("## Training Information")
155
  gr.Markdown("""
156
  ### Hardware:
157
- - 4× NVIDIA L4 GPUs (24GB VRAM each)
158
  - Training with BF16 precision
159
  - Using Data Parallel for multi-GPU
 
160
 
161
  ### Notes:
162
  - Training may take several hours depending on dataset size
 
84
  <ul>
85
  <li><b>Model:</b> {model_name}</li>
86
  <li><b>Learning Rate:</b> {training_config.get('learning_rate', '2e-5')}</li>
87
+ <li><b>Per-Device Batch Size:</b> {batch_size}</li>
88
+ <li><b>Gradient Accumulation:</b> {grad_accum}</li>
89
+ <li><b>Total Effective Batch Size:</b> {batch_size} × {gpu_count} × {grad_accum} = {batch_size * gpu_count * grad_accum}</li>
90
+ <li><b>Epochs:</b> {epochs}</li>
91
  <li><b>Precision:</b> {'BF16' if transformers_config.get('bf16', True) else 'FP16' if transformers_config.get('fp16', False) else 'FP32'}</li>
92
  <li><b>Max Sequence Length:</b> {transformers_config.get('tokenizer', {}).get('max_seq_length', 2048)}</li>
93
  </ul>
94
 
95
  <h3>Hardware</h3>
96
  <ul>
97
+ <li><b>GPU:</b> {gpu_count}× {gpu_type} ({vram} GB VRAM per GPU, total: {vram * gpu_count} GB)</li>
98
  <li><b>Multi-GPU Strategy:</b> {hardware_config.get('training_optimizations', {}).get('multi_gpu_strategy', 'data_parallel')}</li>
99
  <li><b>Memory Optimizations:</b> {'Gradient Checkpointing' if hardware_config.get('training_optimizations', {}).get('memory_optimizations', {}).get('use_gradient_checkpointing', True) else 'None'}</li>
100
  </ul>
 
156
  gr.Markdown("## Training Information")
157
  gr.Markdown("""
158
  ### Hardware:
159
+ - 4× NVIDIA L4 GPUs (24GB VRAM per GPU, 96GB total)
160
  - Training with BF16 precision
161
  - Using Data Parallel for multi-GPU
162
+ - Effective batch size: 16 (per device) × 4 (GPUs) × 3 (gradient accumulation) = 192
163
 
164
  ### Notes:
165
  - Training may take several hours depending on dataset size
requirements.txt CHANGED
@@ -3,6 +3,7 @@ bitsandbytes>=0.41.0
3
  datasets>=2.15.0
4
  einops>=0.7.0
5
  filelock>=3.13.1
 
6
  gradio>=5.17.0
7
  huggingface-hub>=0.19.0
8
  matplotlib>=3.7.0
 
3
  datasets>=2.15.0
4
  einops>=0.7.0
5
  filelock>=3.13.1
6
+ flash-attn>=2.5.1
7
  gradio>=5.17.0
8
  huggingface-hub>=0.19.0
9
  matplotlib>=3.7.0
run_transformers_training.py CHANGED
@@ -309,315 +309,58 @@ def load_dataset_with_mapping(dataset_config):
309
  if source != target: # Only rename if names are different
310
  dataset = dataset.rename_column(source, target)
311
 
312
- # Add prompt_number field that increments based on original order
313
- def add_prompt_numbers(examples, indices):
314
- # Defensive check to ensure indices is not None and is iterable
315
- if indices is None:
316
- logger.warning("Warning: indices is None in add_prompt_numbers, using empty list")
317
- indices = []
318
- elif isinstance(indices, int):
319
- # Handle case where indices is a single integer
320
- logger.warning(f"Warning: indices is an integer ({indices}) in add_prompt_numbers, converting to list")
321
- indices = [indices]
 
 
 
 
 
 
 
 
 
322
 
323
- # Ensure indices is always a list/iterable
324
- try:
325
- # Create a new field with the dataset index as the prompt number, starting at 1
326
- examples["prompt_number"] = [idx + 1 for idx in indices] # Adding 1 to make it 1-indexed
327
- except TypeError:
328
- # Fallback for non-iterable types
329
- logger.warning(f"Warning: non-iterable indices in add_prompt_numbers: {type(indices)}, using default")
330
- examples["prompt_number"] = [1] * len(next(iter(examples.values())))
331
-
332
- return examples
333
-
334
- # Add prompt numbers to the dataset based on original order
335
- logger.info("Adding prompt numbers based on original dataset order (starting at 1)")
336
- try:
337
- dataset = dataset.map(
338
- add_prompt_numbers,
339
- with_indices=True,
340
- desc="Adding prompt numbers"
341
- )
342
- logger.info(f"Successfully added prompt_number field to dataset")
343
- except Exception as e:
344
- logger.error(f"Error adding prompt numbers: {e}")
345
- # Create a fallback implementation that doesn't rely on with_indices
346
- logger.info("Attempting fallback method for adding prompt numbers")
347
-
348
- def add_prompt_numbers_fallback(example, idx):
349
- example["prompt_number"] = idx + 1
350
  return example
351
 
352
- # Process each example one by one with explicit indices
353
- updated_examples = []
354
- for i, example in enumerate(dataset):
355
- updated_examples.append(add_prompt_numbers_fallback(dict(example), i))
356
-
357
- # Create a new dataset with the updated examples
358
- from datasets import Dataset
359
- dataset = Dataset.from_list(updated_examples)
360
- logger.info(f"Successfully added prompt_number field using fallback method")
361
-
362
- # Rename 'id' to 'article_id' if it exists
363
- if 'id' in dataset.column_names and 'article_id' not in dataset.column_names:
364
- logger.info("Renaming 'id' column to 'article_id'")
365
- dataset = dataset.rename_column('id', 'article_id')
366
-
367
- # Reorder columns to make prompt_number first if it exists
368
- if 'prompt_number' in dataset.column_names:
369
- logger.info("Reordering columns to place prompt_number first")
370
- # Get current column names
371
- current_columns = dataset.column_names
372
- # Create new column order with prompt_number first
373
- new_column_order = ['prompt_number'] + [col for col in current_columns if col != 'prompt_number']
374
- # Reorder columns
375
- dataset = dataset.select_columns(new_column_order)
376
-
377
- # Verify all new column names for logging
378
- logger.info(f"Dataset loaded successfully with {len(dataset)} examples")
379
- logger.info(f"Dataset columns: {dataset.column_names}")
380
-
381
- # Verify dataset is not empty
382
- if len(dataset) == 0:
383
- logger.error("Dataset is empty! This will cause errors during training.")
384
- raise ValueError("Empty dataset loaded")
385
-
386
- # Check for required columns
387
- required_columns = ['conversations']
388
- for col in required_columns:
389
- if col not in dataset.column_names:
390
- logger.error(f"Required column '{col}' not found in dataset!")
391
- raise ValueError(f"Required column '{col}' missing from dataset")
392
-
393
- # Verify expected columns exist
394
- expected_columns = {"article_id", "conversations", "prompt_number"}
395
- missing_columns = expected_columns - set(dataset.column_names)
396
- if missing_columns:
397
- logger.warning(f"Some expected columns are missing: {missing_columns}")
398
-
399
- # If "conversations" is missing but "text" exists, attempt conversion
400
- if "conversations" not in dataset.column_names and "text" in dataset.column_names:
401
- logger.info("Converting 'text' field to 'conversations' format")
402
-
403
- def convert_text_to_conversations(example):
404
- # Check if text is already a list of conversation turns
405
- if isinstance(example.get("text"), list):
406
- example["conversations"] = example["text"]
407
- # Otherwise, create a simple conversation with the text as user message
408
- else:
409
- example["conversations"] = [
410
- {"role": "user", "content": str(example.get("text", ""))}
411
- ]
412
- return example
413
-
414
- dataset = dataset.map(convert_text_to_conversations)
415
- logger.info("Successfully converted 'text' to 'conversations'")
416
-
417
- # Verify data ordering requirements
418
- processing_config = dataset_config.get("dataset", {}).get("processing", {})
419
- data_loading_config = dataset_config.get("data_loading", {})
420
-
421
- # Check if sorting is required
422
- sort_by_article_id = processing_config.get("sort_by_article_id", False)
423
- if sort_by_article_id and 'article_id' in dataset.column_names:
424
- logger.info("Sorting dataset by article_id as specified in config")
425
- dataset = dataset.sort("article_id")
426
- sorted_ids = [example['article_id'] for example in dataset.select(range(min(5, len(dataset))))]
427
- logger.info(f"First few article_ids after sorting: {sorted_ids}")
428
-
429
- # Flag consolidation - we only need one flag to control sequence preservation
430
- # Default to True to ensure safety
431
- preserve_sequence = processing_config.get("preserve_entry_sequence", True)
432
- shuffle_disabled = not data_loading_config.get("shuffle", False)
433
-
434
- if not preserve_sequence:
435
- logger.warning("CRITICAL: preserve_entry_sequence is set to False. This is NOT RECOMMENDED!")
436
- logger.warning("Data sequence integrity is essential for proper model training.")
437
-
438
- if not shuffle_disabled:
439
- logger.error("CRITICAL: shuffle is enabled in the dataset config!")
440
- logger.error("This will RANDOMIZE your dataset and break sequential order.")
441
- logger.error("Please set shuffle: false in your data_loading configuration.")
442
- # Actually enforce sequence preservation by raising an error
443
- raise ValueError("Dataset shuffling is enabled but preserve_entry_sequence is required. " +
444
- "Please disable shuffling in your configuration.")
445
-
446
- # Verify the IDs are in sequential order if they're numeric
447
- try:
448
- if len(dataset) > 1:
449
- # Check prompt numbers are sequential
450
- sample_indices = range(min(10, len(dataset)))
451
- sample_prompt_numbers = []
452
-
453
- # Defensive collection of prompt numbers
454
- for i in sample_indices:
455
- try:
456
- if i < len(dataset) and "prompt_number" in dataset[i]:
457
- sample_prompt_numbers.append(dataset[i]["prompt_number"])
458
- else:
459
- # If prompt_number doesn't exist, use index+1 as fallback
460
- sample_prompt_numbers.append(i + 1)
461
- logger.warning(f"Sample at index {i} missing prompt_number, using {i+1} as fallback")
462
- except Exception as e:
463
- logger.warning(f"Error accessing sample at index {i}: {e}")
464
- sample_prompt_numbers.append(i + 1) # Use fallback
465
-
466
- logger.info(f"Verifying sequential integrity with prompt numbers: {sample_prompt_numbers}")
467
-
468
- # Check if prompt numbers are sequential (1-indexed)
469
- if sample_prompt_numbers:
470
- is_sequential = all(sample_prompt_numbers[i] == i + 1 for i in range(len(sample_prompt_numbers)))
471
- if not is_sequential:
472
- logger.warning("WARNING: Prompt numbers are not in sequential order!")
473
- logger.warning("This may indicate that data sequence is not preserved.")
474
- else:
475
- logger.info("Prompt numbers verify that samples are in sequential order.")
476
- else:
477
- logger.warning("Could not verify sequential integrity: no prompt numbers collected")
478
-
479
- # Also check original IDs as a backup if numeric
480
- try:
481
- sample_examples = []
482
- for i in sample_indices:
483
- try:
484
- if i < len(dataset):
485
- sample_examples.append(dataset[i])
486
- except Exception as e:
487
- logger.warning(f"Error accessing dataset at index {i}: {e}")
488
-
489
- if sample_examples:
490
- id_field = 'article_id' if 'article_id' in dataset.column_names else 'id'
491
- if all(isinstance(example.get(id_field, ''), (int, str)) for example in sample_examples):
492
- sample_ids = [example.get(id_field, '') for example in sample_examples if id_field in example]
493
-
494
- if sample_ids and all(isinstance(id, int) or (isinstance(id, str) and id.isdigit()) for id in sample_ids):
495
- numeric_ids = [int(id) if isinstance(id, str) else id for id in sample_ids]
496
- if len(numeric_ids) > 1:
497
- is_ordered = all(numeric_ids[i] <= numeric_ids[i+1] for i in range(len(numeric_ids)-1))
498
- if not is_ordered:
499
- logger.warning(f"WARNING: Sample {id_field}s are not in sequential order.")
500
- else:
501
- logger.info(f"Sample {id_field}s appear to be in sequential order.")
502
- except Exception as e:
503
- logger.warning(f"Error checking ID sequence: {e}")
504
- except Exception as e:
505
- logger.warning(f"Could not verify sequential integrity: {e}")
506
 
507
- # Log examples without printing full content - with defensive coding
508
- if "conversations" in dataset.column_names:
509
- try:
510
- # Safely get first few samples
511
- first_few_indices = range(min(5, len(dataset)))
512
- sample_prompt_numbers = []
513
- sample_article_ids = []
514
-
515
- for i in first_few_indices:
516
- try:
517
- example = dataset[i]
518
- if 'prompt_number' in example:
519
- sample_prompt_numbers.append(example['prompt_number'])
520
- if 'article_id' in example:
521
- sample_article_ids.append(example['article_id'])
522
- except Exception as e:
523
- logger.warning(f"Error accessing sample at index {i}: {e}")
524
-
525
- logger.info(f"First few samples - Prompt numbers: {sample_prompt_numbers}, Article IDs: {sample_article_ids}")
526
-
527
- # Log conversation structure without full content
528
- if len(dataset) > 0:
529
- try:
530
- sample_conv_structure = []
531
- first_example = dataset[0]
532
-
533
- if 'conversations' in first_example and first_example['conversations'] is not None:
534
- for msg in first_example['conversations']:
535
- if isinstance(msg, dict):
536
- content = msg.get('content', '')
537
- preview = content[:50] + "..." if len(content) > 50 else content
538
- sample_conv_structure.append({
539
- "role": msg.get('role', ''),
540
- "content_length": len(content),
541
- "preview": preview
542
- })
543
- logger.info(f"Conversation structure: {sample_conv_structure}")
544
- except Exception as e:
545
- logger.warning(f"Error logging conversation structure: {e}")
546
- except Exception as e:
547
- logger.warning(f"Error logging sample examples: {e}")
548
 
 
549
  logger.info(f"Dataset loaded successfully with {len(dataset)} examples")
550
  logger.info(f"Dataset columns: {dataset.column_names}")
551
 
552
- # Verify dataset is not empty
553
- if len(dataset) == 0:
554
- logger.error("Dataset is empty! Cannot proceed with training.")
555
- return dataset
 
 
556
 
557
- # Check for required columns
558
- required_cols = ['conversations', 'prompt_number']
559
- for col in required_cols:
560
- if col not in dataset.column_names:
561
- logger.error(f"Required column '{col}' missing from dataset. Cannot proceed with training.")
562
- return dataset
563
-
564
- # Validate at least one sample can be processed
565
- try:
566
- if len(dataset) > 0:
567
- sample = dataset[0]
568
- if 'conversations' not in sample or not sample['conversations']:
569
- logger.error("First sample has no conversations! Data format may be incorrect.")
570
- return dataset
571
- if not isinstance(sample['conversations'], list):
572
- logger.error(f"Conversations field should be a list but got {type(sample['conversations'])}")
573
- return dataset
574
- except Exception as e:
575
- logger.error(f"Error validating first sample: {e}")
576
- return dataset
577
-
578
- # Add metadata if specified
579
- metadata_config = dataset_config.get("data_formatting", {}).get("metadata_handling", {})
580
- if metadata_config:
581
- include_article_id = metadata_config.get("include_article_id", False)
582
- include_prompt_number = metadata_config.get("include_prompt_number", False)
583
- metadata_format = metadata_config.get("metadata_format", "")
584
-
585
- if (include_article_id or include_prompt_number) and metadata_format:
586
- logger.info("Adding metadata to conversations")
587
-
588
- def add_metadata(example):
589
- if not example.get("conversations"):
590
- return example
591
-
592
- # Prepare metadata
593
- metadata = metadata_format
594
- if include_article_id and "article_id" in example:
595
- metadata = metadata.replace("{article_id}", str(example.get("article_id", "")))
596
- if include_prompt_number and "prompt_number" in example:
597
- metadata = metadata.replace("{prompt_number}", str(example.get("prompt_number", "")))
598
-
599
- # Add system message with metadata if not empty
600
- if metadata.strip():
601
- if example["conversations"] and isinstance(example["conversations"], list):
602
- # Check if first message is already a system message
603
- if (isinstance(example["conversations"][0], dict) and
604
- example["conversations"][0].get("role") == "system"):
605
- # Append to existing system message
606
- example["conversations"][0]["content"] += f"\n\nMetadata: {metadata}"
607
- else:
608
- # Add new system message at the beginning
609
- example["conversations"].insert(0, {
610
- "role": "system",
611
- "content": f"Metadata: {metadata}"
612
- })
613
-
614
- return example
615
-
616
- dataset = dataset.map(add_metadata)
617
- logger.info("Metadata added to conversations")
618
-
619
  return dataset
620
-
621
  except Exception as e:
622
  logger.error(f"Error loading dataset: {str(e)}")
623
  raise
@@ -1112,6 +855,10 @@ def main():
1112
  per_device_batch_size = transformers_config.get("training", {}).get("per_device_train_batch_size", 16)
1113
  gradient_accumulation_steps = transformers_config.get("training", {}).get("gradient_accumulation_steps", 3)
1114
 
 
 
 
 
1115
  # For multi-GPU setup, adjust for better balance
1116
  if CUDA_AVAILABLE and NUM_GPUS > 1:
1117
  log_info(f"Multi-GPU setup: Adjusting for {NUM_GPUS} GPUs")
@@ -1213,21 +960,17 @@ def main():
1213
  """Custom dataloader that preserves original dataset order"""
1214
  log_info("Creating sequential dataloader to maintain original dataset order")
1215
 
1216
- # Verification of sequence preservation flags - consolidated
1217
- data_loading_config = dataset_config.get("data_loading", {})
1218
- sequential_processing = data_loading_config.get("sequential_processing", True)
1219
- shuffle_disabled = not data_loading_config.get("shuffle", False)
1220
 
1221
- if not sequential_processing:
1222
- log_info("CRITICAL WARNING: sequential_processing flag is disabled! This may affect data order.")
1223
- log_info("Data sequence integrity is essential - using sequential sampler regardless of flag.")
1224
- # Force sequential processing regardless of flag
1225
 
1226
- if not shuffle_disabled:
1227
- log_info("CRITICAL ERROR: Shuffle is not disabled! This will randomize data entry order!")
1228
- # Actually handle the error rather than just logging it
1229
  raise ValueError("Dataset shuffling is enabled but sequential processing is required. " +
1230
- "Please disable shuffling in your configuration.")
1231
 
1232
  # Calculate batch size based on device availability
1233
  if getattr(training_args, "no_cuda", False):
 
309
  if source != target: # Only rename if names are different
310
  dataset = dataset.rename_column(source, target)
311
 
312
+ # Add prompt_number field that increments based on original order - simple approach
313
+ logger.info("Adding prompt_number based on original dataset order (starting at 1)")
314
+
315
+ # Simple approach 1: Add index as a column during dataset creation
316
+ # Create a list of dicts with indices
317
+ examples_with_idx = []
318
+ for i, example in enumerate(dataset):
319
+ example = dict(example) # Make a copy to avoid modifying the original
320
+ example['prompt_number'] = i + 1 # 1-indexed
321
+ examples_with_idx.append(example)
322
+
323
+ # Recreate dataset with prompt_number included
324
+ from datasets import Dataset
325
+ dataset = Dataset.from_list(examples_with_idx)
326
+ logger.info("Successfully added prompt_number to dataset")
327
+
328
+ # If conversations is missing but text exists, attempt conversion
329
+ if "conversations" not in dataset.column_names and "text" in dataset.column_names:
330
+ logger.info("Converting 'text' field to 'conversations' format")
331
 
332
+ def convert_text_to_conversations(example):
333
+ # Check if text is already a list of conversation turns
334
+ if isinstance(example.get("text"), list):
335
+ example["conversations"] = example["text"]
336
+ # Otherwise, create a simple conversation with the text as user message
337
+ else:
338
+ example["conversations"] = [
339
+ {"role": "user", "content": str(example.get("text", ""))}
340
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
341
  return example
342
 
343
+ dataset = dataset.map(convert_text_to_conversations)
344
+ logger.info("Successfully converted 'text' to 'conversations'")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
345
 
346
+ # Verify we have the required columns
347
+ if "conversations" not in dataset.column_names:
348
+ logger.error("Required 'conversations' column not found in dataset!")
349
+ raise ValueError("Required 'conversations' column missing from dataset")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
350
 
351
+ # Log column names and a sample
352
  logger.info(f"Dataset loaded successfully with {len(dataset)} examples")
353
  logger.info(f"Dataset columns: {dataset.column_names}")
354
 
355
+ # Log a sample for inspection
356
+ if len(dataset) > 0:
357
+ sample = dataset[0]
358
+ prompt_num = sample.get("prompt_number", "N/A")
359
+ article_id = sample.get("article_id", sample.get("id", "N/A"))
360
+ logger.info(f"First sample - Prompt number: {prompt_num}, ID: {article_id}")
361
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
362
  return dataset
363
+
364
  except Exception as e:
365
  logger.error(f"Error loading dataset: {str(e)}")
366
  raise
 
855
  per_device_batch_size = transformers_config.get("training", {}).get("per_device_train_batch_size", 16)
856
  gradient_accumulation_steps = transformers_config.get("training", {}).get("gradient_accumulation_steps", 3)
857
 
858
+ # Get multi-GPU strategy from hardware config (default to data_parallel)
859
+ multi_gpu_strategy = hardware_config.get("training_optimizations", {}).get("multi_gpu_strategy", "data_parallel")
860
+ logger.info(f"Multi-GPU strategy: {multi_gpu_strategy}")
861
+
862
  # For multi-GPU setup, adjust for better balance
863
  if CUDA_AVAILABLE and NUM_GPUS > 1:
864
  log_info(f"Multi-GPU setup: Adjusting for {NUM_GPUS} GPUs")
 
960
  """Custom dataloader that preserves original dataset order"""
961
  log_info("Creating sequential dataloader to maintain original dataset order")
962
 
963
+ # Create a simple sequential sampler
964
+ sequential_sampler = torch.utils.data.SequentialSampler(dataset)
 
 
965
 
966
+ # Verification of sequence preservation flags - simplified
967
+ data_loading_config = dataset_config.get("data_loading", {})
968
+ shuffle_enabled = data_loading_config.get("shuffle", False)
 
969
 
970
+ if shuffle_enabled:
971
+ log_info("CRITICAL ERROR: Shuffle is enabled! This will randomize data entry order!")
 
972
  raise ValueError("Dataset shuffling is enabled but sequential processing is required. " +
973
+ "Please disable shuffling in your configuration.")
974
 
975
  # Calculate batch size based on device availability
976
  if getattr(training_args, "no_cuda", False):