LMT-tuning: Llama-3-8B Fine-tuned for Differential Equations
This model is a fine-tuned version of meta-llama/Meta-Llama-3-8B-Instruct
, specialized for solving university-level differential equations problems.
The model was trained using the Learning by Teaching (LbT) paradigm combined with Direct Preference Optimization (DPO). This approach aims to improve a "teacher" model's reasoning capabilities by having it teach a "student" model and learning from the student's performance.
Model Description
The core idea of the training process was to create a high-quality preference dataset where the "better" response was not just more correct, but also a better piece of teaching material.
The pipeline involved:
- Data Augmentation: A raw corpus of
1500 differential equations problems was flattened and structured into a training set (1200 problems) and a test set (~300 problems). - Teacher Generation: The base Llama-3-8B model generated 32 step-by-step solutions (rationales) for each of the 1200 training problems.
- Student Examination (LbT Scoring): For each of the ~39,000 generated rationales, a "student" model (also Llama-3-8B) was taught using that rationale as a one-shot example. The student then took a similarity-based exam, and its performance yielded an "LbT score" for the rationale.
- Preference Creation: Rationales were scored based on a combination of correctness and their LbT score. High-scoring rationales were paired with low-scoring ones to create a preference dataset of
(prompt, chosen, rejected)
triplets. - DPO Fine-tuning: The base Llama-3-8B model was fine-tuned on this preference dataset using
trl
'sDPOTrainer
and QLoRA.
Intended Use
This model is primarily intended for:
- Solving differential equations problems: Providing step-by-step reasoning and a final answer.
- Educational purposes: Serving as a tool for students to check their work and understand problem-solving steps.
- Research: Acting as a baseline for further fine-tuning on specialized mathematical domains.
Note: This is a specialist model. While it has been fine-tuned for differential equations, its capabilities on general-purpose chat or other reasoning tasks may have degraded.
How to Use
You can use this model with the transformers
library pipeline. It is crucial to use the Llama 3 chat template for best results.
import torch
from transformers import pipeline
# Load the model and tokenizer
pipe = pipeline(
"text-generation",
model="Sandesh-Zenteiq/LMT-tuning",
torch_dtype=torch.bfloat16,
device_map="auto"
)
# Your differential equations problem
problem = "Solve the initial value problem: y' - 2y = 0, with y(0) = 3."
# This is the full instruction set the model was trained on
instruction_text = (
"Your task is to answer the last question below. "
"Give step by step reasoning before you answer. "
"When you're ready to answer, please wrap your answer and conclude using the format\n"
"'''\n[[Final Answer]]:\n$ANSWER$\n'''\n\n\n\n"
)
exam_template = (
"[[Question]]:\n{question}\n\n"
"[[Solution]]:\nLet's think step by step.\n\n"
)
# Format the prompt using the Llama 3 chat template
prompt = (
f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n"
f"{instruction_text}{exam_template.format(question=problem)}"
f"<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
# Generate the response
# The pipeline will handle the prompt and only show you the generated part
response = pipe(
prompt,
max_new_tokens=1024,
do_sample=False, # Use do_sample=True for more creative answers
temperature=0.7,
top_p=0.9
)
# Extract and print the generated text
# The pipeline returns a list of outputs
generated_text = response['generated_text']
# The generated text includes the prompt, so we can slice it to see only the model's answer
assistant_response = generated_text[len(prompt):]
print(assistant_response)
Training Details
Base Model: meta-llama/Meta-Llama-3-8B-Instruct
Framework: trl.DPOTrainer with QLoRA
Hardware: NVIDIA A6000 / H200 class GPUs
Key Hyperparameters:
learning_rate: 2e-5
num_epochs: 1
lora_r: 128
lora_alpha: 256
gradient_accumulation_steps: 16
Evaluation
The model was evaluated on a held-out test set of 305 differential equations problems that were not seen during training. The metric is Pass@1 accuracy.
Model Accuracy
meta-llama/Llama-3-8B-Instruct (Base) 10.16%
LMT-tuning (This Model) 16.07%
This represents a +5.90 point absolute improvement and a ~58% relative improvement in performance on this specialized task.
Model fine-tuned by Sandesh-Zenteiq. The methodology is based on the paper "Can LLMs Learn by Teaching for Better Reasoning?"```
- Downloads last month
- 5
Model tree for Sandesh-Zenteiq/LMT-tuning
Base model
meta-llama/Meta-Llama-3-8B-Instruct