File size: 4,373 Bytes
44d8f7c 321cf88 44d8f7c 321cf88 44d8f7c 321cf88 44d8f7c 321cf88 44d8f7c 321cf88 44d8f7c 321cf88 44d8f7c 321cf88 44d8f7c 321cf88 44d8f7c 321cf88 44d8f7c 321cf88 44d8f7c 321cf88 44d8f7c 321cf88 44d8f7c 321cf88 44d8f7c 321cf88 44d8f7c 321cf88 44d8f7c 321cf88 44d8f7c 321cf88 44d8f7c 321cf88 44d8f7c 321cf88 44d8f7c 321cf88 44d8f7c 321cf88 44d8f7c 321cf88 44d8f7c 321cf88 44d8f7c 321cf88 44d8f7c 321cf88 44d8f7c 321cf88 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
---
license: apache-2.0
datasets:
- PKU-ML/Erdos-CoT
language:
- en
metrics:
- accuracy
base_model:
- Qwen/Qwen2.5-7B-Instruct
pipeline_tag: text-generation
tags:
- graph
- chat
library_name: transformers
---
# G1-CoT-SFT-7B
## Introduction
G1 is the series of large language models trained on our benchmark [Erdos](https://huggingface.co/datasets/PKU-ML/Erdos) for solving graph reasoning tasks, based on Qwen2.5-Instruct.
We apply Group Relative Policy Optimization (GRPO) for reinforcement learning with supervised finetuning as a prelimary step.
G1 brings the following improvements:
- **Significant improvement on graph reasoning**: G1 models achieve up to 46% improvement over baselines on Erdős, with the 7B variant matching OpenAI’s o3-mini and the 3B model surpassing Qwen2.5-72B-Instruct by notable margins.
- **Strong Generalization to unseen graph tasks**: G1 exhibits zero-shot generalization on unseen graph tasks, improving performance on *other graph reasoning benchmarks* (GraphWiz, GraphArena) and *real-world graphs* (Cora, PubMed).
- **NO Compromise on general reasoning**: Crucially, G1 preserves general reasoning ability (GSM8K, MATH, MMLU-Pro), proving its versatility.
**This repo contains the G1-CoT-SFT-7B model**, which has the following features:
- Type: Causal Language Models
- Training Stage: SFT
- Architecture: the same with Qwen2.5-Instruct
- Number of Parameters: 7.62B
- Context Length: Full 32,768 tokens and generation 8192 tokens
For more details, please refer to our [paper](https://arxiv.org/pdf/2505.18499) and [GitHub](https://github.com/PKU-ML/G1/tree/main).
## Requirements
The model is trained based on Qwen/Qwen2.5-7B-Instruct. The code of Qwen2.5 has been in the latest Hugging face `transformers` and we advise you to use the latest version of `transformers`.
With `transformers<4.37.0`, you will encounter the following error:
```
KeyError: 'qwen2'
```
## Quickstart
Here provides a code snippet with `apply_chat_template` to show you how to load the tokenizer and model and how to generate contents.
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
INSTRUCTION_TEMPLATE = """
{instruction}
Solve the above problem efficiently and clearly. The last line of your response should be of the following format: 'Therefore, the final answer is: $\\boxed{{ANSWER}}$. I hope it is correct' (without quotes) where ANSWER is just the final number or expression that solves the problem. Think step by step before answering.
""".strip()
model_name = "PKU-ML/G1-CoT-SFT-7B"
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="auto",
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
prompt = "The task is to determine the degree centrality of a node in the graph.\n\n"\
"Degree centrality for a node is the fraction of nodes it is connected to.\n\n"\
"Here is an undirected graph containing nodes from 1 to 15. The edges are: (1, 15), (15, 11), (2, 3), (2, 6), (3, 6), (3, 7), (6, 7), (6, 8), (7, 8), (7, 14), (4, 10), (10, 5), (10, 12), (8, 14), (8, 9), (12, 11), (12, 13).\n\n"\
"Question: What is the degree centrality of node 2 in the graph?\n\n"\
"You need to format your answer as a float number."
messages = [
{"role": "user", "content": INSTRUCTION_TEMPLATE.format(instruction=prompt)}
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
generated_ids = model.generate(
**model_inputs,
max_new_tokens=4096,
top_p=0.95,
top_k=30,
temperature=0.6
)
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
print(response)
```
## Evaluation & Performance
Detailed evaluation results are reported in this [📑 paper](https://arxiv.org/pdf/2505.18499).
## Citation
If you find our work helpful, feel free to give us a cite.
```
@article{guo2025g1,
title={G1: Teaching LLMs to Reason on Graphs with Reinforcement Learning},
author={Guo, Xiaojun and Li, Ang and Wang, Yifei and Jegelka, Stefanie and Wang, Yisen},
journal={arXiv preprint arXiv:2505.18499},
year={2025}
}
``` |