-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmulti_LLM_voting.py
More file actions
388 lines (304 loc) · 16.4 KB
/
multi_LLM_voting.py
File metadata and controls
388 lines (304 loc) · 16.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
import json
import os
import requests
import re
import logging
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
# ==============================================================================
# --- User Configuration ---
# ==============================================================================
# 1. API Configuration (Ensure all models used in VOTER_MODELS and META_JUDGE_MODEL are defined here)
Default_API_KEY = 'KEY' # Replace with actual Key
API_KEYS = {
"gpt-4.1-mini": Default_API_KEY,
"gpt-4.1": Default_API_KEY,
"gpt-5.2": Default_API_KEY,
"claude-3-7-sonnet": Default_API_KEY,
"deepseek-v3": Default_API_KEY
}
# Define API Endpoints (Assuming a unified gateway; configure separately if needed)
DEFAULT_API_URL = "https://api.openai.com/v1/chat/completions"
API_ENDPOINTS = {
"gpt-4.1-mini": DEFAULT_API_URL,
"gpt-4.1": DEFAULT_API_URL,
"gpt-5.2": DEFAULT_API_URL,
"claude-3-7-sonnet": DEFAULT_API_URL,
"deepseek-v3": DEFAULT_API_URL
}
# 2. Model Selection
# Stage 1: Initial Review Panel (List)
# You can mix different models or repeat the same model.
# The number of elements determines the number of initial votes.
VOTER_MODELS = [
"gpt-5.2", # Voter 1
"gpt-5.2", # Voter 2
"gpt-5.2" # Voter 3 (Can be swapped for deepseek, claude, etc.)
]
# Stage 2: Final Arbitration Model (Meta-Judge)
META_JUDGE_MODEL = "gpt-5.2"
# Concurrency Configuration
MAX_WORKERS = 10
# ==============================================================================
# --- Path Configuration ---
# ==============================================================================
BASE_DIR = '/result/multi_result/test_round'
MODEL_RESULT_ID = 'qwen3-235b-a22b_k1_results'
CODE_RESULTS_FILE = os.path.join(BASE_DIR, f'{MODEL_RESULT_ID}.jsonl')
CODEQL_EVALUATION_RESULTS_FILE = os.path.join(BASE_DIR, 'evaluation_result', f'{MODEL_RESULT_ID}.json')
RAG_DATAFLOW_FILE = '/retriever_result/retriever_dataflow_results.json'
RAG_SOURCE_CONTEXT_FILE = 'task_instance.json'
OUTPUT_DIRECTORY = os.path.join(BASE_DIR, 'judgement_result')
os.makedirs(OUTPUT_DIRECTORY, exist_ok=True)
OUTPUT_JSON_FILE = os.path.join(OUTPUT_DIRECTORY, f'{MODEL_RESULT_ID}_meta.json')
LOG_FILE = os.path.join(OUTPUT_DIRECTORY, f'{MODEL_RESULT_ID}_meta.log')
# ==============================================================================
# --- Prompt Templates ---
# ==============================================================================
# Stage 1: Basic Review
VOTER_PROMPT_TEMPLATE = """
You are a Security Analyst. Determine if the reported vulnerability is a TRUE POSITIVE or FALSE POSITIVE.
=== VULNERABILITY REPORT ===
Name: {name}
Message: {message}
Original Context:
{vulnerable_code_snippet}
=== GENERATED CODE ===
{generated_code}
=== OUTPUT FORMAT ===
Return ONLY a JSON object:
{{
"analysis": "Step-by-step reasoning...",
"reason": "Summary of why it is TP or FP.",
"score": <float, 0.0 for Safe/Fixed, 1.0 for Vulnerable>
}}
"""
# Stage 2: Meta-Judge (Dynamic Prompt)
META_PROMPT_TEMPLATE = """
You are the Chief Security Architect. Your goal is to make the FINAL decision on whether a piece of generated Java code contains a security vulnerability reported by CodeQL.
You have access to:
1. The Vulnerability Report.
2. The Generated Code.
3. **Dataflow Information** (Retrieved via RAG).
4. **Source Code Context** (Retrieved via RAG).
5. The opinions of {num_voters} Junior Analysts (using different AI models) who have already reviewed this code.
=== 1. VULNERABILITY REPORT ===
Name: {name}
Description: {description}
Message: {message}
Vulnerable Line in Original: {start_line}
=== 2. GENERATED CODE (Target of Evaluation) ===
```java
{generated_code}
=== 3. DATAFLOW CONTEXT (RAG) === The following dataflow paths were retrieved to help identify if tainted data reaches sinks: {rag_dataflow}
=== 4. FILE CONTEXT (RAG) === Context from the original file/class structure:
Java
{rag_context}
=== 5. ANALYST OPINIONS === {junior_opinions}
=== YOUR TASK === Analyze all provided information.
The Dataflow and Context are the ground truth for how data moves.
The Junior Analysts might be wrong or hallucinating. Use their insights but trust the code and dataflow more.
Determine if the generated code effectively fixes the vulnerability (e.g., via sanitization, validation, or structural changes) or if it remains vulnerable.
=== OUTPUT FORMAT === You must respond with a JSON object strictly in the following format:
{{ "meta_analysis": "Synthesize the dataflow, context, and analyst opinions. Explain specifically why you agree or disagree with the juniors.", "final_reason": "A concise, definitive verdict.", "final_score": <float> }}
Scoring Guide:
0.0: DEFINITELY FALSE POSITIVE (Code is Safe/Fixed).
1.0: DEFINITELY TRUE POSITIVE (Vulnerability Exists).
Use 0.0 or 1.0 predominantly. Use intermediate values only for genuine edge cases. """
==============================================================================
--- Helper Functions ---
==============================================================================
def setup_logging(log_file): logger = logging.getLogger() logger.setLevel(logging.INFO) if logger.hasHandlers(): logger.handlers.clear() formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S') file_handler = logging.FileHandler(log_file, mode='a', encoding='utf-8') file_handler.setFormatter(formatter) logger.addHandler(file_handler) console_handler = logging.StreamHandler() console_handler.setFormatter(formatter) logger.addHandler(console_handler)
def load_json_file(filepath): logging.info(f"Loading JSON file: {filepath}...") try: with open(filepath, 'r', encoding='utf-8') as f: return json.load(f) except Exception as e: logging.error(f"Error loading {filepath}: {e}") return {}
def load_jsonl_file(filepath): logging.info(f"Loading JSONL file: {filepath}...") data = [] try: with open(filepath, 'r', encoding='utf-8') as f: for line in f: if line.strip(): data.append(json.loads(line)) return data except Exception as e: logging.error(f"Error loading {filepath}: {e}") return []
def get_processed_results(filepath): if not os.path.exists(filepath): return {} try: with open(filepath, 'r', encoding='utf-8') as f: return json.load(f) except: return {}
def parse_judgement(response_text: str, score_key="score", reason_key="reason"): if not response_text: return None, None try: data = json.loads(response_text) return data.get(reason_key), float(data.get(score_key, -1)) except json.JSONDecodeError: pass patterns = [r'json\s*([\s\S]+?)\s*', r'\s*([\s\S]+?)\s*', r'{[\s\S]*}'] for pattern in patterns: match = re.search(pattern, response_text) if match: try: json_str = match.group(1) if '```' in pattern else match.group(0) data = json.loads(json_str) score = float(data.get(score_key, -1)) reason = data.get(reason_key, "") if reason and score != -1: return reason, score except: continue return None, None
def load_rag_dataflow(filepath): raw_data = load_json_file(filepath) processed = {} if not raw_data: return {} for task_id, item in raw_data.items(): if isinstance(item, list) and len(item) > 0: processed[task_id] = "\n".join([f"- {line.strip()}" for line in item[0]]) else: processed[task_id] = "No dataflow information available." return processed
def load_rag_context(filepath): raw_data = load_json_file(filepath) processed = {} if not raw_data: return {} if isinstance(raw_data, list): for repo in raw_data: for res in repo.get('analysis_results', []): tid = res.get('id') context = res.get('primary_analysis', {}).get('task_instance', '') if tid: processed[tid] = context elif isinstance(raw_data, dict): processed = raw_data return processed
==============================================================================
--- API Logic ---
==============================================================================
def call_llm_api(prompt, model_name): """ Universal API call function. Assumes models are compatible with OpenAI Chat Completion format. Adjust payload here if specific models require different structures. """ api_key = API_KEYS.get(model_name) api_url = API_ENDPOINTS.get(model_name)
if not api_key:
logging.error(f"Missing API Key for model: {model_name}")
return None
if not api_url:
logging.error(f"Missing API Endpoint for model: {model_name}")
return None
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
}
# Construct Payload
payload = {
"model": model_name,
"messages": [{"role": "user", "content": prompt}],
"temperature": 1,
"response_format": {"type": "json_object"}
}
# Note: Some models (like raw Claude API) might not support response_format="json_object".
# Remove it if you encounter compatibility issues.
for attempt in range(3):
try:
response = requests.post(api_url, headers=headers, json=payload, timeout=60)
response.raise_for_status()
return response.json()['choices'][0]['message']['content']
except Exception as e:
# Simple exponential backoff
time.sleep(2 * (attempt + 1))
logging.error(f"Failed to get response from {model_name} after 3 attempts.")
return None
def run_stage_1_voting(vuln_details, generated_code): """Stage 1: Concurrently run models configured in VOTER_MODELS""" prompt = VOTER_PROMPT_TEMPLATE.format( name=vuln_details.get('name', 'Unknown'), message=vuln_details.get('message', 'N/A'), vulnerable_code_snippet=vuln_details.get('vulnerable_code_snippet', 'N/A'), generated_code=generated_code )
votes = []
# Use ThreadPoolExecutor for concurrent model calls
# max_workers is set to the length of VOTER_MODELS to ensure parallel requests
num_voters = len(VOTER_MODELS)
with ThreadPoolExecutor(max_workers=max(num_voters, 1)) as executor:
# Submit tasks: Pass (model_name, index) to track which model cast the vote
future_to_model = {}
for idx, model_name in enumerate(VOTER_MODELS):
future = executor.submit(call_llm_api, prompt, model_name)
future_to_model[future] = (model_name, idx + 1)
for future in as_completed(future_to_model):
model_name, voter_id = future_to_model[future]
res_text = future.result()
# Parse result
reason, score = parse_judgement(res_text, score_key="score", reason_key="reason")
if reason is not None:
votes.append({
"voter_id": voter_id,
"model_name": model_name, # Record which model voted
"score": score,
"reason": reason
})
else:
logging.warning(f"Voter {voter_id} ({model_name}) returned invalid JSON.")
# Sort by voter_id for clean logging
votes.sort(key=lambda x: x['voter_id'])
return votes
def run_stage_2_meta_judge(vuln_details, generated_code, votes, rag_dataflow, rag_context): """Stage 2: Run Meta Judge"""
# Format Junior Opinions, including model names
opinions_str = ""
for v in votes:
opinions_str += f"- Analyst {v['voter_id']} (Model: {v['model_name']}): Score={v['score']}, Reason=\"{v['reason']}\"\n"
prompt = META_PROMPT_TEMPLATE.format(
name=vuln_details.get('name', 'Unknown'),
description=vuln_details.get('description', 'N/A'),
message=vuln_details.get('message', 'N/A'),
start_line=vuln_details.get('start_line', '?'),
generated_code=generated_code,
rag_dataflow=rag_dataflow if rag_dataflow else "No dataflow snippets provided.",
rag_context=rag_context if rag_context else "No file context provided.",
num_voters=len(votes),
junior_opinions=opinions_str
)
res_text = call_llm_api(prompt, META_JUDGE_MODEL)
final_reason, final_score = parse_judgement(res_text, score_key="final_score", reason_key="final_reason")
return {
"meta_response_raw": res_text,
"final_reason": final_reason,
"final_score": final_score
}
==============================================================================
--- Main Execution ---
==============================================================================
def main(): setup_logging(LOG_FILE)
code_results = load_jsonl_file(CODE_RESULTS_FILE)
codeql_eval_results = load_json_file(CODEQL_EVALUATION_RESULTS_FILE)
logging.info("Loading RAG Dataflow...")
dataflow_map = load_rag_dataflow(RAG_DATAFLOW_FILE)
logging.info("Loading RAG Source Context...")
context_map = load_rag_context(RAG_SOURCE_CONTEXT_FILE)
processed_results = get_processed_results(OUTPUT_JSON_FILE)
evaluation_data = codeql_eval_results.get('results', {})
total_tasks = len(code_results)
stats = {"fixed": 0, "vuln": 0, "error": 0, "compile_fail": 0}
for i, task in enumerate(code_results, 1):
task_id = task.get('id')
generated_codes = task.get('generated_codes', [])
logging.info(f"\n--- Processing Task {i}/{total_tasks} (ID: {task_id}) ---")
if task_id in processed_results:
logging.info("Skipped (Already Processed)")
continue
security_reports = evaluation_data.get(task_id, [])
num_attempts = min(len(generated_codes), len(security_reports))
judged_attempts = []
id_dataflow = dataflow_map.get(task_id, "")
id_context = context_map.get(task_id, "")
cnt_fixed = 0
cnt_fail = 0
for attempt_idx in range(num_attempts):
code_to_review = generated_codes[attempt_idx]
report = security_reports[attempt_idx]
# 1. Compilation Check
fail_reason = report.get('test_result', {}).get('failure_reason', '')
if fail_reason and "Compilation failure" in fail_reason:
judged_attempts.append({"attempt": attempt_idx+1, "status": "compilation_failure"})
logging.info(f" Attempt {attempt_idx+1}: Compilation Failure")
continue
# 2. CodeQL Check
sec_status = report.get('security_result', {}).get('status')
if sec_status in ["fixed", "pass", "perfect_security_pass"]:
judged_attempts.append({"attempt": attempt_idx+1, "status": "codeql_pass"})
cnt_fixed += 1
logging.info(f" Attempt {attempt_idx+1}: CodeQL Pass")
continue
# 3. Multi-Model Voting + Meta Judge
logging.info(f" Attempt {attempt_idx+1}: Starting Multi-Model Voting ({len(VOTER_MODELS)} voters)...")
vuln_details = report.get('security_result', {}).get('details', [{}])[0]
# --- Stage 1 ---
votes = run_stage_1_voting(vuln_details, code_to_review)
if not votes:
logging.error(" -> No valid votes received from Stage 1.")
judged_attempts.append({"attempt": attempt_idx+1, "status": "error_stage1"})
continue
logging.info(f" -> Stage 1 complete. Collected {len(votes)} opinions.")
# --- Stage 2 ---
meta_result = run_stage_2_meta_judge(
vuln_details,
code_to_review,
votes,
id_dataflow,
id_context
)
final_score = meta_result['final_score']
if final_score is not None:
status = "fixed" if final_score < 0.5 else "security_fail"
logging.info(f" -> Meta Verdict: {status.upper()} (Score: {final_score}) using {META_JUDGE_MODEL}")
if status == "fixed": cnt_fixed += 1
else: cnt_fail += 1
judged_attempts.append({
"attempt": attempt_idx+1,
"status": status,
"final_score": final_score,
"final_reason": meta_result['final_reason'],
"stage1_votes": votes,
"rag_used": bool(id_dataflow or id_context)
})
else:
logging.error(" -> Stage 2 Failed.")
judged_attempts.append({"attempt": attempt_idx+1, "status": "error_stage2"})
# Summary
if cnt_fixed > 0:
final_id_status = "fixed"
stats["fixed"] += 1
elif cnt_fail > 0:
final_id_status = "security_fail"
stats["vuln"] += 1
elif all(x.get('status') == 'compilation_failure' for x in judged_attempts):
final_id_status = "compilation_failure"
stats["compile_fail"] += 1
else:
final_id_status = "error"
stats["error"] += 1
processed_results[task_id] = {
"final_status": final_id_status,
"attempts": judged_attempts
}
try:
with open(OUTPUT_JSON_FILE, 'w', encoding='utf-8') as f:
json.dump(processed_results, f, indent=4, ensure_ascii=False)
except:
pass
logging.info(f"Run Complete. Stats: {stats}")
if name == "main": main()