LLM4Binary commited on
Commit
eee0ccf
·
verified ·
1 Parent(s): 0af4a0b

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +206 -1
README.md CHANGED
@@ -7,4 +7,209 @@ SK²Decompile is a novel two-phase framework for binary decompilation using Larg
7
 
8
  Phase 1 Structure Recovery (Skeleton): Transform binary/pseudo-code into obfuscated intermediate representations (current model)
9
 
10
- Phase 2 Identifier Naming (Skin): Generate human-readable source code with meaningful identifiers 🤗 [HF Link](https://huggingface.co/LLM4Binary/sk2decompile-ident-6.7)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  Phase 1 Structure Recovery (Skeleton): Transform binary/pseudo-code into obfuscated intermediate representations (current model)
9
 
10
+ Phase 2 Identifier Naming (Skin): Generate human-readable source code with meaningful identifiers 🤗 [HF Link](https://huggingface.co/LLM4Binary/sk2decompile-ident-6.7)
11
+
12
+
13
+ Usage:
14
+ ```
15
+ from llm_server import llm_inference
16
+ from transformers import AutoTokenizer
17
+ import json
18
+ import argparse
19
+ import shutil
20
+ import os
21
+ from tqdm import tqdm
22
+
23
+ opts = ["O0", "O1", "O2", "O3"]
24
+ current_dir = os.path.dirname(os.path.abspath(__file__))
25
+
26
+ if __name__ == "__main__":
27
+ arg_parser = argparse.ArgumentParser()
28
+ arg_parser.add_argument("--model_path",type=str,default="LLM4Binary/llm4decompile-1.3b-v1.5")
29
+ arg_parser.add_argument("--dataset_path",type=str,default='../data/exebench_test_normsrcpseudo_io.json')
30
+ arg_parser.add_argument("--decompiler",type=str,default='asm')
31
+ arg_parser.add_argument("--gpus", type=int, default=1)
32
+ arg_parser.add_argument("--max_num_seqs", type=int, default=1)
33
+ arg_parser.add_argument("--gpu_memory_utilization", type=float, default=0.8)
34
+ arg_parser.add_argument("--temperature", type=float, default=0)
35
+ arg_parser.add_argument("--max_total_tokens", type=int, default=32768)
36
+ arg_parser.add_argument("--max_new_tokens", type=int, default=4096)
37
+ arg_parser.add_argument("--stop_sequences", type=str, default=None)
38
+ arg_parser.add_argument("--recover_model_path", type=str, default=None, help="Path to the model to recover from, if any.")
39
+ arg_parser.add_argument("--output_path", type=str, default='../result/exebench-1.3b-v2')
40
+ arg_parser.add_argument("--only_save", type=int, default=0)
41
+ arg_parser.add_argument("--strip", type=int, default=1)
42
+ arg_parser.add_argument("--language", type=str, default='c')
43
+ args = arg_parser.parse_args()
44
+
45
+ before = "# This is the assembly code:\n"
46
+ after = "\n# What is the source code?\n"
47
+
48
+ if args.dataset_path.endswith('.json'):
49
+ with open(args.dataset_path, "r") as f:
50
+ print("===========")
51
+ print(f"Loading dataset from {args.dataset_path}")
52
+ print("===========")
53
+ samples = json.load(f)
54
+ elif args.dataset_path.endswith('.jsonl'):
55
+ samples = []
56
+ with open(args.dataset_path, "r") as f:
57
+ for line in f:
58
+ line = line.strip()
59
+ if line:
60
+ samples.append(json.loads(line))
61
+
62
+ if args.language == 'c':
63
+ samples = [sample for sample in samples if sample['language'] == 'c']
64
+ elif args.language == 'cpp':
65
+ samples = [sample for sample in samples if sample['language'] == 'cpp']
66
+
67
+ tokenizer = AutoTokenizer.from_pretrained(args.model_path)
68
+ if args.stop_sequences is None:
69
+ args.stop_sequences = [tokenizer.eos_token]
70
+
71
+
72
+ filtered_samples = []
73
+ for sample in tqdm(samples, desc="Filtering samples by token length"):
74
+ if 'ida_strip_pseudo_norm' in sample:
75
+ prompt = before + sample['ida_strip_pseudo_norm'] + after
76
+ tokens = tokenizer.encode(prompt)
77
+ if len(tokens) <= 12000:
78
+ filtered_samples.append(sample)
79
+ else:
80
+ print(f"Discarded sample with {len(tokens)} tokens")
81
+ else:
82
+ filtered_samples.append(sample)
83
+
84
+ samples = filtered_samples
85
+ print(f"Filtered samples: {len(samples)} remaining")
86
+
87
+
88
+ inputs = []
89
+ infos = []
90
+ for sample in samples:
91
+ prompt = before + sample[args.decompiler].strip() + after
92
+ sample['prompt_model1'] = prompt
93
+ inputs.append(prompt)
94
+ infos.append({
95
+ "opt": sample["opt"],
96
+ "language": sample["language"],
97
+ "index": sample["index"],
98
+ "func_name": sample["func_name"]
99
+ })
100
+
101
+
102
+ print("Starting first model inference...")
103
+ gen_results = llm_inference(inputs, args.model_path,
104
+ args.gpus,
105
+ args.max_total_tokens,
106
+ args.gpu_memory_utilization,
107
+ args.temperature,
108
+ args.max_new_tokens,
109
+ args.stop_sequences)
110
+ gen_results = [gen_result[0] for gen_result in gen_results]
111
+
112
+ for idx in range(len(gen_results)):
113
+ samples[idx]['gen_result_model1'] = gen_results[idx]
114
+
115
+ inputs_recovery = []
116
+ before_recovery = "# This is the normalized code:\n"
117
+ after_recovery = "\n# What is the source code?\n"
118
+
119
+ for idx, sample in enumerate(gen_results):
120
+ prompt_recovery = before_recovery + sample.strip() + after_recovery
121
+ samples[idx]['prompt_model2'] = prompt_recovery
122
+ inputs_recovery.append(prompt_recovery)
123
+
124
+ print("Starting recovery model inference...")
125
+ gen_results_recovery = llm_inference(inputs_recovery, args.recover_model_path,
126
+ args.gpus,
127
+ args.max_total_tokens,
128
+ args.gpu_memory_utilization,
129
+ args.temperature,
130
+ args.max_new_tokens,
131
+ args.stop_sequences)
132
+ gen_results_recovery = [gen_result[0] for gen_result in gen_results_recovery]
133
+
134
+
135
+ for idx in range(len(gen_results_recovery)):
136
+ samples[idx]['gen_result_model2'] = gen_results_recovery[idx]
137
+
138
+ if args.output_path:
139
+ if os.path.exists(args.output_path):
140
+ shutil.rmtree(args.output_path)
141
+ for opt in opts:
142
+ os.makedirs(os.path.join(args.output_path, opt))
143
+
144
+ if args.strip:
145
+ print("Processing function name stripping...")
146
+ for idx in range(len(gen_results_recovery)):
147
+ one = gen_results_recovery[idx]
148
+ func_name_in_gen = one.split('(')[0].split(' ')[-1].strip()
149
+ if func_name_in_gen.strip() and func_name_in_gen[0:2] == '**':
150
+ func_name_in_gen = func_name_in_gen[2:]
151
+ elif func_name_in_gen.strip() and func_name_in_gen[0] == '*':
152
+ func_name_in_gen = func_name_in_gen[1:]
153
+
154
+ original_func_name = samples[idx]["func_name"]
155
+ gen_results_recovery[idx] = one.replace(func_name_in_gen, original_func_name)
156
+ samples[idx]["gen_result_model2_stripped"] = gen_results_recovery[idx]
157
+
158
+ print("Saving inference results and logs...")
159
+ for idx_sample, final_result in enumerate(gen_results_recovery):
160
+ opt = infos[idx_sample]['opt']
161
+ language = infos[idx_sample]['language']
162
+ original_index = samples[idx_sample]['index']
163
+
164
+ save_path = os.path.join(args.output_path, opt, f"{original_index}_{opt}.{language}")
165
+ with open(save_path, "w") as f:
166
+ f.write(final_result)
167
+
168
+ log_path = save_path + ".log"
169
+ log_data = {
170
+ "index": original_index,
171
+ "opt": opt,
172
+ "language": language,
173
+ "func_name": samples[idx_sample]["func_name"],
174
+ "decompiler": args.decompiler,
175
+ "input_asm": samples[idx_sample][args.decompiler].strip(),
176
+ "prompt_model1": samples[idx_sample]['prompt_model1'],
177
+ "gen_result_model1": samples[idx_sample]['gen_result_model1'],
178
+ "prompt_model2": samples[idx_sample]['prompt_model2'],
179
+ "gen_result_model2": samples[idx_sample]['gen_result_model2'],
180
+ "final_result": final_result,
181
+ "stripped": args.strip
182
+ }
183
+
184
+ if args.strip and "gen_result_model2_stripped" in samples[idx_sample]:
185
+ log_data["gen_result_model2_stripped"] = samples[idx_sample]["gen_result_model2_stripped"]
186
+
187
+ with open(log_path, "w") as f:
188
+ json.dump(log_data, f, indent=2, ensure_ascii=False)
189
+
190
+ json_path = os.path.join(args.output_path, 'inference_results.jsonl')
191
+ with open(json_path, 'w') as f:
192
+ for sample in samples:
193
+ f.write(json.dumps(sample) + '\n')
194
+
195
+ stats_path = os.path.join(args.output_path, 'inference_stats.txt')
196
+ with open(stats_path, 'w') as f:
197
+ f.write(f"Total samples processed: {len(samples)}\n")
198
+ f.write(f"Model path: {args.model_path}\n")
199
+ f.write(f"Recovery model path: {args.recover_model_path}\n")
200
+ f.write(f"Dataset path: {args.dataset_path}\n")
201
+ f.write(f"Language: {args.language}\n")
202
+ f.write(f"Decompiler: {args.decompiler}\n")
203
+ f.write(f"Strip function names: {bool(args.strip)}\n")
204
+
205
+ opt_counts = {"O0": 0, "O1": 0, "O2": 0, "O3": 0}
206
+ for sample in samples:
207
+ opt_counts[sample['opt']] += 1
208
+
209
+ f.write("\nSamples per optimization level:\n")
210
+ for opt, count in opt_counts.items():
211
+ f.write(f" {opt}: {count}\n")
212
+
213
+ print(f"Inference completed! Results saved to {args.output_path}")
214
+ print(f"Total {len(samples)} samples processed.")
215
+ ```