George-API commited on
Commit
4a1fd53
·
verified ·
1 Parent(s): adb15f9

Upload folder using huggingface_hub

Browse files
DEPLOY_CHECKLIST.md ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Phi-4 Training Space Deployment Checklist
2
+
3
+ ## Critical Configuration Review
4
+
5
+ Before updating the Hugging Face Space, verify each of these items to prevent deployment issues:
6
+
7
+ ### 1. Model Configuration ✓
8
+
9
+ - [ ] Confirmed model name in transformers_config.json: `unsloth/phi-4-unsloth-bnb-4bit`
10
+ - [ ] BF16 precision enabled, FP16 disabled (`"bf16": true, "fp16": false`)
11
+ - [ ] Chat template correctly set to `"phi"` in config
12
+ - [ ] LoRA parameters properly configured:
13
+ - [ ] `r`: 32
14
+ - [ ] `lora_alpha`: 16
15
+ - [ ] `target_modules`: All required attention modules included
16
+ - [ ] Max sequence length matches dataset needs (default: 2048)
17
+
18
+ ### 2. GPU & Memory Management ✓
19
+
20
+ - [ ] Per-device batch size set to 16 or lower
21
+ - [ ] Gradient accumulation steps set to 3 or higher
22
+ - [ ] Device mapping set to "auto" for multi-GPU
23
+ - [ ] Max memory limit set to 85% of each GPU's capacity
24
+ - [ ] `PYTORCH_CUDA_ALLOC_CONF` includes `"expandable_segments:True"`
25
+ - [ ] Gradient checkpointing enabled (`"gradient_checkpointing": true`)
26
+ - [ ] Dataloader workers reduced to 2 (from 4)
27
+ - [ ] FSDP configuration enabled for multi-GPU setups
28
+
29
+ ### 3. Dataset Handling ✓
30
+
31
+ - [ ] Dataset configuration correctly specified in dataset_config.json
32
+ - [ ] Conversation structure preserved (id + conversations fields)
33
+ - [ ] SimpleDataCollator configured to use apply_chat_template
34
+ - [ ] No re-ordering or sorting of the dataset (preserves original order)
35
+ - [ ] Sequential sampler used in dataloader (no shuffling)
36
+ - [ ] Max sequence length of 2048 applied
37
+ - [ ] Format validation for first few examples enabled
38
+
39
+ ### 4. Dependency Management ✓
40
+
41
+ - [ ] requirements.txt includes all necessary packages:
42
+ - [ ] unsloth
43
+ - [ ] peft
44
+ - [ ] bitsandbytes
45
+ - [ ] einops
46
+ - [ ] sentencepiece
47
+ - [ ] datasets
48
+ - [ ] transformers
49
+ - [ ] Optional packages marked as such (e.g., flash-attn)
50
+ - [ ] Dependency version constraints avoid known conflicts
51
+
52
+ ### 5. Error Handling & Logging ✓
53
+
54
+ - [ ] Proper error catching for dataset loading
55
+ - [ ] Fallback mechanisms for chat template application
56
+ - [ ] Clear, concise log messages that work with HF Space interface
57
+ - [ ] Memory usage tracking at key points (start, end, periodic)
58
+ - [ ] Third-party loggers set to WARNING to reduce noise
59
+ - [ ] Low-verbosity log format for better HF Space compatibility
60
+
61
+ ### 6. Training Setup ✓
62
+
63
+ - [ ] Number of epochs properly configured (default: 3)
64
+ - [ ] Learning rate appropriate (default: 2e-5)
65
+ - [ ] Warmup ratio set (default: 0.05)
66
+ - [ ] Checkpointing frequency set to reasonable value (default: 100 steps)
67
+ - [ ] Output directory correctly configured
68
+ - [ ] HuggingFace Hub parameters set correctly if pushing models
69
+
70
+ ### 7. Pre-Flight Verification ✓
71
+
72
+ - [ ] No linting errors or indentation issues
73
+ - [ ] Updated config values are consistent across files
74
+ - [ ] Batch size × gradient accumulation × GPUs gives reasonable total batch
75
+ - [ ] Verified that requirements.txt matches actual imports in code
76
+ - [ ] Confirmed tokenizer settings match the model requirements
77
+
78
+ ---
79
+
80
+ ## Last-Minute Configuration Changes
81
+
82
+ If you've made any configuration changes, record them here before deployment:
83
+
84
+ | Date | Parameter Changed | Old Value | New Value | Reason | Reviewer |
85
+ |------|-------------------|-----------|-----------|--------|----------|
86
+ | | | | | | |
87
+ | | | | | | |
88
+
89
+ ---
90
+
91
+ ## Deployment Notes
92
+
93
+ **Current Space Hardware**: 4× NVIDIA L4 GPUs (24GB VRAM each)
94
+
95
+ **Expected Training Speed**: ~XXX examples/second with current configuration
96
+
97
+ **Memory Requirements**: Peak usage expected to be ~20GB per GPU
98
+
99
+ **Common Issues to Watch For**:
100
+ - OOM errors on GPU 0: If seen, reduce batch size by 2 and increase grad accumulation by 1
101
+ - Imbalanced GPU usage: Check device mapping and FSDP configuration
102
+ - Slow training: Verify that all GPUs are being utilized efficiently
103
+ - Log flooding: Reduce verbosity of component logs (transformers, datasets, etc.)
104
+
105
+ ---
106
+
107
+ *Last Updated: 2025-03-09*
app.py CHANGED
@@ -1,14 +1,15 @@
 
 
 
1
  import os
2
  import sys
3
  import json
4
  import logging
5
- import gradio as gr
6
- from pathlib import Path
7
  import subprocess
8
  import time
9
  from datetime import datetime
10
 
