George-API commited on
Commit
7e5a6ad
·
verified ·
1 Parent(s): 0364d5c

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. run_transformers_training.py +94 -133
run_transformers_training.py CHANGED
@@ -297,20 +297,27 @@ def load_dataset_with_mapping(dataset_config):
297
  else:
298
  logger.warning(f"Expected column '{col}' not found in dataset")
299
 
300
- # Sort dataset if required
301
- sort_by_id = dataset_config.get("dataset", {}).get("processing", {}).get("sort_by_id", False)
302
- if sort_by_id and "id" in dataset.column_names:
303
- logger.info("Sorting dataset by ID")
304
- dataset = dataset.sort("id")
305
-
306
- # Log the first few IDs to verify sorting
307
  sample_ids = [example['id'] for example in dataset.select(range(min(5, len(dataset))))]
308
- logger.info(f"First few IDs after sorting: {sample_ids}")
309
 
310
- # Log example of conversations structure to verify format
311
- if "conversations" in dataset.column_names:
312
- sample_conv = dataset["conversations"][0] if len(dataset) > 0 else []
313
- logger.info(f"Example conversation structure: {sample_conv}")
 
 
 
 
 
 
 
 
 
314
 
315
  logger.info(f"Dataset loaded successfully with {len(dataset)} examples")
316
  logger.info(f"Dataset columns: {dataset.column_names}")
@@ -374,142 +381,91 @@ class SimpleDataCollator:
374
  self.dataset_config = dataset_config
375
  self.stats = {"processed": 0, "skipped": 0, "total_tokens": 0}
376
  self.pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
377
- self.paper_counters = {}
378
  self.max_seq_length = dataset_config.get("dataset", {}).get("processing", {}).get("max_seq_length", 2048)
379
- self.include_metadata = False # Disable automatic metadata inclusion as it's already in content
380
- self.roles = dataset_config.get("data_formatting", {}).get("roles", {})
381
- logger.info(f"SimpleDataCollator initialized - using phi-4 chat format with max_seq_length={self.max_seq_length}")
382
- logger.info("Metadata handling disabled - using metadata from content field")
383
 
384
  # Check if we're on GPU
385
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
386
  logger.info(f"SimpleDataCollator using device: {self.device}")
387
 
388
- def normalize_conversation(self, conversation):
389
- """Normalize conversation format to ensure consistent structure."""
390
- normalized = []
391
-
392
- # Handle non-list or empty inputs
393
- if not isinstance(conversation, list):
394
- logger.warning(f"Conversation is not a list: {type(conversation)}")
395
- if hasattr(conversation, 'items'): # It's a dict-like object
396
- conversation = [conversation]
397
- else:
398
- return []
399
-
400
- # Get introductory message if present (should be first and without chunk number)
401
- intro_msg = None
402
- for i, turn in enumerate(conversation):
403
- if isinstance(turn, dict) and turn.get('content') and "[RESEARCH INTRODUCTION]" in turn.get('content', ''):
404
- intro_msg = turn
405
- break
406
-
407
- # Process introduction message first if found
408
- if intro_msg:
409
- normalized.append({
410
- "role": "system",
411
- "content": intro_msg.get('content', '')
412
- })
413
- # Remove intro from further processing
414
- conversation = [t for t in conversation if t != intro_msg]
415
-
416
- # Process remaining messages
417
- for turn in conversation:
418
- # Skip empty or None entries
419
- if not turn:
420
- continue
421
-
422
- # Handle string entries (convert to user message)
423
- if isinstance(turn, str):
424
- normalized.append({"role": "user", "content": turn})
425
- continue
426
-
427
- # Handle dict-like entries
428
- if not isinstance(turn, dict) and hasattr(turn, 'get'):
429
- # Convert to dict
430
- turn = {k: turn.get(k) for k in ['role', 'content'] if hasattr(turn, 'get') and turn.get(k) is not None}
431
-
432
- # Ensure both role and content exist
433
- if not isinstance(turn, dict) or 'role' not in turn or 'content' not in turn:
434
- logger.warning(f"Skipping malformatted conversation turn: {turn}")
435
- continue
436
-
437
- # Normalize role field
438
- role = turn.get('role', '').lower()
439
- if role == 'user' or role == 'human':
440
- role = 'user'
441
- elif role == 'assistant' or role == 'bot':
442
- role = 'assistant'
443
-
444
- # Add normalized turn
445
- normalized.append({
446
- "role": role,
447
- "content": str(turn.get('content', ''))
448
- })
449
-
450
- return normalized
451
-
452
  def __call__(self, features):
 
