-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathdivide_dataset.py
More file actions
262 lines (205 loc) · 9.89 KB
/
divide_dataset.py
File metadata and controls
262 lines (205 loc) · 9.89 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
import sys
import json
import os
import random
import argparse
from antlr4 import *
from antlr.CLexer import CLexer
from antlr.CParser import CParser
from antlr.CVisitor import CVisitor
lines = []
column_offset = {}
function_map = {}
function_count = 0
class SubstituteFunctionNameVisitor(CVisitor):
### Substitute address related constant
# def visitPrimaryExpression(self, ctx:CParser.PrimaryExpressionContext):
# if ctx.Constant() is not None:
# constant_token = ctx.Constant().getSymbol()
# if constant_token.text.startswith('0x'):
# lines[constant_token.line - 1] = lines[constant_token.line - 1][0: constant_token.column + column_offset[constant_token.line - 1]] + '[Magic Number]' + lines[constant_token.line - 1][constant_token.column + len(constant_token.text) + column_offset[constant_token.line - 1]:]
# column_offset[constant_token.line - 1] += len('[Magic Number]') - (len(constant_token.text))
# return self.visitChildren(ctx)
### Substitute address related function name
def visitPostfixExpression(self, ctx:CParser.PostfixExpressionContext):
global lines
global column_offset
global function_count
global function_map
if ctx.LeftParen() is not None and ctx.primaryExpression() is not None:
if ctx.primaryExpression().Identifier() is not None:
functionNameToken = ctx.primaryExpression().Identifier().getSymbol()
if functionNameToken.text.startswith('FUN_'):
if functionNameToken.text in function_map.keys():
new_function_name = function_map[functionNameToken.text]
else:
new_function_name = 'FUN_' + str(function_count)
function_map[functionNameToken.text] = new_function_name
function_count += 1
lines[functionNameToken.line - 1] = lines[functionNameToken.line - 1][0: functionNameToken.column + column_offset[functionNameToken.line - 1]] + new_function_name + lines[functionNameToken.line - 1][functionNameToken.column + len(functionNameToken.text) + column_offset[functionNameToken.line - 1]:]
column_offset[functionNameToken.line - 1] += len(new_function_name) - (len(functionNameToken.text))
return self.visitChildren(ctx)
def substitute_decompiled(code):
global lines
global column_offset
global function_map
global function_count
file = open('code.txt', 'w')
file.write(code)
file.close()
file = open('code.txt', 'r')
code = file.read()
antlrInput = InputStream(code)
file.close()
file = open('code.txt', 'r')
lines = file.readlines()
file.close()
for i in range(len(lines)):
column_offset[i] = 0
function_map = {}
function_count = 0
lexer = CLexer(antlrInput)
stream = CommonTokenStream(lexer)
parser = CParser(stream)
tree = parser.compilationUnit()
visitor = SubstituteFunctionNameVisitor()
visitor.visit(tree)
res = ""
for line in lines:
res += line
os.remove('code.txt')
return res
def main(args):
input_dir = args.input_dir
output_dir = args.output_dir
### Divide Binary
train_part = 0.8
test_part = 0.1
validation_part = 0.1
binary_names = []
path_map = {}
for root, dirs, files in os.walk(input_dir):
for file in files:
binary_names.append(file)
path_map[file] = os.path.join(root, file)
random.shuffle(binary_names)
train_binary = binary_names[0: int(train_part * len(binary_names))]
test_binary = binary_names[int(train_part * len(binary_names)): int((train_part + test_part) * len(binary_names))]
validation_binary = binary_names[int((train_part + test_part) * len(binary_names)): ]
### save the division
# with open('/process_data/division_binary.json', 'w') as f:
# data = {}
# data['train_binary'] = train_binary
# data['test_binary'] = test_binary
# data['validation_binary'] = validation_binary
# json.dump(data, f, indent=4)
### use the saved division
# with open('/process_data/division_binary.json', 'r') as f:
# data = json.load(f)
# train_binary = data['train_binary']
# test_binary = data['test_binary']
# validation_binary = data['validation_binary']
print("[+] Training Set", train_binary)
print("[+] Test Set", test_binary)
print("[+] Validation Set", validation_binary)
existed_function_name = []
existed_function_body = []
train = []
test = []
validation = []
print("[+] Process Training Set Binary")
for binary in train_binary:
with open(path_map[binary]) as f:
data = json.load(f)
for function_name in data.keys():
### remove meaningless like 'FUN_00000f70' in training set
if 'FUN_' in function_name:
continue
### delete duplicate name
if function_name in existed_function_name:
continue
existed_function_name.append(function_name)
decompiled_code = data[function_name]['unstripped']
if decompiled_code is None:
continue
modified_decompiled_code = substitute_decompiled(data[function_name]['stripped'])
### delete duplicate func content
if modified_decompiled_code in existed_function_body:
continue
existed_function_body.append(modified_decompiled_code)
sample = {}
sample["instruction"] = "Suppose you are an expert in software reverse engineering. Here is a piece of decompiled code, you should infer code semantics and tell me the original function name from the contents of the function to replace [MASK]. Now the decompiled codes are as follows:"
sample["input"] = decompiled_code
sample["output"] = 'The predicted function name is ' + function_name
train.append(sample)
print("[+] Process Test Set Binary")
for binary in test_binary:
with open(path_map[binary]) as f:
data = json.load(f)
for function_name in data.keys():
### remove meaningless like 'FUN_00000f70'
if 'FUN_' in function_name:
continue
### delete duplicate name
if function_name in existed_function_name:
continue
existed_function_name.append(function_name)
decompiled_code = data[function_name]['stripped']
if decompiled_code is None:
continue
modified_decompiled_code = substitute_decompiled(decompiled_code)
### delete duplicate func content
if modified_decompiled_code in existed_function_body:
continue
existed_function_body.append(modified_decompiled_code)
sample = {}
sample["instruction"] = "Suppose you are an expert in software reverse engineering. Here is a piece of decompiled code, you should infer code semantics and tell me the original function name from the contents of the function to replace [MASK]. Now the decompiled codes are as follows:"
sample["input"] = decompiled_code
sample["output"] = 'The predicted function name is ' + function_name
test.append(sample)
print("[+] Process Valiation Set Binary")
for binary in validation_binary:
with open(path_map[binary]) as f:
data = json.load(f)
for function_name in data.keys():
### remove meaningless like 'FUN_00000f70'
if 'FUN_' in function_name:
continue
### delete duplicate name
if function_name in existed_function_name:
continue
existed_function_name.append(function_name)
decompiled_code = data[function_name]['stripped']
if decompiled_code is None:
continue
modified_decompiled_code = substitute_decompiled(decompiled_code)
### delete duplicate func content
if modified_decompiled_code in existed_function_body:
continue
existed_function_body.append(modified_decompiled_code)
sample = {}
sample["instruction"] = "Suppose you are an expert in software reverse engineering. Here is a piece of decompiled code, you should infer code semantics and tell me the original function name from the contents of the function to replace [MASK]. Now the decompiled codes are as follows:"
sample["input"] = decompiled_code
sample["output"] = 'The predicted function name is ' + function_name
validation.append(sample)
if not os.path.exists(output_dir):
os.makedirs(output_dir, exist_ok=True)
with open(os.path.join(output_dir, 'training_set.json'), 'w') as f:
json.dump(train, f, indent=4)
print("[+] Save training set to", os.path.join(output_dir, 'train_set.json'))
with open(os.path.join(output_dir, 'test_set.json'), 'w') as f:
json.dump(test, f, indent=4)
print("[+] Save test set to", os.path.join(output_dir, 'test_set.json'))
with open(os.path.join(output_dir, 'validation_set.json'), 'w') as f:
json.dump(validation, f, indent=4)
print("[+] Save validation set to", os.path.join(output_dir, 'validation_set.json'))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Divide data into training, test and validation set.')
parser.add_argument('-i', '--input_dir', type=str, required=True,
# default='',
help='Directory containing the combined decompiled code.')
parser.add_argument('-o', '--output_dir', type=str, required=True,
# default='',
help='Directory to save the divided dataset.')
args = parser.parse_args()
main(args)