Rulga commited on
Commit
c4364db
·
1 Parent(s): 6e1f5f4

Enhance finetune_from_chat_history function to improve chat history loading, QA pair extraction, and add temporary file cleanup for training data

Browse files
Files changed (1) hide show
  1. src/training/fine_tuner.py +26 -4
src/training/fine_tuner.py CHANGED
@@ -402,20 +402,42 @@ def finetune_from_chat_history(epochs: int = 3,
402
  """
403
  # Analyze chats and prepare data
404
  analyzer = ChatAnalyzer()
405
- report = analyzer.generate_analytics_report()
406
 
407
- # Check if there's enough data
408
- if report["qa_pairs_count"] < 10:
409
- return False, f"Insufficient data for fine-tuning. Only {report['qa_pairs_count']} QA pairs found."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
410
 
411
  # Create and start fine-tuning process
412
  tuner = FineTuner()
413
  success, message = tuner.prepare_and_train(
 
414
  num_train_epochs=epochs,
415
  per_device_train_batch_size=batch_size,
416
  learning_rate=learning_rate
417
  )
418
 
 
 
 
 
419
  return success, message
420
 
421
  if __name__ == "__main__":
 
402
  """
403
  # Analyze chats and prepare data
404
  analyzer = ChatAnalyzer()
405
+ report = analyzer.analyze_chats()
406
 
407
+ if not report or "Failed to load chat history" in report:
408
+ return False, "Failed to load chat history for training"
409
+
410
+ # Extract QA pairs for training
411
+ qa_pairs = analyzer.extract_question_answer_pairs()
412
+
413
+ if len(qa_pairs) < 10:
414
+ return False, f"Insufficient data for fine-tuning. Only {len(qa_pairs)} QA pairs found."
415
+
416
+ # Create temporary file for training data
417
+ with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.jsonl') as f:
418
+ for pair in qa_pairs:
419
+ json.dump({
420
+ "messages": [
421
+ {"role": "user", "content": pair["question"]},
422
+ {"role": "assistant", "content": pair["answer"]}
423
+ ]
424
+ }, f, ensure_ascii=False)
425
+ f.write('\n')
426
+ training_data_path = f.name
427
 
428
  # Create and start fine-tuning process
429
  tuner = FineTuner()
430
  success, message = tuner.prepare_and_train(
431
+ training_data_path=training_data_path,
432
  num_train_epochs=epochs,
433
  per_device_train_batch_size=batch_size,
434
  learning_rate=learning_rate
435
  )
436
 
437
+ # Cleanup
438
+ if os.path.exists(training_data_path):
439
+ os.remove(training_data_path)
440
+
441
  return success, message
442
 
443
  if __name__ == "__main__":