Spaces:
Running
Running
Upload folder using huggingface_hub
Browse files- inference.py +588 -90
- server/gradio_ui.py +0 -6
inference.py
CHANGED
|
@@ -4,6 +4,7 @@ import sys
|
|
| 4 |
import asyncio
|
| 5 |
import inspect
|
| 6 |
import random
|
|
|
|
| 7 |
from pathlib import Path
|
| 8 |
from typing import Any
|
| 9 |
|
|
@@ -37,28 +38,46 @@ _ASYNC_LOOP: asyncio.AbstractEventLoop | None = None
|
|
| 37 |
SYSTEM_PROMPT = """You are a senior Network Forensics Analyst. Your goal is to investigate malicious network traffic and achieve a 100% detection score.
|
| 38 |
|
| 39 |
### SCORING RULES:
|
| 40 |
-
- You MUST identify and `flag_as_suspicious`
|
| 41 |
- Only grouped packets or flagged packets contribute towards your score.
|
| 42 |
-
- If RECALL is < 0.5, your score will be 0.0. DO NOT stop until you have grouped at least
|
|
|
|
|
|
|
| 43 |
|
| 44 |
### WORKFLOW:
|
| 45 |
1. **Explore**: `inspect_packet` on suspicious samples.
|
| 46 |
-
2. **
|
| 47 |
-
3. **
|
| 48 |
-
4. **
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
|
| 50 |
### JSON SCHEMA EXAMPLES (Use these exactly):
|
| 51 |
- Inspect: {"action_type":"inspect_packet","packet_id":"pkt_0001"}
|
| 52 |
- Flag: {"action_type":"flag_as_suspicious","packet_id":"pkt_0001"}
|
| 53 |
- Group: {"action_type":"group_into_session","session_name":"DDoS_Burst_2","packet_ids":["pkt_0001","pkt_0002"]}
|
| 54 |
- Tag: {"action_type":"tag_pattern","session_name":"DDoS_Burst_2","pattern_type":"ddos"}
|
| 55 |
-
-
|
|
|
|
| 56 |
|
| 57 |
HISTORY_WINDOW = 20
|
| 58 |
REPEAT_ACTION_LIMIT = 3
|
| 59 |
CORRECTION_WINDOW = 5
|
| 60 |
-
UNTAGGED_BACKLOG_LIMIT =
|
| 61 |
INSPECT_SOFT_RATIO_THRESHOLD = 0.60
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
|
| 64 |
def build_client() -> OpenAI:
|
|
@@ -93,7 +112,7 @@ def format_action(action: NetworkForensicsAction) -> str:
|
|
| 93 |
|
| 94 |
|
| 95 |
def summarize_observation(obs: Any, agent_state: dict[str, Any]) -> str:
|
| 96 |
-
"""Provide a structured
|
| 97 |
packets = obs.visible_packets
|
| 98 |
revealed = [p for p in packets if p.is_revealed]
|
| 99 |
revealed_ids = [p.packet_id for p in revealed]
|
|
@@ -104,14 +123,31 @@ def summarize_observation(obs: Any, agent_state: dict[str, Any]) -> str:
|
|
| 104 |
reward_feedback = agent_state.get("last_reward_feedback", "n/a")
|
| 105 |
recent_corrections = agent_state.get("recent_corrections", [])[-CORRECTION_WINDOW:]
|
| 106 |
strategy_hints = agent_state.get("strategy_hints", [])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
|
| 108 |
summary = [
|
| 109 |
f"Step: {obs.step_number}/{obs.step_number + obs.steps_remaining}",
|
| 110 |
f"Current Progress: {obs.current_score_estimate:.2f}",
|
| 111 |
-
f"
|
|
|
|
|
|
|
| 112 |
f"Last Step Reward: {last_reward:.2f}" if isinstance(last_reward, (int, float)) else "Last Step Reward: n/a",
|
| 113 |
f"Last Reward Feedback: {reward_feedback}",
|
| 114 |
-
f"ALREADY REVEALED: {', '.join(revealed_ids[-
|
| 115 |
"\n### SESSIONS PENDING TAGGING:",
|
| 116 |
]
|
| 117 |
|
|
@@ -126,13 +162,13 @@ def summarize_observation(obs: Any, agent_state: dict[str, Any]) -> str:
|
|
| 126 |
summary.append(f"- {hint}")
|
| 127 |
|
| 128 |
if untagged_sessions:
|
| 129 |
-
for s in untagged_sessions:
|
| 130 |
summary.append(f"- {s} ({len(sessions[s])} packets)")
|
| 131 |
else:
|
| 132 |
summary.append("- [No pending sessions]")
|
| 133 |
|
| 134 |
summary.append("\n### REVEALED INDICATORS:")
|
| 135 |
-
for p in revealed[-
|
| 136 |
payload = (p.full_payload or "")[:150]
|
| 137 |
if payload:
|
| 138 |
summary.append(f"- {p.packet_id}: {payload}")
|
|
@@ -207,30 +243,72 @@ def packet_payload_text(packet: Any) -> str:
|
|
| 207 |
|
| 208 |
def keyword_to_pattern(payload: str) -> str | None:
|
| 209 |
text = payload.lower()
|
|
|
|
| 210 |
if "slowloris" in text:
|
| 211 |
return "dos_slowloris"
|
| 212 |
-
if "slowhttptest" in text:
|
| 213 |
return "dos_slowhttptest"
|
| 214 |
-
if "goldeneye" in text:
|
| 215 |
return "dos_goldeneye"
|
| 216 |
if "hulk" in text:
|
| 217 |
return "dos_hulk"
|
| 218 |
-
if "heartbeat" in text or "tls" in text:
|
| 219 |
return "heartbleed"
|
| 220 |
-
if "
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 221 |
return "web_xss"
|
| 222 |
if (
|
| 223 |
"or 1=1" in text
|
| 224 |
or "%20or" in text
|
| 225 |
or "/items?id=" in text
|
| 226 |
or "1=1" in text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 227 |
or "sql" in text
|
|
|
|
|
|
|
|
|
|
| 228 |
):
|
| 229 |
return "web_sql_injection"
|
| 230 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 231 |
return "web_bruteforce"
|
| 232 |
-
|
| 233 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 234 |
return None
|
| 235 |
|
| 236 |
|
|
@@ -245,9 +323,40 @@ def packet_signature(packet: Any, pattern: str) -> tuple[str, str, int, str]:
|
|
| 245 |
return (packet.src_ip, packet.dst_ip, packet.dst_port, pattern)
|
| 246 |
|
| 247 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 248 |
def session_candidates(obs: Any) -> list[tuple[tuple[str, str, int, str], list[Any]]]:
|
| 249 |
grouped: dict[tuple[str, str, int, str], list[Any]] = {}
|
| 250 |
attack_source_ports: dict[tuple[str, str, int, str], set[int]] = {}
|
|
|
|
| 251 |
for packet in obs.visible_packets:
|
| 252 |
pattern = keyword_to_pattern(packet_payload_text(packet))
|
| 253 |
if pattern:
|
|
@@ -255,6 +364,7 @@ def session_candidates(obs: Any) -> list[tuple[tuple[str, str, int, str], list[A
|
|
| 255 |
grouped.setdefault(key, []).append(packet)
|
| 256 |
attack_source_ports.setdefault(key, set()).add(packet.src_port)
|
| 257 |
|
|
|
|
| 258 |
for key, source_ports in attack_source_ports.items():
|
| 259 |
src_ip, dst_ip, dst_port, _pattern = key
|
| 260 |
for packet in obs.visible_packets:
|
|
@@ -267,6 +377,30 @@ def session_candidates(obs: Any) -> list[tuple[tuple[str, str, int, str], list[A
|
|
| 267 |
if is_reverse_response:
|
| 268 |
grouped[key].append(packet)
|
| 269 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 270 |
candidates = [
|
| 271 |
(
|
| 272 |
key,
|
|
@@ -287,8 +421,19 @@ def required_tag_count(task_name: str, total_sessions: int) -> int:
|
|
| 287 |
return 0
|
| 288 |
|
| 289 |
|
| 290 |
-
def select_inspect_packet(
|
| 291 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 292 |
if not unrevealed:
|
| 293 |
return None
|
| 294 |
|
|
@@ -314,6 +459,9 @@ def select_inspect_packet(obs: Any, inspected_ids: set[str]) -> str | None:
|
|
| 314 |
def append_action_history(agent_state: dict[str, Any], action: NetworkForensicsAction) -> None:
|
| 315 |
history = agent_state.setdefault("previous_actions", [])
|
| 316 |
history.append(format_action(action))
|
|
|
|
|
|
|
|
|
|
| 317 |
if len(history) > HISTORY_WINDOW:
|
| 318 |
del history[:-HISTORY_WINDOW]
|
| 319 |
|
|
@@ -355,25 +503,29 @@ def group_meets_evidence_gate(
|
|
| 355 |
candidate_packets, flagged_ids, visible_by_id
|
| 356 |
)
|
| 357 |
size = len(candidate_packets)
|
|
|
|
| 358 |
if task_name == "easy":
|
| 359 |
min_flagged = 1 if size >= 2 else 0
|
| 360 |
elif task_name == "medium":
|
| 361 |
-
min_flagged = 1 if size >=
|
| 362 |
else:
|
| 363 |
-
min_flagged =
|
| 364 |
-
if trusted_pattern and size >=
|
| 365 |
min_flagged = 1
|
| 366 |
if flagged >= min_flagged:
|
| 367 |
return True
|
| 368 |
# Allow grouping with strong revealed malicious evidence.
|
| 369 |
-
if malicious_revealed >=
|
| 370 |
return True
|
| 371 |
-
|
| 372 |
-
if trusted_pattern and size >= 5:
|
| 373 |
return True
|
| 374 |
-
if
|
|
|
|
|
|
|
|
|
|
| 375 |
return True
|
| 376 |
-
|
|
|
|
| 377 |
return True
|
| 378 |
return False
|
| 379 |
|
|
@@ -415,10 +567,10 @@ def derive_strategy_hints(obs: Any, agent_state: dict[str, Any]) -> list[str]:
|
|
| 415 |
)
|
| 416 |
|
| 417 |
inspect_limit = {
|
| 418 |
-
"easy":
|
| 419 |
-
"medium":
|
| 420 |
-
"hard":
|
| 421 |
-
}.get(agent_state.get("current_task_name", ""),
|
| 422 |
if len(previous_actions) >= inspect_limit and inspect_ratio >= INSPECT_SOFT_RATIO_THRESHOLD:
|
| 423 |
hints.append(
|
| 424 |
"You are over-inspecting. Shift to flagging, grouping, tagging, or report submission unless the next packet is clearly high-value."
|
|
@@ -426,10 +578,43 @@ def derive_strategy_hints(obs: Any, agent_state: dict[str, Any]) -> list[str]:
|
|
| 426 |
return hints
|
| 427 |
|
| 428 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 429 |
def build_fallback_action(
|
| 430 |
task_name: str, obs: Any, agent_state: dict[str, Any]
|
| 431 |
) -> NetworkForensicsAction:
|
| 432 |
-
"""Smart workflow engine:
|
| 433 |
inspected_ids = agent_state.setdefault("inspected_ids", set())
|
| 434 |
flagged_ids = agent_state.setdefault("flagged_ids", set())
|
| 435 |
session_map = agent_state.setdefault("sessions", {}) # key -> session_name
|
|
@@ -438,7 +623,7 @@ def build_fallback_action(
|
|
| 438 |
visible_by_id = {p.packet_id: p for p in obs.visible_packets}
|
| 439 |
trusted = trusted_patterns(session_map, tagged_sessions)
|
| 440 |
|
| 441 |
-
if obs.steps_remaining <= 1:
|
| 442 |
summary = _build_report_summary(obs, agent_state)
|
| 443 |
return NetworkForensicsAction(
|
| 444 |
action_type="submit_report",
|
|
@@ -446,21 +631,32 @@ def build_fallback_action(
|
|
| 446 |
claimed_entry_point=claimed_entry,
|
| 447 |
)
|
| 448 |
|
| 449 |
-
# PHASE 1:
|
|
|
|
|
|
|
| 450 |
for packet in obs.visible_packets:
|
| 451 |
if packet.is_revealed and packet.packet_id not in flagged_ids:
|
| 452 |
payload = packet.full_payload or ""
|
| 453 |
pattern = keyword_to_pattern(payload)
|
| 454 |
if pattern:
|
| 455 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 456 |
return NetworkForensicsAction(
|
| 457 |
action_type="flag_as_suspicious",
|
| 458 |
-
packet_id=
|
| 459 |
)
|
| 460 |
|
| 461 |
# PHASE 2: Group flagged packets into sessions with evidence gate and backlog pacing.
|
|
|
|
| 462 |
untagged_backlog = max(0, len(session_map) - len(tagged_sessions))
|
| 463 |
-
if untagged_backlog <= UNTAGGED_BACKLOG_LIMIT:
|
| 464 |
candidates = session_candidates(obs)
|
| 465 |
for key, items in candidates:
|
| 466 |
if key in session_map:
|
|
@@ -482,14 +678,54 @@ def build_fallback_action(
|
|
| 482 |
packet_ids=packet_ids,
|
| 483 |
)
|
| 484 |
|
| 485 |
-
# PHASE
|
| 486 |
-
#
|
| 487 |
-
|
| 488 |
-
|
| 489 |
-
|
| 490 |
-
|
| 491 |
-
|
| 492 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 493 |
tagged_sessions.add(session_name)
|
| 494 |
return NetworkForensicsAction(
|
| 495 |
action_type="tag_pattern",
|
|
@@ -497,19 +733,52 @@ def build_fallback_action(
|
|
| 497 |
pattern_type=pattern,
|
| 498 |
)
|
| 499 |
|
| 500 |
-
# PHASE 4: Identify entry point
|
| 501 |
-
if not claimed_entry
|
| 502 |
-
|
| 503 |
-
|
| 504 |
-
|
| 505 |
-
|
| 506 |
-
|
| 507 |
-
|
| 508 |
-
|
| 509 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 510 |
|
| 511 |
-
# PHASE 5: Inspect more unrevealed packets
|
| 512 |
-
inspect_id = select_inspect_packet(obs, inspected_ids)
|
| 513 |
if inspect_id is not None:
|
| 514 |
return NetworkForensicsAction(action_type="inspect_packet", packet_id=inspect_id)
|
| 515 |
|
|
@@ -523,19 +792,50 @@ def build_fallback_action(
|
|
| 523 |
|
| 524 |
|
| 525 |
def _build_report_summary(obs: Any, agent_state: dict[str, Any]) -> str:
|
| 526 |
-
"""Generate a
|
| 527 |
flagged = agent_state.get("flagged_ids", set())
|
| 528 |
sessions = agent_state.get("sessions", {})
|
| 529 |
tagged = agent_state.get("tagged_sessions", set())
|
| 530 |
-
|
| 531 |
-
|
|
|
|
|
|
|
|
|
|
| 532 |
if len(key) >= 4:
|
| 533 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 534 |
return (
|
| 535 |
-
f"
|
| 536 |
-
f"{len(
|
| 537 |
-
f"{
|
| 538 |
-
f"{
|
|
|
|
|
|
|
|
|
|
|
|
|
| 539 |
)
|
| 540 |
|
| 541 |
|
|
@@ -556,10 +856,10 @@ def should_override_action(
|
|
| 556 |
inspect_count = sum(1 for a in previous_actions if '"inspect_packet"' in a)
|
| 557 |
revealed_count = sum(1 for p in obs.visible_packets if p.is_revealed)
|
| 558 |
inspect_limit = {
|
| 559 |
-
"easy":
|
| 560 |
-
"medium":
|
| 561 |
-
"hard":
|
| 562 |
-
}.get(task_name,
|
| 563 |
|
| 564 |
if action.action_type not in {
|
| 565 |
"inspect_packet",
|
|
@@ -580,13 +880,52 @@ def should_override_action(
|
|
| 580 |
return "Missing packet_id for inspect_packet"
|
| 581 |
if action.packet_id not in {p.packet_id for p in obs.visible_packets}:
|
| 582 |
return f"Invalid packet_id {action.packet_id} - not in visible_packets"
|
|
|
|
|
|
|
|
|
|
| 583 |
revealed_ids = {p.packet_id for p in obs.visible_packets if p.is_revealed}
|
| 584 |
if action.packet_id in revealed_ids:
|
| 585 |
return f"Packet {action.packet_id} is ALREADY revealed. Choose a HIDDEN packet."
|
| 586 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 587 |
return (
|
| 588 |
-
|
| 589 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 590 |
|
| 591 |
if action.action_type == "flag_as_suspicious":
|
| 592 |
if not action.packet_id:
|
|
@@ -606,6 +945,22 @@ def should_override_action(
|
|
| 606 |
}
|
| 607 |
if invalid_ids:
|
| 608 |
return f"Invalid packet_ids in session: {invalid_ids}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 609 |
untagged_backlog = max(0, len(sessions) - len(tagged_sessions))
|
| 610 |
if untagged_backlog > UNTAGGED_BACKLOG_LIMIT:
|
| 611 |
return (
|
|
@@ -631,18 +986,61 @@ def should_override_action(
|
|
| 631 |
|
| 632 |
if action.action_type == "submit_report":
|
| 633 |
untagged_backlog = max(0, len(sessions) - len(tagged_sessions))
|
| 634 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 635 |
return (
|
| 636 |
"Premature report submission. Improve coverage and score estimate before submit_report."
|
| 637 |
)
|
| 638 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 639 |
return "Premature report submission. Tag pending sessions before submitting report."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 640 |
|
| 641 |
if action.action_type == "tag_pattern":
|
| 642 |
if not action.session_name:
|
| 643 |
return "Missing session_name for tag_pattern"
|
| 644 |
if not action.pattern_type:
|
| 645 |
return "Missing pattern_type for tag_pattern"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 646 |
valid_patterns = {
|
| 647 |
"ddos", "dos_slowloris", "dos_slowhttptest", "dos_goldeneye", "dos_hulk",
|
| 648 |
"heartbleed", "web_sql_injection", "web_xss", "web_bruteforce",
|
|
@@ -654,7 +1052,9 @@ def should_override_action(
|
|
| 654 |
if action.action_type == "identify_entry_point":
|
| 655 |
if not action.claimed_entry_point:
|
| 656 |
return "Missing claimed_entry_point for identify_entry_point"
|
| 657 |
-
|
|
|
|
|
|
|
| 658 |
return (
|
| 659 |
"Premature entry-point claim. Gather and flag more evidence before identify_entry_point."
|
| 660 |
)
|
|
@@ -671,6 +1071,14 @@ def choose_action(
|
|
| 671 |
) -> NetworkForensicsAction:
|
| 672 |
agent_state["current_task_name"] = task_name
|
| 673 |
agent_state["strategy_hints"] = derive_strategy_hints(obs, agent_state)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 674 |
history = agent_state.get("previous_actions", [])[-HISTORY_WINDOW:]
|
| 675 |
history_str = "\n".join([f"Step {i+1}: {a}" for i, a in enumerate(history)])
|
| 676 |
|
|
@@ -685,17 +1093,24 @@ def choose_action(
|
|
| 685 |
"Follow the JSON schema in the system prompt."
|
| 686 |
)
|
| 687 |
|
| 688 |
-
|
| 689 |
-
|
| 690 |
-
|
| 691 |
-
|
| 692 |
-
|
| 693 |
-
|
| 694 |
-
"role": "
|
| 695 |
-
|
| 696 |
-
|
| 697 |
-
|
| 698 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 699 |
content = response.choices[0].message.content or ""
|
| 700 |
try:
|
| 701 |
action = sanitize_action(parse_action(content))
|
|
@@ -857,6 +1272,58 @@ def step_env(env: NetworkForensicsEnv, action: NetworkForensicsAction) -> Any:
|
|
| 857 |
return result
|
| 858 |
|
| 859 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 860 |
def close_env(env: NetworkForensicsEnv | None) -> None:
|
| 861 |
if env is None:
|
| 862 |
return
|
|
@@ -889,19 +1356,50 @@ def run_task(task_name: str) -> None:
|
|
| 889 |
obs = reset_result.observation
|
| 890 |
sync_agent_state(obs, agent_state)
|
| 891 |
max_steps = obs.steps_remaining or 50
|
|
|
|
|
|
|
|
|
|
|
|
|
| 892 |
|
| 893 |
-
for _ in range(
|
| 894 |
if obs.done:
|
| 895 |
break
|
| 896 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 897 |
error = None
|
| 898 |
try:
|
| 899 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 900 |
except Exception as exc:
|
| 901 |
error = str(exc).replace("\n", " ")
|
| 902 |
action = build_fallback_action(task_name, obs, agent_state)
|
| 903 |
|
| 904 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 905 |
obs = step_result.observation
|
| 906 |
sync_agent_state(obs, agent_state)
|
| 907 |
step_reward = float(step_result.reward or 0.0)
|
|
|
|
| 4 |
import asyncio
|
| 5 |
import inspect
|
| 6 |
import random
|
| 7 |
+
import time
|
| 8 |
from pathlib import Path
|
| 9 |
from typing import Any
|
| 10 |
|
|
|
|
| 38 |
SYSTEM_PROMPT = """You are a senior Network Forensics Analyst. Your goal is to investigate malicious network traffic and achieve a 100% detection score.
|
| 39 |
|
| 40 |
### SCORING RULES:
|
| 41 |
+
- You MUST identify and `flag_as_suspicious` EVERY malicious packet to maximize RECALL (very important!).
|
| 42 |
- Only grouped packets or flagged packets contribute towards your score.
|
| 43 |
+
- If RECALL is < 0.5, your score will be 0.0. DO NOT stop until you have flagged/grouped at least 60% of visible malicious packets.
|
| 44 |
+
- Entry point must be the EARLIEST packet that initiated the attack (often in first group).
|
| 45 |
+
- For HARD tasks: wrong entry point = score 0. Always identify_entry_point before submitting.
|
| 46 |
|
| 47 |
### WORKFLOW:
|
| 48 |
1. **Explore**: `inspect_packet` on suspicious samples.
|
| 49 |
+
2. **Flag**: `flag_as_suspicious` on ALL revealed malicious packets.
|
| 50 |
+
3. **Correlate**: `group_into_session` with descriptive names.
|
| 51 |
+
4. **Classify**: `tag_pattern` with a valid type.
|
| 52 |
+
5. **Root Cause**: `identify_entry_point` with the earliest malicious packet.
|
| 53 |
+
6. **Report**: `submit_report` ONLY when you have covered all visible malicious sessions.
|
| 54 |
+
|
| 55 |
+
### VALID PATTERN TYPES:
|
| 56 |
+
ddos, dos_slowloris, dos_slowhttptest, dos_goldeneye, dos_hulk, heartbleed, web_sql_injection, web_xss, web_bruteforce, c2, exfiltration, scan, lateral
|
| 57 |
|
| 58 |
### JSON SCHEMA EXAMPLES (Use these exactly):
|
| 59 |
- Inspect: {"action_type":"inspect_packet","packet_id":"pkt_0001"}
|
| 60 |
- Flag: {"action_type":"flag_as_suspicious","packet_id":"pkt_0001"}
|
| 61 |
- Group: {"action_type":"group_into_session","session_name":"DDoS_Burst_2","packet_ids":["pkt_0001","pkt_0002"]}
|
| 62 |
- Tag: {"action_type":"tag_pattern","session_name":"DDoS_Burst_2","pattern_type":"ddos"}
|
| 63 |
+
- Entry: {"action_type":"identify_entry_point","claimed_entry_point":"pkt_0001"}
|
| 64 |
+
- Report: {"action_type":"submit_report","incident_summary":"Detailed incident summary here.","claimed_entry_point":"pkt_0001"}"""
|
| 65 |
|
| 66 |
HISTORY_WINDOW = 20
|
| 67 |
REPEAT_ACTION_LIMIT = 3
|
| 68 |
CORRECTION_WINDOW = 5
|
| 69 |
+
UNTAGGED_BACKLOG_LIMIT = 6
|
| 70 |
INSPECT_SOFT_RATIO_THRESHOLD = 0.60
|
| 71 |
+
SOFT_STEP_BUDGETS = {"easy": 14, "medium": 28, "hard": 40}
|
| 72 |
+
HARD_STEP_CAPS = {"easy": 30, "medium": 50, "hard": 65}
|
| 73 |
+
TASK_SCORE_TARGETS = {"easy": 0.70, "medium": 0.68, "hard": 0.66}
|
| 74 |
+
TASK_COVERAGE_TARGETS = {"easy": 0.32, "medium": 0.24, "hard": 0.20}
|
| 75 |
+
MAX_TASK_SECONDS = float(os.getenv("MAX_TASK_SECONDS", "780"))
|
| 76 |
+
TASK_TIME_BUDGET_SECONDS = {
|
| 77 |
+
"easy": float(os.getenv("EASY_MAX_SECONDS", "150")),
|
| 78 |
+
"medium": float(os.getenv("MEDIUM_MAX_SECONDS", "220")),
|
| 79 |
+
"hard": float(os.getenv("HARD_MAX_SECONDS", "320")),
|
| 80 |
+
}
|
| 81 |
|
| 82 |
|
| 83 |
def build_client() -> OpenAI:
|
|
|
|
| 112 |
|
| 113 |
|
| 114 |
def summarize_observation(obs: Any, agent_state: dict[str, Any]) -> str:
|
| 115 |
+
"""Provide a compact structured summary for low-latency policy learning."""
|
| 116 |
packets = obs.visible_packets
|
| 117 |
revealed = [p for p in packets if p.is_revealed]
|
| 118 |
revealed_ids = [p.packet_id for p in revealed]
|
|
|
|
| 123 |
reward_feedback = agent_state.get("last_reward_feedback", "n/a")
|
| 124 |
recent_corrections = agent_state.get("recent_corrections", [])[-CORRECTION_WINDOW:]
|
| 125 |
strategy_hints = agent_state.get("strategy_hints", [])
|
| 126 |
+
task_name = agent_state.get("current_task_name", "")
|
| 127 |
+
|
| 128 |
+
flagged_count = len(obs.flagged_packet_ids)
|
| 129 |
+
total_visible = max(1, len(obs.visible_packets))
|
| 130 |
+
coverage = flagged_count / total_visible
|
| 131 |
+
coverage_target = TASK_COVERAGE_TARGETS.get(task_name, 0.25)
|
| 132 |
+
score_target = TASK_SCORE_TARGETS.get(task_name, 0.65)
|
| 133 |
+
grouped_count = len(sessions)
|
| 134 |
+
tagged_count = len(tags)
|
| 135 |
+
ready_to_submit = (
|
| 136 |
+
obs.current_score_estimate >= score_target
|
| 137 |
+
and coverage >= coverage_target
|
| 138 |
+
and (task_name == "easy" or grouped_count >= 2)
|
| 139 |
+
and (task_name == "easy" or tagged_count >= 1)
|
| 140 |
+
)
|
| 141 |
|
| 142 |
summary = [
|
| 143 |
f"Step: {obs.step_number}/{obs.step_number + obs.steps_remaining}",
|
| 144 |
f"Current Progress: {obs.current_score_estimate:.2f}",
|
| 145 |
+
f"Coverage: {flagged_count}/{total_visible} ({coverage:.2%}) | target {coverage_target:.0%}",
|
| 146 |
+
f"Sessions: grouped={grouped_count}, tagged={tagged_count}",
|
| 147 |
+
f"Submit Readiness: {'READY' if ready_to_submit else 'KEEP INVESTIGATING'}",
|
| 148 |
f"Last Step Reward: {last_reward:.2f}" if isinstance(last_reward, (int, float)) else "Last Step Reward: n/a",
|
| 149 |
f"Last Reward Feedback: {reward_feedback}",
|
| 150 |
+
f"ALREADY REVEALED: {', '.join(revealed_ids[-6:])} " + ("..." if len(revealed_ids) > 6 else ""),
|
| 151 |
"\n### SESSIONS PENDING TAGGING:",
|
| 152 |
]
|
| 153 |
|
|
|
|
| 162 |
summary.append(f"- {hint}")
|
| 163 |
|
| 164 |
if untagged_sessions:
|
| 165 |
+
for s in untagged_sessions[:6]:
|
| 166 |
summary.append(f"- {s} ({len(sessions[s])} packets)")
|
| 167 |
else:
|
| 168 |
summary.append("- [No pending sessions]")
|
| 169 |
|
| 170 |
summary.append("\n### REVEALED INDICATORS:")
|
| 171 |
+
for p in revealed[-4:]:
|
| 172 |
payload = (p.full_payload or "")[:150]
|
| 173 |
if payload:
|
| 174 |
summary.append(f"- {p.packet_id}: {payload}")
|
|
|
|
| 243 |
|
| 244 |
def keyword_to_pattern(payload: str) -> str | None:
|
| 245 |
text = payload.lower()
|
| 246 |
+
# --- DoS / DDoS variants ---
|
| 247 |
if "slowloris" in text:
|
| 248 |
return "dos_slowloris"
|
| 249 |
+
if "slowhttptest" in text or "slow http" in text:
|
| 250 |
return "dos_slowhttptest"
|
| 251 |
+
if "goldeneye" in text or "golden eye" in text:
|
| 252 |
return "dos_goldeneye"
|
| 253 |
if "hulk" in text:
|
| 254 |
return "dos_hulk"
|
| 255 |
+
if "heartbeat" in text or "heartbleed" in text or ("tls" in text and "ext" in text):
|
| 256 |
return "heartbleed"
|
| 257 |
+
if "flood" in text or "burst" in text or "ddos" in text:
|
| 258 |
+
return "ddos"
|
| 259 |
+
# HTTP flood indicators (repeated GET/POST to same endpoint)
|
| 260 |
+
if text.startswith("get /") or text.startswith("post /") or text.startswith("get http"):
|
| 261 |
+
if "accept-encoding" in text or "connection" in text or "keep-alive" in text:
|
| 262 |
+
return "ddos"
|
| 263 |
+
# SYN flood / connection flood
|
| 264 |
+
if "syn" in text and "ack" not in text and len(text) < 30:
|
| 265 |
+
return "ddos"
|
| 266 |
+
# ICMP flood
|
| 267 |
+
if "icmp" in text and ("echo" in text or "request" in text or len(text) < 20):
|
| 268 |
+
return "ddos"
|
| 269 |
+
# --- Web attacks ---
|
| 270 |
+
if "xss" in text or "<script>" in text or "<scrip" in text or "/search?q=" in text or "onerror" in text or "onload" in text or "javascript:" in text or "alert(" in text or "%3cscript" in text:
|
| 271 |
return "web_xss"
|
| 272 |
if (
|
| 273 |
"or 1=1" in text
|
| 274 |
or "%20or" in text
|
| 275 |
or "/items?id=" in text
|
| 276 |
or "1=1" in text
|
| 277 |
+
or "' or " in text
|
| 278 |
+
or "'--" in text
|
| 279 |
+
or "union select" in text
|
| 280 |
+
or "union all select" in text
|
| 281 |
+
or "drop table" in text
|
| 282 |
+
or "select * from" in text
|
| 283 |
or "sql" in text
|
| 284 |
+
or "%27" in text # URL-encoded single quote
|
| 285 |
+
or "' and " in text
|
| 286 |
+
or "admin'--" in text
|
| 287 |
):
|
| 288 |
return "web_sql_injection"
|
| 289 |
+
if (
|
| 290 |
+
"login" in text
|
| 291 |
+
or "username=admin" in text
|
| 292 |
+
or "password=" in text
|
| 293 |
+
or "passwd=" in text
|
| 294 |
+
or "user=admin" in text
|
| 295 |
+
or "brute" in text
|
| 296 |
+
or "/login" in text
|
| 297 |
+
or "/signin" in text
|
| 298 |
+
or "/auth" in text
|
| 299 |
+
or "post /login" in text
|
| 300 |
+
or "post /sign" in text
|
| 301 |
+
):
|
| 302 |
return "web_bruteforce"
|
| 303 |
+
# --- C2 / exfil / scan / lateral ---
|
| 304 |
+
if "c2" in text or "command" in text or "shell" in text or "cmd" in text or "/bin/" in text or "reverse" in text:
|
| 305 |
+
return "c2"
|
| 306 |
+
if "exfil" in text or "exfiltrat" in text or "data_leak" in text or "dns_tunnel" in text:
|
| 307 |
+
return "exfiltration"
|
| 308 |
+
if "scan" in text or "nmap" in text or "port_scan" in text or "recon" in text:
|
| 309 |
+
return "scan"
|
| 310 |
+
if "lateral" in text or "pivot" in text or "spread" in text or "propagat" in text:
|
| 311 |
+
return "lateral"
|
| 312 |
return None
|
| 313 |
|
| 314 |
|
|
|
|
| 323 |
return (packet.src_ip, packet.dst_ip, packet.dst_port, pattern)
|
| 324 |
|
| 325 |
|
| 326 |
+
SUSPICIOUS_PORTS = {22, 23, 445, 1433, 3306, 5432, 4444, 5555, 6666, 6667, 7777, 8888, 9999, 31337}
|
| 327 |
+
SUSPICIOUS_PROTOCOLS = {"ICMP"}
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
def _infer_flow_pattern(packet: Any, flow_size: int) -> str | None:
|
| 331 |
+
"""Heuristic pattern inference from flow characteristics when keyword matching fails."""
|
| 332 |
+
dst_port = packet.dst_port
|
| 333 |
+
protocol = packet.protocol
|
| 334 |
+
flags = getattr(packet, "flags", []) or []
|
| 335 |
+
# High-density flows to web ports → likely DDoS
|
| 336 |
+
if flow_size >= 5 and dst_port in {80, 8080, 443, 8443}:
|
| 337 |
+
return "ddos"
|
| 338 |
+
# SYN-only flood
|
| 339 |
+
if flow_size >= 5 and flags == ["SYN"]:
|
| 340 |
+
return "ddos"
|
| 341 |
+
# Suspicious ports → C2 or lateral
|
| 342 |
+
if dst_port in SUSPICIOUS_PORTS:
|
| 343 |
+
if dst_port in {4444, 5555, 6666, 7777, 31337}:
|
| 344 |
+
return "c2"
|
| 345 |
+
if dst_port in {445, 1433, 3306, 5432}:
|
| 346 |
+
return "lateral"
|
| 347 |
+
# ICMP flood
|
| 348 |
+
if protocol in SUSPICIOUS_PROTOCOLS and flow_size >= 3:
|
| 349 |
+
return "ddos"
|
| 350 |
+
# High-density flow to non-standard port
|
| 351 |
+
if flow_size >= 8 and dst_port not in {53, 80, 443, 8080}:
|
| 352 |
+
return "scan"
|
| 353 |
+
return None
|
| 354 |
+
|
| 355 |
+
|
| 356 |
def session_candidates(obs: Any) -> list[tuple[tuple[str, str, int, str], list[Any]]]:
|
| 357 |
grouped: dict[tuple[str, str, int, str], list[Any]] = {}
|
| 358 |
attack_source_ports: dict[tuple[str, str, int, str], set[int]] = {}
|
| 359 |
+
# Phase 1: keyword-based grouping (high confidence)
|
| 360 |
for packet in obs.visible_packets:
|
| 361 |
pattern = keyword_to_pattern(packet_payload_text(packet))
|
| 362 |
if pattern:
|
|
|
|
| 364 |
grouped.setdefault(key, []).append(packet)
|
| 365 |
attack_source_ports.setdefault(key, set()).add(packet.src_port)
|
| 366 |
|
| 367 |
+
# Add reverse-response packets to keyword-matched sessions
|
| 368 |
for key, source_ports in attack_source_ports.items():
|
| 369 |
src_ip, dst_ip, dst_port, _pattern = key
|
| 370 |
for packet in obs.visible_packets:
|
|
|
|
| 377 |
if is_reverse_response:
|
| 378 |
grouped[key].append(packet)
|
| 379 |
|
| 380 |
+
# Phase 2: flow-based grouping for packets without keyword match
|
| 381 |
+
# Group unclaimed packets by (src_ip, dst_ip, dst_port) and infer pattern
|
| 382 |
+
claimed_ids: set[str] = set()
|
| 383 |
+
for items in grouped.values():
|
| 384 |
+
for p in items:
|
| 385 |
+
claimed_ids.add(p.packet_id)
|
| 386 |
+
|
| 387 |
+
flow_buckets: dict[tuple[str, str, int], list[Any]] = {}
|
| 388 |
+
for packet in obs.visible_packets:
|
| 389 |
+
if packet.packet_id in claimed_ids:
|
| 390 |
+
continue
|
| 391 |
+
flow_key = (packet.src_ip, packet.dst_ip, packet.dst_port)
|
| 392 |
+
flow_buckets.setdefault(flow_key, []).append(packet)
|
| 393 |
+
|
| 394 |
+
for flow_key, items in flow_buckets.items():
|
| 395 |
+
if len(items) < 2:
|
| 396 |
+
continue
|
| 397 |
+
pattern = _infer_flow_pattern(items[0], len(items))
|
| 398 |
+
if pattern:
|
| 399 |
+
session_key = (*flow_key, pattern)
|
| 400 |
+
grouped.setdefault(session_key, []).extend(items)
|
| 401 |
+
for p in items:
|
| 402 |
+
claimed_ids.add(p.packet_id)
|
| 403 |
+
|
| 404 |
candidates = [
|
| 405 |
(
|
| 406 |
key,
|
|
|
|
| 421 |
return 0
|
| 422 |
|
| 423 |
|
| 424 |
+
def select_inspect_packet(
|
| 425 |
+
obs: Any,
|
| 426 |
+
inspected_ids: set[str],
|
| 427 |
+
flagged_ids: set[str] | None = None,
|
| 428 |
+
) -> str | None:
|
| 429 |
+
flagged_ids = flagged_ids or set()
|
| 430 |
+
unrevealed = [
|
| 431 |
+
p
|
| 432 |
+
for p in obs.visible_packets
|
| 433 |
+
if (not p.is_revealed)
|
| 434 |
+
and (p.packet_id not in inspected_ids)
|
| 435 |
+
and (p.packet_id not in flagged_ids)
|
| 436 |
+
]
|
| 437 |
if not unrevealed:
|
| 438 |
return None
|
| 439 |
|
|
|
|
| 459 |
def append_action_history(agent_state: dict[str, Any], action: NetworkForensicsAction) -> None:
|
| 460 |
history = agent_state.setdefault("previous_actions", [])
|
| 461 |
history.append(format_action(action))
|
| 462 |
+
if action.action_type == "inspect_packet" and action.packet_id:
|
| 463 |
+
inspected_ids = agent_state.setdefault("inspected_ids", set())
|
| 464 |
+
inspected_ids.add(action.packet_id)
|
| 465 |
if len(history) > HISTORY_WINDOW:
|
| 466 |
del history[:-HISTORY_WINDOW]
|
| 467 |
|
|
|
|
| 503 |
candidate_packets, flagged_ids, visible_by_id
|
| 504 |
)
|
| 505 |
size = len(candidate_packets)
|
| 506 |
+
# Lowered thresholds for more aggressive grouping
|
| 507 |
if task_name == "easy":
|
| 508 |
min_flagged = 1 if size >= 2 else 0
|
| 509 |
elif task_name == "medium":
|
| 510 |
+
min_flagged = 1 if size >= 2 else 0
|
| 511 |
else:
|
| 512 |
+
min_flagged = 1 if size >= 3 else 0
|
| 513 |
+
if trusted_pattern and size >= 3:
|
| 514 |
min_flagged = 1
|
| 515 |
if flagged >= min_flagged:
|
| 516 |
return True
|
| 517 |
# Allow grouping with strong revealed malicious evidence.
|
| 518 |
+
if task_name == "easy" and (malicious_revealed >= 1 or revealed >= 1):
|
| 519 |
return True
|
| 520 |
+
if task_name == "medium" and malicious_revealed >= 1 and revealed >= 1:
|
|
|
|
| 521 |
return True
|
| 522 |
+
if malicious_revealed >= 1 and revealed >= min(2, size):
|
| 523 |
+
return True
|
| 524 |
+
# After a pattern has been confirmed by tagging, allow structure-first grouping.
|
| 525 |
+
if trusted_pattern and size >= 3:
|
| 526 |
return True
|
| 527 |
+
# Large flows are very likely attack sessions - allow with minimal evidence
|
| 528 |
+
if size >= 6 and (flagged >= 1 or revealed >= 2 or malicious_revealed >= 1):
|
| 529 |
return True
|
| 530 |
return False
|
| 531 |
|
|
|
|
| 567 |
)
|
| 568 |
|
| 569 |
inspect_limit = {
|
| 570 |
+
"easy": 18,
|
| 571 |
+
"medium": 20,
|
| 572 |
+
"hard": 25,
|
| 573 |
+
}.get(agent_state.get("current_task_name", ""), 15)
|
| 574 |
if len(previous_actions) >= inspect_limit and inspect_ratio >= INSPECT_SOFT_RATIO_THRESHOLD:
|
| 575 |
hints.append(
|
| 576 |
"You are over-inspecting. Shift to flagging, grouping, tagging, or report submission unless the next packet is clearly high-value."
|
|
|
|
| 578 |
return hints
|
| 579 |
|
| 580 |
|
| 581 |
+
def should_submit_early(task_name: str, obs: Any, agent_state: dict[str, Any]) -> bool:
|
| 582 |
+
flagged_count = len(obs.flagged_packet_ids)
|
| 583 |
+
total_visible = max(1, len(obs.visible_packets))
|
| 584 |
+
coverage = flagged_count / total_visible
|
| 585 |
+
score = float(obs.current_score_estimate)
|
| 586 |
+
sessions = obs.grouped_sessions or {}
|
| 587 |
+
tags = obs.tagged_patterns or {}
|
| 588 |
+
|
| 589 |
+
score_target = TASK_SCORE_TARGETS.get(task_name, 0.65)
|
| 590 |
+
coverage_target = TASK_COVERAGE_TARGETS.get(task_name, 0.25)
|
| 591 |
+
|
| 592 |
+
if task_name == "easy":
|
| 593 |
+
return (
|
| 594 |
+
coverage >= max(coverage_target * 0.7, 0.20)
|
| 595 |
+
and flagged_count >= 6
|
| 596 |
+
and len(sessions) >= 1
|
| 597 |
+
)
|
| 598 |
+
if task_name == "medium":
|
| 599 |
+
return (
|
| 600 |
+
score >= score_target * 0.8
|
| 601 |
+
and coverage >= coverage_target * 0.7
|
| 602 |
+
and len(sessions) >= 1
|
| 603 |
+
and len(tags) >= 1
|
| 604 |
+
)
|
| 605 |
+
return (
|
| 606 |
+
score >= score_target * 0.8
|
| 607 |
+
and coverage >= coverage_target * 0.7
|
| 608 |
+
and len(sessions) >= 2
|
| 609 |
+
and len(tags) >= 1
|
| 610 |
+
and bool(agent_state.get("claimed_entry_point") or obs.claimed_entry_point)
|
| 611 |
+
)
|
| 612 |
+
|
| 613 |
+
|
| 614 |
def build_fallback_action(
|
| 615 |
task_name: str, obs: Any, agent_state: dict[str, Any]
|
| 616 |
) -> NetworkForensicsAction:
|
| 617 |
+
"""Smart workflow engine: Flag aggressive -> Group -> Tag -> Entry Point -> Report."""
|
| 618 |
inspected_ids = agent_state.setdefault("inspected_ids", set())
|
| 619 |
flagged_ids = agent_state.setdefault("flagged_ids", set())
|
| 620 |
session_map = agent_state.setdefault("sessions", {}) # key -> session_name
|
|
|
|
| 623 |
visible_by_id = {p.packet_id: p for p in obs.visible_packets}
|
| 624 |
trusted = trusted_patterns(session_map, tagged_sessions)
|
| 625 |
|
| 626 |
+
if obs.steps_remaining <= 1 or should_submit_early(task_name, obs, agent_state):
|
| 627 |
summary = _build_report_summary(obs, agent_state)
|
| 628 |
return NetworkForensicsAction(
|
| 629 |
action_type="submit_report",
|
|
|
|
| 631 |
claimed_entry_point=claimed_entry,
|
| 632 |
)
|
| 633 |
|
| 634 |
+
# PHASE 1: Aggressive flag of ALL revealed malicious packets
|
| 635 |
+
# This maximizes recall by comprehensively flagging known-bad traffic
|
| 636 |
+
unflagged_malicious = []
|
| 637 |
for packet in obs.visible_packets:
|
| 638 |
if packet.is_revealed and packet.packet_id not in flagged_ids:
|
| 639 |
payload = packet.full_payload or ""
|
| 640 |
pattern = keyword_to_pattern(payload)
|
| 641 |
if pattern:
|
| 642 |
+
unflagged_malicious.append(packet.packet_id)
|
| 643 |
+
|
| 644 |
+
if unflagged_malicious:
|
| 645 |
+
# Flag up to 5 per turn for aggressive recall buildup
|
| 646 |
+
target = min(5, len(unflagged_malicious))
|
| 647 |
+
for _ in range(target):
|
| 648 |
+
if unflagged_malicious:
|
| 649 |
+
pid = unflagged_malicious.pop(0)
|
| 650 |
+
flagged_ids.add(pid)
|
| 651 |
return NetworkForensicsAction(
|
| 652 |
action_type="flag_as_suspicious",
|
| 653 |
+
packet_id=pid,
|
| 654 |
)
|
| 655 |
|
| 656 |
# PHASE 2: Group flagged packets into sessions with evidence gate and backlog pacing.
|
| 657 |
+
min_flagged_before_group = 1 if task_name == "easy" else 2
|
| 658 |
untagged_backlog = max(0, len(session_map) - len(tagged_sessions))
|
| 659 |
+
if len(flagged_ids) >= min_flagged_before_group and untagged_backlog <= UNTAGGED_BACKLOG_LIMIT:
|
| 660 |
candidates = session_candidates(obs)
|
| 661 |
for key, items in candidates:
|
| 662 |
if key in session_map:
|
|
|
|
| 678 |
packet_ids=packet_ids,
|
| 679 |
)
|
| 680 |
|
| 681 |
+
# PHASE 2.5: Recall sweep - flag packets that are already part of grouped sessions.
|
| 682 |
+
# This boosts recall quickly without requiring more inspections.
|
| 683 |
+
grouped_packets = []
|
| 684 |
+
for packet_ids in (obs.grouped_sessions or {}).values():
|
| 685 |
+
grouped_packets.extend(packet_ids)
|
| 686 |
+
for pid in sorted(set(grouped_packets), key=packet_sort_key):
|
| 687 |
+
if pid in flagged_ids:
|
| 688 |
+
continue
|
| 689 |
+
if pid in visible_by_id:
|
| 690 |
+
flagged_ids.add(pid)
|
| 691 |
+
return NetworkForensicsAction(
|
| 692 |
+
action_type="flag_as_suspicious",
|
| 693 |
+
packet_id=pid,
|
| 694 |
+
)
|
| 695 |
+
|
| 696 |
+
# PHASE 3: Tag ALL untagged sessions aggressively (critical for medium/hard logic_score).
|
| 697 |
+
# Tagging helps LLM report score and logic_score for all difficulties.
|
| 698 |
+
for key, session_name in session_map.items():
|
| 699 |
+
if session_name in tagged_sessions:
|
| 700 |
+
continue
|
| 701 |
+
_src_ip, _dst_ip, _dst_port, pattern = key
|
| 702 |
+
tagged_sessions.add(session_name)
|
| 703 |
+
return NetworkForensicsAction(
|
| 704 |
+
action_type="tag_pattern",
|
| 705 |
+
session_name=session_name,
|
| 706 |
+
pattern_type=pattern,
|
| 707 |
+
)
|
| 708 |
+
# Also tag any observed sessions not yet in our session_map
|
| 709 |
+
for session_name, session_data in (obs.grouped_sessions or {}).items():
|
| 710 |
+
if session_name in tagged_sessions:
|
| 711 |
+
continue
|
| 712 |
+
if session_name in (obs.tagged_patterns or {}):
|
| 713 |
+
tagged_sessions.add(session_name)
|
| 714 |
+
continue
|
| 715 |
+
# Infer pattern from session packets
|
| 716 |
+
pattern = None
|
| 717 |
+
for pid in session_data:
|
| 718 |
+
pkt = visible_by_id.get(pid)
|
| 719 |
+
if pkt and pkt.is_revealed:
|
| 720 |
+
pattern = keyword_to_pattern(packet_payload_text(pkt))
|
| 721 |
+
if pattern:
|
| 722 |
+
break
|
| 723 |
+
if not pattern:
|
| 724 |
+
# Try flow-based inference
|
| 725 |
+
pkt = visible_by_id.get(session_data[0]) if session_data else None
|
| 726 |
+
if pkt:
|
| 727 |
+
pattern = _infer_flow_pattern(pkt, len(session_data))
|
| 728 |
+
if pattern:
|
| 729 |
tagged_sessions.add(session_name)
|
| 730 |
return NetworkForensicsAction(
|
| 731 |
action_type="tag_pattern",
|
|
|
|
| 733 |
pattern_type=pattern,
|
| 734 |
)
|
| 735 |
|
| 736 |
+
# PHASE 4: Identify entry point - CRITICAL for hard mode (score=0 without it)
|
| 737 |
+
if not claimed_entry:
|
| 738 |
+
entry_candidate = None
|
| 739 |
+
# Strategy 1: earliest packet in any grouped session from observation
|
| 740 |
+
try:
|
| 741 |
+
grouped_packets = set()
|
| 742 |
+
for session_name in session_map.values():
|
| 743 |
+
if obs.grouped_sessions and session_name in obs.grouped_sessions:
|
| 744 |
+
grouped_packets.update(obs.grouped_sessions[session_name])
|
| 745 |
+
if grouped_packets:
|
| 746 |
+
entry_candidate = min(grouped_packets, key=lambda pid: packet_sort_key(pid))
|
| 747 |
+
except Exception:
|
| 748 |
+
pass
|
| 749 |
+
# Strategy 2: earliest flagged packet (often the first discovered attack)
|
| 750 |
+
if not entry_candidate and flagged_ids:
|
| 751 |
+
entry_candidate = min(flagged_ids, key=lambda pid: packet_sort_key(pid))
|
| 752 |
+
# Strategy 3: earliest revealed malicious packet
|
| 753 |
+
if not entry_candidate:
|
| 754 |
+
revealed_malicious = [
|
| 755 |
+
p for p in obs.visible_packets
|
| 756 |
+
if p.is_revealed and keyword_to_pattern(packet_payload_text(p))
|
| 757 |
+
]
|
| 758 |
+
if revealed_malicious:
|
| 759 |
+
entry_candidate = min(
|
| 760 |
+
revealed_malicious, key=lambda p: packet_sort_key(p.packet_id)
|
| 761 |
+
).packet_id
|
| 762 |
+
# Strategy 4: earliest packet in session_candidates
|
| 763 |
+
if not entry_candidate:
|
| 764 |
+
all_session_packets = []
|
| 765 |
+
for key, items in session_candidates(obs):
|
| 766 |
+
for p in items:
|
| 767 |
+
all_session_packets.append(p.packet_id)
|
| 768 |
+
if all_session_packets:
|
| 769 |
+
entry_candidate = min(all_session_packets, key=packet_sort_key)
|
| 770 |
+
# Strategy 5: earliest flagged packet from observation
|
| 771 |
+
if not entry_candidate and obs.flagged_packet_ids:
|
| 772 |
+
entry_candidate = min(obs.flagged_packet_ids, key=packet_sort_key)
|
| 773 |
+
if entry_candidate:
|
| 774 |
+
agent_state["claimed_entry_point"] = entry_candidate
|
| 775 |
+
return NetworkForensicsAction(
|
| 776 |
+
action_type="identify_entry_point",
|
| 777 |
+
claimed_entry_point=entry_candidate,
|
| 778 |
+
)
|
| 779 |
|
| 780 |
+
# PHASE 5: Inspect more unrevealed packets (to discover more malicious traffic)
|
| 781 |
+
inspect_id = select_inspect_packet(obs, inspected_ids, flagged_ids)
|
| 782 |
if inspect_id is not None:
|
| 783 |
return NetworkForensicsAction(action_type="inspect_packet", packet_id=inspect_id)
|
| 784 |
|
|
|
|
| 792 |
|
| 793 |
|
| 794 |
def _build_report_summary(obs: Any, agent_state: dict[str, Any]) -> str:
|
| 795 |
+
"""Generate a detailed incident summary for high LLM judge scores."""
|
| 796 |
flagged = agent_state.get("flagged_ids", set())
|
| 797 |
sessions = agent_state.get("sessions", {})
|
| 798 |
tagged = agent_state.get("tagged_sessions", set())
|
| 799 |
+
entry_point = agent_state.get("claimed_entry_point") or getattr(obs, "claimed_entry_point", None)
|
| 800 |
+
patterns_by_session: dict[str, str] = {}
|
| 801 |
+
src_ips_by_pattern: dict[str, set[str]] = {}
|
| 802 |
+
dst_ips_by_pattern: dict[str, set[str]] = {}
|
| 803 |
+
for key, session_name in sessions.items():
|
| 804 |
if len(key) >= 4:
|
| 805 |
+
pattern = key[3]
|
| 806 |
+
patterns_by_session[session_name] = pattern
|
| 807 |
+
src_ips_by_pattern.setdefault(pattern, set()).add(key[0])
|
| 808 |
+
dst_ips_by_pattern.setdefault(pattern, set()).add(key[1])
|
| 809 |
+
|
| 810 |
+
# Build detailed per-pattern section
|
| 811 |
+
pattern_details = []
|
| 812 |
+
for pattern in sorted(set(patterns_by_session.values())):
|
| 813 |
+
srcs = ", ".join(sorted(src_ips_by_pattern.get(pattern, set()))[:5])
|
| 814 |
+
dsts = ", ".join(sorted(dst_ips_by_pattern.get(pattern, set()))[:5])
|
| 815 |
+
session_names = [n for n, p in patterns_by_session.items() if p == pattern]
|
| 816 |
+
pattern_details.append(
|
| 817 |
+
f" - {pattern}: {len(session_names)} session(s) from {srcs} targeting {dsts}"
|
| 818 |
+
)
|
| 819 |
+
pattern_section = "\n".join(pattern_details) if pattern_details else " - No patterns classified"
|
| 820 |
+
|
| 821 |
+
# Tagged pattern summary
|
| 822 |
+
tagged_details = []
|
| 823 |
+
for session_name in sorted(tagged):
|
| 824 |
+
pattern = patterns_by_session.get(session_name, "unknown")
|
| 825 |
+
tagged_details.append(f"{session_name}={pattern}")
|
| 826 |
+
tagged_section = "; ".join(tagged_details) if tagged_details else "none"
|
| 827 |
+
|
| 828 |
+
entry_section = f"Entry point: {entry_point}" if entry_point else "Entry point: not identified"
|
| 829 |
+
|
| 830 |
return (
|
| 831 |
+
f"INCIDENT REPORT\n\n"
|
| 832 |
+
f"Summary: Detected {len(flagged)} malicious packets across "
|
| 833 |
+
f"{len(sessions)} attack sessions.\n\n"
|
| 834 |
+
f"Attack Patterns:\n{pattern_section}\n\n"
|
| 835 |
+
f"Tagged Sessions: {tagged_section}\n\n"
|
| 836 |
+
f"{entry_section}\n\n"
|
| 837 |
+
f"Total flagged: {len(flagged)} | Total sessions: {len(sessions)} | "
|
| 838 |
+
f"Classified sessions: {len(tagged)}"
|
| 839 |
)
|
| 840 |
|
| 841 |
|
|
|
|
| 856 |
inspect_count = sum(1 for a in previous_actions if '"inspect_packet"' in a)
|
| 857 |
revealed_count = sum(1 for p in obs.visible_packets if p.is_revealed)
|
| 858 |
inspect_limit = {
|
| 859 |
+
"easy": 25,
|
| 860 |
+
"medium": 18,
|
| 861 |
+
"hard": 25,
|
| 862 |
+
}.get(task_name, 15)
|
| 863 |
|
| 864 |
if action.action_type not in {
|
| 865 |
"inspect_packet",
|
|
|
|
| 880 |
return "Missing packet_id for inspect_packet"
|
| 881 |
if action.packet_id not in {p.packet_id for p in obs.visible_packets}:
|
| 882 |
return f"Invalid packet_id {action.packet_id} - not in visible_packets"
|
| 883 |
+
inspected_ids = agent_state.setdefault("inspected_ids", set())
|
| 884 |
+
if action.packet_id in inspected_ids:
|
| 885 |
+
return f"Packet {action.packet_id} was already inspected. Choose a different hidden packet."
|
| 886 |
revealed_ids = {p.packet_id for p in obs.visible_packets if p.is_revealed}
|
| 887 |
if action.packet_id in revealed_ids:
|
| 888 |
return f"Packet {action.packet_id} is ALREADY revealed. Choose a HIDDEN packet."
|
| 889 |
+
if action.packet_id in set(obs.flagged_packet_ids):
|
| 890 |
+
return (
|
| 891 |
+
f"Packet {action.packet_id} is already flagged. Inspect a new hidden unflagged packet instead."
|
| 892 |
+
)
|
| 893 |
+
revealed_unflagged_malicious = [
|
| 894 |
+
p.packet_id
|
| 895 |
+
for p in obs.visible_packets
|
| 896 |
+
if p.is_revealed
|
| 897 |
+
and p.packet_id not in set(obs.flagged_packet_ids)
|
| 898 |
+
and keyword_to_pattern(packet_payload_text(p))
|
| 899 |
+
]
|
| 900 |
+
if revealed_unflagged_malicious:
|
| 901 |
+
return (
|
| 902 |
+
"Recall-first policy: revealed malicious packets exist and must be flagged before new inspection."
|
| 903 |
+
)
|
| 904 |
+
grouped_unflagged = [
|
| 905 |
+
pid
|
| 906 |
+
for packet_ids in (obs.grouped_sessions or {}).values()
|
| 907 |
+
for pid in packet_ids
|
| 908 |
+
if pid not in set(obs.flagged_packet_ids)
|
| 909 |
+
]
|
| 910 |
+
if grouped_unflagged:
|
| 911 |
return (
|
| 912 |
+
"Recall-first policy: grouped session packets remain unflagged. Flag them before further inspection."
|
| 913 |
)
|
| 914 |
+
if task_name == "easy" and len(flagged_ids) >= 4:
|
| 915 |
+
grouped_session_names = set((obs.grouped_sessions or {}).keys())
|
| 916 |
+
for key, items in session_candidates(obs):
|
| 917 |
+
if key in sessions:
|
| 918 |
+
continue
|
| 919 |
+
if len(items) >= 4:
|
| 920 |
+
return (
|
| 921 |
+
"Exploit mode: enough evidence exists. Group high-confidence attack flows before more inspection."
|
| 922 |
+
)
|
| 923 |
+
if inspect_count >= inspect_limit and (len(sessions) > 0 or len(flagged_ids) > 0 or revealed_count >= 4):
|
| 924 |
+
# Only block inspections for medium/hard modes; easy mode needs discovery
|
| 925 |
+
if task_name != "easy":
|
| 926 |
+
return (
|
| 927 |
+
f"Inspection budget reached for {task_name}. Shift to flagging, grouping, tagging, or report submission."
|
| 928 |
+
)
|
| 929 |
|
| 930 |
if action.action_type == "flag_as_suspicious":
|
| 931 |
if not action.packet_id:
|
|
|
|
| 945 |
}
|
| 946 |
if invalid_ids:
|
| 947 |
return f"Invalid packet_ids in session: {invalid_ids}"
|
| 948 |
+
if action.session_name in sessions.values():
|
| 949 |
+
return f"Session name {action.session_name} is already used."
|
| 950 |
+
min_flagged_before_group = 1 if task_name == "easy" else 1
|
| 951 |
+
if len(flagged_ids) < min_flagged_before_group:
|
| 952 |
+
return (
|
| 953 |
+
f"Group blocked until enough evidence is flagged ({len(flagged_ids)}/{min_flagged_before_group}). "
|
| 954 |
+
"Inspect and flag suspicious packets first."
|
| 955 |
+
)
|
| 956 |
+
new_group_ids = set(action.packet_ids)
|
| 957 |
+
for existing_ids in (obs.grouped_sessions or {}).values():
|
| 958 |
+
existing_set = set(existing_ids)
|
| 959 |
+
if not existing_set:
|
| 960 |
+
continue
|
| 961 |
+
overlap = len(new_group_ids & existing_set) / max(1, len(new_group_ids))
|
| 962 |
+
if overlap >= 0.8:
|
| 963 |
+
return "This grouping heavily overlaps an existing session. Prioritize new evidence."
|
| 964 |
untagged_backlog = max(0, len(sessions) - len(tagged_sessions))
|
| 965 |
if untagged_backlog > UNTAGGED_BACKLOG_LIMIT:
|
| 966 |
return (
|
|
|
|
| 986 |
|
| 987 |
if action.action_type == "submit_report":
|
| 988 |
untagged_backlog = max(0, len(sessions) - len(tagged_sessions))
|
| 989 |
+
total_visible = max(1, len(obs.visible_packets))
|
| 990 |
+
flagged_count = len(obs.flagged_packet_ids)
|
| 991 |
+
coverage = flagged_count / total_visible
|
| 992 |
+
min_cov = TASK_COVERAGE_TARGETS.get(task_name, 0.25) * 0.6
|
| 993 |
+
min_flags = 4 if task_name == "easy" else (3 if task_name == "medium" else 4)
|
| 994 |
+
min_groups = 1 if task_name == "easy" else (2 if task_name == "medium" else 2)
|
| 995 |
+
if (
|
| 996 |
+
obs.steps_remaining > 2
|
| 997 |
+
and obs.current_score_estimate < 0.40
|
| 998 |
+
and not should_submit_early(task_name, obs, agent_state)
|
| 999 |
+
):
|
| 1000 |
return (
|
| 1001 |
"Premature report submission. Improve coverage and score estimate before submit_report."
|
| 1002 |
)
|
| 1003 |
+
if obs.steps_remaining > 1 and (coverage < min_cov or flagged_count < min_flags):
|
| 1004 |
+
return (
|
| 1005 |
+
f"Premature report submission. Need stronger recall coverage before submit_report "
|
| 1006 |
+
f"(coverage {coverage:.0%}/{min_cov:.0%}, flags {flagged_count}/{min_flags})."
|
| 1007 |
+
)
|
| 1008 |
+
if obs.steps_remaining > 1 and len(sessions) < min_groups:
|
| 1009 |
+
return (
|
| 1010 |
+
f"Premature report submission. Need stronger session evidence before submit_report "
|
| 1011 |
+
f"(grouped {len(sessions)}/{min_groups})."
|
| 1012 |
+
)
|
| 1013 |
+
if task_name == "hard" and obs.steps_remaining > 3 and untagged_backlog > 0:
|
| 1014 |
return "Premature report submission. Tag pending sessions before submitting report."
|
| 1015 |
+
# CRITICAL: Hard mode zero-out if no entry point identified
|
| 1016 |
+
if task_name == "hard" and not (agent_state.get("claimed_entry_point") or obs.claimed_entry_point):
|
| 1017 |
+
return (
|
| 1018 |
+
"FATAL: Hard mode requires identify_entry_point before submit_report. "
|
| 1019 |
+
"No entry point claimed yet — score will be 0.0 without it. "
|
| 1020 |
+
"Use identify_entry_point with the earliest malicious packet first."
|
| 1021 |
+
)
|
| 1022 |
+
# Medium mode: need entry point for good logic_score
|
| 1023 |
+
if task_name == "medium" and obs.steps_remaining > 5 and not (agent_state.get("claimed_entry_point") or obs.claimed_entry_point):
|
| 1024 |
+
return (
|
| 1025 |
+
"Missing entry point. Use identify_entry_point before submit_report for higher score."
|
| 1026 |
+
)
|
| 1027 |
+
# Require minimum tagging coverage for medium/hard
|
| 1028 |
+
min_tagged = 1 if task_name == "medium" else 2
|
| 1029 |
+
if task_name in {"medium", "hard"} and len(tagged_sessions) < min_tagged and obs.steps_remaining > 3:
|
| 1030 |
+
return (
|
| 1031 |
+
f"Premature report submission. Need at least {min_tagged} tagged session(s) before submit_report "
|
| 1032 |
+
f"(currently {len(tagged_sessions)})."
|
| 1033 |
+
)
|
| 1034 |
|
| 1035 |
if action.action_type == "tag_pattern":
|
| 1036 |
if not action.session_name:
|
| 1037 |
return "Missing session_name for tag_pattern"
|
| 1038 |
if not action.pattern_type:
|
| 1039 |
return "Missing pattern_type for tag_pattern"
|
| 1040 |
+
if action.session_name in set((obs.tagged_patterns or {}).keys()):
|
| 1041 |
+
return f"Session {action.session_name} is already tagged."
|
| 1042 |
+
if task_name == "easy" and obs.steps_remaining > 8:
|
| 1043 |
+
return "For easy mode, prioritize recall actions (inspect/flag/group) before tagging."
|
| 1044 |
valid_patterns = {
|
| 1045 |
"ddos", "dos_slowloris", "dos_slowhttptest", "dos_goldeneye", "dos_hulk",
|
| 1046 |
"heartbleed", "web_sql_injection", "web_xss", "web_bruteforce",
|
|
|
|
| 1052 |
if action.action_type == "identify_entry_point":
|
| 1053 |
if not action.claimed_entry_point:
|
| 1054 |
return "Missing claimed_entry_point for identify_entry_point"
|
| 1055 |
+
# Lenient gating for easy mode
|
| 1056 |
+
min_flags_needed = 1 if task_name == "easy" else (2 if task_name == "medium" else 2)
|
| 1057 |
+
if obs.steps_remaining > 8 and len(flagged_ids) < min_flags_needed:
|
| 1058 |
return (
|
| 1059 |
"Premature entry-point claim. Gather and flag more evidence before identify_entry_point."
|
| 1060 |
)
|
|
|
|
| 1071 |
) -> NetworkForensicsAction:
|
| 1072 |
agent_state["current_task_name"] = task_name
|
| 1073 |
agent_state["strategy_hints"] = derive_strategy_hints(obs, agent_state)
|
| 1074 |
+
if should_submit_early(task_name, obs, agent_state):
|
| 1075 |
+
action = NetworkForensicsAction(
|
| 1076 |
+
action_type="submit_report",
|
| 1077 |
+
incident_summary=_build_report_summary(obs, agent_state),
|
| 1078 |
+
claimed_entry_point=agent_state.get("claimed_entry_point") or obs.claimed_entry_point,
|
| 1079 |
+
)
|
| 1080 |
+
append_action_history(agent_state, action)
|
| 1081 |
+
return action
|
| 1082 |
history = agent_state.get("previous_actions", [])[-HISTORY_WINDOW:]
|
| 1083 |
history_str = "\n".join([f"Step {i+1}: {a}" for i, a in enumerate(history)])
|
| 1084 |
|
|
|
|
| 1093 |
"Follow the JSON schema in the system prompt."
|
| 1094 |
)
|
| 1095 |
|
| 1096 |
+
try:
|
| 1097 |
+
response = client.chat.completions.create(
|
| 1098 |
+
model=model_name or MODEL_NAME,
|
| 1099 |
+
temperature=0.1,
|
| 1100 |
+
timeout=LLM_TIMEOUT_S,
|
| 1101 |
+
messages=[
|
| 1102 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 1103 |
+
{
|
| 1104 |
+
"role": "user",
|
| 1105 |
+
"content": f"TASK: {task_name}{correction_text}\n\n### RECENT HISTORY:\n{history_str}\n\n### CURRENT OBSERVATION:\n{summarize_observation(obs, agent_state)}",
|
| 1106 |
+
},
|
| 1107 |
+
],
|
| 1108 |
+
)
|
| 1109 |
+
except Exception as llm_exc:
|
| 1110 |
+
print(f"[WARN] LLM call failed/timed out: {llm_exc}")
|
| 1111 |
+
fallback = build_fallback_action(task_name, obs, agent_state)
|
| 1112 |
+
append_action_history(agent_state, fallback)
|
| 1113 |
+
return fallback
|
| 1114 |
content = response.choices[0].message.content or ""
|
| 1115 |
try:
|
| 1116 |
action = sanitize_action(parse_action(content))
|
|
|
|
| 1272 |
return result
|
| 1273 |
|
| 1274 |
|
| 1275 |
+
WS_RETRY_COUNT = 3
|
| 1276 |
+
WS_RETRY_DELAY_S = 2.0
|
| 1277 |
+
LLM_TIMEOUT_S = 45.0
|
| 1278 |
+
|
| 1279 |
+
|
| 1280 |
+
def step_env_with_retry(
|
| 1281 |
+
env: NetworkForensicsEnv,
|
| 1282 |
+
action: NetworkForensicsAction,
|
| 1283 |
+
task_name: str,
|
| 1284 |
+
agent_state: dict[str, Any],
|
| 1285 |
+
) -> tuple[Any, NetworkForensicsEnv | None]:
|
| 1286 |
+
"""Try step_env with retries on WebSocket timeout.
|
| 1287 |
+
|
| 1288 |
+
Returns (step_result, new_env_or_None).
|
| 1289 |
+
If the WebSocket connection drops, reconnects and retries.
|
| 1290 |
+
"""
|
| 1291 |
+
last_exc = None
|
| 1292 |
+
for attempt in range(1, WS_RETRY_COUNT + 1):
|
| 1293 |
+
try:
|
| 1294 |
+
result = step_env(env, action)
|
| 1295 |
+
return result, None
|
| 1296 |
+
except Exception as exc:
|
| 1297 |
+
last_exc = exc
|
| 1298 |
+
exc_str = str(exc).lower()
|
| 1299 |
+
is_ws_timeout = any(
|
| 1300 |
+
kw in exc_str
|
| 1301 |
+
for kw in ("keepalive", "ping timeout", "1011", "websocket", "connection")
|
| 1302 |
+
)
|
| 1303 |
+
if not is_ws_timeout:
|
| 1304 |
+
raise
|
| 1305 |
+
print(
|
| 1306 |
+
f"[WARN] WebSocket timeout on attempt {attempt}/{WS_RETRY_COUNT}: {exc}"
|
| 1307 |
+
)
|
| 1308 |
+
if attempt < WS_RETRY_COUNT:
|
| 1309 |
+
time.sleep(WS_RETRY_DELAY_S * attempt)
|
| 1310 |
+
# Try reconnecting
|
| 1311 |
+
try:
|
| 1312 |
+
close_env(env)
|
| 1313 |
+
except Exception:
|
| 1314 |
+
pass
|
| 1315 |
+
try:
|
| 1316 |
+
env = create_env()
|
| 1317 |
+
reset_result = reset_env(env, task_name)
|
| 1318 |
+
obs = reset_result.observation
|
| 1319 |
+
sync_agent_state(obs, agent_state)
|
| 1320 |
+
print(f"[INFO] Reconnected to environment, resuming task={task_name}")
|
| 1321 |
+
except Exception as reconnect_exc:
|
| 1322 |
+
print(f"[WARN] Reconnect failed: {reconnect_exc}")
|
| 1323 |
+
continue
|
| 1324 |
+
raise last_exc # type: ignore[misc]
|
| 1325 |
+
|
| 1326 |
+
|
| 1327 |
def close_env(env: NetworkForensicsEnv | None) -> None:
|
| 1328 |
if env is None:
|
| 1329 |
return
|
|
|
|
| 1356 |
obs = reset_result.observation
|
| 1357 |
sync_agent_state(obs, agent_state)
|
| 1358 |
max_steps = obs.steps_remaining or 50
|
| 1359 |
+
soft_budget = min(max_steps, SOFT_STEP_BUDGETS.get(task_name, max_steps))
|
| 1360 |
+
hard_budget = min(max_steps, HARD_STEP_CAPS.get(task_name, max_steps))
|
| 1361 |
+
start_ts = time.monotonic()
|
| 1362 |
+
task_time_budget = min(MAX_TASK_SECONDS, TASK_TIME_BUDGET_SECONDS.get(task_name, MAX_TASK_SECONDS))
|
| 1363 |
|
| 1364 |
+
for _ in range(hard_budget):
|
| 1365 |
if obs.done:
|
| 1366 |
break
|
| 1367 |
|
| 1368 |
+
elapsed = time.monotonic() - start_ts
|
| 1369 |
+
total_visible = max(1, len(obs.visible_packets))
|
| 1370 |
+
current_coverage = len(obs.flagged_packet_ids) / total_visible
|
| 1371 |
+
min_cov = TASK_COVERAGE_TARGETS.get(task_name, 0.25)
|
| 1372 |
+
ready_for_budget_submit = (
|
| 1373 |
+
obs.step_number >= soft_budget
|
| 1374 |
+
and should_submit_early(task_name, obs, agent_state)
|
| 1375 |
+
)
|
| 1376 |
+
forced_at_hard_cap = (
|
| 1377 |
+
obs.step_number >= max(1, hard_budget - 1)
|
| 1378 |
+
and (should_submit_early(task_name, obs, agent_state) or task_name != "easy")
|
| 1379 |
+
)
|
| 1380 |
+
nearing_time_limit = elapsed >= max(20.0, task_time_budget - 12.0)
|
| 1381 |
+
|
| 1382 |
error = None
|
| 1383 |
try:
|
| 1384 |
+
if forced_at_hard_cap or nearing_time_limit or ready_for_budget_submit:
|
| 1385 |
+
action = NetworkForensicsAction(
|
| 1386 |
+
action_type="submit_report",
|
| 1387 |
+
incident_summary=_build_report_summary(obs, agent_state),
|
| 1388 |
+
claimed_entry_point=agent_state.get("claimed_entry_point") or obs.claimed_entry_point,
|
| 1389 |
+
)
|
| 1390 |
+
else:
|
| 1391 |
+
action = choose_action(client, task_name, obs, agent_state)
|
| 1392 |
except Exception as exc:
|
| 1393 |
error = str(exc).replace("\n", " ")
|
| 1394 |
action = build_fallback_action(task_name, obs, agent_state)
|
| 1395 |
|
| 1396 |
+
try:
|
| 1397 |
+
step_result, new_env = step_env_with_retry(env, action, task_name, agent_state)
|
| 1398 |
+
if new_env is not None:
|
| 1399 |
+
env = new_env
|
| 1400 |
+
except Exception as exc:
|
| 1401 |
+
print(f"[WARN] step failure on task={task_name}: {exc}")
|
| 1402 |
+
break
|
| 1403 |
obs = step_result.observation
|
| 1404 |
sync_agent_state(obs, agent_state)
|
| 1405 |
step_reward = float(step_result.reward or 0.0)
|
server/gradio_ui.py
CHANGED
|
@@ -453,12 +453,6 @@ def create_demo() -> gr.Blocks:
|
|
| 453 |
|
| 454 |
with gr.Blocks(
|
| 455 |
title="NetForensics-RL · Analyst Console",
|
| 456 |
-
theme=gr.themes.Base(
|
| 457 |
-
primary_hue="blue",
|
| 458 |
-
neutral_hue="slate",
|
| 459 |
-
font=gr.themes.GoogleFont("Inter"),
|
| 460 |
-
),
|
| 461 |
-
css=css,
|
| 462 |
) as demo:
|
| 463 |
with gr.Column(elem_classes=["app-shell"]):
|
| 464 |
gr.HTML(f"<style>{css}</style>")
|
|
|
|
| 453 |
|
| 454 |
with gr.Blocks(
|
| 455 |
title="NetForensics-RL · Analyst Console",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 456 |
) as demo:
|
| 457 |
with gr.Column(elem_classes=["app-shell"]):
|
| 458 |
gr.HTML(f"<style>{css}</style>")
|