George-API commited on
Commit
2b5da3a
·
verified ·
1 Parent(s): 4ce739a

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. run_transformers_training.py +83 -93
run_transformers_training.py CHANGED
@@ -11,13 +11,25 @@ from datetime import datetime
11
  import time
12
  import warnings
13
  from importlib.util import find_spec
 
14
 
15
  # Check hardware capabilities first
16
- import torch
17
- CUDA_AVAILABLE = torch.cuda.is_available()
18
  NUM_GPUS = torch.cuda.device_count() if CUDA_AVAILABLE else 0
19
  DEVICE_TYPE = "cuda" if CUDA_AVAILABLE else "cpu"
20
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  # Configure logging early
22
  logging.basicConfig(
23
  level=logging.INFO,
@@ -435,104 +447,82 @@ def format_phi_chat(messages, dataset_config):
435
  class SimpleDataCollator:
436
  def __init__(self, tokenizer, dataset_config):
437
  self.tokenizer = tokenizer
438
- self.dataset_config = dataset_config
439
- self.stats = {"processed": 0, "skipped": 0, "total_tokens": 0}
440
- self.pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
441
- self.max_seq_length = dataset_config.get("dataset", {}).get("processing", {}).get("max_seq_length", 2048)
442
- logger.info(f"SimpleDataCollator initialized - using pre-tokenized chunks with max_seq_length={self.max_seq_length}")
443
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
444
-
 
445
  def __call__(self, features):
446
- batch = {"input_ids": [], "attention_mask": [], "labels": []}
 
 
 
 
 
447
 
448
- for example in features:
449
- try:
450
- # Get ID for logging
451
- paper_id = example.get("article_id", "unknown")
452
- prompt_num = example.get("prompt_number", "unknown")
453
-
454
- # Get the conversations list
455
- conversations = example.get("conversations", [])
456
-
457
- # Skip if no conversations
458
- if not conversations:
459
- logger.warning(f"Empty conversations for paper_id {paper_id}, prompt {prompt_num}")
460
- self.stats["skipped"] += 1
461
- continue
462
-
463
- # Get the pre-tokenized content directly
464
- # The content should already be properly tokenized and formatted
465
- content = conversations[0].get("content", "")
466
- if not content:
467
- logger.warning(f"Empty content for paper_id {paper_id}, prompt {prompt_num}")
468
- self.stats["skipped"] += 1
469
- continue
470
-
471
- # Convert string of numbers to list of integers if needed
472
- if isinstance(content, str):
473
- try:
474
- # Assuming content is space-separated numbers
475
- input_ids = [int(x) for x in content.split()]
476
- except ValueError:
477
- logger.warning(f"Invalid pre-tokenized content format for paper_id {paper_id}, prompt {prompt_num}")
478
- self.stats["skipped"] += 1
479
- continue
480
- else:
481
- input_ids = content
482
-
483
- # Truncate if needed
484
- if len(input_ids) > self.max_seq_length:
485
- input_ids = input_ids[:self.max_seq_length]
486
- logger.warning(f"Truncated sequence for paper_id {paper_id}, prompt {prompt_num}")
487
-
488
- # Create attention mask (1s for all tokens)
489
- attention_mask = [1] * len(input_ids)
490
-
491
- # Add to batch
492
- batch["input_ids"].append(input_ids)
493
- batch["attention_mask"].append(attention_mask)
494
- batch["labels"].append(input_ids.copy()) # For causal LM, labels = input_ids
495
-
496
- self.stats["processed"] += 1
497
- self.stats["total_tokens"] += len(input_ids)
498
-
499
- # Log first few examples for verification
500
- if self.stats["processed"] <= 3:
501
- logger.info(f"Sample {self.stats['processed']} token count: {len(input_ids)}")
502
 
503
- except Exception as e:
504
- logger.warning(f"Error processing example {paper_id}, prompt {prompt_num}: {str(e)}")
 
 
505
  self.stats["skipped"] += 1
506
  continue
507
-
508
- if not batch["input_ids"]:
509
- logger.warning("Empty batch, returning dummy tensors")
510
- return {
511
- "input_ids": torch.zeros((1, 1), dtype=torch.long, device=self.device),
512
- "attention_mask": torch.zeros((1, 1), dtype=torch.long, device=self.device),
513
- "labels": torch.zeros((1, 1), dtype=torch.long, device=self.device)
514
- }
515
-
516
- # Pad the batch
517
- max_length = max(len(ids) for ids in batch["input_ids"])
518
-
519
- for i in range(len(batch["input_ids"])):
520
- padding_length = max_length - len(batch["input_ids"][i])
521
- if padding_length > 0:
522
- batch["input_ids"][i].extend([self.pad_token_id] * padding_length)
523
- batch["attention_mask"][i].extend([0] * padding_length)
524
- batch["labels"][i].extend([-100] * padding_length) # -100 is the ignore index for loss
525
-
526
- # Convert to tensors
527
- batch = {k: torch.tensor(v, dtype=torch.long, device=self.device) for k, v in batch.items()}
528
-
529
- # Log stats periodically
 
 
 
530
  if self.stats["processed"] % 100 == 0:
531
- logger.info(f"Collator stats: processed={self.stats['processed']}, "
532
- f"skipped={self.stats['skipped']}, "
533
- f"avg_tokens={self.stats['total_tokens']/max(1, self.stats['processed']):.1f}")
534
 
535
- return batch
 
 
 
 
 
 
 
 
 
 
 
 
536
 
537
  class LoggingCallback(TrainerCallback):
538
  def __init__(self, model=None, dataset=None):
 
11
  import time
12
  import warnings
13
  from importlib.util import find_spec
14
+ import multiprocessing
15
 
16
  # Check hardware capabilities first
17
+ CUDA_AVAILABLE = "CUDA_VISIBLE_DEVICES" in os.environ or os.environ.get("NVIDIA_VISIBLE_DEVICES") != ""
 
18
  NUM_GPUS = torch.cuda.device_count() if CUDA_AVAILABLE else 0
19
  DEVICE_TYPE = "cuda" if CUDA_AVAILABLE else "cpu"
20
 
21
+ # Set the multiprocessing start method to 'spawn' for CUDA compatibility
22
+ if CUDA_AVAILABLE:
23
+ try:
24
+ multiprocessing.set_start_method('spawn', force=True)
25
+ print("Set multiprocessing start method to 'spawn' for CUDA compatibility")
26
+ except RuntimeError:
27
+ # Method already set, which is fine
28
+ print("Multiprocessing start method already set")
29
+
30
+ # Now import the rest of the modules
31
+ import torch
32
+
33
  # Configure logging early
34
  logging.basicConfig(
35
  level=logging.INFO,
 
447
  class SimpleDataCollator:
448
  def __init__(self, tokenizer, dataset_config):
449
  self.tokenizer = tokenizer
450
+ self.max_seq_length = min(dataset_config.get("max_seq_length", 2048), tokenizer.model_max_length)
451
+ self.stats = {
452
+ "processed": 0,
453
+ "skipped": 0,
454
+ "total_tokens": 0
455
+ }
456
+ logger.info(f"Initialized SimpleDataCollator with max_seq_length={self.max_seq_length}")
457
+
458
  def __call__(self, features):
459
+ # Initialize tensors on CPU to save GPU memory
460
+ batch = {
461
+ "input_ids": [],
462
+ "attention_mask": [],
463
+ "labels": []
464
+ }
465
 
466
+ for feature in features:
467
+ paper_id = feature.get("article_id", "unknown")
468
+ prompt_num = feature.get("prompt_number", 0)
469
+ conversations = feature.get("conversations", [])
470
+
471
+ if not conversations:
472
+ logger.warning(f"No conversations for paper_id {paper_id}, prompt {prompt_num}")
473
+ self.stats["skipped"] += 1
474
+ continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
475
 
476
+ # Get the content directly
477
+ content = conversations[0].get("content", "")
478
+ if not content:
479
+ logger.warning(f"Empty content for paper_id {paper_id}, prompt {prompt_num}")
480
  self.stats["skipped"] += 1
481
  continue
482
+
483
+ # Process the content string by tokenizing it
484
+ if isinstance(content, str):
485
+ # Tokenize the content string
486
+ input_ids = self.tokenizer.encode(content, add_special_tokens=True)
487
+ else:
488
+ # If somehow the content is already tokenized (not a string), use it directly
489
+ input_ids = content
490
+
491
+ # Truncate if needed
492
+ if len(input_ids) > self.max_seq_length:
493
+ input_ids = input_ids[:self.max_seq_length]
494
+ logger.warning(f"Truncated sequence for paper_id {paper_id}, prompt {prompt_num}")
495
+
496
+ # Create attention mask (1s for all tokens)
497
+ attention_mask = [1] * len(input_ids)
498
+
499
+ # Add to batch
500
+ batch["input_ids"].append(input_ids)
501
+ batch["attention_mask"].append(attention_mask)
502
+ batch["labels"].append(input_ids.copy()) # For causal LM, labels = input_ids
503
+
504
+ self.stats["processed"] += 1
505
+ self.stats["total_tokens"] += len(input_ids)
506
+
507
+ # Log statistics periodically
508
  if self.stats["processed"] % 100 == 0:
509
+ avg_tokens = self.stats["total_tokens"] / max(1, self.stats["processed"])
510
+ logger.info(f"Data collation stats: processed={self.stats['processed']}, "
511
+ f"skipped={self.stats['skipped']}, avg_tokens={avg_tokens:.1f}")
512
 
513
+ # Convert to tensors or pad sequences (PyTorch will handle)
514
+ if batch["input_ids"]:
515
+ # Pad sequences to max length in batch using the tokenizer
516
+ batch = self.tokenizer.pad(
517
+ batch,
518
+ padding="max_length",
519
+ max_length=self.max_seq_length,
520
+ return_tensors="pt"
521
+ )
522
+ return batch
523
+ else:
524
+ # Return empty batch if no valid examples
525
+ return {k: [] for k in batch}
526
 
527
  class LoggingCallback(TrainerCallback):
528
  def __init__(self, model=None, dataset=None):