Med-REFL-Llama-3.1-8B-lora
Introduction
Med-REFL (Medical Reasoning Enhancement via self-corrected Fine-grained refLection) is a novel framework designed to enhance the complex reasoning capabilities of Large Language Models (LLMs) in the medical domain.
Instead of focusing solely on the final answer, Med-REFL improves the model's intermediate reasoning process. It leverages a Tree-of-Thought (ToT) methodology to explore diverse reasoning paths and automatically constructs Direct Preference Optimization (DPO) data. This trains the model to identify and correct its own reasoning errors, leading to more accurate and trustworthy outputs.
This repository contains the LoRA weights produced by the Med-REFL framework for various base models.
Available Weights
The Med-REFL LoRA weights can be applied to the following base models to enhance their medical reasoning abilities.
LoRA for Base Model | Backbone | Hugging Face Link |
---|---|---|
Med-REFL for Llama-3.1-8B | Llama-3.1-8B | HF Link |
Med-REFL for Qwen2.5-7B | Qwen2.5-7B | HF Link |
Med-REFL for Huatuo-o1-8B | Huatuo-o1-8b | HF Link |
Med-REFL for MedReason-8B | MedReason-8B | HF Link |
Usage
You can deploy it with tools like vllm. For more usages, please refer to our github page.
import torch
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
# Define the paths for the base model and your LoRA adapter on the Hugging Face Hub
base_model_path = "meta-llama/Meta-Llama-3.1-8B-Instruct"
lora_path = "HANI-LAB/Med-REFL-Llama-3.1-8B-lora/Llama3.1-Med-REFL-LoraAdapter"
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model_path)
# Load the base model
base_model = AutoModelForCausalLM.from_pretrained(
base_model_path,
torch_dtype=torch.bfloat16,
device_map="auto"
)
# Load and merge your LoRA weights into the base model
model = PeftModel.from_pretrained(base_model, lora_path)
# Prepare the prompt
system_prompt = '''You are a helpful medical expert specializing in USMLE exam questions, and your task is to answer a multi-choice medical question. Please first think step-by-step and then choose the answer from the provided options. Your responses will be used for research purposes only, so please have a definite answer.\nProvide your response in the following JSON format:\n{"reason": "Step-by-step explanation of your thought process","answer": "Chosen answer from the given options"}\n'''
user_prompt = "A 67-year-old man with transitional cell carcinoma of the bladder comes to the physician because of a 2-day history of ringing sensation in his ear. He received this first course of neoadjuvant chemotherapy 1 week ago. Pure tone audiometry shows a sensorineural hearing loss of 45 dB. The expected beneficial effect of the drug that caused this patient's symptoms is most likely due to which of the following actions?\nOptions:\nA: Inhibition of thymidine synthesis\nB: Inhibition of proteasome\nC: Hyperstabilization of microtubules\nD: Generation of free radicals\nE: Cross-linking of DNA"
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
# Convert the formatted prompt into input tensors
input_ids = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt"
).to(model.device)
# Generate the response
outputs = model.generate(
input_ids,
max_new_tokens=4096,
do_sample=True,
temperature=0.2,
top_p=0.7,
repetition_penalty=1
)
# Decode and print the generated text
response = outputs[0][input_ids.shape[-1]:]
print(tokenizer.decode(response, skip_special_tokens=True))
π Citation
If you use these weights or the Med-REFL framework in your research, please cite our paper:
Model tree for HANI-LAB/Med-REFL-Llama-3.1-8B-lora
Base model
meta-llama/Llama-3.1-8B