File size: 4,373 Bytes
44d8f7c
321cf88
 
 
 
 
 
 
 
 
 
 
 
 
44d8f7c
 
 
 
321cf88
44d8f7c
321cf88
44d8f7c
321cf88
 
44d8f7c
321cf88
44d8f7c
321cf88
 
 
44d8f7c
 
321cf88
 
 
 
 
 
44d8f7c
321cf88
44d8f7c
 
321cf88
44d8f7c
321cf88
44d8f7c
321cf88
 
 
 
44d8f7c
 
321cf88
44d8f7c
321cf88
44d8f7c
321cf88
 
44d8f7c
321cf88
 
44d8f7c
321cf88
 
44d8f7c
321cf88
44d8f7c
321cf88
 
 
 
 
 
44d8f7c
321cf88
 
 
 
 
 
 
 
 
 
 
 
 
 
44d8f7c
321cf88
 
 
 
 
 
 
 
 
 
44d8f7c
321cf88
 
 
44d8f7c
 
321cf88
44d8f7c
321cf88
44d8f7c
 
321cf88
44d8f7c
321cf88
44d8f7c
321cf88
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
---
license: apache-2.0
datasets:
- PKU-ML/Erdos-CoT
language:
- en
metrics:
- accuracy
base_model:
- Qwen/Qwen2.5-7B-Instruct
pipeline_tag: text-generation
tags:
- graph
- chat
library_name: transformers
---


# G1-CoT-SFT-7B

## Introduction

G1 is the series of large language models trained on our benchmark [Erdos](https://huggingface.co/datasets/PKU-ML/Erdos) for solving graph reasoning tasks, based on Qwen2.5-Instruct. 
We apply Group Relative Policy Optimization (GRPO) for reinforcement learning with supervised finetuning as a prelimary step.

G1 brings the following improvements:

- **Significant improvement on graph reasoning**: G1 models achieve up to 46% improvement over baselines on Erdős, with the 7B variant matching OpenAI’s o3-mini and the 3B model surpassing Qwen2.5-72B-Instruct by notable margins.
- **Strong Generalization to unseen graph tasks**: G1 exhibits zero-shot generalization on unseen graph tasks, improving performance on *other graph reasoning benchmarks* (GraphWiz, GraphArena) and *real-world graphs* (Cora, PubMed).
- **NO Compromise on general reasoning**: Crucially, G1 preserves general reasoning ability (GSM8K, MATH, MMLU-Pro), proving its versatility.


**This repo contains the G1-CoT-SFT-7B model**, which has the following features:
- Type: Causal Language Models
- Training Stage: SFT
- Architecture: the same with Qwen2.5-Instruct
- Number of Parameters: 7.62B
- Context Length: Full 32,768 tokens and generation 8192 tokens

For more details, please refer to our [paper](https://arxiv.org/pdf/2505.18499) and [GitHub](https://github.com/PKU-ML/G1/tree/main).


## Requirements

The model is trained based on Qwen/Qwen2.5-7B-Instruct. The code of Qwen2.5 has been in the latest Hugging face `transformers` and we advise you to use the latest version of `transformers`.

With `transformers<4.37.0`, you will encounter the following error:
```
KeyError: 'qwen2'
```


## Quickstart

Here provides a code snippet with `apply_chat_template` to show you how to load the tokenizer and model and how to generate contents.

```python
from transformers import AutoModelForCausalLM, AutoTokenizer

INSTRUCTION_TEMPLATE = """
    {instruction}

    Solve the above problem efficiently and clearly. The last line of your response should be of the following format: 'Therefore, the final answer is: $\\boxed{{ANSWER}}$. I hope it is correct' (without quotes) where ANSWER is just the final number or expression that solves the problem. Think step by step before answering.
    """.strip()

model_name = "PKU-ML/G1-CoT-SFT-7B"

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype="auto",
    device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)

prompt = "The task is to determine the degree centrality of a node in the graph.\n\n"\
        "Degree centrality for a node is the fraction of nodes it is connected to.\n\n"\
        "Here is an undirected graph containing nodes from 1 to 15. The edges are: (1, 15), (15, 11), (2, 3), (2, 6), (3, 6), (3, 7), (6, 7), (6, 8), (7, 8), (7, 14), (4, 10), (10, 5), (10, 12), (8, 14), (8, 9), (12, 11), (12, 13).\n\n"\
        "Question: What is the degree centrality of node 2 in the graph?\n\n"\
        "You need to format your answer as a float number."
messages = [
    {"role": "user", "content": INSTRUCTION_TEMPLATE.format(instruction=prompt)}
]
text = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)

generated_ids = model.generate(
    **model_inputs,
    max_new_tokens=4096,
    top_p=0.95,
    top_k=30,
    temperature=0.6
)
generated_ids = [
    output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]

response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
print(response)
```


## Evaluation & Performance

Detailed evaluation results are reported in this [📑 paper](https://arxiv.org/pdf/2505.18499).


## Citation

If you find our work helpful, feel free to give us a cite.

```
@article{guo2025g1,
  title={G1: Teaching LLMs to Reason on Graphs with Reinforcement Learning},
  author={Guo, Xiaojun and Li, Ang and Wang, Yifei and Jegelka, Stefanie and Wang, Yisen},
  journal={arXiv preprint arXiv:2505.18499},
  year={2025}
}
```