11
- # Configure logging
12
  logging.basicConfig(
13
  level=logging.INFO,
14
  format="%(asctime)s - %(levelname)s - %(message)s",
@@ -16,11 +17,23 @@ logging.basicConfig(
16
  )
17
  logger = logging.getLogger(__name__)
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  # Configuration paths
20
  CONFIG_DIR = "."
21
  TRANSFORMERS_CONFIG = os.path.join(CONFIG_DIR, "transformers_config.json")
22
- HARDWARE_CONFIG = os.path.join(CONFIG_DIR, "hardware_config.json")
23
- DATASET_CONFIG = os.path.join(CONFIG_DIR, "dataset_config.json")
24
 
25
  def load_config(config_path):
26
  """Load configuration from JSON file."""
@@ -29,207 +42,145 @@ def load_config(config_path):
29
  with open(config_path, 'r') as f:
30
  return json.load(f)
31
  else:
32
- logger.error(f"Config file not found: {config_path}")
33
  return None
34
  except Exception as e:
35
- logger.error(f"Error loading config: {str(e)}")
36
  return None
37
 
38
  def display_config():
39
  """Display current training configuration."""
40
- transformers_config = load_config(TRANSFORMERS_CONFIG)
41
- hardware_config = load_config(HARDWARE_CONFIG)
42
- dataset_config = load_config(DATASET_CONFIG)
43
 
44
- if not all([transformers_config, hardware_config, dataset_config]):
45
- return "Error loading configuration files."
46
 
47
- # Extract key parameters
48
- model_name = transformers_config.get("model", {}).get("name", "")
49
- dataset_name = dataset_config.get("dataset", {}).get("name", "")
50
- batch_size = transformers_config.get("training", {}).get("per_device_train_batch_size", 0)
51
- gradient_accum = transformers_config.get("training", {}).get("gradient_accumulation_steps", 0)
52
- lr = transformers_config.get("training", {}).get("learning_rate", 0)
53
- epochs = transformers_config.get("training", {}).get("num_train_epochs", 0)
54
- gpu_count = hardware_config.get("specs", {}).get("gpu_count", 0)
55
- gpu_type = hardware_config.get("specs", {}).get("gpu_type", "")
56
 
57
- config_info = f"""
58
- ## Current Training Configuration
 
 
 
 
59
 
60
- **Model**: {model_name}
61
- **Dataset**: {dataset_name}
 
 
62
 
63
- **Training Parameters**:
64
- - Learning Rate: {lr}
65
- - Epochs: {epochs}
66
- - Batch Size/GPU: {batch_size}
67
- - Gradient Accumulation: {gradient_accum}
68
- - Effective Batch Size: {batch_size * gradient_accum * gpu_count}
 
 
 
 
 
 
 
 
 
69
 
70
- **Hardware**:
71
- - GPUs: {gpu_count}x {gpu_type}
72
- - Flash Attention: {hardware_config.get("memory_optimization", {}).get("use_flash_attention", False)}
73
- - Gradient Checkpointing: {hardware_config.get("memory_optimization", {}).get("use_gradient_checkpointing", False)}
 
 
74
 
75
- **Pre-quantized 4-bit Training**: Enabled
 
 
 
 
76
  """
77
 
78
- return config_info
79
 
80
  def start_training():
81
  """Start the training process."""
82
  try:
83
- # Check if already running
84
- if os.path.exists("training.pid"):
85
- with open("training.pid", "r") as f:
86
- pid = f.read().strip()
87
- try:
88
- # Check if process is still running
89
- os.kill(int(pid), 0)
90
- return f"Training is already running with PID {pid}"
91
- except OSError:
92
- # Process not running, remove stale PID file
93
- os.remove("training.pid")
 
 
 
 
94
 
95
- # Start training in background
96
  cmd = "python run_transformers_training.py"
97
- process = subprocess.Popen(
98
- cmd,
99
- shell=True,
100
- stdout=open('training.log', 'a'),
101
- stderr=subprocess.STDOUT
102
- )
103
 
104
- # Save PID
105
- with open("training.pid", "w") as f:
106
- f.write(str(process.pid))
107
 
108
- # Log start time
109
- with open("training_history.log", "a") as f:
110
- f.write(f"{datetime.now().isoformat()}: Training started (PID: {process.pid})\n")
111
 
112
- return f"Training started with PID {process.pid}. Check status for updates."
113
-
114
- except Exception as e:
115
- return f"Error starting training: {str(e)}"
116
-
117
- def check_training_status():
118
- """Check the status of training."""
119
- try:
120
- # Check if training is running
121
- if os.path.exists("training.pid"):
122
- with open("training.pid", "r") as f:
123
- pid = f.read().strip()
124
- try:
125
- # Check if process is still running
126
- os.kill(int(pid), 0)
127
- status = f"Training is running with PID {pid}"
128
- except OSError:
129
- status = "Training process has stopped"
130
- os.remove("training.pid")
131
- else:
132
- status = "No training process is currently running"
133
-
134
- # Get last lines from training log
135
- log_content = "No training log available"
136
- if os.path.exists("training.log"):
137
- with open("training.log", "r") as f:
138
- lines = f.readlines()
139
- log_content = "".join(lines[-20:]) if lines else "Log file is empty"
140
 
141
- return f"{status}\n\n**Recent Log:**\n```\n{log_content}\n```"
142
-
143
  except Exception as e:
144
- return f"Error checking status: {str(e)}"
 
 
145
 
146
- # Create the Gradio interface
147
- with gr.Blocks(title="Phi-4 Unsloth Training", theme=gr.themes.Soft(primary_hue="blue")) as app:
148
- gr.Markdown("# Phi-4 Unsloth 4-bit Training Interface")
149
 
150
- with gr.Tabs():
151
- with gr.TabItem("Configuration"):
152
- config_output = gr.Markdown(display_config())
153
- refresh_btn = gr.Button("Refresh Configuration")
154
- refresh_btn.click(fn=display_config, outputs=config_output)
155
-
156
- with gr.TabItem("Training Control"):
157
- gr.Markdown("## Training Management")
158
-
159
- with gr.Row():
160
- start_btn = gr.Button("Start Training", variant="primary")
161
- check_btn = gr.Button("Check Status")
162
-
163
- status_output = gr.Markdown("Click 'Check Status' to see training progress")
164
-
165
- start_btn.click(fn=start_training, outputs=status_output)
166
- check_btn.click(fn=check_training_status, outputs=status_output)
167
-
168
- # Auto-refresh status
169
- gr.HTML('''
170
- <script>
171
- let intervalId;
172
-
173
- document.addEventListener('DOMContentLoaded', function() {
174
- // Find the "Check Status" button
175
- const buttons = Array.from(document.querySelectorAll('button'));
176
- const checkBtn = buttons.find(btn => btn.textContent.includes('Check Status'));
177
 
178
- // Set up interval to click the button every 30 seconds
179
- if (checkBtn) {
180
- intervalId = setInterval(() => {
181
- checkBtn.click();
182
- }, 30000);
183
- }
184
- });
185
-
186
- // Clean up on tab/window close
187
- window.addEventListener('beforeunload', function() {
188
- clearInterval(intervalId);
189
- });
190
- </script>
191
- ''')
 
 
 
 
 
 
 
 
 
 
 
192
 
193
- with gr.TabItem("Help"):
194
- gr.Markdown("""
195
- ## Phi-4 Unsloth Training Help
196
-
197
- This interface allows you to manage training of the Phi-4 model with Unsloth 4-bit optimizations.
198
-
199
- ### Installation
200
-
201
- Before starting training, ensure all dependencies are installed:
202
-
203
- ```bash
204
- pip install -r requirements.txt
205
- ```
206
-
207
- Critical packages:
208
- - unsloth (>=2024.3)
209
- - peft (>=0.9.0)
210
- - transformers (>=4.36.0)
211
-
212
- ### Quick Start
213
-
214
- 1. Review the configuration in the Configuration tab
215
- 2. Click "Start Training" to begin the process
216
- 3. Use "Check Status" to monitor progress
217
-
218
- ### Notes
219
-
220
- - Training uses the pre-quantized model `unsloth/phi-4-unsloth-bnb-4bit`
221
- - The process maintains paper order and handles metadata appropriately
222
- - Training progress will be regularly saved to HuggingFace Hub
223
-
224
- ### Troubleshooting
225
-
226
- If training stops unexpectedly:
227
- - Check the logs for out-of-memory errors
228
- - Verify the VRAM usage on each GPU
229
- - Check for CUDA version compatibility
230
- - If you see "Unsloth not available" error, run: `pip install unsloth>=2024.3 peft>=0.9.0`
231
- """)
232
 
233
- # Launch the app
234
  if __name__ == "__main__":
235
- app.launch()
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+
4
  import os
5
  import sys
6
  import json
7
  import logging
 
 
8
  import subprocess
9
  import time
10
  from datetime import datetime
11
 
12
+ # Configure logging to match HF Space logs
13
  logging.basicConfig(
14
  level=logging.INFO,
15
  format="%(asctime)s - %(levelname)s - %(message)s",
 
17
  )
18
  logger = logging.getLogger(__name__)
19
 
20
+ # Set other loggers to WARNING to reduce noise and ensure our logs are visible
21
+ logging.getLogger("transformers").setLevel(logging.WARNING)
22
+ logging.getLogger("datasets").setLevel(logging.WARNING)
23
+ logging.getLogger("accelerate").setLevel(logging.WARNING)
24
+ logging.getLogger("torch").setLevel(logging.WARNING)
25
+ logging.getLogger("bitsandbytes").setLevel(logging.WARNING)
26
+
27
+ # Define a clean logging function for HF Space compatibility
28
+ def log_info(message):
29
+ """Log information in a format compatible with Hugging Face Spaces"""
30
+ logger.info(message)
31
+ # Ensure output is flushed immediately for streaming
32
+ sys.stdout.flush()
33
+
34
  # Configuration paths
35
  CONFIG_DIR = "."
36
  TRANSFORMERS_CONFIG = os.path.join(CONFIG_DIR, "transformers_config.json")
 
 
37
 
38
  def load_config(config_path):
39
  """Load configuration from JSON file."""
 
42
  with open(config_path, 'r') as f:
43
  return json.load(f)
44
  else:
45
+ log_info(f"Config file not found: {config_path}")
46
  return None
47
  except Exception as e:
48
+ log_info(f"Error loading config: {str(e)}")
49
  return None
50
 
51
  def display_config():
52
  """Display current training configuration."""
53
+ config = load_config(TRANSFORMERS_CONFIG)
 
 
54
 
55
+ if not config:
56
+ return "Error loading configuration file."
57
 
58
+ # Extract sub-configurations
59
+ transformers_config = config
60
+ hardware_config = config.get("hardware", {})
61
+ dataset_config = config.get("dataset", {})
62
+
63
+ model_name = transformers_config.get("model", {}).get("name") or transformers_config.get("model_name_or_path", "")
 
 
 
64
 
65
+ # Training parameters
66
+ training_config = transformers_config.get("training", {})
67
+ batch_size = training_config.get("per_device_train_batch_size", 16)
68
+ grad_accum = training_config.get("gradient_accumulation_steps", 3)
69
+ epochs = training_config.get("num_train_epochs", 3)
70
+ learning_rate = training_config.get("learning_rate", 2e-5)
71
 
72
+ # Hardware settings
73
+ gpu_count = hardware_config.get("specs", {}).get("gpu_count", 4)
74
+ gpu_type = hardware_config.get("specs", {}).get("gpu_type", "L4")
75
+ vram = hardware_config.get("specs", {}).get("vram_per_gpu", 24)
76
 
77
+ # Dataset info
78
+ dataset_name = dataset_config.get("dataset", {}).get("name", "")
79
+
80
+ # Format response as HTML for better display
81
+ html = f"""
82
+ <h2>Training Configuration</h2>
83
+ <h3>Model</h3>
84
+ <ul>
85
+ <li><b>Model:</b> {model_name}</li>
86
+ <li><b>Learning Rate:</b> {training_config.get('learning_rate', '2e-5')}</li>
87
+ <li><b>Batch Size:</b> {training_config.get('per_device_train_batch_size', 4)} × {training_config.get('gradient_accumulation_steps', 4)} = {training_config.get('per_device_train_batch_size', 4) * training_config.get('gradient_accumulation_steps', 4)}</li>
88
+ <li><b>Epochs:</b> {training_config.get('num_train_epochs', 3)}</li>
89
+ <li><b>Precision:</b> {'BF16' if transformers_config.get('bf16', True) else 'FP16' if transformers_config.get('fp16', False) else 'FP32'}</li>
90
+ <li><b>Max Sequence Length:</b> {transformers_config.get('tokenizer', {}).get('max_seq_length', 2048)}</li>
91
+ </ul>
92
 
93
+ <h3>Hardware</h3>
94
+ <ul>
95
+ <li><b>GPU:</b> {gpu_count}× {gpu_type} ({vram} GB)</li>
96
+ <li><b>Multi-GPU Strategy:</b> {hardware_config.get('training_optimizations', {}).get('multi_gpu_strategy', 'data_parallel')}</li>
97
+ <li><b>Memory Optimizations:</b> {'Gradient Checkpointing' if hardware_config.get('training_optimizations', {}).get('memory_optimizations', {}).get('use_gradient_checkpointing', True) else 'None'}</li>
98
+ </ul>
99
 
100
+ <h3>Dataset</h3>
101
+ <ul>
102
+ <li><b>Dataset:</b> {dataset_name}</li>
103
+ <li><b>Dataset Split:</b> {dataset_config.get('dataset', {}).get('split', 'train')}</li>
104
+ </ul>
105
  """
106
 
107
+ return html
108
 
109
  def start_training():
110
  """Start the training process."""
111
  try:
112
+ # Run verification script first
113
+ log_info("Running pre-training verification...")
114
+ verify_cmd = "python verify_deployment.py"
115
+ try:
116
+ result = subprocess.run(verify_cmd, shell=True, check=True, capture_output=True, text=True)
117
+ if "All critical checks passed!" not in result.stdout:
118
+ log_info("Verification found issues. Please review:")
119
+ log_info(result.stdout)
120
+ return "Verification detected potential issues. Please review the logs before proceeding."
121
+ except subprocess.CalledProcessError as e:
122
+ log_info(f"Verification failed: {e.stderr}")
123
+ return "Verification failed. Please check the logs for details."
124
+
125
+ # Start training
126
+ log_info("Starting training process...")
127
 
128
+ # Run in a background process for HF Space
129
  cmd = "python run_transformers_training.py"
 
 
 
 
 
 
130
 
131
+ # In HF Spaces, we don't need to handle process management ourselves
132
+ subprocess.Popen(cmd, shell=True, stdout=sys.stdout, stderr=sys.stderr)
 
133
 
134
+ log_info("Training process has been started. You can monitor progress in the logs.")
 
 
135
 
136
+ return "Training started successfully. Monitor progress in the Hugging Face Space logs."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
 
 
138
  except Exception as e:
139
+ error_msg = f"Error starting training: {str(e)}"
140
+ log_info(error_msg)
141
+ return error_msg
142
 
143
+ # Interface setup for gradio
144
+ def create_interface():
145
+ import gradio as gr
146
 
147
+ with gr.Blocks(title="Phi-4 Training Center") as demo:
148
+ gr.Markdown("# Phi-4 Research Assistant Training")
149
+
150
+ with gr.Row():
151
+ with gr.Column():
152
+ gr.Markdown("## Control Panel")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
+ # Display current config
155
+ config_html = gr.HTML(display_config())
156
+ refresh_btn = gr.Button("Refresh Configuration")
157
+
158
+ # Training controls
159
+ train_btn = gr.Button("Start Training", variant="primary")
160
+ train_output = gr.Textbox(label="Status", interactive=False)
161
+
162
+ with gr.Column():
163
+ gr.Markdown("## Training Information")
164
+ gr.Markdown("""
165
+ ### Hardware:
166
+ - 4× NVIDIA L4 GPUs (24GB VRAM each)
167
+ - Training with BF16 precision
168
+ - Using Data Parallel for multi-GPU
169
+
170
+ ### Notes:
171
+ - Training may take several hours depending on dataset size
172
+ - Check the Space logs for real-time progress
173
+ - Model checkpoints will be saved to ./results directory
174
+ """)
175
+
176
+ # Connect buttons to functions
177
+ refresh_btn.click(lambda: gr.update(value=display_config()), outputs=config_html)
178
+ train_btn.click(start_training, outputs=train_output)
179
 
180
+ return demo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
 
 
182
  if __name__ == "__main__":
183
+ # If run directly, create and launch the Gradio interface
184
+ demo = create_interface()
185
+ demo.queue()
186
+ demo.launch()
fixed_run_transformers_training.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def format_phi_chat(messages, dataset_config):
2
+ """Format messages according to phi-4's chat template and dataset config."""
3
+ formatted_chat = ""
4
+
5
+ # Get role templates from config
6
+ roles = dataset_config.get("data_formatting", {}).get("roles", {
7
+ "system": "System: {content}\n\n",
8
+ "human": "Human: {content}\n\n",
9
+ "user": "Human: {content}\n\n",
10
+ "assistant": "Assistant: {content}\n\n"
11
+ })
12
+
13
+ # Handle research introduction metadata first
14
+ metadata = next((msg for msg in messages if isinstance(msg, dict) and
15
+ "[RESEARCH INTRODUCTION]" in msg.get("content", "")), None)
16
+ if metadata:
17
+ system_template = roles.get("system", "System: {content}\n\n")
18
+ formatted_chat = system_template.format(content=metadata['content'])
19
+ messages = [msg for msg in messages if msg != metadata]
20
+
21
+ # Process remaining messages
22
+ for message in messages:
23
+ if not isinstance(message, dict) or "content" not in message:
24
+ logger.warning(f"Skipping invalid message format: {message}")
25
+ continue
26
+
27
+ role = message.get("role", "").lower()
28
+ content = message.get("content", "")
29
+
30
+ # Format based on role
31
+ if role == "human" or role == "user":
32
+ template = roles.get("user", roles.get("human", "Human: {content}\n\n"))
33
+ formatted_chat += template.format(content=content)
34
+ elif role == "assistant" or role == "bot":
35
+ template = roles.get("assistant", "Assistant: {content}\n\n")
36
+ formatted_chat += template.format(content=content)
37
+ elif role == "system":
38
+ # For system messages, prepend them
39
+ template = roles.get("system", "System: {content}\n\n")
40
+ formatted_chat = template.format(content=content) + formatted_chat
41
+ else:
42
+ # Default to system for unknown roles
43
+ logger.warning(f"Unknown role '{role}' - treating as system message")
44
+ template = roles.get("system", "System: {content}\n\n")
45
+ formatted_chat += template.format(content=content)
46
+
47
+ return formatted_chat.strip()
48
+
49
+ class SimpleDataCollator:
50
+ def __init__(self, tokenizer, dataset_config):
51
+ self.tokenizer = tokenizer
52
+ self.dataset_config = dataset_config
53
+ self.stats = {"processed": 0, "skipped": 0, "total_tokens": 0}
54
+ self.pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
55
+ self.max_seq_length = dataset_config.get("dataset", {}).get("processing", {}).get("max_seq_length", 2048)
56
+ logger.info(f"SimpleDataCollator initialized - using pre-audited dataset with max_seq_length={self.max_seq_length}")
57
+ logger.info("Using exact dataset structure without reformatting")
58
+
59
+ # Check if we're on GPU
60
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
61
+ logger.info(f"SimpleDataCollator using device: {self.device}")
62
+
63
+ def __call__(self, features):
64
+ """Process examples preserving exact JSONL structure"""
65
+ batch = {"input_ids": [], "attention_mask": [], "labels": []}
66
+
67
+ for example in features:
68
+ try:
69
+ # Get ID
70
+ paper_id = example.get("id", "")
71
+
72
+ # Get conversations - these should already contain role and content
73
+ conversations = example.get("conversations", [])
74
+ if not conversations:
75
+ self.stats["skipped"] += 1
76
+ continue
77
+
78
+ # Directly use the conversations array as input to the model's chat template
79
+ # This preserves the exact structure with roles and content as they are
80
+ try:
81
+ # Let tokenizer handle the content with the model's chat template
82
+ inputs = self.tokenizer.apply_chat_template(
83
+ conversations,
84
+ return_tensors=None,
85
+ add_generation_prompt=False
86
+ )
87
+ except Exception as chat_error:
88
+ # Fallback if apply_chat_template fails
89
+ logger.warning(f"Chat template application failed for example {paper_id}: {str(chat_error)[:100]}")
90
+
91
+ # Create a basic representation of the conversation
92
+ conversation_text = ""
93
+ for msg in conversations:
94
+ if isinstance(msg, dict) and 'content' in msg:
95
+ conversation_text += msg.get('content', '') + "\n\n"
96
+
97
+ # Basic tokenization
98
+ inputs = self.tokenizer(
99
+ conversation_text,
100
+ add_special_tokens=True,
101
+ return_tensors=None
102
+ )
103
+
104
+ # Apply length cap if needed (shouldn't be necessary for pre-audited data)
105
+ if self.max_seq_length > 0 and len(inputs) > self.max_seq_length:
106
+ logger.warning(f"Example {paper_id} exceeds max_seq_length ({len(inputs)} > {self.max_seq_length})")
107
+ inputs = inputs[:self.max_seq_length]
108
+
109
+ # Create attention mask (1 for all tokens)
110
+ attention_mask = [1] * len(inputs)
111
+
112
+ if len(inputs) > 0:
113
+ # For causal language modeling, labels are the same as inputs
114
+ labels = inputs.copy()
115
+
116
+ batch["input_ids"].append(inputs)
117
+ batch["attention_mask"].append(attention_mask)
118
+ batch["labels"].append(labels)
119
+
120
+ self.stats["processed"] += 1
121
+ self.stats["total_tokens"] += len(inputs)
122
+
123
+ # Debug logging for first few examples
124
+ log_samples = self.dataset_config.get("validation", {}).get("log_samples", 3)
125
+ if self.stats["processed"] <= log_samples:
126
+ logger.info(f"Example {self.stats['processed']}:")
127
+ logger.info(f"Paper ID: {paper_id}")
128
+ logger.info(f"Token count: {len(inputs)}")
129
+ logger.info(f"Conversation entries: {len(conversations)}")
130
+ else:
131
+ self.stats["skipped"] += 1
132
+ except Exception as e:
133
+ logger.warning(f"Error processing example: {str(e)[:100]}...")
134
+ logger.warning(f"Problematic example ID: {example.get('id', 'unknown')}")
135
+ self.stats["skipped"] += 1
136
+ continue
137
+
138
+ if not batch["input_ids"]:
139
+ logger.warning("Empty batch, returning dummy tensors")
140
+ return {
141
+ "input_ids": torch.zeros((1, 1), dtype=torch.long),
142
+ "attention_mask": torch.zeros((1, 1), dtype=torch.long),
143
+ "labels": torch.zeros((1, 1), dtype=torch.long)
144
+ }
145
+
146
+ # Pad the batch
147
+ max_length = max(len(ids) for ids in batch["input_ids"])
148
+
149
+ for i in range(len(batch["input_ids"])):
150
+ padding_length = max_length - len(batch["input_ids"][i])
151
+ if padding_length > 0:
152
+ batch["input_ids"][i].extend([self.pad_token_id] * padding_length)
153
+ batch["attention_mask"][i].extend([0] * padding_length)
154
+ batch["labels"][i].extend([-100] * padding_length)
155
+
156
+ # Convert to tensors
157
+ batch = {k: torch.tensor(v, dtype=torch.long) for k, v in batch.items()}
158
+
159
+ # Log stats periodically
160
+ log_interval = self.dataset_config.get("validation", {}).get("log_interval", 100)
161
+ if self.stats["processed"] % log_interval == 0 and self.stats["processed"] > 0:
162
+ logger.info(f"Data collator stats: processed={self.stats['processed']}, "
163
+ f"skipped={self.stats['skipped']}, "
164
+ f"avg_tokens={self.stats['total_tokens']/self.stats['processed']:.1f}")
165
+
166
+ return batch
167
+
168
+ class LoggingCallback(TrainerCallback):
169
+ def __init__(self):
170
+ self.last_log_time = time.time()
171
+ self.last_memory_log_time = time.time()
172
+
173
+ def on_step_end(self, args, state, control, **kwargs):
174
+ # Log every 50 steps or every 5 minutes, whichever comes first
175
+ current_time = time.time()
176
+
177
+ # Log loss every 50 steps or 5 minutes
178
+ if (state.global_step % 50 == 0) or (current_time - self.last_log_time > 300):
179
+ if state.log_history:
180
+ loss = state.log_history[-1].get('loss', 'N/A')
181
+ # Use simple formatting for better HF Space log compatibility
182
+ log_info(f"Step {state.global_step}: Loss {loss}")
183
+ else:
184
+ log_info(f"Step {state.global_step}: No loss data available")
185
+ self.last_log_time = current_time
186
+
187
+ # Log memory usage every 15 minutes
188
+ if current_time - self.last_memory_log_time > 900: # 15 minutes
189
+ if torch.cuda.is_available():
190
+ memory_info = []
191
+ for i in range(torch.cuda.device_count()):
192
+ allocated = torch.cuda.memory_allocated(i) / 1024**2
193
+ reserved = torch.cuda.memory_reserved(i) / 1024**2
194
+ memory_info.append(f"GPU {i}: {allocated:.1f}MB/{reserved:.1f}MB")
195
+
196
+ # Log in compact format for better visibility
197
+ log_info(f"Memory usage - {', '.join(memory_info)}")
198
+ self.last_memory_log_time = current_time
199
+
200
+ def on_train_begin(self, args, state, control, **kwargs):
201
+ log_info("=== Training is starting ===")
202
+
203
+ # Log important training parameters for visibility
204
+ log_info(f"Batch size: {args.per_device_train_batch_size} × {args.gradient_accumulation_steps} steps × {max(1, torch.cuda.device_count())} GPUs")
205
+ log_info(f"Learning rate: {args.learning_rate}")
206
+ log_info(f"Epochs: {args.num_train_epochs}")
207
+
208
+ # Log memory information in compact format
209
+ if torch.cuda.is_available():
210
+ memory_info = []
211
+ for i in range(torch.cuda.device_count()):
212
+ allocated = torch.cuda.memory_allocated(i) / 1024**2
213
+ max_mem = torch.cuda.max_memory_allocated(i) / 1024**2
214
+ memory_info.append(f"GPU {i}: {allocated:.1f}MB (max: {max_mem:.1f}MB)")
215
+
216
+ log_info(f"Initial memory usage - {', '.join(memory_info)}")
217
+
218
+ def on_train_end(self, args, state, control, **kwargs):
219
+ log_info("=== Training completed ===")
220
+ if torch.cuda.is_available():
221
+ memory_info = []
222
+ for i in range(torch.cuda.device_count()):
223
+ allocated = torch.cuda.memory_allocated(i) / 1024**2
224
+ max_mem = torch.cuda.max_memory_allocated(i) / 1024**2
225
+ memory_info.append(f"GPU {i}: {allocated:.1f}MB (max: {max_mem:.1f}MB)")
226
+
227
+ log_info(f"Final memory usage - {', '.join(memory_info)}")
228
+
229
+ log_info(f"Total steps: {state.global_step}")
230
+ log_info(f"Final loss: {state.log_history[-1].get('loss', 'N/A') if state.log_history else 'N/A'}")
run_transformers_training.py CHANGED
@@ -113,26 +113,24 @@ def load_env_variables():
113
  os.environ["HUGGING_FACE_HUB_TOKEN"] = os.environ.get("HF_TOKEN")
114
 
115
  def load_configs(base_path):
116
- """Load all configuration files."""
117
  configs = {}
118
 
119
- # List of config files to load
120
- config_files = [
121
- "transformers_config.json",
122
- "hardware_config.json",
123
- "dataset_config.json"
124
- ]
125
 
126
- for config_file in config_files:
127
- file_path = os.path.join(base_path, config_file)
128
- try:
129
- with open(file_path, "r") as f:
130
- config_name = config_file.replace("_config.json", "")
131
- configs[config_name] = json.load(f)
132
- logger.info(f"Loaded {config_name} configuration from {file_path}")
133
- except Exception as e:
134
- logger.error(f"Error loading {config_file}: {e}")
135
- raise
 
 
136
 
137
  return configs
138
 
@@ -238,7 +236,7 @@ def load_model_and_tokenizer(config):
238
 
239
  # Ensure model and optimizer init is on the same device
240
  logger.info(f"Model device map: {model.hf_device_map if hasattr(model, 'hf_device_map') else 'Not available'}")
241
-
242
  # Apply Unsloth's training optimizations with config parameters
243
  unsloth_config = config.get("unsloth", {})
244
  model = FastLanguageModel.get_peft_model(
@@ -640,25 +638,32 @@ def main():
640
  # Load environment variables
641
  load_env_variables()
642
 
 
 
 
 
 
 
 
643
  # Load all configurations
644
  try:
645
  configs = load_configs(args.config_dir)
646
 
647
  # Extract specific configs
648
  if not configs:
649
- logger.error("Failed to load configurations")
650
  return 1
651
 
652
- # Verify configurations exist
653
  if "transformers" not in configs:
654
  logger.error("transformers_config.json not found or invalid")
655
  return 1
656
 
657
- if "hardware" not in configs:
658
- logger.warning("hardware_config.json not found. Using default hardware configuration.")
659
 
660
- if "dataset" not in configs:
661
- logger.error("dataset_config.json not found or invalid")
662
  return 1
663
 
664
  # Validate model configuration
@@ -679,22 +684,36 @@ def main():
679
 
680
  # Apply hardware-specific settings if available
681
  if hardware_config:
 
682
  training_opts = hardware_config.get("training_optimizations", {})
683
- per_device_batch_size = training_opts.get("per_device_batch_size")
684
- gradient_accumulation = training_opts.get("gradient_accumulation_steps")
685
 
686
- if per_device_batch_size and model_config.get("training"):
687
- model_config["training"]["per_device_train_batch_size"] = per_device_batch_size
688
- log_info(f"Applied hardware-specific batch size: {per_device_batch_size}")
689
-
690
- if gradient_accumulation and model_config.get("training"):
691
- model_config["training"]["gradient_accumulation_steps"] = gradient_accumulation
692
- log_info(f"Applied hardware-specific gradient accumulation: {gradient_accumulation}")
693
 
 
 
 
 
 
694
  # Apply memory optimizations
695
  memory_opts = training_opts.get("memory_optimizations", {})
696
  if memory_opts.get("use_gradient_checkpointing") is not None and model_config.get("training"):
697
- model_config["training"]["gradient_checkpointing"] = memory_opts["use_gradient_checkpointing"]
 
 
 
 
 
 
 
 
 
 
 
 
698
 
699
  except Exception as e:
700
  logger.error(f"Error loading configurations: {e}")
@@ -713,13 +732,17 @@ def main():
713
  # Set memory management env vars for better fragmentation handling
714
  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128,expandable_segments:True"
715
 
 
 
 
716
  # Log initial memory information in a compact form
717
  gpu_info = []
718
  for i in range(torch.cuda.device_count()):
719
  name = torch.cuda.get_device_name(i)
720
  allocated = torch.cuda.memory_allocated(i) / 1024**3
721
  total = torch.cuda.get_device_properties(i).total_memory / 1024**3
722
- gpu_info.append(f"GPU {i}: {name} ({allocated:.1f}GB/{total:.1f}GB)")
 
723
 
724
  log_info(f"Hardware: {torch.cuda.device_count()} GPUs detected")
725
  log_info(f"GPU details: {', '.join(gpu_info)}")
@@ -739,28 +762,44 @@ def main():
739
  except Exception as e:
740
  logger.error(f"Error loading dataset: {e}")
741
  return 1
742
-
743
  # Create data collator
744
  data_collator = SimpleDataCollator(tokenizer, dataset_config)
745
 
746
  # Verify precision settings - ensure only one of bf16/fp16 is set, with bf16 taking precedence
747
- use_bf16 = model_config.get("bf16", False) or model_config.get("torch_dtype", "") == "bfloat16"
748
- use_fp16 = model_config.get("fp16", False) and not use_bf16 # Only use fp16 if bf16 is not set
749
-
750
- log_info(f"Using precision: {'bf16' if use_bf16 else 'fp16' if use_fp16 else 'full precision'}")
 
 
 
 
 
 
 
 
 
 
 
 
 
751
 
752
- # Get per device batch size - temporarily reduce if necessary for multi-GPU setup
753
- per_device_batch_size = model_config.get("training", {}).get("per_device_train_batch_size", 24)
754
- gradient_accumulation_steps = model_config.get("training", {}).get("gradient_accumulation_steps", 2)
755
 
756
  # For multi-GPU setup, adjust for better balance
757
  if torch.cuda.device_count() > 1:
758
  log_info(f"Multi-GPU setup with {torch.cuda.device_count()} GPUs")
759
  log_info(f"Training config: {per_device_batch_size} samples/GPU × {gradient_accumulation_steps} accumulation steps")
760
 
761
- # Set up FSDP for multi-GPU training if available
 
 
 
762
  fsdp_config = None
763
- if torch.cuda.device_count() > 1:
764
  try:
765
  from torch.distributed.fsdp import (
766
  FullyShardedDataParallel as FSDP,
@@ -793,6 +832,15 @@ def main():
793
  except ImportError:
794
  log_info("FSDP imports failed, falling back to standard DDP")
795
  fsdp_config = None
 
 
 
 
 
 
 
 
 
796
 
797
  # Set up training arguments
798
  log_info("Setting up training arguments")
@@ -818,13 +866,14 @@ def main():
818
  report_to="tensorboard",
819
  remove_unused_columns=False, # Keep all columns
820
  gradient_checkpointing=model_config.get("training", {}).get("gradient_checkpointing", True),
821
- dataloader_pin_memory=True, # Keep data in pinned memory for faster transfer
822
  optim=model_config.get("training", {}).get("optim", "adamw_torch"),
823
  ddp_find_unused_parameters=False, # Improve distributed training efficiency
824
  dataloader_drop_last=False, # Process all examples
825
- dataloader_num_workers=2, # Reduced worker count
826
  no_cuda=False if torch.cuda.is_available() else True, # Use CUDA if available
827
- fsdp=fsdp_config, # Add FSDP configuration if available
 
828
  )
829
 
830
  # Create sequential sampler to maintain original dataset order
@@ -907,7 +956,7 @@ def main():
907
  memory_info.append(f"GPU {i}: {allocated:.1f}MB/{reserved:.1f}MB (max: {max_mem:.1f}MB)")
908
  logger.error(f"GPU memory at failure: {', '.join(memory_info)}")
909
  raise
910
-
911
  except Exception as e:
912
  logger.error(f"Error in main training loop: {str(e)}")
913
  return 1
 
113
  os.environ["HUGGING_FACE_HUB_TOKEN"] = os.environ.get("HF_TOKEN")
114
 
115
  def load_configs(base_path):
116
+ """Load all configuration from a single consolidated file."""
117
  configs = {}
118
 
119
+ # Using a single consolidated config file
120
+ config_file = "transformers_config.json"
 
 
 
 
121
 
122
+ file_path = os.path.join(base_path, config_file)
123
+ try:
124
+ with open(file_path, "r") as f:
125
+ config = json.load(f)
126
+ # Extract sections into separate config dictionaries for compatibility
127
+ configs["transformers"] = config
128
+ configs["hardware"] = config.get("hardware", {})
129
+ configs["dataset"] = config.get("dataset", {})
130
+ logger.info(f"Loaded consolidated configuration from {file_path}")
131
+ except Exception as e:
132
+ logger.error(f"Error loading {config_file}: {e}")
133
+ raise
134
 
135
  return configs
136
 
 
236
 
237
  # Ensure model and optimizer init is on the same device
238
  logger.info(f"Model device map: {model.hf_device_map if hasattr(model, 'hf_device_map') else 'Not available'}")
239
+
240
  # Apply Unsloth's training optimizations with config parameters
241
  unsloth_config = config.get("unsloth", {})
242
  model = FastLanguageModel.get_peft_model(
 
638
  # Load environment variables
639
  load_env_variables()
640
 
641
+ # Check if we're in distributed mode
642
+ is_distributed = "WORLD_SIZE" in os.environ and int(os.environ.get("WORLD_SIZE", "1")) > 1
643
+ if is_distributed:
644
+ log_info(f"Running in distributed mode with world size: {os.environ.get('WORLD_SIZE')}")
645
+ else:
646
+ log_info("Running in non-distributed mode (single process)")
647
+
648
  # Load all configurations
649
  try:
650
  configs = load_configs(args.config_dir)
651
 
652
  # Extract specific configs
653
  if not configs:
654
+ logger.error("Failed to load configuration")
655
  return 1
656
 
657
+ # Verify configuration sections exist
658
  if "transformers" not in configs:
659
  logger.error("transformers_config.json not found or invalid")
660
  return 1
661
 
662
+ if "hardware" not in configs or not configs["hardware"]:
663
+ logger.warning("Hardware configuration section not found in transformers_config.json. Using default hardware configuration.")
664
 
665
+ if "dataset" not in configs or not configs["dataset"]:
666
+ logger.error("Dataset configuration section not found in transformers_config.json")
667
  return 1
668
 
669
  # Validate model configuration
 
684
 
685
  # Apply hardware-specific settings if available
686
  if hardware_config:
687
+ # Get training optimizations from hardware config
688
  training_opts = hardware_config.get("training_optimizations", {})
 
 
689
 
690
+ # Apply batch size and gradient accumulation settings
691
+ if training_opts.get("per_device_batch_size") and model_config.get("training"):
692
+ batch_size = training_opts.get("per_device_batch_size")
693
+ model_config["training"]["per_device_train_batch_size"] = batch_size
694
+ log_info(f"Applied hardware-optimized batch size: {batch_size}")
 
 
695
 
696
+ if training_opts.get("gradient_accumulation_steps") and model_config.get("training"):
697
+ grad_steps = training_opts.get("gradient_accumulation_steps")
698
+ model_config["training"]["gradient_accumulation_steps"] = grad_steps
699
+ log_info(f"Applied hardware-optimized gradient accumulation: {grad_steps}")
700
+
701
  # Apply memory optimizations
702
  memory_opts = training_opts.get("memory_optimizations", {})
703
  if memory_opts.get("use_gradient_checkpointing") is not None and model_config.get("training"):
704
+ grad_ckpt = memory_opts.get("use_gradient_checkpointing")
705
+ model_config["training"]["gradient_checkpointing"] = grad_ckpt
706
+ log_info(f"Applied hardware-optimized gradient checkpointing: {grad_ckpt}")
707
+
708
+ # Apply system settings
709
+ system_settings = hardware_config.get("system_settings", {})
710
+ if system_settings.get("dataloader_num_workers") is not None:
711
+ workers = system_settings.get("dataloader_num_workers")
712
+ log_info(f"Using {workers} dataloader workers from hardware config")
713
+
714
+ # Get distribution strategy
715
+ multi_gpu_strategy = training_opts.get("multi_gpu_strategy", "data_parallel")
716
+ log_info(f"Hardware config specifies {multi_gpu_strategy} for multi-GPU training")
717
 
718
  except Exception as e:
719
  logger.error(f"Error loading configurations: {e}")
 
732
  # Set memory management env vars for better fragmentation handling
733
  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128,expandable_segments:True"
734
 
735
+ # Get memory fraction from hardware config
736
+ cuda_memory_fraction = hardware_config.get("system_settings", {}).get("cuda_memory_fraction", 0.85)
737
+
738
  # Log initial memory information in a compact form
739
  gpu_info = []
740
  for i in range(torch.cuda.device_count()):
741
  name = torch.cuda.get_device_name(i)
742
  allocated = torch.cuda.memory_allocated(i) / 1024**3
743
  total = torch.cuda.get_device_properties(i).total_memory / 1024**3
744
+ reserved_memory = total * cuda_memory_fraction
745
+ gpu_info.append(f"GPU {i}: {name} ({allocated:.1f}GB/{reserved_memory:.1f}GB)")
746
 
747
  log_info(f"Hardware: {torch.cuda.device_count()} GPUs detected")
748
  log_info(f"GPU details: {', '.join(gpu_info)}")
 
762
  except Exception as e:
763
  logger.error(f"Error loading dataset: {e}")
764
  return 1
765
+
766
  # Create data collator
767
  data_collator = SimpleDataCollator(tokenizer, dataset_config)
768
 
769
  # Verify precision settings - ensure only one of bf16/fp16 is set, with bf16 taking precedence
770
+ # First check hardware config, then transformers config
771
+ use_bf16 = False
772
+ use_fp16 = False
773
+
774
+ # Check hardware config first
775
+ hardware_precision = hardware_config.get("training_optimizations", {}).get("mixed_precision", "")
776
+ if hardware_precision.lower() == "bf16":
777
+ use_bf16 = True
778
+ log_info("Using BF16 precision from hardware config")
779
+ elif hardware_precision.lower() == "fp16":
780
+ use_fp16 = True
781
+ log_info("Using FP16 precision from hardware config")
782
+ else:
783
+ # Fall back to transformers config
784
+ use_bf16 = model_config.get("bf16", False) or model_config.get("torch_dtype", "") == "bfloat16"
785
+ use_fp16 = model_config.get("fp16", False) and not use_bf16 # Only use fp16 if bf16 is not set
786
+ log_info(f"Using precision: {'bf16' if use_bf16 else 'fp16' if use_fp16 else 'full precision'}")
787
 
788
+ # Get per device batch size - from transformers config, but possibly overridden by hardware config
789
+ per_device_batch_size = model_config.get("training", {}).get("per_device_train_batch_size", 16)
790
+ gradient_accumulation_steps = model_config.get("training", {}).get("gradient_accumulation_steps", 3)
791
 
792
  # For multi-GPU setup, adjust for better balance
793
  if torch.cuda.device_count() > 1:
794
  log_info(f"Multi-GPU setup with {torch.cuda.device_count()} GPUs")
795
  log_info(f"Training config: {per_device_batch_size} samples/GPU × {gradient_accumulation_steps} accumulation steps")
796
 
797
+ # Determine multi-GPU strategy from hardware config
798
+ multi_gpu_strategy = hardware_config.get("training_optimizations", {}).get("multi_gpu_strategy", "data_parallel")
799
+
800
+ # Set up FSDP for multi-GPU training if specified and in distributed mode
801
  fsdp_config = None
802
+ if multi_gpu_strategy == "fsdp" and is_distributed and torch.cuda.device_count() > 1:
803
  try:
804
  from torch.distributed.fsdp import (
805
  FullyShardedDataParallel as FSDP,
 
832
  except ImportError:
833
  log_info("FSDP imports failed, falling back to standard DDP")
834
  fsdp_config = None
835
+ elif multi_gpu_strategy == "fsdp" and not is_distributed:
836
+ log_info("FSDP disabled: requires distributed environment (use torchrun or accelerate)")
837
+ log_info("Using DataParallel for multi-GPU training instead")
838
+ else:
839
+ log_info(f"Using {multi_gpu_strategy} for multi-GPU training")
840
+
841
+ # Get system settings from hardware config
842
+ dataloader_workers = hardware_config.get("system_settings", {}).get("dataloader_num_workers", 2)
843
+ pin_memory = hardware_config.get("system_settings", {}).get("dataloader_pin_memory", True)
844
 
845
  # Set up training arguments
846
  log_info("Setting up training arguments")
 
866
  report_to="tensorboard",
867
  remove_unused_columns=False, # Keep all columns
868
  gradient_checkpointing=model_config.get("training", {}).get("gradient_checkpointing", True),
869
+ dataloader_pin_memory=pin_memory,
870
  optim=model_config.get("training", {}).get("optim", "adamw_torch"),
871
  ddp_find_unused_parameters=False, # Improve distributed training efficiency
872
  dataloader_drop_last=False, # Process all examples
873
+ dataloader_num_workers=dataloader_workers,
874
  no_cuda=False if torch.cuda.is_available() else True, # Use CUDA if available
875
+ # Only add FSDP if we're in distributed mode with FSDP strategy
876
+ fsdp=fsdp_config if is_distributed and multi_gpu_strategy == "fsdp" else None,
877
  )
878
 
879
  # Create sequential sampler to maintain original dataset order
 
956
  memory_info.append(f"GPU {i}: {allocated:.1f}MB/{reserved:.1f}MB (max: {max_mem:.1f}MB)")
957
  logger.error(f"GPU memory at failure: {', '.join(memory_info)}")
958
  raise
959
+
960
  except Exception as e:
961
  logger.error(f"Error in main training loop: {str(e)}")
962
  return 1
transformers_config.json CHANGED
@@ -60,7 +60,7 @@
60
 
61
  "distributed_training": {
62
  "fsdp_config": {
63
- "enabled": true,
64
  "sharding_strategy": "FULL_SHARD",
65
  "mixed_precision": "BF16",
66
  "activation_checkpointing": true,
@@ -86,5 +86,88 @@
86
  "use_flash_attention": true,
87
  "torch_dtype": "bfloat16",
88
  "bf16": true,
89
- "fp16": false
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  }
 
60
 
61
  "distributed_training": {
62
  "fsdp_config": {
63
+ "enabled": false,
64
  "sharding_strategy": "FULL_SHARD",
65
  "mixed_precision": "BF16",
66
  "activation_checkpointing": true,
 
86
  "use_flash_attention": true,
87
  "torch_dtype": "bfloat16",
88
  "bf16": true,
89
+ "fp16": false,
90
+
91
+ "hardware": {
92
+ "hardware_name": "4xL4",
93
+ "specs": {
94
+ "gpu_count": 4,
95
+ "gpu_type": "L4",
96
+ "vram_per_gpu": 24,
97
+ "total_vram": 96,
98
+ "vcpu_count": 48,
99
+ "ram": 186
100
+ },
101
+ "hardware_setup": {
102
+ "use_cpu": false,
103
+ "num_gpus": 4,
104
+ "device_map": "auto"
105
+ },
106
+ "training_optimizations": {
107
+ "per_device_batch_size": 16,
108
+ "gradient_accumulation_steps": 3,
109
+ "mixed_precision": "bf16",
110
+ "torch_compile": false,
111
+ "memory_optimizations": {
112
+ "use_gradient_checkpointing": true,
113
+ "use_flash_attention": true
114
+ },
115
+ "multi_gpu_strategy": "data_parallel"
116
+ },
117
+ "system_settings": {
118
+ "cuda_memory_fraction": 0.85,
119
+ "dataloader_num_workers": 2,
120
+ "dataloader_pin_memory": true
121
+ },
122
+ "memory_breakdown": {
123
+ "model_size": "~3.5GB (pre-quantized 4-bit)",
124
+ "optimizer_states": "~1GB",
125
+ "batch_memory_per_gpu": "~3GB",
126
+ "peak_memory_estimate": "~18GB",
127
+ "safe_headroom": "~6GB"
128
+ },
129
+ "compute_environment": "L4_CLOUD"
130
+ },
131
+
132
+ "dataset": {
133
+ "dataset": {
134
+ "name": "George-API/cognitive-data",
135
+ "split": "train",
136
+ "column_mapping": {
137
+ "conversations": "text"
138
+ },
139
+ "processing": {
140
+ "sort_by_id": true,
141
+ "maintain_paper_order": true,
142
+ "max_seq_length": 2048
143
+ }
144
+ },
145
+ "data_formatting": {
146
+ "chat_template": "phi",
147
+ "roles": {
148
+ "system": "System: {content}\n\n",
149
+ "human": "Human: {content}\n\n",
150
+ "assistant": "Assistant: {content}\n\n",
151
+ "user": "Human: {content}\n\n"
152
+ },
153
+ "metadata_handling": {
154
+ "include_paper_id": true,
155
+ "include_chunk_number": true,
156
+ "metadata_format": "Paper ID: {paper_id} | Chunk: {chunk_number}"
157
+ }
158
+ },
159
+ "data_loading": {
160
+ "batch_size": 24,
161
+ "shuffle": false,
162
+ "drop_last": false,
163
+ "num_workers": 4,
164
+ "pin_memory": true,
165
+ "prefetch_factor": 4
166
+ },
167
+ "validation": {
168
+ "log_samples": 3,
169
+ "log_interval": 50,
170
+ "metrics": ["processed", "skipped", "avg_tokens", "unique_papers"]
171
+ }
172
+ }
173
  }
update_space.py CHANGED
@@ -74,8 +74,6 @@ def verify_configs():
74
  current_dir = Path(__file__).parent
75
  required_files = [
76
  "transformers_config.json",
77
- "hardware_config.json",
78
- "dataset_config.json",
79
  "requirements.txt",
80
  "run_transformers_training.py"
81
  ]
 
74
  current_dir = Path(__file__).parent
75
  required_files = [
76
  "transformers_config.json",
 
 
77
  "requirements.txt",
78
  "run_transformers_training.py"
79
  ]