453
  batch = {"input_ids": [], "attention_mask": [], "labels": []}
454
 
455
  for example in features:
456
  try:
457
- # Get ID and conversation fields
458
  paper_id = example.get("id", "")
459
 
460
- # Handle conversation field - could be under 'conversations' or 'text'
461
- conversation = example.get("conversations", example.get("text", []))
462
-
463
- # Normalize conversation format
464
- conversation = self.normalize_conversation(conversation)
465
-
466
- if not conversation:
467
  self.stats["skipped"] += 1
468
  continue
469
 
470
- # Format conversation with research introduction and chunk info
471
- formatted_content = format_phi_chat(conversation, self.dataset_config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
472
 
473
- # Tokenize with the model's chat template
474
- inputs = self.tokenizer(
475
- formatted_content,
476
- add_special_tokens=True,
477
- truncation=True,
478
- max_length=self.max_seq_length,
479
- return_tensors=None,
480
- padding=False, # Don't pad here, we'll pad the batch later
481
- )
482
 
483
- if len(inputs["input_ids"]) > 0:
484
  # For causal language modeling, labels are the same as inputs
485
- labels = inputs["input_ids"].copy()
486
 
487
- batch["input_ids"].append(inputs["input_ids"])
488
- batch["attention_mask"].append(inputs["attention_mask"])
489
  batch["labels"].append(labels)
490
 
491
  self.stats["processed"] += 1
492
- self.stats["total_tokens"] += len(inputs["input_ids"])
493
 
494
  # Debug logging for first few examples
495
  log_samples = self.dataset_config.get("validation", {}).get("log_samples", 3)
496
  if self.stats["processed"] <= log_samples:
497
- logger.info(f"Example {self.stats['processed']} format:")
498
  logger.info(f"Paper ID: {paper_id}")
499
- logger.info(f"Token count: {len(inputs['input_ids'])}")
500
- logger.info(f"Content preview:\n{formatted_content[:500]}...")
501
- logger.info(f"Conversation structure: {conversation[:2]}...")
502
  else:
503
  self.stats["skipped"] += 1
504
  except Exception as e:
505
  logger.warning(f"Error processing example: {str(e)[:100]}...")
506
- logger.warning(f"Problematic example: {str(example)[:200]}...")
507
  self.stats["skipped"] += 1
508
  continue
509
 
510
  if not batch["input_ids"]:
511
  logger.warning("Empty batch, returning dummy tensors")
512
- # Return tensors on the right device
513
  return {
514
  "input_ids": torch.zeros((1, 1), dtype=torch.long),
515
  "attention_mask": torch.zeros((1, 1), dtype=torch.long),
@@ -526,7 +482,7 @@ class SimpleDataCollator:
526
  batch["attention_mask"][i].extend([0] * padding_length)
527
  batch["labels"][i].extend([-100] * padding_length)
528
 
529
- # Convert to tensors on CPU first
530
  batch = {k: torch.tensor(v, dtype=torch.long) for k, v in batch.items()}
531
 
532
  # Log stats periodically
@@ -534,8 +490,7 @@ class SimpleDataCollator:
534
  if self.stats["processed"] % log_interval == 0 and self.stats["processed"] > 0:
535
  logger.info(f"Data collator stats: processed={self.stats['processed']}, "
536
  f"skipped={self.stats['skipped']}, "
537
- f"avg_tokens={self.stats['total_tokens']/self.stats['processed']:.1f}, "
538
- f"unique_papers={len(self.paper_counters)}")
539
 
540
  return batch
541
 
@@ -731,21 +686,35 @@ def main():
731
  no_cuda=False if torch.cuda.is_available() else True, # Use CUDA if available
732
  )
733
 
734
- # Custom dataloader to ensure no shuffling of dataset
735
- # This preserves the order of chunks in papers
736
- def get_train_dataloader_no_shuffle():
737
- logger.info("Creating data loader with sequential sampler to maintain paper order")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
738
  if getattr(training_args, "no_cuda", False):
739
  batch_size = training_args.per_device_train_batch_size
740
  else:
741
- batch_size = max(training_args.per_device_train_batch_size * torch.cuda.device_count(), 1)
742
 
743
- # Use sequential sampler to preserve order
744
- sequential_sampler = torch.utils.data.SequentialSampler(dataset["train"])
745
- logger.info(f"Using sequential sampler for batch size {batch_size}")
746
 
 
747
  return torch.utils.data.DataLoader(
748
- dataset["train"],
749
  batch_size=batch_size,
750
  sampler=sequential_sampler,
751
  collate_fn=data_collator,
@@ -754,16 +723,8 @@ def main():
754
  pin_memory=training_args.dataloader_pin_memory,
755
  )
756
 
757
- # Set up trainer with custom dataloader
758
- logger.info("Initializing Trainer")
759
- trainer = Trainer(
760
- model=model,
761
- args=training_args,
762
- get_train_dataloader=get_train_dataloader_no_shuffle,
763
- tokenizer=tokenizer,
764
- data_collator=data_collator,
765
- callbacks=[LoggingCallback()]
766
- )
767
 
768
  # Start training
769
  logger.info("Starting training process")
 
297
  else:
298
  logger.warning(f"Expected column '{col}' not found in dataset")
299
 
300
+ # Note: Explicitly NOT sorting the dataset to preserve original order
301
+ logger.info("Preserving original dataset order (no sorting)")
302
+
303
+ # Log examples without printing full content
304
+ if "conversations" in dataset.column_names:
 
 
305
  sample_ids = [example['id'] for example in dataset.select(range(min(5, len(dataset))))]
306
+ logger.info(f"First few IDs: {sample_ids}")
307
 
308
+ # Log conversation structure without full content
309
+ if len(dataset) > 0:
310
+ sample_conv_structure = []
311
+ for msg in dataset["conversations"][0]:
312
+ if isinstance(msg, dict):
313
+ content = msg.get('content', '')
314
+ preview = content[:50] + "..." if len(content) > 50 else content
315
+ sample_conv_structure.append({
316
+ "role": msg.get('role', ''),
317
+ "content_length": len(content),
318
+ "preview": preview
319
+ })
320
+ logger.info(f"Conversation structure: {sample_conv_structure}")
321
 
322
  logger.info(f"Dataset loaded successfully with {len(dataset)} examples")
323
  logger.info(f"Dataset columns: {dataset.column_names}")
 
381
  self.dataset_config = dataset_config
382
  self.stats = {"processed": 0, "skipped": 0, "total_tokens": 0}
383
  self.pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
 
384
  self.max_seq_length = dataset_config.get("dataset", {}).get("processing", {}).get("max_seq_length", 2048)
385
+ logger.info(f"SimpleDataCollator initialized - using pre-audited dataset with max_seq_length={self.max_seq_length}")
386
+ logger.info("Using exact dataset structure without reformatting")
 
 
387
 
388
  # Check if we're on GPU
389
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
390
  logger.info(f"SimpleDataCollator using device: {self.device}")
391
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
392
  def __call__(self, features):
393
+ """Process examples preserving exact JSONL structure"""
394
  batch = {"input_ids": [], "attention_mask": [], "labels": []}
395
 
396
  for example in features:
397
  try:
398
+ # Get ID
399
  paper_id = example.get("id", "")
400
 
401
+ # Get conversations - these should already contain role and content
402
+ conversations = example.get("conversations", [])
403
+ if not conversations:
 
 
 
 
404
  self.stats["skipped"] += 1
405
  continue
406
 
407
+ # Directly use the conversations array as input to the model's chat template
408
+ # This preserves the exact structure with roles and content as they are
409
+ try:
410
+ # Let tokenizer handle the content with the model's chat template
411
+ inputs = self.tokenizer.apply_chat_template(
412
+ conversations,
413
+ return_tensors=None,
414
+ add_generation_prompt=False
415
+ )
416
+ except:
417
+ # Fallback if apply_chat_template fails
418
+ logger.warning(f"Chat template application failed for example {paper_id}, using basic tokenization")
419
+
420
+ # Create a basic representation of the conversation
421
+ conversation_text = ""
422
+ for msg in conversations:
423
+ if isinstance(msg, dict) and 'content' in msg:
424
+ conversation_text += msg.get('content', '') + "\n\n"
425
+
426
+ # Basic tokenization
427
+ inputs = self.tokenizer(
428
+ conversation_text,
429
+ add_special_tokens=True,
430
+ return_tensors=None
431
+ )
432
 
433
+ # Apply length cap if needed (shouldn't be necessary for pre-audited data)
434
+ if self.max_seq_length > 0 and len(inputs) > self.max_seq_length:
435
+ logger.warning(f"Example {paper_id} exceeds max_seq_length ({len(inputs)} > {self.max_seq_length})")
436
+ inputs = inputs[:self.max_seq_length]
437
+
438
+ # Create attention mask (1 for all tokens)
439
+ attention_mask = [1] * len(inputs)
 
 
440
 
441
+ if len(inputs) > 0:
442
  # For causal language modeling, labels are the same as inputs
443
+ labels = inputs.copy()
444
 
445
+ batch["input_ids"].append(inputs)
446
+ batch["attention_mask"].append(attention_mask)
447
  batch["labels"].append(labels)
448
 
449
  self.stats["processed"] += 1
450
+ self.stats["total_tokens"] += len(inputs)
451
 
452
  # Debug logging for first few examples
453
  log_samples = self.dataset_config.get("validation", {}).get("log_samples", 3)
454
  if self.stats["processed"] <= log_samples:
455
+ logger.info(f"Example {self.stats['processed']}:")
456
  logger.info(f"Paper ID: {paper_id}")
457
+ logger.info(f"Token count: {len(inputs)}")
458
+ logger.info(f"Conversation entries: {len(conversations)}")
 
459
  else:
460
  self.stats["skipped"] += 1
461
  except Exception as e:
462
  logger.warning(f"Error processing example: {str(e)[:100]}...")
463
+ logger.warning(f"Problematic example ID: {example.get('id', 'unknown')}")
464
  self.stats["skipped"] += 1
465
  continue
466
 
467
  if not batch["input_ids"]:
468
  logger.warning("Empty batch, returning dummy tensors")
 
469
  return {
470
  "input_ids": torch.zeros((1, 1), dtype=torch.long),
471
  "attention_mask": torch.zeros((1, 1), dtype=torch.long),
 
482
  batch["attention_mask"][i].extend([0] * padding_length)
483
  batch["labels"][i].extend([-100] * padding_length)
484
 
485
+ # Convert to tensors
486
  batch = {k: torch.tensor(v, dtype=torch.long) for k, v in batch.items()}
487
 
488
  # Log stats periodically
 
490
  if self.stats["processed"] % log_interval == 0 and self.stats["processed"] > 0:
491
  logger.info(f"Data collator stats: processed={self.stats['processed']}, "
492
  f"skipped={self.stats['skipped']}, "
493
+ f"avg_tokens={self.stats['total_tokens']/self.stats['processed']:.1f}")
 
494
 
495
  return batch
496
 
 
686
  no_cuda=False if torch.cuda.is_available() else True, # Use CUDA if available
687
  )
688
 
689
+ # Create sequential sampler to maintain original dataset order
690
+ sequential_sampler = torch.utils.data.SequentialSampler(dataset)
691
+
692
+ # Initialize trainer first
693
+ logger.info("Initializing Trainer")
694
+ trainer = Trainer(
695
+ model=model,
696
+ args=training_args,
697
+ train_dataset=dataset, # We'll override this with our custom dataloader
698
+ data_collator=data_collator,
699
+ callbacks=[LoggingCallback()],
700
+ )
701
+
702
+ # Then override the get_train_dataloader method
703
+ def custom_get_train_dataloader():
704
+ """Custom dataloader that preserves original dataset order"""
705
+ logger.info("Creating sequential dataloader to maintain original dataset order")
706
+
707
+ # Calculate batch size based on device availability
708
  if getattr(training_args, "no_cuda", False):
709
  batch_size = training_args.per_device_train_batch_size
710
  else:
711
+ batch_size = max(training_args.per_device_train_batch_size * max(1, torch.cuda.device_count()), 1)
712
 
713
+ logger.info(f"Using sequential sampler with batch size {batch_size}")
 
 
714
 
715
+ # Return DataLoader with sequential sampler
716
  return torch.utils.data.DataLoader(
717
+ dataset,
718
  batch_size=batch_size,
719
  sampler=sequential_sampler,
720
  collate_fn=data_collator,
 
723
  pin_memory=training_args.dataloader_pin_memory,
724
  )
725
 
726
+ # Override the get_train_dataloader method
727
+ trainer.get_train_dataloader = custom_get_train_dataloader
 
 
 
 
 
 
 
 
728
 
729
  # Start training
730
  logger.info("Starting training process")