rudrapatel-1908 commited on
Commit
484482f
Β·
verified Β·
1 Parent(s): 35d1b92

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +39 -13
inference.py CHANGED
@@ -56,25 +56,40 @@ def warroom_step(att_type, att_res, scn_type, scn_tip, scn_res, rem_type, rem_ti
56
 
57
  # ── LLM: Single-agent action ──
58
  def get_single_action(task_id: str, obs: dict) -> dict:
59
- prompt = f"""You are a cybersecurity AI agent.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  Task: {task_id}
61
- Terminal: {obs.get('terminal_output','')}
62
- Inventory: {json.dumps(obs.get('inventory',[]),indent=2)}
63
 
64
- Rules:
65
- - easy-lockdown: command="lockdown", target_id="s3-vault"
66
- - medium-access: command="revoke_admin", target_id="user-dev-01"
67
- - hard-breach step1: command="block_ip", target_id="attacker-ip"
68
- - hard-breach step2: command="close_port", target_id="web-server"
69
 
70
- Respond ONLY with JSON: {{"command":"...","target_id":"..."}}"""
 
 
71
 
72
  resp = client.chat.completions.create(
73
  model=MODEL_NAME,
74
  messages=[{"role": "user", "content": prompt}],
75
- max_tokens=64, temperature=0,
 
76
  )
77
- raw = resp.choices[0].message.content.strip().replace("```json","").replace("```","").strip()
 
78
  return json.loads(raw)
79
 
80
  # ── LLM: Multi-agent actions ──
@@ -218,6 +233,7 @@ def run_warroom_task() -> None:
218
  f"rewards={rewards_str}", flush=True)
219
 
220
  # ── Main ──
 
221
  def main():
222
  try:
223
  requests.get(f"{SPACE_URL}/health", timeout=15).raise_for_status()
@@ -225,8 +241,18 @@ def main():
225
  print("[END] success=false steps=0 rewards=0.05", flush=True)
226
  raise SystemExit(1)
227
 
228
- valid = ["easy-lockdown", "medium-access", "hard-breach", "red-vs-blue"]
229
- if TASK_NAME not in valid:
 
 
 
 
 
 
 
 
 
 
230
  print("[END] success=false steps=0 rewards=0.05", flush=True)
231
  raise SystemExit(1)
232
 
 
56
 
57
  # ── LLM: Single-agent action ──
58
  def get_single_action(task_id: str, obs: dict) -> dict:
59
+ terminal = obs.get("terminal_output", "")
60
+ inventory = json.dumps(obs.get("inventory", []), indent=2)
61
+
62
+ rules = {
63
+ "easy-lockdown": 'command="lockdown", target_id="s3-vault"',
64
+ "easy-secrets": 'Step 1: command="audit", target_id="api-key-01" β†’ Step 2: command="revoke", target_id="api-key-01" β†’ Step 3: command="rotate", target_id="api-key-01"',
65
+ "medium-access": 'command="revoke_admin", target_id="user-dev-01"',
66
+ "medium-mfa": 'Step 1: command="audit", target_id="iam-users" β†’ Then: command="enforce_mfa", target_id="admin-alice" (repeat for admin-bob and admin-carol)',
67
+ "hard-breach": 'Step 1: command="block_ip", target_id="attacker-ip" β†’ Step 2: command="close_port", target_id="web-server"',
68
+ "critical-ransomware": 'Step 1: command="isolate", target_id="db-server" β†’ Step 2: command="revoke_sessions", target_id="active-sessions" β†’ Step 3: command="restore_backup", target_id="db-server"',
69
+ "expert-apt": 'Step 1: command="detect_c2", target_id="c2-beacon" β†’ Step 2: command="block_outbound", target_id="outbound-fw" β†’ Step 3: command="isolate_host", target_id="infected-host" β†’ Step 4: command="patch_vulnerability", target_id="vuln-cve-2024"',
70
+ }
71
+
72
+ prompt = f"""You are a cybersecurity AI agent for Sentinel-Env.
73
+
74
  Task: {task_id}
75
+ Terminal output: {terminal}
76
+ Current inventory: {inventory}
77
 
78
+ Correct action sequence for this task:
79
+ {rules.get(task_id, 'Follow the terminal instructions.')}
 
 
 
80
 
81
+ Based on the terminal output, choose the NEXT correct action.
82
+ Respond ONLY with a JSON object. No markdown. No explanation.
83
+ Example: {{"command": "lockdown", "target_id": "s3-vault"}}"""
84
 
85
  resp = client.chat.completions.create(
86
  model=MODEL_NAME,
87
  messages=[{"role": "user", "content": prompt}],
88
+ max_tokens=64,
89
+ temperature=0,
90
  )
91
+ raw = resp.choices[0].message.content.strip()
92
+ raw = raw.replace("```json", "").replace("```", "").strip()
93
  return json.loads(raw)
94
 
95
  # ── LLM: Multi-agent actions ──
 
233
  f"rewards={rewards_str}", flush=True)
234
 
235
  # ── Main ──
236
+ # ── Main β€” runs ONLY the task in TASK_NAME ──
237
  def main():
238
  try:
239
  requests.get(f"{SPACE_URL}/health", timeout=15).raise_for_status()
 
241
  print("[END] success=false steps=0 rewards=0.05", flush=True)
242
  raise SystemExit(1)
243
 
244
+ valid_tasks = [
245
+ "easy-lockdown",
246
+ "easy-secrets",
247
+ "medium-access",
248
+ "medium-mfa",
249
+ "hard-breach",
250
+ "critical-ransomware",
251
+ "expert-apt",
252
+ "red-vs-blue",
253
+ ]
254
+
255
+ if TASK_NAME not in valid_tasks:
256
  print("[END] success=false steps=0 rewards=0.05", flush=True)
257
  raise SystemExit(1)
258