Spaces:
Runtime error
Runtime error
Upload folder using huggingface_hub
Browse files- run_transformers_training.py +83 -93
run_transformers_training.py
CHANGED
@@ -11,13 +11,25 @@ from datetime import datetime
|
|
11 |
import time
|
12 |
import warnings
|
13 |
from importlib.util import find_spec
|
|
|
14 |
|
15 |
# Check hardware capabilities first
|
16 |
-
|
17 |
-
CUDA_AVAILABLE = torch.cuda.is_available()
|
18 |
NUM_GPUS = torch.cuda.device_count() if CUDA_AVAILABLE else 0
|
19 |
DEVICE_TYPE = "cuda" if CUDA_AVAILABLE else "cpu"
|
20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
# Configure logging early
|
22 |
logging.basicConfig(
|
23 |
level=logging.INFO,
|
@@ -435,104 +447,82 @@ def format_phi_chat(messages, dataset_config):
|
|
435 |
class SimpleDataCollator:
|
436 |
def __init__(self, tokenizer, dataset_config):
|
437 |
self.tokenizer = tokenizer
|
438 |
-
self.
|
439 |
-
self.stats = {
|
440 |
-
|
441 |
-
|
442 |
-
|
443 |
-
|
444 |
-
|
|
|
445 |
def __call__(self, features):
|
446 |
-
|
|
|
|
|
|
|
|
|
|
|
447 |
|
448 |
-
for
|
449 |
-
|
450 |
-
|
451 |
-
|
452 |
-
|
453 |
-
|
454 |
-
|
455 |
-
|
456 |
-
|
457 |
-
# Skip if no conversations
|
458 |
-
if not conversations:
|
459 |
-
logger.warning(f"Empty conversations for paper_id {paper_id}, prompt {prompt_num}")
|
460 |
-
self.stats["skipped"] += 1
|
461 |
-
continue
|
462 |
-
|
463 |
-
# Get the pre-tokenized content directly
|
464 |
-
# The content should already be properly tokenized and formatted
|
465 |
-
content = conversations[0].get("content", "")
|
466 |
-
if not content:
|
467 |
-
logger.warning(f"Empty content for paper_id {paper_id}, prompt {prompt_num}")
|
468 |
-
self.stats["skipped"] += 1
|
469 |
-
continue
|
470 |
-
|
471 |
-
# Convert string of numbers to list of integers if needed
|
472 |
-
if isinstance(content, str):
|
473 |
-
try:
|
474 |
-
# Assuming content is space-separated numbers
|
475 |
-
input_ids = [int(x) for x in content.split()]
|
476 |
-
except ValueError:
|
477 |
-
logger.warning(f"Invalid pre-tokenized content format for paper_id {paper_id}, prompt {prompt_num}")
|
478 |
-
self.stats["skipped"] += 1
|
479 |
-
continue
|
480 |
-
else:
|
481 |
-
input_ids = content
|
482 |
-
|
483 |
-
# Truncate if needed
|
484 |
-
if len(input_ids) > self.max_seq_length:
|
485 |
-
input_ids = input_ids[:self.max_seq_length]
|
486 |
-
logger.warning(f"Truncated sequence for paper_id {paper_id}, prompt {prompt_num}")
|
487 |
-
|
488 |
-
# Create attention mask (1s for all tokens)
|
489 |
-
attention_mask = [1] * len(input_ids)
|
490 |
-
|
491 |
-
# Add to batch
|
492 |
-
batch["input_ids"].append(input_ids)
|
493 |
-
batch["attention_mask"].append(attention_mask)
|
494 |
-
batch["labels"].append(input_ids.copy()) # For causal LM, labels = input_ids
|
495 |
-
|
496 |
-
self.stats["processed"] += 1
|
497 |
-
self.stats["total_tokens"] += len(input_ids)
|
498 |
-
|
499 |
-
# Log first few examples for verification
|
500 |
-
if self.stats["processed"] <= 3:
|
501 |
-
logger.info(f"Sample {self.stats['processed']} token count: {len(input_ids)}")
|
502 |
|
503 |
-
|
504 |
-
|
|
|
|
|
505 |
self.stats["skipped"] += 1
|
506 |
continue
|
507 |
-
|
508 |
-
|
509 |
-
|
510 |
-
|
511 |
-
|
512 |
-
|
513 |
-
|
514 |
-
|
515 |
-
|
516 |
-
|
517 |
-
|
518 |
-
|
519 |
-
|
520 |
-
|
521 |
-
|
522 |
-
|
523 |
-
|
524 |
-
|
525 |
-
|
526 |
-
|
527 |
-
|
528 |
-
|
529 |
-
|
|
|
|
|
|
|
530 |
if self.stats["processed"] % 100 == 0:
|
531 |
-
|
532 |
-
|
533 |
-
f"
|
534 |
|
535 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
536 |
|
537 |
class LoggingCallback(TrainerCallback):
|
538 |
def __init__(self, model=None, dataset=None):
|
|
|
11 |
import time
|
12 |
import warnings
|
13 |
from importlib.util import find_spec
|
14 |
+
import multiprocessing
|
15 |
|
16 |
# Check hardware capabilities first
|
17 |
+
CUDA_AVAILABLE = "CUDA_VISIBLE_DEVICES" in os.environ or os.environ.get("NVIDIA_VISIBLE_DEVICES") != ""
|
|
|
18 |
NUM_GPUS = torch.cuda.device_count() if CUDA_AVAILABLE else 0
|
19 |
DEVICE_TYPE = "cuda" if CUDA_AVAILABLE else "cpu"
|
20 |
|
21 |
+
# Set the multiprocessing start method to 'spawn' for CUDA compatibility
|
22 |
+
if CUDA_AVAILABLE:
|
23 |
+
try:
|
24 |
+
multiprocessing.set_start_method('spawn', force=True)
|
25 |
+
print("Set multiprocessing start method to 'spawn' for CUDA compatibility")
|
26 |
+
except RuntimeError:
|
27 |
+
# Method already set, which is fine
|
28 |
+
print("Multiprocessing start method already set")
|
29 |
+
|
30 |
+
# Now import the rest of the modules
|
31 |
+
import torch
|
32 |
+
|
33 |
# Configure logging early
|
34 |
logging.basicConfig(
|
35 |
level=logging.INFO,
|
|
|
447 |
class SimpleDataCollator:
|
448 |
def __init__(self, tokenizer, dataset_config):
|
449 |
self.tokenizer = tokenizer
|
450 |
+
self.max_seq_length = min(dataset_config.get("max_seq_length", 2048), tokenizer.model_max_length)
|
451 |
+
self.stats = {
|
452 |
+
"processed": 0,
|
453 |
+
"skipped": 0,
|
454 |
+
"total_tokens": 0
|
455 |
+
}
|
456 |
+
logger.info(f"Initialized SimpleDataCollator with max_seq_length={self.max_seq_length}")
|
457 |
+
|
458 |
def __call__(self, features):
|
459 |
+
# Initialize tensors on CPU to save GPU memory
|
460 |
+
batch = {
|
461 |
+
"input_ids": [],
|
462 |
+
"attention_mask": [],
|
463 |
+
"labels": []
|
464 |
+
}
|
465 |
|
466 |
+
for feature in features:
|
467 |
+
paper_id = feature.get("article_id", "unknown")
|
468 |
+
prompt_num = feature.get("prompt_number", 0)
|
469 |
+
conversations = feature.get("conversations", [])
|
470 |
+
|
471 |
+
if not conversations:
|
472 |
+
logger.warning(f"No conversations for paper_id {paper_id}, prompt {prompt_num}")
|
473 |
+
self.stats["skipped"] += 1
|
474 |
+
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
475 |
|
476 |
+
# Get the content directly
|
477 |
+
content = conversations[0].get("content", "")
|
478 |
+
if not content:
|
479 |
+
logger.warning(f"Empty content for paper_id {paper_id}, prompt {prompt_num}")
|
480 |
self.stats["skipped"] += 1
|
481 |
continue
|
482 |
+
|
483 |
+
# Process the content string by tokenizing it
|
484 |
+
if isinstance(content, str):
|
485 |
+
# Tokenize the content string
|
486 |
+
input_ids = self.tokenizer.encode(content, add_special_tokens=True)
|
487 |
+
else:
|
488 |
+
# If somehow the content is already tokenized (not a string), use it directly
|
489 |
+
input_ids = content
|
490 |
+
|
491 |
+
# Truncate if needed
|
492 |
+
if len(input_ids) > self.max_seq_length:
|
493 |
+
input_ids = input_ids[:self.max_seq_length]
|
494 |
+
logger.warning(f"Truncated sequence for paper_id {paper_id}, prompt {prompt_num}")
|
495 |
+
|
496 |
+
# Create attention mask (1s for all tokens)
|
497 |
+
attention_mask = [1] * len(input_ids)
|
498 |
+
|
499 |
+
# Add to batch
|
500 |
+
batch["input_ids"].append(input_ids)
|
501 |
+
batch["attention_mask"].append(attention_mask)
|
502 |
+
batch["labels"].append(input_ids.copy()) # For causal LM, labels = input_ids
|
503 |
+
|
504 |
+
self.stats["processed"] += 1
|
505 |
+
self.stats["total_tokens"] += len(input_ids)
|
506 |
+
|
507 |
+
# Log statistics periodically
|
508 |
if self.stats["processed"] % 100 == 0:
|
509 |
+
avg_tokens = self.stats["total_tokens"] / max(1, self.stats["processed"])
|
510 |
+
logger.info(f"Data collation stats: processed={self.stats['processed']}, "
|
511 |
+
f"skipped={self.stats['skipped']}, avg_tokens={avg_tokens:.1f}")
|
512 |
|
513 |
+
# Convert to tensors or pad sequences (PyTorch will handle)
|
514 |
+
if batch["input_ids"]:
|
515 |
+
# Pad sequences to max length in batch using the tokenizer
|
516 |
+
batch = self.tokenizer.pad(
|
517 |
+
batch,
|
518 |
+
padding="max_length",
|
519 |
+
max_length=self.max_seq_length,
|
520 |
+
return_tensors="pt"
|
521 |
+
)
|
522 |
+
return batch
|
523 |
+
else:
|
524 |
+
# Return empty batch if no valid examples
|
525 |
+
return {k: [] for k in batch}
|
526 |
|
527 |
class LoggingCallback(TrainerCallback):
|
528 |
def __init__(self, model=None, dataset=None):
|