adamo1139 commited on
Commit
5cd4da8
·
verified ·
1 Parent(s): a4fab27

Upload yi-34b-aezakmi-sft-1-hf.py

Browse files
Files changed (1) hide show
  1. yi-34b-aezakmi-sft-1-hf.py +163 -0
yi-34b-aezakmi-sft-1-hf.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from unsloth import FastLanguageModel
2
+ from datasets import Dataset, load_dataset
3
+ from dataclasses import dataclass, field
4
+ from typing import Dict, Optional
5
+ import torch
6
+ max_seq_length = 4096 # Choose any! We auto support RoPE Scaling internally!
7
+ dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
8
+ load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.
9
+
10
+ model, tokenizer = FastLanguageModel.from_pretrained(
11
+ model_name = "/run/.../yi-34b-rawrr-dpo-2-unsloth", # Choose ANY! eg mistralai/Mistral-7B-Instruct-v0.2
12
+ max_seq_length = max_seq_length,
13
+ attn_implementation="flash_attention_2",
14
+ dtype = dtype,
15
+ load_in_4bit = load_in_4bit,
16
+ # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
17
+ )
18
+
19
+
20
+
21
+ #@title Alignment Handbook utils
22
+ import os
23
+ import re
24
+ from typing import List, Literal, Optional
25
+
26
+ from datasets import DatasetDict, concatenate_datasets, load_dataset, load_from_disk
27
+ from datasets.builder import DatasetGenerationError
28
+
29
+
30
+ #DEFAULT_CHAT_TEMPLATE = "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"
31
+ tokenizer.chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
32
+
33
+ def chatml_format(example):
34
+ # Format system
35
+ if len(example['system']) > 0:
36
+ message = {"role": "system", "content": example['system']}
37
+ system = tokenizer.apply_chat_template([message], tokenize=False)
38
+ else:
39
+ system = ""
40
+
41
+ # Format instruction
42
+ message = {"role": "user", "content": example['instruction']}
43
+ prompt = tokenizer.apply_chat_template([message], tokenize=False, add_generation_prompt=True)
44
+
45
+ # Format response
46
+ response = example['response'] + "<|im_end|>\n"
47
+
48
+
49
+ return {
50
+ "text": system + prompt + response,
51
+ }
52
+
53
+ # Load dataset
54
+ #dataset = load_dataset("adamo1139/AEZAKMI_v2", split="train")
55
+ dataset = load_dataset("json", data_files="/run/..../datasets/aezakmi_v2/aezakmi_v2.jsonl", split="train")
56
+
57
+
58
+ import pprint
59
+ pprint.pprint("""NOT a formatted dataset""")
60
+ pprint
61
+ pprint.pprint(dataset[25])
62
+ pprint.pprint(dataset[26])
63
+ pprint.pprint(dataset[27])
64
+ pprint.pprint(dataset[28])
65
+ pprint.pprint(dataset[29])
66
+ # Save columns
67
+ original_columns = dataset.column_names
68
+
69
+ # Format dataset
70
+ dataset = dataset.map(
71
+ chatml_format,
72
+ remove_columns=original_columns
73
+ )
74
+
75
+ # Print sample
76
+ pprint.pprint("""formatted dataset""")
77
+ pprint.pprint(dataset[25])
78
+ pprint.pprint(dataset[26])
79
+ pprint.pprint(dataset[27])
80
+ pprint.pprint(dataset[28])
81
+ pprint.pprint(dataset[29])
82
+
83
+
84
+ model = FastLanguageModel.get_peft_model(
85
+ model,
86
+ r = 16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
87
+ target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
88
+ "gate_proj", "up_proj", "down_proj",],
89
+ lora_alpha = 32,
90
+ lora_dropout = 0, # Currently only supports dropout = 0
91
+ bias = "none", # Currently only supports bias = "none"
92
+ use_gradient_checkpointing = True,
93
+ random_state = 3407,
94
+ use_rslora = False, # We support rank stabilized LoRA
95
+ loftq_config = None, # And LoftQ
96
+ )
97
+
98
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, HfArgumentParser, TrainingArguments
99
+ from transformers.utils import logging
100
+ from trl import SFTTrainer
101
+
102
+ sft_trainer = SFTTrainer(
103
+ model = model,
104
+ tokenizer = tokenizer,
105
+ train_dataset = dataset,
106
+ dataset_text_field = "text",
107
+ max_seq_length = 2200,
108
+ packing=True,
109
+ args = TrainingArguments(
110
+ evaluation_strategy = "no",
111
+ per_device_train_batch_size = 1,
112
+ gradient_accumulation_steps = 1,
113
+ num_train_epochs = 1.4,
114
+ warmup_steps = 100,
115
+ learning_rate = 0.00006,
116
+ fp16 = not torch.cuda.is_bf16_supported(),
117
+ bf16 = torch.cuda.is_bf16_supported(),
118
+ logging_steps = 1,
119
+ output_dir = "outputs3",
120
+ optim = "adamw_8bit",
121
+ weight_decay = 0.0,
122
+ lr_scheduler_type = "cosine",
123
+ lr_scheduler_kwargs = {
124
+ "num_cycles" : 0.3,
125
+ },
126
+ seed = 42,
127
+ save_strategy = "steps",
128
+ save_steps = 1000,
129
+ save_total_limit = 10,
130
+ ),
131
+ )
132
+
133
+ '''
134
+ dpo_trainer = DPOTrainer(
135
+ model = model,
136
+ ref_model = None,
137
+ args = TrainingArguments(
138
+ per_device_train_batch_size = 1,
139
+ gradient_accumulation_steps = 16,
140
+ warmup_ratio = 0.05,
141
+ num_train_epochs = 1,
142
+ learning_rate = 5e-5,
143
+ fp16 = not torch.cuda.is_bf16_supported(),
144
+ bf16 = torch.cuda.is_bf16_supported(),
145
+ logging_steps = 1,
146
+ optim = "adamw_8bit",
147
+ weight_decay = 0.0,
148
+ lr_scheduler_type = "linear",
149
+ seed = 42,
150
+ output_dir = "outputs2",
151
+ ),
152
+ beta = 0.1,
153
+ train_dataset = dataset,
154
+ # eval_dataset = raw_datasets["test"],
155
+ tokenizer = tokenizer,
156
+ max_length = 500,
157
+ max_prompt_length = 500,
158
+ )
159
+ '''
160
+
161
+
162
+ sft_trainer.train()
163
+ model.save_pretrained("yi-34b-200k-aezakmi-raw-unsloth-2") # Local saving