reissbaker commited on
Commit
3d06b91
·
1 Parent(s): 9ddbf48

Add unfat config

Browse files
Files changed (3) hide show
  1. .gitignore +5 -0
  2. main.py +157 -0
  3. pyproject.toml +10 -0
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ /output
2
+ __pycache__
3
+ .vim
4
+ /dist
5
+ unfat.egg-info
main.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ from typing import cast, Any
4
+ from datasets import load_dataset, Dataset as HfDataset
5
+ from unfat.extract import Extractor
6
+ from unfat.client import OpenAiCompatClient
7
+ from unfat.datasets import Dataset, Prompts, hub_prompts, HubSplit
8
+ from unfat.together import llama_3_1_70b_together
9
+ from unfat.lora import LoraSettings
10
+
11
+ def gen_prompts(
12
+ ds_name: str,
13
+ text_field: str,
14
+ start_regex: re.Pattern | None = None,
15
+ end_regex: re.Pattern | None = None,
16
+ ):
17
+ ds = cast(HfDataset, load_dataset(ds_name, split="train"))
18
+ def items():
19
+ for row in ds:
20
+ casted = cast(dict[Any, Any], row)
21
+ text = casted[text_field]
22
+ if start_regex and end_regex:
23
+ yield end_regex.sub("", start_regex.sub("", text))
24
+ elif start_regex:
25
+ yield start_regex.sub("", text)
26
+ elif end_regex:
27
+ yield end_regex.sub("", text)
28
+ else:
29
+ yield text
30
+
31
+ return Prompts(
32
+ output_path=f"hub/{ds_name}.jsonl",
33
+ count=lambda: len(ds),
34
+ items=items,
35
+ )
36
+
37
+ def extract_prompts_from_convos(
38
+ ds_name: str,
39
+ messages_field: str,
40
+ role_field: str,
41
+ content_field: str,
42
+ user_role: str,
43
+ ):
44
+ ds = cast(HfDataset, load_dataset(ds_name, split="train"))
45
+ def items():
46
+ for row in ds:
47
+ casted = cast(dict[Any, Any], row)
48
+ for message in casted[messages_field]:
49
+ if message[role_field] == user_role:
50
+ yield message[content_field]
51
+ break
52
+ return Prompts(
53
+ output_path=f"hub/{ds_name}.jsonl",
54
+ count=lambda: len(ds),
55
+ items=items,
56
+ )
57
+
58
+ def main():
59
+ output_dir = "output"
60
+ rp_english = extract_prompts_from_convos(
61
+ ds_name="OdiaGenAI/roleplay_english",
62
+ messages_field="conversations",
63
+ role_field="from",
64
+ content_field="value",
65
+ user_role="user",
66
+ )
67
+ bluemoon = extract_prompts_from_convos(
68
+ ds_name="xDAN2099/RolePlay-Mixed-Bluemoon-Limarp",
69
+ messages_field="conversations",
70
+ role_field="from",
71
+ content_field="value",
72
+ user_role="human",
73
+ )
74
+ roleplay_prompts = gen_prompts(
75
+ ds_name="AlekseyKorshuk/roleplay-io",
76
+ text_field="input_text",
77
+ start_regex=re.compile(r'^User: '),
78
+ end_regex=re.compile(r'Bot:\s*$'),
79
+ )
80
+ roleplay_instr_prompts = gen_prompts(
81
+ ds_name="iamketan25/roleplay-instructions-dataset",
82
+ text_field="prompt",
83
+ start_regex=re.compile(r'^Human: '),
84
+ end_regex=re.compile(r'Assistant:\s*$'),
85
+ )
86
+
87
+ extractor = Extractor(
88
+ max_concurrent=50,
89
+ output_dir=output_dir,
90
+ client=OpenAiCompatClient(
91
+ base_url="https://glhf.chat/api/openai/v1",
92
+ api_key=os.environ["GLHF_API_KEY"],
93
+ model="hf:TheDrummer/Behemoth-123B-v1.2",
94
+ retries=20,
95
+ ),
96
+ dataset=Dataset(
97
+ train=[
98
+ hub_prompts(
99
+ name="mlabonne/harmful_behaviors",
100
+ text_field="text",
101
+ split="train",
102
+ ),
103
+ roleplay_instr_prompts,
104
+ roleplay_prompts,
105
+ rp_english,
106
+ bluemoon,
107
+ hub_prompts(
108
+ name="TheDrummer/AmoralQA-v2",
109
+ text_field="prompt",
110
+ split="train",
111
+ ),
112
+ hub_prompts(
113
+ name="vicgalle/OpenHermesPreferences-roleplay",
114
+ text_field="prompt",
115
+ split="train",
116
+ ),
117
+ hub_prompts(
118
+ name="mrcuddle/DPO_Pairs_Roleplay-Alpaca",
119
+ text_field="prompt",
120
+ split="train",
121
+ ),
122
+ hub_prompts(
123
+ name="ResplendentAI/theory_of_mind_fixed_output",
124
+ text_field="instruction",
125
+ split="train",
126
+ ),
127
+ hub_prompts(
128
+ name="mlabonne/harmless_alpaca",
129
+ text_field="text",
130
+ split=HubSplit(name="train", max_rows=1000),
131
+ ),
132
+ ],
133
+ ),
134
+ )
135
+ extractor.run()
136
+ dataset = extractor.output_dataset()
137
+ together_config = llama_3_1_70b_together(
138
+ output_dir=output_dir,
139
+ dataset=dataset,
140
+ api_key=os.environ["TOGETHER_API_KEY"],
141
+ settings=LoraSettings(
142
+ rank=32,
143
+ alpha=16,
144
+ dropout=0.01,
145
+ num_epochs=2,
146
+ learning_rate=4e-4,
147
+ evals_per_epoch=0,
148
+ wandb_project="behemoth-distill",
149
+ wandb_api_key=os.environ["WANDB_API_KEY"],
150
+ )
151
+ )
152
+ files = together_config.upload_files()
153
+ together_config.finetune(files)
154
+
155
+
156
+ if __name__ == "__main__":
157
+ main()
pyproject.toml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "behemoth-lora"
3
+ version = "0.1.0"
4
+ description = "Add your description here"
5
+ readme = "README.md"
6
+ requires-python = ">=3.11"
7
+ dependencies = [
8
+ "datasets>=3.3.2",
9
+ "unfat>=0.0.13",
10
+ ]