George-API commited on
Commit
356ee13
·
verified ·
1 Parent(s): b3a8a7a

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. dataset_config.json +4 -4
  2. run_transformers_training.py +116 -8
dataset_config.json CHANGED
@@ -3,8 +3,7 @@
3
  "name": "George-API/cognitive-data",
4
  "split": "train",
5
  "column_mapping": {
6
- "text": "conversations",
7
- "id": "id"
8
  },
9
  "processing": {
10
  "sort_by_id": true,
@@ -17,7 +16,8 @@
17
  "roles": {
18
  "system": "System: {content}\n\n",
19
  "human": "Human: {content}\n\n",
20
- "assistant": "Assistant: {content}\n\n"
 
21
  },
22
  "metadata_handling": {
23
  "include_paper_id": true,
@@ -29,7 +29,7 @@
29
  "batch_size": 24,
30
  "shuffle": false,
31
  "drop_last": false,
32
- "num_workers": 8,
33
  "pin_memory": true,
34
  "prefetch_factor": 4
35
  },
 
3
  "name": "George-API/cognitive-data",
4
  "split": "train",
5
  "column_mapping": {
6
+ "conversations": "text"
 
7
  },
8
  "processing": {
9
  "sort_by_id": true,
 
16
  "roles": {
17
  "system": "System: {content}\n\n",
18
  "human": "Human: {content}\n\n",
19
+ "assistant": "Assistant: {content}\n\n",
20
+ "user": "Human: {content}\n\n"
21
  },
22
  "metadata_handling": {
23
  "include_paper_id": true,
 
29
  "batch_size": 24,
30
  "shuffle": false,
31
  "drop_last": false,
32
+ "num_workers": 4,
33
  "pin_memory": true,
34
  "prefetch_factor": 4
35
  },
run_transformers_training.py CHANGED
@@ -208,15 +208,51 @@ def load_dataset_with_mapping(dataset_config):
208
  logger.info(f"Loading dataset {dataset_name}, split {dataset_split}")
209
  dataset = load_dataset(dataset_name, split=dataset_split)
210
 
211
- # Map columns if specified
212
  column_mapping = dataset_config.get("dataset", {}).get("column_mapping", {})
213
  if column_mapping:
214
- logger.info(f"Applying column mapping: {column_mapping}")
215
 
216
- # Rename columns according to mapping
 
217
  for target, source in column_mapping.items():
218
  if source in dataset.column_names:
219
- dataset = dataset.rename_column(source, target)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
 
221
  # Sort dataset if required
222
  sort_by_id = dataset_config.get("dataset", {}).get("processing", {}).get("sort_by_id", False)
@@ -227,8 +263,14 @@ def load_dataset_with_mapping(dataset_config):
227
  # Log the first few IDs to verify sorting
228
  sample_ids = [example['id'] for example in dataset.select(range(min(5, len(dataset))))]
229
  logger.info(f"First few IDs after sorting: {sample_ids}")
 
 
 
 
 
230
 
231
  logger.info(f"Dataset loaded successfully with {len(dataset)} examples")
 
232
  return dataset
233
 
234
  except Exception as e:
@@ -243,11 +285,13 @@ def format_phi_chat(messages, dataset_config):
243
  roles = dataset_config.get("data_formatting", {}).get("roles", {
244
  "system": "System: {content}\n\n",
245
  "human": "Human: {content}\n\n",
 
246
  "assistant": "Assistant: {content}\n\n"
247
  })
248
 
249
  # Handle research introduction metadata first
250
- metadata = next((msg for msg in messages if "[RESEARCH INTRODUCTION]" in msg.get("content", "")), None)
 
251
  if metadata:
252
  system_template = roles.get("system", "System: {content}\n\n")
253
  formatted_chat = system_template.format(content=metadata['content'])
@@ -255,20 +299,29 @@ def format_phi_chat(messages, dataset_config):
255
 
256
  # Process remaining messages
257
  for message in messages:
 
 
 
 
258
  role = message.get("role", "").lower()
259
  content = message.get("content", "")
260
 
261
  # Format based on role
262
  if role == "human" or role == "user":
263
- template = roles.get("human", "Human: {content}\n\n")
264
  formatted_chat += template.format(content=content)
265
- elif role == "assistant":
266
  template = roles.get("assistant", "Assistant: {content}\n\n")
267
  formatted_chat += template.format(content=content)
268
  elif role == "system":
269
  # For system messages, prepend them
270
  template = roles.get("system", "System: {content}\n\n")
271
  formatted_chat = template.format(content=content) + formatted_chat
 
 
 
 
 
272
 
273
  return formatted_chat.strip()
274
 
@@ -284,8 +337,56 @@ class SimpleDataCollator:
284
  self.include_metadata = dataset_config.get("data_formatting", {}).get("metadata_handling", {}).get("include_paper_id", True)
285
  self.include_chunk = dataset_config.get("data_formatting", {}).get("metadata_handling", {}).get("include_chunk_number", True)
286
  self.metadata_format = dataset_config.get("data_formatting", {}).get("metadata_handling", {}).get("metadata_format", "Paper ID: {paper_id} | Chunk: {chunk_number}")
 
287
  logger.info(f"SimpleDataCollator initialized - using phi-4 chat format with max_seq_length={self.max_seq_length}")
288
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289
  def __call__(self, features):
290
  batch = {"input_ids": [], "attention_mask": [], "labels": []}
291
 
@@ -293,7 +394,12 @@ class SimpleDataCollator:
293
  try:
294
  # Get ID and conversation fields
295
  paper_id = example.get("id", "")
296
- conversation = example.get("conversations", [])
 
 
 
 
 
297
 
298
  if not conversation:
299
  self.stats["skipped"] += 1
@@ -346,10 +452,12 @@ class SimpleDataCollator:
346
  logger.info(f"Paper ID: {paper_id} | Chunk: {self.paper_counters[paper_id]}")
347
  logger.info(f"Token count: {len(inputs['input_ids'])}")
348
  logger.info(f"Content preview:\n{formatted_content[:500]}...")
 
349
  else:
350
  self.stats["skipped"] += 1
351
  except Exception as e:
352
  logger.warning(f"Error processing example: {str(e)[:100]}...")
 
353
  self.stats["skipped"] += 1
354
  continue
355
 
 
208
  logger.info(f"Loading dataset {dataset_name}, split {dataset_split}")
209
  dataset = load_dataset(dataset_name, split=dataset_split)
210
 
211
+ # Map columns if specified - with checks to avoid conflicts
212
  column_mapping = dataset_config.get("dataset", {}).get("column_mapping", {})
213
  if column_mapping:
214
+ logger.info(f"Checking column mapping: {column_mapping}")
215
 
216
+ # Only apply mappings for columns that need renaming and don't already exist
217
+ safe_mappings = {}
218
  for target, source in column_mapping.items():
219
  if source in dataset.column_names:
220
+ # Skip if target already exists and is not the same as source
221
+ if target in dataset.column_names and target != source:
222
+ logger.warning(f"Cannot rename '{source}' to '{target}' - target column already exists")
223
+ else:
224
+ safe_mappings[source] = target
225
+
226
+ # Apply safe renames
227
+ if safe_mappings:
228
+ logger.info(f"Applying safe column mapping: {safe_mappings}")
229
+ for source, target in safe_mappings.items():
230
+ if source != target: # Only rename if names are different
231
+ dataset = dataset.rename_column(source, target)
232
+
233
+ # Verify expected columns exist
234
+ expected_columns = {"id", "conversations"}
235
+ for col in expected_columns:
236
+ if col not in dataset.column_names:
237
+ # If "conversations" is missing but "text" exists, it might need conversion
238
+ if col == "conversations" and "text" in dataset.column_names:
239
+ logger.info("Converting 'text' field to 'conversations' format")
240
+
241
+ def convert_text_to_conversations(example):
242
+ # Check if text is already a list of conversation turns
243
+ if isinstance(example.get("text"), list):
244
+ return {"conversations": example["text"]}
245
+ # Otherwise, create a simple conversation with the text as user message
246
+ else:
247
+ return {
248
+ "conversations": [
249
+ {"role": "user", "content": str(example.get("text", ""))}
250
+ ]
251
+ }
252
+
253
+ dataset = dataset.map(convert_text_to_conversations)
254
+ else:
255
+ logger.warning(f"Expected column '{col}' not found in dataset")
256
 
257
  # Sort dataset if required
258
  sort_by_id = dataset_config.get("dataset", {}).get("processing", {}).get("sort_by_id", False)
 
263
  # Log the first few IDs to verify sorting
264
  sample_ids = [example['id'] for example in dataset.select(range(min(5, len(dataset))))]
265
  logger.info(f"First few IDs after sorting: {sample_ids}")
266
+
267
+ # Log example of conversations structure to verify format
268
+ if "conversations" in dataset.column_names:
269
+ sample_conv = dataset["conversations"][0] if len(dataset) > 0 else []
270
+ logger.info(f"Example conversation structure: {sample_conv}")
271
 
272
  logger.info(f"Dataset loaded successfully with {len(dataset)} examples")
273
+ logger.info(f"Dataset columns: {dataset.column_names}")
274
  return dataset
275
 
276
  except Exception as e:
 
285
  roles = dataset_config.get("data_formatting", {}).get("roles", {
286
  "system": "System: {content}\n\n",
287
  "human": "Human: {content}\n\n",
288
+ "user": "Human: {content}\n\n",
289
  "assistant": "Assistant: {content}\n\n"
290
  })
291
 
292
  # Handle research introduction metadata first
293
+ metadata = next((msg for msg in messages if isinstance(msg, dict) and
294
+ "[RESEARCH INTRODUCTION]" in msg.get("content", "")), None)
295
  if metadata:
296
  system_template = roles.get("system", "System: {content}\n\n")
297
  formatted_chat = system_template.format(content=metadata['content'])
 
299
 
300
  # Process remaining messages
301
  for message in messages:
302
+ if not isinstance(message, dict) or "content" not in message:
303
+ logger.warning(f"Skipping invalid message format: {message}")
304
+ continue
305
+
306
  role = message.get("role", "").lower()
307
  content = message.get("content", "")
308
 
309
  # Format based on role
310
  if role == "human" or role == "user":
311
+ template = roles.get("user", roles.get("human", "Human: {content}\n\n"))
312
  formatted_chat += template.format(content=content)
313
+ elif role == "assistant" or role == "bot":
314
  template = roles.get("assistant", "Assistant: {content}\n\n")
315
  formatted_chat += template.format(content=content)
316
  elif role == "system":
317
  # For system messages, prepend them
318
  template = roles.get("system", "System: {content}\n\n")
319
  formatted_chat = template.format(content=content) + formatted_chat
320
+ else:
321
+ # Default to system for unknown roles
322
+ logger.warning(f"Unknown role '{role}' - treating as system message")
323
+ template = roles.get("system", "System: {content}\n\n")
324
+ formatted_chat += template.format(content=content)
325
 
326
  return formatted_chat.strip()
327
 
 
337
  self.include_metadata = dataset_config.get("data_formatting", {}).get("metadata_handling", {}).get("include_paper_id", True)
338
  self.include_chunk = dataset_config.get("data_formatting", {}).get("metadata_handling", {}).get("include_chunk_number", True)
339
  self.metadata_format = dataset_config.get("data_formatting", {}).get("metadata_handling", {}).get("metadata_format", "Paper ID: {paper_id} | Chunk: {chunk_number}")
340
+ self.roles = dataset_config.get("data_formatting", {}).get("roles", {})
341
  logger.info(f"SimpleDataCollator initialized - using phi-4 chat format with max_seq_length={self.max_seq_length}")
342
 
343
+ def normalize_conversation(self, conversation):
344
+ """Normalize conversation format to ensure consistent structure."""
345
+ normalized = []
346
+
347
+ # Handle non-list or empty inputs
348
+ if not isinstance(conversation, list):
349
+ logger.warning(f"Conversation is not a list: {type(conversation)}")
350
+ if hasattr(conversation, 'items'): # It's a dict-like object
351
+ conversation = [conversation]
352
+ else:
353
+ return []
354
+
355
+ for turn in conversation:
356
+ # Skip empty or None entries
357
+ if not turn:
358
+ continue
359
+
360
+ # Handle string entries (convert to user message)
361
+ if isinstance(turn, str):
362
+ normalized.append({"role": "user", "content": turn})
363
+ continue
364
+
365
+ # Handle dict-like entries
366
+ if not isinstance(turn, dict) and hasattr(turn, 'get'):
367
+ # Convert to dict
368
+ turn = {k: turn.get(k) for k in ['role', 'content'] if hasattr(turn, 'get') and turn.get(k) is not None}
369
+
370
+ # Ensure both role and content exist
371
+ if not isinstance(turn, dict) or 'role' not in turn or 'content' not in turn:
372
+ logger.warning(f"Skipping malformatted conversation turn: {turn}")
373
+ continue
374
+
375
+ # Normalize role field
376
+ role = turn.get('role', '').lower()
377
+ if role == 'user' or role == 'human':
378
+ role = 'user'
379
+ elif role == 'assistant' or role == 'bot':
380
+ role = 'assistant'
381
+
382
+ # Add normalized turn
383
+ normalized.append({
384
+ "role": role,
385
+ "content": str(turn.get('content', ''))
386
+ })
387
+
388
+ return normalized
389
+
390
  def __call__(self, features):
391
  batch = {"input_ids": [], "attention_mask": [], "labels": []}
392
 
 
394
  try:
395
  # Get ID and conversation fields
396
  paper_id = example.get("id", "")
397
+
398
+ # Handle conversation field - could be under 'conversations' or 'text'
399
+ conversation = example.get("conversations", example.get("text", []))
400
+
401
+ # Normalize conversation format
402
+ conversation = self.normalize_conversation(conversation)
403
 
404
  if not conversation:
405
  self.stats["skipped"] += 1
 
452
  logger.info(f"Paper ID: {paper_id} | Chunk: {self.paper_counters[paper_id]}")
453
  logger.info(f"Token count: {len(inputs['input_ids'])}")
454
  logger.info(f"Content preview:\n{formatted_content[:500]}...")
455
+ logger.info(f"Conversation structure: {conversation[:2]}...")
456
  else:
457
  self.stats["skipped"] += 1
458
  except Exception as e:
459
  logger.warning(f"Error processing example: {str(e)[:100]}...")
460
+ logger.warning(f"Problematic example: {str(example)[:200]}...")
461
  self.stats["skipped"] += 1
462
  continue
463