George-API commited on
Commit
20852a7
·
verified ·
1 Parent(s): 3da7418

Upload folder using huggingface_hub

Browse files
app.py CHANGED
@@ -1,162 +1,221 @@
1
- import gradio as gr
2
  import os
3
- import subprocess
4
  import sys
5
  import json
6
- import re
7
- from threading import Thread
8
- import datetime
9
- import torch
10
- import threading
 
11
 
12
- def load_env_variables():
13
- """Load environment variables from system or .env file."""
14
- if os.environ.get("SPACE_ID"):
15
- print("Running in Hugging Face Space")
16
- if "/" in os.environ.get("SPACE_ID", ""):
17
- username = os.environ.get("SPACE_ID").split("/")[0]
18
- os.environ["HF_USERNAME"] = username
19
- print(f"Set HF_USERNAME from SPACE_ID: {username}")
20
- else:
21
- try:
22
- from dotenv import load_dotenv
23
- env_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), ".env")
24
- if os.path.exists(env_path):
25
- load_dotenv(env_path)
26
- print(f"Loaded environment variables from {env_path}")
27
- except ImportError:
28
- print("python-dotenv not installed, skipping .env loading")
 
 
 
 
 
 
 
 
 
29
 
30
- def check_environment():
31
- """Check the environment for GPU availability and other requirements."""
32
- env_info = {
33
- "System": {
34
- "Platform": sys.platform,
35
- "Python Version": sys.version.split()[0]
36
- },
37
- "GPU": {
38
- "CUDA Available": torch.cuda.is_available(),
39
- "Device Count": torch.cuda.device_count() if torch.cuda.is_available() else 0
40
- },
41
- "Environment Variables": {
42
- "HF_TOKEN": bool(os.environ.get("HF_TOKEN")),
43
- "HF_USERNAME": bool(os.environ.get("HF_USERNAME")),
44
- "HF_SPACE_NAME": bool(os.environ.get("HF_SPACE_NAME"))
45
- }
46
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
- if torch.cuda.is_available():
49
- env_info["GPU"]["Device Name"] = torch.cuda.get_device_name(0)
50
- env_info["GPU"]["Memory (GB)"] = round(torch.cuda.get_device_properties(0).total_memory / (1024**3), 2)
51
 
52
- return env_info
53
 
54
- def run_training_process():
55
- """Run the training process using the configuration files."""
56
  try:
57
- current_dir = os.path.dirname(os.path.abspath(__file__))
58
- training_script = os.path.join(current_dir, "run_transformers_training.py")
 
 
 
 
 
 
 
 
 
59
 
60
- # Start the training process
 
61
  process = subprocess.Popen(
62
- [sys.executable, training_script],
63
- stdout=subprocess.PIPE,
64
- stderr=subprocess.STDOUT,
65
- text=True,
66
- bufsize=1
67
  )
68
 
69
- # Process the output line by line
70
- for line in process.stdout:
71
- print(line.strip())
72
 
73
- process.wait()
74
- return process.returncode
 
 
 
 
75
  except Exception as e:
76
- print(f"Error in training process: {e}")
77
- return 1
78
 
79
- def start_training(learning_rate, num_train_epochs, per_device_train_batch_size,
80
- gradient_accumulation_steps):
81
- """Start the training process with the specified parameters."""
82
  try:
83
- load_env_variables()
84
- current_dir = os.path.dirname(os.path.abspath(__file__))
85
-
86
- # Load and update transformers config
87
- with open(os.path.join(current_dir, "transformers_config.json"), "r") as f:
88
- config = json.load(f)
89
-
90
- # Update training parameters
91
- config["training"].update({
92
- "num_train_epochs": num_train_epochs,
93
- "learning_rate": learning_rate,
94
- "per_device_train_batch_size": per_device_train_batch_size,
95
- "gradient_accumulation_steps": gradient_accumulation_steps
96
- })
97
-
98
- # Update hub settings if username is available
99
- if os.environ.get("HF_USERNAME"):
100
- config["huggingface_hub"].update({
101
- "hub_model_id": f"{os.environ['HF_USERNAME']}/Phi4-Cognitive-Science"
102
- })
103
-
104
- # Save updated config
105
- with open(os.path.join(current_dir, "transformers_config.json"), "w") as f:
106
- json.dump(config, f, indent=4)
107
 
108
- # Start training in a separate thread
109
- thread = threading.Thread(target=run_training_process)
110
- thread.daemon = True
111
- thread.start()
 
 
112
 
113
- return "Training started! Check the Hugging Face Space logs for progress."
 
114
  except Exception as e:
115
- return f"Error starting training: {str(e)}"
116
 
117
- with gr.Blocks(title="Phi-4 Training Interface") as demo:
118
- gr.Markdown("# Phi-4 Unsupervised Training for Cognitive Science")
 
119
 
120
- with gr.Tab("Training"):
121
- with gr.Row():
122
- with gr.Column():
123
- gr.Markdown("## Model Configuration")
124
- gr.Markdown("**Model**: unsloth/phi-4-unsloth-bnb-4bit")
125
- gr.Markdown("**Dataset**: George-API/cognitive-data")
126
-
127
- gr.Markdown("## Training Parameters")
128
- learning_rate = gr.Slider(minimum=1e-6, maximum=1e-4, value=2e-5, step=1e-6,
129
- label="Learning Rate")
130
- num_train_epochs = gr.Slider(minimum=1, maximum=5, value=3, step=1,
131
- label="Number of Epochs")
132
- per_device_train_batch_size = gr.Slider(minimum=4, maximum=24, value=12, step=4,
133
- label="Per Device Train Batch Size (Unsloth Optimized)")
134
- gradient_accumulation_steps = gr.Slider(minimum=1, maximum=8, value=4, step=1,
135
- label="Gradient Accumulation Steps")
136
-
137
  start_btn = gr.Button("Start Training", variant="primary")
138
- training_output = gr.Textbox(label="Training Output", interactive=False)
139
-
140
- with gr.Tab("Environment"):
141
- with gr.Row():
142
- with gr.Column():
143
- gr.Markdown("## Environment Information")
144
- env_info = gr.JSON(label="Environment Info")
145
- check_env_btn = gr.Button("Check Environment")
146
-
147
- # Set up event handlers
148
- start_btn.click(
149
- fn=start_training,
150
- inputs=[learning_rate, num_train_epochs, per_device_train_batch_size, gradient_accumulation_steps],
151
- outputs=training_output
152
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
- check_env_btn.click(
155
- fn=check_environment,
156
- inputs=[],
157
- outputs=env_info
158
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
 
 
160
  if __name__ == "__main__":
161
- load_env_variables()
162
- demo.launch()
 
 
1
  import os
 
2
  import sys
3
  import json
4
+ import logging
5
+ import gradio as gr
6
+ from pathlib import Path
7
+ import subprocess
8
+ import time
9
+ from datetime import datetime
10
 
11
+ # Configure logging
12
+ logging.basicConfig(
13
+ level=logging.INFO,
14
+ format="%(asctime)s - %(levelname)s - %(message)s",
15
+ handlers=[logging.StreamHandler(sys.stdout)]
16
+ )
17
+ logger = logging.getLogger(__name__)
18
+
19
+ # Configuration paths
20
+ CONFIG_DIR = "."
21
+ TRANSFORMERS_CONFIG = os.path.join(CONFIG_DIR, "transformers_config.json")
22
+ HARDWARE_CONFIG = os.path.join(CONFIG_DIR, "hardware_config.json")
23
+ DATASET_CONFIG = os.path.join(CONFIG_DIR, "dataset_config.json")
24
+
25
+ def load_config(config_path):
26
+ """Load configuration from JSON file."""
27
+ try:
28
+ if os.path.exists(config_path):
29
+ with open(config_path, 'r') as f:
30
+ return json.load(f)
31
+ else:
32
+ logger.error(f"Config file not found: {config_path}")
33
+ return None
34
+ except Exception as e:
35
+ logger.error(f"Error loading config: {str(e)}")
36
+ return None
37
 
38
+ def display_config():
39
+ """Display current training configuration."""
40
+ transformers_config = load_config(TRANSFORMERS_CONFIG)
41
+ hardware_config = load_config(HARDWARE_CONFIG)
42
+ dataset_config = load_config(DATASET_CONFIG)
43
+
44
+ if not all([transformers_config, hardware_config, dataset_config]):
45
+ return "Error loading configuration files."
46
+
47
+ # Extract key parameters
48
+ model_name = transformers_config.get("model", {}).get("name", "")
49
+ dataset_name = dataset_config.get("dataset", {}).get("name", "")
50
+ batch_size = transformers_config.get("training", {}).get("per_device_train_batch_size", 0)
51
+ gradient_accum = transformers_config.get("training", {}).get("gradient_accumulation_steps", 0)
52
+ lr = transformers_config.get("training", {}).get("learning_rate", 0)
53
+ epochs = transformers_config.get("training", {}).get("num_train_epochs", 0)
54
+ gpu_count = hardware_config.get("specs", {}).get("gpu_count", 0)
55
+ gpu_type = hardware_config.get("specs", {}).get("gpu_type", "")
56
+
57
+ config_info = f"""
58
+ ## Current Training Configuration
59
+
60
+ **Model**: {model_name}
61
+ **Dataset**: {dataset_name}
62
+
63
+ **Training Parameters**:
64
+ - Learning Rate: {lr}
65
+ - Epochs: {epochs}
66
+ - Batch Size/GPU: {batch_size}
67
+ - Gradient Accumulation: {gradient_accum}
68
+ - Effective Batch Size: {batch_size * gradient_accum * gpu_count}
69
+
70
+ **Hardware**:
71
+ - GPUs: {gpu_count}x {gpu_type}
72
+ - Flash Attention: {hardware_config.get("memory_optimization", {}).get("use_flash_attention", False)}
73
+ - Gradient Checkpointing: {hardware_config.get("memory_optimization", {}).get("use_gradient_checkpointing", False)}
74
 
75
+ **Pre-quantized 4-bit Training**: Enabled
76
+ """
 
77
 
78
+ return config_info
79
 
80
+ def start_training():
81
+ """Start the training process."""
82
  try:
83
+ # Check if already running
84
+ if os.path.exists("training.pid"):
85
+ with open("training.pid", "r") as f:
86
+ pid = f.read().strip()
87
+ try:
88
+ # Check if process is still running
89
+ os.kill(int(pid), 0)
90
+ return f"Training is already running with PID {pid}"
91
+ except OSError:
92
+ # Process not running, remove stale PID file
93
+ os.remove("training.pid")
94
 
95
+ # Start training in background
96
+ cmd = "python run_transformers_training.py"
97
  process = subprocess.Popen(
98
+ cmd,
99
+ shell=True,
100
+ stdout=open('training.log', 'a'),
101
+ stderr=subprocess.STDOUT
 
102
  )
103
 
104
+ # Save PID
105
+ with open("training.pid", "w") as f:
106
+ f.write(str(process.pid))
107
 
108
+ # Log start time
109
+ with open("training_history.log", "a") as f:
110
+ f.write(f"{datetime.now().isoformat()}: Training started (PID: {process.pid})\n")
111
+
112
+ return f"Training started with PID {process.pid}. Check status for updates."
113
+
114
  except Exception as e:
115
+ return f"Error starting training: {str(e)}"
 
116
 
117
+ def check_training_status():
118
+ """Check the status of training."""
 
119
  try:
120
+ # Check if training is running
121
+ if os.path.exists("training.pid"):
122
+ with open("training.pid", "r") as f:
123
+ pid = f.read().strip()
124
+ try:
125
+ # Check if process is still running
126
+ os.kill(int(pid), 0)
127
+ status = f"Training is running with PID {pid}"
128
+ except OSError:
129
+ status = "Training process has stopped"
130
+ os.remove("training.pid")
131
+ else:
132
+ status = "No training process is currently running"
 
 
 
 
 
 
 
 
 
 
 
133
 
134
+ # Get last lines from training log
135
+ log_content = "No training log available"
136
+ if os.path.exists("training.log"):
137
+ with open("training.log", "r") as f:
138
+ lines = f.readlines()
139
+ log_content = "".join(lines[-20:]) if lines else "Log file is empty"
140
 
141
+ return f"{status}\n\n**Recent Log:**\n```\n{log_content}\n```"
142
+
143
  except Exception as e:
144
+ return f"Error checking status: {str(e)}"
145
 
146
+ # Create the Gradio interface
147
+ with gr.Blocks(title="Phi-4 Unsloth Training", theme=gr.themes.Soft(primary_hue="blue")) as app:
148
+ gr.Markdown("# Phi-4 Unsloth 4-bit Training Interface")
149
 
150
+ with gr.Tabs():
151
+ with gr.TabItem("Configuration"):
152
+ config_output = gr.Markdown(display_config())
153
+ refresh_btn = gr.Button("Refresh Configuration")
154
+ refresh_btn.click(fn=display_config, outputs=config_output)
155
+
156
+ with gr.TabItem("Training Control"):
157
+ gr.Markdown("## Training Management")
158
+
159
+ with gr.Row():
 
 
 
 
 
 
 
160
  start_btn = gr.Button("Start Training", variant="primary")
161
+ check_btn = gr.Button("Check Status")
162
+
163
+ status_output = gr.Markdown("Click 'Check Status' to see training progress")
164
+
165
+ start_btn.click(fn=start_training, outputs=status_output)
166
+ check_btn.click(fn=check_training_status, outputs=status_output)
167
+
168
+ # Auto-refresh status
169
+ gr.HTML('''
170
+ <script>
171
+ let intervalId;
172
+
173
+ document.addEventListener('DOMContentLoaded', function() {
174
+ // Find the "Check Status" button
175
+ const buttons = Array.from(document.querySelectorAll('button'));
176
+ const checkBtn = buttons.find(btn => btn.textContent.includes('Check Status'));
177
+
178
+ // Set up interval to click the button every 30 seconds
179
+ if (checkBtn) {
180
+ intervalId = setInterval(() => {
181
+ checkBtn.click();
182
+ }, 30000);
183
+ }
184
+ });
185
+
186
+ // Clean up on tab/window close
187
+ window.addEventListener('beforeunload', function() {
188
+ clearInterval(intervalId);
189
+ });
190
+ </script>
191
+ ''')
192
 
193
+ with gr.TabItem("Help"):
194
+ gr.Markdown("""
195
+ ## Phi-4 Unsloth Training Help
196
+
197
+ This interface allows you to manage training of the Phi-4 model with Unsloth 4-bit optimizations.
198
+
199
+ ### Quick Start
200
+
201
+ 1. Review the configuration in the Configuration tab
202
+ 2. Click "Start Training" to begin the process
203
+ 3. Use "Check Status" to monitor progress
204
+
205
+ ### Notes
206
+
207
+ - Training uses the pre-quantized model `unsloth/phi-4-unsloth-bnb-4bit`
208
+ - The process maintains paper order and handles metadata appropriately
209
+ - Training progress will be regularly saved to HuggingFace Hub
210
+
211
+ ### Troubleshooting
212
+
213
+ If training stops unexpectedly:
214
+ - Check the logs for out-of-memory errors
215
+ - Verify the VRAM usage on each GPU
216
+ - Check for CUDA version compatibility
217
+ """)
218
 
219
+ # Launch the app
220
  if __name__ == "__main__":
221
+ app.launch()
 
hardware_config.json CHANGED
@@ -9,13 +9,13 @@
9
  "ram": 186
10
  },
11
  "training_optimizations": {
12
- "per_device_batch_size": 32,
13
  "gradient_accumulation_steps": 2,
14
- "effective_batch_size": 256,
15
  "memory_optimizations": {
16
  "use_gradient_checkpointing": true,
17
  "pin_memory": true,
18
- "num_workers": 8,
19
  "use_flash_attention": true
20
  },
21
  "distributed_settings": {
@@ -41,9 +41,9 @@
41
  "mixed_precision": "bf16",
42
  "num_gpus": 4,
43
  "training_parameters": {
44
- "per_device_train_batch_size": 32,
45
  "gradient_accumulation_steps": 2,
46
- "dataloader_num_workers": 8,
47
  "dataloader_pin_memory": true,
48
  "gradient_checkpointing": true,
49
  "max_grad_norm": 1.0
 
9
  "ram": 186
10
  },
11
  "training_optimizations": {
12
+ "per_device_batch_size": 24,
13
  "gradient_accumulation_steps": 2,
14
+ "effective_batch_size": 192,
15
  "memory_optimizations": {
16
  "use_gradient_checkpointing": true,
17
  "pin_memory": true,
18
+ "num_workers": 4,
19
  "use_flash_attention": true
20
  },
21
  "distributed_settings": {
 
41
  "mixed_precision": "bf16",
42
  "num_gpus": 4,
43
  "training_parameters": {
44
+ "per_device_train_batch_size": 24,
45
  "gradient_accumulation_steps": 2,
46
+ "dataloader_num_workers": 4,
47
  "dataloader_pin_memory": true,
48
  "gradient_checkpointing": true,
49
  "max_grad_norm": 1.0
run_transformers_training.py CHANGED
@@ -127,13 +127,12 @@ def parse_args():
127
  def load_model_and_tokenizer(config):
128
  """Load model and tokenizer with proper error handling and optimizations."""
129
  try:
130
- if config.get("use_unsloth", False) and unsloth_available:
131
- logger.info("Using Unsloth optimizations")
132
  model, tokenizer = FastLanguageModel.from_pretrained(
133
  model_name=config.get("model_name"),
134
  max_seq_length=config.get("max_seq_length", 2048),
135
  dtype=None, # Let Unsloth choose optimal dtype
136
- load_in_4bit=config.get("load_in_4bit", True),
137
  device_map="auto",
138
  )
139
 
@@ -151,49 +150,14 @@ def load_model_and_tokenizer(config):
151
  )
152
  logger.info("Unsloth optimizations applied successfully")
153
  else:
154
- if config.get("use_unsloth", False):
155
- logger.warning("Unsloth requested but not available. Falling back to standard training.")
156
-
157
- # Standard quantization setup
158
- quantization_config = None
159
- if config.get("load_in_4bit", False) and bitsandbytes_available:
160
- logger.info("Using 4-bit quantization")
161
- quantization_config = BitsAndBytesConfig(
162
- load_in_4bit=True,
163
- bnb_4bit_quant_type="nf4",
164
- bnb_4bit_compute_dtype=torch.float16,
165
- bnb_4bit_use_double_quant=True
166
- )
167
-
168
- # Load model with standard settings
169
- model = AutoModelForCausalLM.from_pretrained(
170
- config.get("model_name"),
171
- quantization_config=quantization_config,
172
- device_map="auto",
173
- trust_remote_code=config.get("trust_remote_code", True),
174
- use_cache=not config.get("gradient_checkpointing", True)
175
- )
176
-
177
- # Load tokenizer
178
- tokenizer = AutoTokenizer.from_pretrained(
179
- config.get("model_name"),
180
- use_fast=config.get("use_fast_tokenizer", True),
181
- trust_remote_code=config.get("trust_remote_code", True)
182
- )
183
-
184
- # Enable gradient checkpointing if requested
185
- if config.get("gradient_checkpointing", True) and hasattr(model, "gradient_checkpointing_enable"):
186
- model.gradient_checkpointing_enable(use_reentrant=False)
187
- logger.info("Gradient checkpointing enabled")
188
 
189
  # Set up tokenizer settings
190
  if config.get("chat_template"):
191
- if unsloth_available and config.get("use_unsloth", False):
192
- chat_template = get_chat_template("phi")
193
- tokenizer.chat_template = chat_template
194
- else:
195
- tokenizer.chat_template = config.get("chat_template")
196
- logger.info(f"Set chat template to {config.get('chat_template')}")
197
 
198
  # Ensure proper token settings
199
  if tokenizer.pad_token_id is None:
@@ -210,33 +174,191 @@ def load_dataset_with_mapping(dataset_config):
210
  """Load and prepare dataset with proper column mapping."""
211
  try:
212
  # Load dataset
213
- dataset = load_dataset(
214
- dataset_config["dataset"]["name"],
215
- split=dataset_config["dataset"]["split"]
216
- )
217
- logger.info(f"Dataset loaded successfully with {len(dataset)} examples")
218
 
219
- # Apply column mapping if specified
220
- if "column_mapping" in dataset_config["dataset"]:
221
- mapping = dataset_config["dataset"]["column_mapping"]
222
- dataset = dataset.rename_columns({v: k for k, v in mapping.items()})
223
- logger.info(f"Applied column mapping: {mapping}")
 
 
 
 
 
 
 
224
 
225
  # Sort dataset if required
226
- if dataset_config["dataset"]["processing"]["sort_by_id"]:
227
- logger.info("Sorting dataset by ID to maintain paper chunk order")
 
228
  dataset = dataset.sort("id")
229
 
230
- # Log first few IDs to verify sorting
231
- sample_ids = [example["id"] for example in dataset.select(range(min(5, len(dataset))))]
232
  logger.info(f"First few IDs after sorting: {sample_ids}")
233
 
 
234
  return dataset
235
-
236
  except Exception as e:
237
  logger.error(f"Error loading dataset: {str(e)}")
238
  raise
239
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
  def main():
241
  # Set up logging
242
  logger.info("Starting training process")
@@ -322,148 +444,34 @@ def main():
322
  logger.error(f"Error setting up PEFT: {e}")
323
  return 1
324
 
325
- # Load dataset with proper mapping
 
326
  try:
327
- dataset = load_dataset_with_mapping(dataset_config)
328
- logger.info("Dataset loaded and prepared successfully")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
329
  except Exception as e:
330
- logger.error(f"Error loading dataset: {e}")
331
  return 1
332
 
333
- # Simple data collator that processes each entry independently
334
- class SimpleDataCollator:
335
- def __init__(self, tokenizer):
336
- self.tokenizer = tokenizer
337
- self.stats = {"processed": 0, "skipped": 0, "total_tokens": 0}
338
- self.pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
339
- self.prompt_counter = 0
340
- self.paper_counters = {}
341
- logger.info("SimpleDataCollator initialized - using phi-4 chat format")
342
-
343
- def format_phi_chat(self, messages):
344
- """Format messages according to phi-4's chat template."""
345
- formatted_chat = ""
346
- for message in messages:
347
- # Extract role and content
348
- if isinstance(message, dict):
349
- role = message.get("role", "").lower()
350
- content = message.get("content", "")
351
- else:
352
- role = getattr(message, "role", "").lower()
353
- content = getattr(message, "content", "")
354
-
355
- # Format based on role
356
- if role == "human" or role == "user":
357
- formatted_chat += f"Human: {content}\n\n"
358
- elif role == "assistant":
359
- formatted_chat += f"Assistant: {content}\n\n"
360
- elif role == "system":
361
- # For system messages, we prepend them with a special format
362
- formatted_chat = f"System: {content}\n\n" + formatted_chat
363
- else:
364
- logger.warning(f"Unknown role '{role}' - treating as system message")
365
- formatted_chat += f"System: {content}\n\n"
366
-
367
- return formatted_chat.strip()
368
-
369
- def __call__(self, features):
370
- batch = {"input_ids": [], "attention_mask": [], "labels": []}
371
-
372
- for example in features:
373
- try:
374
- # Get ID and conversation fields
375
- paper_id = example.get("id", "") if isinstance(example, dict) else getattr(example, "id", "")
376
- conversation = example.get("conversations", []) if isinstance(example, dict) else getattr(example, "conversations", [])
377
-
378
- if not conversation:
379
- self.stats["skipped"] += 1
380
- continue
381
-
382
- # Increment counters
383
- self.prompt_counter += 1
384
- if paper_id not in self.paper_counters:
385
- self.paper_counters[paper_id] = 0
386
- self.paper_counters[paper_id] += 1
387
-
388
- # Add metadata as system message
389
- metadata = {
390
- "role": "system",
391
- "content": f"Paper ID: {paper_id} | Chunk: {self.paper_counters[paper_id]}"
392
- }
393
-
394
- # Format the conversation using phi-4's chat template
395
- formatted_content = self.format_phi_chat([metadata] + conversation)
396
-
397
- # Tokenize with the model's chat template
398
- inputs = self.tokenizer(
399
- formatted_content,
400
- add_special_tokens=True,
401
- truncation=True,
402
- max_length=model_config.get("max_seq_length", 2048),
403
- return_tensors=None, # Return list instead of tensors
404
- )
405
-
406
- input_ids = inputs["input_ids"]
407
- attention_mask = inputs["attention_mask"]
408
-
409
- if len(input_ids) > 0:
410
- # For causal language modeling, labels are the same as inputs
411
- labels = input_ids.copy()
412
-
413
- batch["input_ids"].append(input_ids)
414
- batch["attention_mask"].append(attention_mask)
415
- batch["labels"].append(labels)
416
-
417
- self.stats["processed"] += 1
418
- self.stats["total_tokens"] += len(input_ids)
419
-
420
- # Debug logging for first few examples
421
- if self.stats["processed"] <= 3:
422
- logger.info(f"Example {self.stats['processed']} format:")
423
- logger.info(f"Paper ID: {paper_id} | Chunk: {self.paper_counters[paper_id]}")
424
- logger.info(f"Token count: {len(input_ids)}")
425
- logger.info(f"Content preview:\n{formatted_content[:500]}...")
426
- else:
427
- self.stats["skipped"] += 1
428
-
429
- except Exception as e:
430
- logger.warning(f"Error processing example: {str(e)[:100]}...")
431
- self.stats["skipped"] += 1
432
- continue
433
-
434
- # Handle empty batches
435
- if not batch["input_ids"]:
436
- logger.warning("Empty batch, returning dummy tensors")
437
- return {
438
- "input_ids": torch.zeros((1, 1), dtype=torch.long),
439
- "attention_mask": torch.zeros((1, 1), dtype=torch.long),
440
- "labels": torch.zeros((1, 1), dtype=torch.long)
441
- }
442
-
443
- # Pad the batch
444
- max_length = max(len(ids) for ids in batch["input_ids"])
445
-
446
- for i in range(len(batch["input_ids"])):
447
- padding_length = max_length - len(batch["input_ids"][i])
448
- if padding_length > 0:
449
- batch["input_ids"][i].extend([self.pad_token_id] * padding_length)
450
- batch["attention_mask"][i].extend([0] * padding_length)
451
- batch["labels"][i].extend([-100] * padding_length) # Don't compute loss on padding
452
-
453
- # Convert to tensors
454
- batch = {k: torch.tensor(v) for k, v in batch.items()}
455
-
456
- # Log stats periodically
457
- if self.stats["processed"] % 100 == 0 and self.stats["processed"] > 0:
458
- logger.info(f"Data collator stats: processed={self.stats['processed']}, "
459
- f"skipped={self.stats['skipped']}, "
460
- f"avg_tokens={self.stats['total_tokens']/self.stats['processed']:.1f}, "
461
- f"unique_papers={len(self.paper_counters)}")
462
-
463
- return batch
464
-
465
  # Create data collator
466
- data_collator = SimpleDataCollator(tokenizer)
467
 
468
  # Simple logging callback
469
  class LoggingCallback(TrainerCallback):
 
127
  def load_model_and_tokenizer(config):
128
  """Load model and tokenizer with proper error handling and optimizations."""
129
  try:
130
+ if unsloth_available:
131
+ logger.info("Using Unsloth optimizations with pre-quantized model")
132
  model, tokenizer = FastLanguageModel.from_pretrained(
133
  model_name=config.get("model_name"),
134
  max_seq_length=config.get("max_seq_length", 2048),
135
  dtype=None, # Let Unsloth choose optimal dtype
 
136
  device_map="auto",
137
  )
138
 
 
150
  )
151
  logger.info("Unsloth optimizations applied successfully")
152
  else:
153
+ logger.error("Unsloth is required for training with pre-quantized model")
154
+ raise ImportError("Unsloth is required for this training setup")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
  # Set up tokenizer settings
157
  if config.get("chat_template"):
158
+ chat_template = get_chat_template("phi")
159
+ tokenizer.chat_template = chat_template
160
+ logger.info("Set phi chat template")
 
 
 
161
 
162
  # Ensure proper token settings
163
  if tokenizer.pad_token_id is None:
 
174
  """Load and prepare dataset with proper column mapping."""
175
  try:
176
  # Load dataset
177
+ dataset_name = dataset_config.get("dataset", {}).get("name", "")
178
+ dataset_split = dataset_config.get("dataset", {}).get("split", "train")
179
+
180
+ if not dataset_name:
181
+ raise ValueError("Dataset name not provided in configuration")
182
 
183
+ logger.info(f"Loading dataset {dataset_name}, split {dataset_split}")
184
+ dataset = load_dataset(dataset_name, split=dataset_split)
185
+
186
+ # Map columns if specified
187
+ column_mapping = dataset_config.get("dataset", {}).get("column_mapping", {})
188
+ if column_mapping:
189
+ logger.info(f"Applying column mapping: {column_mapping}")
190
+
191
+ # Rename columns according to mapping
192
+ for target, source in column_mapping.items():
193
+ if source in dataset.column_names:
194
+ dataset = dataset.rename_column(source, target)
195
 
196
  # Sort dataset if required
197
+ sort_by_id = dataset_config.get("dataset", {}).get("processing", {}).get("sort_by_id", False)
198
+ if sort_by_id and "id" in dataset.column_names:
199
+ logger.info("Sorting dataset by ID")
200
  dataset = dataset.sort("id")
201
 
202
+ # Log the first few IDs to verify sorting
203
+ sample_ids = [example['id'] for example in dataset.select(range(min(5, len(dataset))))]
204
  logger.info(f"First few IDs after sorting: {sample_ids}")
205
 
206
+ logger.info(f"Dataset loaded successfully with {len(dataset)} examples")
207
  return dataset
208
+
209
  except Exception as e:
210
  logger.error(f"Error loading dataset: {str(e)}")
211
  raise
212
 
213
+ def format_phi_chat(messages, dataset_config):
214
+ """Format messages according to phi-4's chat template and dataset config."""
215
+ formatted_chat = ""
216
+
217
+ # Get role templates from config
218
+ roles = dataset_config.get("data_formatting", {}).get("roles", {
219
+ "system": "System: {content}\n\n",
220
+ "human": "Human: {content}\n\n",
221
+ "assistant": "Assistant: {content}\n\n"
222
+ })
223
+
224
+ # Handle research introduction metadata first
225
+ metadata = next((msg for msg in messages if "[RESEARCH INTRODUCTION]" in msg.get("content", "")), None)
226
+ if metadata:
227
+ system_template = roles.get("system", "System: {content}\n\n")
228
+ formatted_chat = system_template.format(content=metadata['content'])
229
+ messages = [msg for msg in messages if msg != metadata]
230
+
231
+ # Process remaining messages
232
+ for message in messages:
233
+ role = message.get("role", "").lower()
234
+ content = message.get("content", "")
235
+
236
+ # Format based on role
237
+ if role == "human" or role == "user":
238
+ template = roles.get("human", "Human: {content}\n\n")
239
+ formatted_chat += template.format(content=content)
240
+ elif role == "assistant":
241
+ template = roles.get("assistant", "Assistant: {content}\n\n")
242
+ formatted_chat += template.format(content=content)
243
+ elif role == "system":
244
+ # For system messages, prepend them
245
+ template = roles.get("system", "System: {content}\n\n")
246
+ formatted_chat = template.format(content=content) + formatted_chat
247
+
248
+ return formatted_chat.strip()
249
+
250
+ class SimpleDataCollator:
251
+ def __init__(self, tokenizer, dataset_config):
252
+ self.tokenizer = tokenizer
253
+ self.dataset_config = dataset_config
254
+ self.stats = {"processed": 0, "skipped": 0, "total_tokens": 0}
255
+ self.pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
256
+ self.prompt_counter = 0
257
+ self.paper_counters = {}
258
+ self.max_seq_length = dataset_config.get("dataset", {}).get("processing", {}).get("max_seq_length", 2048)
259
+ self.include_metadata = dataset_config.get("data_formatting", {}).get("metadata_handling", {}).get("include_paper_id", True)
260
+ self.include_chunk = dataset_config.get("data_formatting", {}).get("metadata_handling", {}).get("include_chunk_number", True)
261
+ self.metadata_format = dataset_config.get("data_formatting", {}).get("metadata_handling", {}).get("metadata_format", "Paper ID: {paper_id} | Chunk: {chunk_number}")
262
+ logger.info(f"SimpleDataCollator initialized - using phi-4 chat format with max_seq_length={self.max_seq_length}")
263
+
264
+ def __call__(self, features):
265
+ batch = {"input_ids": [], "attention_mask": [], "labels": []}
266
+
267
+ for example in features:
268
+ try:
269
+ # Get ID and conversation fields
270
+ paper_id = example.get("id", "")
271
+ conversation = example.get("conversations", [])
272
+
273
+ if not conversation:
274
+ self.stats["skipped"] += 1
275
+ continue
276
+
277
+ # Track paper chunks
278
+ if paper_id not in self.paper_counters:
279
+ self.paper_counters[paper_id] = 0
280
+ self.paper_counters[paper_id] += 1
281
+
282
+ # Add metadata if configured
283
+ if self.include_metadata:
284
+ # Format metadata according to configured format
285
+ metadata_content = self.metadata_format.format(
286
+ paper_id=paper_id,
287
+ chunk_number=self.paper_counters[paper_id]
288
+ )
289
+
290
+ # Add as system message if not already in conversation
291
+ if not any(msg.get("role") == "system" for msg in conversation):
292
+ conversation = [{"role": "system", "content": metadata_content}] + conversation
293
+
294
+ # Format conversation with research introduction and chunk info
295
+ formatted_content = format_phi_chat(conversation, self.dataset_config)
296
+
297
+ # Tokenize with the model's chat template
298
+ inputs = self.tokenizer(
299
+ formatted_content,
300
+ add_special_tokens=True,
301
+ truncation=True,
302
+ max_length=self.max_seq_length,
303
+ return_tensors=None,
304
+ )
305
+
306
+ if len(inputs["input_ids"]) > 0:
307
+ # For causal language modeling, labels are the same as inputs
308
+ labels = inputs["input_ids"].copy()
309
+
310
+ batch["input_ids"].append(inputs["input_ids"])
311
+ batch["attention_mask"].append(inputs["attention_mask"])
312
+ batch["labels"].append(labels)
313
+
314
+ self.stats["processed"] += 1
315
+ self.stats["total_tokens"] += len(inputs["input_ids"])
316
+
317
+ # Debug logging for first few examples
318
+ log_samples = self.dataset_config.get("validation", {}).get("log_samples", 3)
319
+ if self.stats["processed"] <= log_samples:
320
+ logger.info(f"Example {self.stats['processed']} format:")
321
+ logger.info(f"Paper ID: {paper_id} | Chunk: {self.paper_counters[paper_id]}")
322
+ logger.info(f"Token count: {len(inputs['input_ids'])}")
323
+ logger.info(f"Content preview:\n{formatted_content[:500]}...")
324
+ else:
325
+ self.stats["skipped"] += 1
326
+ except Exception as e:
327
+ logger.warning(f"Error processing example: {str(e)[:100]}...")
328
+ self.stats["skipped"] += 1
329
+ continue
330
+
331
+ if not batch["input_ids"]:
332
+ logger.warning("Empty batch, returning dummy tensors")
333
+ return {
334
+ "input_ids": torch.zeros((1, 1), dtype=torch.long),
335
+ "attention_mask": torch.zeros((1, 1), dtype=torch.long),
336
+ "labels": torch.zeros((1, 1), dtype=torch.long)
337
+ }
338
+
339
+ # Pad the batch
340
+ max_length = max(len(ids) for ids in batch["input_ids"])
341
+
342
+ for i in range(len(batch["input_ids"])):
343
+ padding_length = max_length - len(batch["input_ids"][i])
344
+ if padding_length > 0:
345
+ batch["input_ids"][i].extend([self.pad_token_id] * padding_length)
346
+ batch["attention_mask"][i].extend([0] * padding_length)
347
+ batch["labels"][i].extend([-100] * padding_length)
348
+
349
+ # Convert to tensors
350
+ batch = {k: torch.tensor(v) for k, v in batch.items()}
351
+
352
+ # Log stats periodically
353
+ log_interval = self.dataset_config.get("validation", {}).get("log_interval", 100)
354
+ if self.stats["processed"] % log_interval == 0 and self.stats["processed"] > 0:
355
+ logger.info(f"Data collator stats: processed={self.stats['processed']}, "
356
+ f"skipped={self.stats['skipped']}, "
357
+ f"avg_tokens={self.stats['total_tokens']/self.stats['processed']:.1f}, "
358
+ f"unique_papers={len(self.paper_counters)}")
359
+
360
+ return batch
361
+
362
  def main():
363
  # Set up logging
364
  logger.info("Starting training process")
 
444
  logger.error(f"Error setting up PEFT: {e}")
445
  return 1
446
 
447
+ # Load dataset
448
+ logger.info(f"Loading dataset: {dataset_config.get('dataset_name')}")
449
  try:
450
+ dataset = load_dataset(dataset_config.get("dataset_name"))
451
+ logger.info(f"Dataset loaded successfully with {len(dataset['train'])} training examples")
452
+
453
+ # Sort dataset by ID to ensure chunks from the same paper are processed together
454
+ logger.info("Sorting dataset by ID to maintain paper chunk order")
455
+ def sort_by_id(example):
456
+ # Extract ID as integer if possible, otherwise keep as string
457
+ try:
458
+ return int(example['id'])
459
+ except (ValueError, TypeError):
460
+ return example['id']
461
+
462
+ # Apply sorting to the dataset
463
+ dataset['train'] = dataset['train'].sort('id')
464
+ logger.info("Dataset sorted by ID")
465
+
466
+ # Log the first few IDs to verify sorting
467
+ sample_ids = [example['id'] for example in dataset['train'].select(range(min(5, len(dataset['train']))))]
468
+ logger.info(f"First few IDs after sorting: {sample_ids}")
469
  except Exception as e:
470
+ logger.error(f"Error loading or sorting dataset: {e}")
471
  return 1
472
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
473
  # Create data collator
474
+ data_collator = SimpleDataCollator(tokenizer, dataset_config)
475
 
476
  # Simple logging callback
477
  class LoggingCallback(TrainerCallback):
transformers_config.json CHANGED
@@ -15,7 +15,7 @@
15
  "training": {
16
  "per_device_train_batch_size": 24,
17
  "gradient_accumulation_steps": 2,
18
- "learning_rate": 3e-5,
19
  "num_train_epochs": 3,
20
  "max_steps": -1,
21
  "logging_steps": 10,
@@ -65,7 +65,7 @@
65
  "offload_params": false
66
  },
67
  "ddp_find_unused_parameters": false,
68
- "dataloader_num_workers": 8
69
  },
70
 
71
  "logging": {
 
15
  "training": {
16
  "per_device_train_batch_size": 24,
17
  "gradient_accumulation_steps": 2,
18
+ "learning_rate": 2e-5,
19
  "num_train_epochs": 3,
20
  "max_steps": -1,
21
  "logging_steps": 10,
 
65
  "offload_params": false
66
  },
67
  "ddp_find_unused_parameters": false,
68
+ "dataloader_num_workers": 4
69
  },
70
 
71
  "logging": {