LMT-tuning: Llama-3-8B Fine-tuned for Differential Equations

This model is a fine-tuned version of meta-llama/Meta-Llama-3-8B-Instruct, specialized for solving university-level differential equations problems.

The model was trained using the Learning by Teaching (LbT) paradigm combined with Direct Preference Optimization (DPO). This approach aims to improve a "teacher" model's reasoning capabilities by having it teach a "student" model and learning from the student's performance.

Model Description

The core idea of the training process was to create a high-quality preference dataset where the "better" response was not just more correct, but also a better piece of teaching material.

The pipeline involved:

  1. Data Augmentation: A raw corpus of 1500 differential equations problems was flattened and structured into a training set (1200 problems) and a test set (~300 problems).
  2. Teacher Generation: The base Llama-3-8B model generated 32 step-by-step solutions (rationales) for each of the 1200 training problems.
  3. Student Examination (LbT Scoring): For each of the ~39,000 generated rationales, a "student" model (also Llama-3-8B) was taught using that rationale as a one-shot example. The student then took a similarity-based exam, and its performance yielded an "LbT score" for the rationale.
  4. Preference Creation: Rationales were scored based on a combination of correctness and their LbT score. High-scoring rationales were paired with low-scoring ones to create a preference dataset of (prompt, chosen, rejected) triplets.
  5. DPO Fine-tuning: The base Llama-3-8B model was fine-tuned on this preference dataset using trl's DPOTrainer and QLoRA.

Intended Use

This model is primarily intended for:

  • Solving differential equations problems: Providing step-by-step reasoning and a final answer.
  • Educational purposes: Serving as a tool for students to check their work and understand problem-solving steps.
  • Research: Acting as a baseline for further fine-tuning on specialized mathematical domains.

Note: This is a specialist model. While it has been fine-tuned for differential equations, its capabilities on general-purpose chat or other reasoning tasks may have degraded.

How to Use

You can use this model with the transformers library pipeline. It is crucial to use the Llama 3 chat template for best results.

import torch
from transformers import pipeline

# Load the model and tokenizer
pipe = pipeline(
    "text-generation",
    model="Sandesh-Zenteiq/LMT-tuning",
    torch_dtype=torch.bfloat16,
    device_map="auto"
)

# Your differential equations problem
problem = "Solve the initial value problem: y' - 2y = 0, with y(0) = 3."

# This is the full instruction set the model was trained on
instruction_text = (
    "Your task is to answer the last question below. "
    "Give step by step reasoning before you answer. "
    "When you're ready to answer, please wrap your answer and conclude using the format\n"
    "'''\n[[Final Answer]]:\n$ANSWER$\n'''\n\n\n\n"
)
exam_template = (
    "[[Question]]:\n{question}\n\n"
    "[[Solution]]:\nLet's think step by step.\n\n"
)

# Format the prompt using the Llama 3 chat template
prompt = (
    f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n"
    f"{instruction_text}{exam_template.format(question=problem)}"
    f"<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)

# Generate the response
# The pipeline will handle the prompt and only show you the generated part
response = pipe(
    prompt,
    max_new_tokens=1024,
    do_sample=False, # Use do_sample=True for more creative answers
    temperature=0.7,
    top_p=0.9
)

# Extract and print the generated text
# The pipeline returns a list of outputs
generated_text = response['generated_text']
# The generated text includes the prompt, so we can slice it to see only the model's answer
assistant_response = generated_text[len(prompt):]
print(assistant_response)

Training Details

Base Model: meta-llama/Meta-Llama-3-8B-Instruct

Framework: trl.DPOTrainer with QLoRA

Hardware: NVIDIA A6000 / H200 class GPUs

Key Hyperparameters:

learning_rate: 2e-5

num_epochs: 1

lora_r: 128

lora_alpha: 256

gradient_accumulation_steps: 16

Evaluation

The model was evaluated on a held-out test set of 305 differential equations problems that were not seen during training. The metric is Pass@1 accuracy.

Model	Accuracy
meta-llama/Llama-3-8B-Instruct (Base)	10.16%
LMT-tuning (This Model)	16.07%

This represents a +5.90 point absolute improvement and a ~58% relative improvement in performance on this specialized task.

Model fine-tuned by Sandesh-Zenteiq. The methodology is based on the paper "Can LLMs Learn by Teaching for Better Reasoning?"```
Downloads last month
5
Safetensors
Model size
8.03B params
Tensor type
BF16
ยท
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support

Model tree for Sandesh-Zenteiq/LMT-tuning

Finetuned
(701)
this model