Spaces:
Sleeping
Sleeping
Update inference.py
Browse files- inference.py +39 -13
inference.py
CHANGED
|
@@ -56,25 +56,40 @@ def warroom_step(att_type, att_res, scn_type, scn_tip, scn_res, rem_type, rem_ti
|
|
| 56 |
|
| 57 |
# ββ LLM: Single-agent action ββ
|
| 58 |
def get_single_action(task_id: str, obs: dict) -> dict:
|
| 59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
Task: {task_id}
|
| 61 |
-
Terminal: {
|
| 62 |
-
|
| 63 |
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
- medium-access: command="revoke_admin", target_id="user-dev-01"
|
| 67 |
-
- hard-breach step1: command="block_ip", target_id="attacker-ip"
|
| 68 |
-
- hard-breach step2: command="close_port", target_id="web-server"
|
| 69 |
|
| 70 |
-
|
|
|
|
|
|
|
| 71 |
|
| 72 |
resp = client.chat.completions.create(
|
| 73 |
model=MODEL_NAME,
|
| 74 |
messages=[{"role": "user", "content": prompt}],
|
| 75 |
-
max_tokens=64,
|
|
|
|
| 76 |
)
|
| 77 |
-
raw = resp.choices[0].message.content.strip()
|
|
|
|
| 78 |
return json.loads(raw)
|
| 79 |
|
| 80 |
# ββ LLM: Multi-agent actions ββ
|
|
@@ -218,6 +233,7 @@ def run_warroom_task() -> None:
|
|
| 218 |
f"rewards={rewards_str}", flush=True)
|
| 219 |
|
| 220 |
# ββ Main ββ
|
|
|
|
| 221 |
def main():
|
| 222 |
try:
|
| 223 |
requests.get(f"{SPACE_URL}/health", timeout=15).raise_for_status()
|
|
@@ -225,8 +241,18 @@ def main():
|
|
| 225 |
print("[END] success=false steps=0 rewards=0.05", flush=True)
|
| 226 |
raise SystemExit(1)
|
| 227 |
|
| 228 |
-
|
| 229 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 230 |
print("[END] success=false steps=0 rewards=0.05", flush=True)
|
| 231 |
raise SystemExit(1)
|
| 232 |
|
|
|
|
| 56 |
|
| 57 |
# ββ LLM: Single-agent action ββ
|
| 58 |
def get_single_action(task_id: str, obs: dict) -> dict:
|
| 59 |
+
terminal = obs.get("terminal_output", "")
|
| 60 |
+
inventory = json.dumps(obs.get("inventory", []), indent=2)
|
| 61 |
+
|
| 62 |
+
rules = {
|
| 63 |
+
"easy-lockdown": 'command="lockdown", target_id="s3-vault"',
|
| 64 |
+
"easy-secrets": 'Step 1: command="audit", target_id="api-key-01" β Step 2: command="revoke", target_id="api-key-01" β Step 3: command="rotate", target_id="api-key-01"',
|
| 65 |
+
"medium-access": 'command="revoke_admin", target_id="user-dev-01"',
|
| 66 |
+
"medium-mfa": 'Step 1: command="audit", target_id="iam-users" β Then: command="enforce_mfa", target_id="admin-alice" (repeat for admin-bob and admin-carol)',
|
| 67 |
+
"hard-breach": 'Step 1: command="block_ip", target_id="attacker-ip" β Step 2: command="close_port", target_id="web-server"',
|
| 68 |
+
"critical-ransomware": 'Step 1: command="isolate", target_id="db-server" β Step 2: command="revoke_sessions", target_id="active-sessions" β Step 3: command="restore_backup", target_id="db-server"',
|
| 69 |
+
"expert-apt": 'Step 1: command="detect_c2", target_id="c2-beacon" β Step 2: command="block_outbound", target_id="outbound-fw" β Step 3: command="isolate_host", target_id="infected-host" β Step 4: command="patch_vulnerability", target_id="vuln-cve-2024"',
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
prompt = f"""You are a cybersecurity AI agent for Sentinel-Env.
|
| 73 |
+
|
| 74 |
Task: {task_id}
|
| 75 |
+
Terminal output: {terminal}
|
| 76 |
+
Current inventory: {inventory}
|
| 77 |
|
| 78 |
+
Correct action sequence for this task:
|
| 79 |
+
{rules.get(task_id, 'Follow the terminal instructions.')}
|
|
|
|
|
|
|
|
|
|
| 80 |
|
| 81 |
+
Based on the terminal output, choose the NEXT correct action.
|
| 82 |
+
Respond ONLY with a JSON object. No markdown. No explanation.
|
| 83 |
+
Example: {{"command": "lockdown", "target_id": "s3-vault"}}"""
|
| 84 |
|
| 85 |
resp = client.chat.completions.create(
|
| 86 |
model=MODEL_NAME,
|
| 87 |
messages=[{"role": "user", "content": prompt}],
|
| 88 |
+
max_tokens=64,
|
| 89 |
+
temperature=0,
|
| 90 |
)
|
| 91 |
+
raw = resp.choices[0].message.content.strip()
|
| 92 |
+
raw = raw.replace("```json", "").replace("```", "").strip()
|
| 93 |
return json.loads(raw)
|
| 94 |
|
| 95 |
# ββ LLM: Multi-agent actions ββ
|
|
|
|
| 233 |
f"rewards={rewards_str}", flush=True)
|
| 234 |
|
| 235 |
# ββ Main ββ
|
| 236 |
+
# ββ Main β runs ONLY the task in TASK_NAME ββ
|
| 237 |
def main():
|
| 238 |
try:
|
| 239 |
requests.get(f"{SPACE_URL}/health", timeout=15).raise_for_status()
|
|
|
|
| 241 |
print("[END] success=false steps=0 rewards=0.05", flush=True)
|
| 242 |
raise SystemExit(1)
|
| 243 |
|
| 244 |
+
valid_tasks = [
|
| 245 |
+
"easy-lockdown",
|
| 246 |
+
"easy-secrets",
|
| 247 |
+
"medium-access",
|
| 248 |
+
"medium-mfa",
|
| 249 |
+
"hard-breach",
|
| 250 |
+
"critical-ransomware",
|
| 251 |
+
"expert-apt",
|
| 252 |
+
"red-vs-blue",
|
| 253 |
+
]
|
| 254 |
+
|
| 255 |
+
if TASK_NAME not in valid_tasks:
|
| 256 |
print("[END] success=false steps=0 rewards=0.05", flush=True)
|
| 257 |
raise SystemExit(1)
|
| 258 |
|