WHOAM-EYE commited on
Commit
c894ea4
·
verified ·
1 Parent(s): 3d0eba6

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. inference.py +588 -90
  2. server/gradio_ui.py +0 -6
inference.py CHANGED
@@ -4,6 +4,7 @@ import sys
4
  import asyncio
5
  import inspect
6
  import random
 
7
  from pathlib import Path
8
  from typing import Any
9
 
@@ -37,28 +38,46 @@ _ASYNC_LOOP: asyncio.AbstractEventLoop | None = None
37
  SYSTEM_PROMPT = """You are a senior Network Forensics Analyst. Your goal is to investigate malicious network traffic and achieve a 100% detection score.
38
 
39
  ### SCORING RULES:
40
- - You MUST identify and `flag_as_suspicious` every malicious packet to increase RECALL.
41
  - Only grouped packets or flagged packets contribute towards your score.
42
- - If RECALL is < 0.5, your score will be 0.0. DO NOT stop until you have grouped at least 50% of the traffic.
 
 
43
 
44
  ### WORKFLOW:
45
  1. **Explore**: `inspect_packet` on suspicious samples.
46
- 2. **Correlate**: `group_into_session` with descriptive names.
47
- 3. **Classify**: `tag_pattern` with a valid type (ddos, web_sql_injection, heartbleed, etc.).
48
- 4. **Report**: `submit_report` ONLY when you have covered all visible malicious sessions.
 
 
 
 
 
49
 
50
  ### JSON SCHEMA EXAMPLES (Use these exactly):
51
  - Inspect: {"action_type":"inspect_packet","packet_id":"pkt_0001"}
52
  - Flag: {"action_type":"flag_as_suspicious","packet_id":"pkt_0001"}
53
  - Group: {"action_type":"group_into_session","session_name":"DDoS_Burst_2","packet_ids":["pkt_0001","pkt_0002"]}
54
  - Tag: {"action_type":"tag_pattern","session_name":"DDoS_Burst_2","pattern_type":"ddos"}
55
- - Report: {"action_type":"submit_report","incident_summary":"Brief summary here.","claimed_entry_point":"pkt_0001"}"""
 
56
 
57
  HISTORY_WINDOW = 20
58
  REPEAT_ACTION_LIMIT = 3
59
  CORRECTION_WINDOW = 5
60
- UNTAGGED_BACKLOG_LIMIT = 4
61
  INSPECT_SOFT_RATIO_THRESHOLD = 0.60
 
 
 
 
 
 
 
 
 
 
62
 
63
 
64
  def build_client() -> OpenAI:
@@ -93,7 +112,7 @@ def format_action(action: NetworkForensicsAction) -> str:
93
 
94
 
95
  def summarize_observation(obs: Any, agent_state: dict[str, Any]) -> str:
96
- """Provide a structured text summary for the LLM to learn from."""
97
  packets = obs.visible_packets
98
  revealed = [p for p in packets if p.is_revealed]
99
  revealed_ids = [p.packet_id for p in revealed]
@@ -104,14 +123,31 @@ def summarize_observation(obs: Any, agent_state: dict[str, Any]) -> str:
104
  reward_feedback = agent_state.get("last_reward_feedback", "n/a")
105
  recent_corrections = agent_state.get("recent_corrections", [])[-CORRECTION_WINDOW:]
106
  strategy_hints = agent_state.get("strategy_hints", [])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
  summary = [
109
  f"Step: {obs.step_number}/{obs.step_number + obs.steps_remaining}",
110
  f"Current Progress: {obs.current_score_estimate:.2f}",
111
- f"Recall Progress: {len(obs.flagged_packet_ids)} flagged / {len(obs.visible_packets)} visible",
 
 
112
  f"Last Step Reward: {last_reward:.2f}" if isinstance(last_reward, (int, float)) else "Last Step Reward: n/a",
113
  f"Last Reward Feedback: {reward_feedback}",
114
- f"ALREADY REVEALED: {', '.join(revealed_ids[-10:])} " + ("..." if len(revealed_ids) > 10 else ""),
115
  "\n### SESSIONS PENDING TAGGING:",
116
  ]
117
 
@@ -126,13 +162,13 @@ def summarize_observation(obs: Any, agent_state: dict[str, Any]) -> str:
126
  summary.append(f"- {hint}")
127
 
128
  if untagged_sessions:
129
- for s in untagged_sessions:
130
  summary.append(f"- {s} ({len(sessions[s])} packets)")
131
  else:
132
  summary.append("- [No pending sessions]")
133
 
134
  summary.append("\n### REVEALED INDICATORS:")
135
- for p in revealed[-8:]: # Show last 8 revealed for context
136
  payload = (p.full_payload or "")[:150]
137
  if payload:
138
  summary.append(f"- {p.packet_id}: {payload}")
@@ -207,30 +243,72 @@ def packet_payload_text(packet: Any) -> str:
207
 
208
  def keyword_to_pattern(payload: str) -> str | None:
209
  text = payload.lower()
 
210
  if "slowloris" in text:
211
  return "dos_slowloris"
212
- if "slowhttptest" in text:
213
  return "dos_slowhttptest"
214
- if "goldeneye" in text:
215
  return "dos_goldeneye"
216
  if "hulk" in text:
217
  return "dos_hulk"
218
- if "heartbeat" in text or "tls" in text:
219
  return "heartbleed"
220
- if "xss" in text or "<script>" in text or "<scrip" in text or "/search?q=" in text:
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  return "web_xss"
222
  if (
223
  "or 1=1" in text
224
  or "%20or" in text
225
  or "/items?id=" in text
226
  or "1=1" in text
 
 
 
 
 
 
227
  or "sql" in text
 
 
 
228
  ):
229
  return "web_sql_injection"
230
- if "login" in text or "username=admin" in text:
 
 
 
 
 
 
 
 
 
 
 
 
231
  return "web_bruteforce"
232
- if "flood" in text or "burst" in text:
233
- return "ddos"
 
 
 
 
 
 
 
234
  return None
235
 
236
 
@@ -245,9 +323,40 @@ def packet_signature(packet: Any, pattern: str) -> tuple[str, str, int, str]:
245
  return (packet.src_ip, packet.dst_ip, packet.dst_port, pattern)
246
 
247
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
  def session_candidates(obs: Any) -> list[tuple[tuple[str, str, int, str], list[Any]]]:
249
  grouped: dict[tuple[str, str, int, str], list[Any]] = {}
250
  attack_source_ports: dict[tuple[str, str, int, str], set[int]] = {}
 
251
  for packet in obs.visible_packets:
252
  pattern = keyword_to_pattern(packet_payload_text(packet))
253
  if pattern:
@@ -255,6 +364,7 @@ def session_candidates(obs: Any) -> list[tuple[tuple[str, str, int, str], list[A
255
  grouped.setdefault(key, []).append(packet)
256
  attack_source_ports.setdefault(key, set()).add(packet.src_port)
257
 
 
258
  for key, source_ports in attack_source_ports.items():
259
  src_ip, dst_ip, dst_port, _pattern = key
260
  for packet in obs.visible_packets:
@@ -267,6 +377,30 @@ def session_candidates(obs: Any) -> list[tuple[tuple[str, str, int, str], list[A
267
  if is_reverse_response:
268
  grouped[key].append(packet)
269
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
  candidates = [
271
  (
272
  key,
@@ -287,8 +421,19 @@ def required_tag_count(task_name: str, total_sessions: int) -> int:
287
  return 0
288
 
289
 
290
- def select_inspect_packet(obs: Any, inspected_ids: set[str]) -> str | None:
291
- unrevealed = [p for p in obs.visible_packets if not p.is_revealed]
 
 
 
 
 
 
 
 
 
 
 
292
  if not unrevealed:
293
  return None
294
 
@@ -314,6 +459,9 @@ def select_inspect_packet(obs: Any, inspected_ids: set[str]) -> str | None:
314
  def append_action_history(agent_state: dict[str, Any], action: NetworkForensicsAction) -> None:
315
  history = agent_state.setdefault("previous_actions", [])
316
  history.append(format_action(action))
 
 
 
317
  if len(history) > HISTORY_WINDOW:
318
  del history[:-HISTORY_WINDOW]
319
 
@@ -355,25 +503,29 @@ def group_meets_evidence_gate(
355
  candidate_packets, flagged_ids, visible_by_id
356
  )
357
  size = len(candidate_packets)
 
358
  if task_name == "easy":
359
  min_flagged = 1 if size >= 2 else 0
360
  elif task_name == "medium":
361
- min_flagged = 1 if size >= 3 else 0
362
  else:
363
- min_flagged = 2 if size >= 4 else 1
364
- if trusted_pattern and size >= 4:
365
  min_flagged = 1
366
  if flagged >= min_flagged:
367
  return True
368
  # Allow grouping with strong revealed malicious evidence.
369
- if malicious_revealed >= min_flagged and revealed >= min(3, size):
370
  return True
371
- # After a pattern has been confirmed by tagging, allow structure-first grouping.
372
- if trusted_pattern and size >= 5:
373
  return True
374
- if task_name == "easy" and malicious_revealed >= 1:
 
 
 
375
  return True
376
- if task_name == "medium" and malicious_revealed >= 1 and revealed >= 2:
 
377
  return True
378
  return False
379
 
@@ -415,10 +567,10 @@ def derive_strategy_hints(obs: Any, agent_state: dict[str, Any]) -> list[str]:
415
  )
416
 
417
  inspect_limit = {
418
- "easy": 2,
419
- "medium": 4,
420
- "hard": 6,
421
- }.get(agent_state.get("current_task_name", ""), 8)
422
  if len(previous_actions) >= inspect_limit and inspect_ratio >= INSPECT_SOFT_RATIO_THRESHOLD:
423
  hints.append(
424
  "You are over-inspecting. Shift to flagging, grouping, tagging, or report submission unless the next packet is clearly high-value."
@@ -426,10 +578,43 @@ def derive_strategy_hints(obs: Any, agent_state: dict[str, Any]) -> list[str]:
426
  return hints
427
 
428
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
429
  def build_fallback_action(
430
  task_name: str, obs: Any, agent_state: dict[str, Any]
431
  ) -> NetworkForensicsAction:
432
- """Smart workflow engine: Inspect -> Flag -> Group -> Tag -> Report."""
433
  inspected_ids = agent_state.setdefault("inspected_ids", set())
434
  flagged_ids = agent_state.setdefault("flagged_ids", set())
435
  session_map = agent_state.setdefault("sessions", {}) # key -> session_name
@@ -438,7 +623,7 @@ def build_fallback_action(
438
  visible_by_id = {p.packet_id: p for p in obs.visible_packets}
439
  trusted = trusted_patterns(session_map, tagged_sessions)
440
 
441
- if obs.steps_remaining <= 1:
442
  summary = _build_report_summary(obs, agent_state)
443
  return NetworkForensicsAction(
444
  action_type="submit_report",
@@ -446,21 +631,32 @@ def build_fallback_action(
446
  claimed_entry_point=claimed_entry,
447
  )
448
 
449
- # PHASE 1: Flag revealed malicious packets
 
 
450
  for packet in obs.visible_packets:
451
  if packet.is_revealed and packet.packet_id not in flagged_ids:
452
  payload = packet.full_payload or ""
453
  pattern = keyword_to_pattern(payload)
454
  if pattern:
455
- flagged_ids.add(packet.packet_id)
 
 
 
 
 
 
 
 
456
  return NetworkForensicsAction(
457
  action_type="flag_as_suspicious",
458
- packet_id=packet.packet_id,
459
  )
460
 
461
  # PHASE 2: Group flagged packets into sessions with evidence gate and backlog pacing.
 
462
  untagged_backlog = max(0, len(session_map) - len(tagged_sessions))
463
- if untagged_backlog <= UNTAGGED_BACKLOG_LIMIT:
464
  candidates = session_candidates(obs)
465
  for key, items in candidates:
466
  if key in session_map:
@@ -482,14 +678,54 @@ def build_fallback_action(
482
  packet_ids=packet_ids,
483
  )
484
 
485
- # PHASE 3: Tag ungrouped sessions.
486
- # Easy mode prioritizes coverage/recall and skips tagging to spend turns on recovery.
487
- allow_tagging = task_name != "easy"
488
- if allow_tagging:
489
- for key, session_name in session_map.items():
490
- if session_name in tagged_sessions:
491
- continue
492
- _src_ip, _dst_ip, _dst_port, pattern = key
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
493
  tagged_sessions.add(session_name)
494
  return NetworkForensicsAction(
495
  action_type="tag_pattern",
@@ -497,19 +733,52 @@ def build_fallback_action(
497
  pattern_type=pattern,
498
  )
499
 
500
- # PHASE 4: Identify entry point only when confidence is higher or near episode end.
501
- if not claimed_entry and flagged_ids and (
502
- len(tagged_sessions) >= 3 or obs.steps_remaining <= 8
503
- ):
504
- earliest = min(flagged_ids, key=lambda pid: packet_sort_key(pid))
505
- agent_state["claimed_entry_point"] = earliest
506
- return NetworkForensicsAction(
507
- action_type="identify_entry_point",
508
- claimed_entry_point=earliest,
509
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
510
 
511
- # PHASE 5: Inspect more unrevealed packets
512
- inspect_id = select_inspect_packet(obs, inspected_ids)
513
  if inspect_id is not None:
514
  return NetworkForensicsAction(action_type="inspect_packet", packet_id=inspect_id)
515
 
@@ -523,19 +792,50 @@ def build_fallback_action(
523
 
524
 
525
  def _build_report_summary(obs: Any, agent_state: dict[str, Any]) -> str:
526
- """Generate a meaningful incident summary for the report."""
527
  flagged = agent_state.get("flagged_ids", set())
528
  sessions = agent_state.get("sessions", {})
529
  tagged = agent_state.get("tagged_sessions", set())
530
- patterns = set()
531
- for key in sessions:
 
 
 
532
  if len(key) >= 4:
533
- patterns.add(key[3])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
534
  return (
535
- f"Incident report: Detected {len(flagged)} malicious packets across "
536
- f"{len(sessions)} attack sessions. Attack patterns observed: "
537
- f"{', '.join(patterns) if patterns else 'unknown'}. "
538
- f"{len(tagged)} sessions were classified."
 
 
 
 
539
  )
540
 
541
 
@@ -556,10 +856,10 @@ def should_override_action(
556
  inspect_count = sum(1 for a in previous_actions if '"inspect_packet"' in a)
557
  revealed_count = sum(1 for p in obs.visible_packets if p.is_revealed)
558
  inspect_limit = {
559
- "easy": 2,
560
- "medium": 4,
561
- "hard": 6,
562
- }.get(task_name, 8)
563
 
564
  if action.action_type not in {
565
  "inspect_packet",
@@ -580,13 +880,52 @@ def should_override_action(
580
  return "Missing packet_id for inspect_packet"
581
  if action.packet_id not in {p.packet_id for p in obs.visible_packets}:
582
  return f"Invalid packet_id {action.packet_id} - not in visible_packets"
 
 
 
583
  revealed_ids = {p.packet_id for p in obs.visible_packets if p.is_revealed}
584
  if action.packet_id in revealed_ids:
585
  return f"Packet {action.packet_id} is ALREADY revealed. Choose a HIDDEN packet."
586
- if inspect_count >= inspect_limit and (len(sessions) > 0 or len(flagged_ids) > 0 or revealed_count >= 4):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
587
  return (
588
- f"Inspection budget reached for {task_name}. Shift to flagging, grouping, tagging, or report submission."
589
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
590
 
591
  if action.action_type == "flag_as_suspicious":
592
  if not action.packet_id:
@@ -606,6 +945,22 @@ def should_override_action(
606
  }
607
  if invalid_ids:
608
  return f"Invalid packet_ids in session: {invalid_ids}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
609
  untagged_backlog = max(0, len(sessions) - len(tagged_sessions))
610
  if untagged_backlog > UNTAGGED_BACKLOG_LIMIT:
611
  return (
@@ -631,18 +986,61 @@ def should_override_action(
631
 
632
  if action.action_type == "submit_report":
633
  untagged_backlog = max(0, len(sessions) - len(tagged_sessions))
634
- if obs.steps_remaining > 2 and obs.current_score_estimate < 0.60:
 
 
 
 
 
 
 
 
 
 
635
  return (
636
  "Premature report submission. Improve coverage and score estimate before submit_report."
637
  )
638
- if task_name != "easy" and obs.steps_remaining > 2 and untagged_backlog > 0:
 
 
 
 
 
 
 
 
 
 
639
  return "Premature report submission. Tag pending sessions before submitting report."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
640
 
641
  if action.action_type == "tag_pattern":
642
  if not action.session_name:
643
  return "Missing session_name for tag_pattern"
644
  if not action.pattern_type:
645
  return "Missing pattern_type for tag_pattern"
 
 
 
 
646
  valid_patterns = {
647
  "ddos", "dos_slowloris", "dos_slowhttptest", "dos_goldeneye", "dos_hulk",
648
  "heartbleed", "web_sql_injection", "web_xss", "web_bruteforce",
@@ -654,7 +1052,9 @@ def should_override_action(
654
  if action.action_type == "identify_entry_point":
655
  if not action.claimed_entry_point:
656
  return "Missing claimed_entry_point for identify_entry_point"
657
- if obs.steps_remaining > 8 and len(flagged_ids) < 3:
 
 
658
  return (
659
  "Premature entry-point claim. Gather and flag more evidence before identify_entry_point."
660
  )
@@ -671,6 +1071,14 @@ def choose_action(
671
  ) -> NetworkForensicsAction:
672
  agent_state["current_task_name"] = task_name
673
  agent_state["strategy_hints"] = derive_strategy_hints(obs, agent_state)
 
 
 
 
 
 
 
 
674
  history = agent_state.get("previous_actions", [])[-HISTORY_WINDOW:]
675
  history_str = "\n".join([f"Step {i+1}: {a}" for i, a in enumerate(history)])
676
 
@@ -685,17 +1093,24 @@ def choose_action(
685
  "Follow the JSON schema in the system prompt."
686
  )
687
 
688
- response = client.chat.completions.create(
689
- model=model_name or MODEL_NAME,
690
- temperature=0.1,
691
- messages=[
692
- {"role": "system", "content": SYSTEM_PROMPT},
693
- {
694
- "role": "user",
695
- "content": f"TASK: {task_name}{correction_text}\n\n### RECENT HISTORY:\n{history_str}\n\n### CURRENT OBSERVATION:\n{summarize_observation(obs, agent_state)}",
696
- },
697
- ],
698
- )
 
 
 
 
 
 
 
699
  content = response.choices[0].message.content or ""
700
  try:
701
  action = sanitize_action(parse_action(content))
@@ -857,6 +1272,58 @@ def step_env(env: NetworkForensicsEnv, action: NetworkForensicsAction) -> Any:
857
  return result
858
 
859
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
860
  def close_env(env: NetworkForensicsEnv | None) -> None:
861
  if env is None:
862
  return
@@ -889,19 +1356,50 @@ def run_task(task_name: str) -> None:
889
  obs = reset_result.observation
890
  sync_agent_state(obs, agent_state)
891
  max_steps = obs.steps_remaining or 50
 
 
 
 
892
 
893
- for _ in range(max_steps):
894
  if obs.done:
895
  break
896
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
897
  error = None
898
  try:
899
- action = choose_action(client, task_name, obs, agent_state)
 
 
 
 
 
 
 
900
  except Exception as exc:
901
  error = str(exc).replace("\n", " ")
902
  action = build_fallback_action(task_name, obs, agent_state)
903
 
904
- step_result = step_env(env, action)
 
 
 
 
 
 
905
  obs = step_result.observation
906
  sync_agent_state(obs, agent_state)
907
  step_reward = float(step_result.reward or 0.0)
 
4
  import asyncio
5
  import inspect
6
  import random
7
+ import time
8
  from pathlib import Path
9
  from typing import Any
10
 
 
38
  SYSTEM_PROMPT = """You are a senior Network Forensics Analyst. Your goal is to investigate malicious network traffic and achieve a 100% detection score.
39
 
40
  ### SCORING RULES:
41
+ - You MUST identify and `flag_as_suspicious` EVERY malicious packet to maximize RECALL (very important!).
42
  - Only grouped packets or flagged packets contribute towards your score.
43
+ - If RECALL is < 0.5, your score will be 0.0. DO NOT stop until you have flagged/grouped at least 60% of visible malicious packets.
44
+ - Entry point must be the EARLIEST packet that initiated the attack (often in first group).
45
+ - For HARD tasks: wrong entry point = score 0. Always identify_entry_point before submitting.
46
 
47
  ### WORKFLOW:
48
  1. **Explore**: `inspect_packet` on suspicious samples.
49
+ 2. **Flag**: `flag_as_suspicious` on ALL revealed malicious packets.
50
+ 3. **Correlate**: `group_into_session` with descriptive names.
51
+ 4. **Classify**: `tag_pattern` with a valid type.
52
+ 5. **Root Cause**: `identify_entry_point` with the earliest malicious packet.
53
+ 6. **Report**: `submit_report` ONLY when you have covered all visible malicious sessions.
54
+
55
+ ### VALID PATTERN TYPES:
56
+ ddos, dos_slowloris, dos_slowhttptest, dos_goldeneye, dos_hulk, heartbleed, web_sql_injection, web_xss, web_bruteforce, c2, exfiltration, scan, lateral
57
 
58
  ### JSON SCHEMA EXAMPLES (Use these exactly):
59
  - Inspect: {"action_type":"inspect_packet","packet_id":"pkt_0001"}
60
  - Flag: {"action_type":"flag_as_suspicious","packet_id":"pkt_0001"}
61
  - Group: {"action_type":"group_into_session","session_name":"DDoS_Burst_2","packet_ids":["pkt_0001","pkt_0002"]}
62
  - Tag: {"action_type":"tag_pattern","session_name":"DDoS_Burst_2","pattern_type":"ddos"}
63
+ - Entry: {"action_type":"identify_entry_point","claimed_entry_point":"pkt_0001"}
64
+ - Report: {"action_type":"submit_report","incident_summary":"Detailed incident summary here.","claimed_entry_point":"pkt_0001"}"""
65
 
66
  HISTORY_WINDOW = 20
67
  REPEAT_ACTION_LIMIT = 3
68
  CORRECTION_WINDOW = 5
69
+ UNTAGGED_BACKLOG_LIMIT = 6
70
  INSPECT_SOFT_RATIO_THRESHOLD = 0.60
71
+ SOFT_STEP_BUDGETS = {"easy": 14, "medium": 28, "hard": 40}
72
+ HARD_STEP_CAPS = {"easy": 30, "medium": 50, "hard": 65}
73
+ TASK_SCORE_TARGETS = {"easy": 0.70, "medium": 0.68, "hard": 0.66}
74
+ TASK_COVERAGE_TARGETS = {"easy": 0.32, "medium": 0.24, "hard": 0.20}
75
+ MAX_TASK_SECONDS = float(os.getenv("MAX_TASK_SECONDS", "780"))
76
+ TASK_TIME_BUDGET_SECONDS = {
77
+ "easy": float(os.getenv("EASY_MAX_SECONDS", "150")),
78
+ "medium": float(os.getenv("MEDIUM_MAX_SECONDS", "220")),
79
+ "hard": float(os.getenv("HARD_MAX_SECONDS", "320")),
80
+ }
81
 
82
 
83
  def build_client() -> OpenAI:
 
112
 
113
 
114
  def summarize_observation(obs: Any, agent_state: dict[str, Any]) -> str:
115
+ """Provide a compact structured summary for low-latency policy learning."""
116
  packets = obs.visible_packets
117
  revealed = [p for p in packets if p.is_revealed]
118
  revealed_ids = [p.packet_id for p in revealed]
 
123
  reward_feedback = agent_state.get("last_reward_feedback", "n/a")
124
  recent_corrections = agent_state.get("recent_corrections", [])[-CORRECTION_WINDOW:]
125
  strategy_hints = agent_state.get("strategy_hints", [])
126
+ task_name = agent_state.get("current_task_name", "")
127
+
128
+ flagged_count = len(obs.flagged_packet_ids)
129
+ total_visible = max(1, len(obs.visible_packets))
130
+ coverage = flagged_count / total_visible
131
+ coverage_target = TASK_COVERAGE_TARGETS.get(task_name, 0.25)
132
+ score_target = TASK_SCORE_TARGETS.get(task_name, 0.65)
133
+ grouped_count = len(sessions)
134
+ tagged_count = len(tags)
135
+ ready_to_submit = (
136
+ obs.current_score_estimate >= score_target
137
+ and coverage >= coverage_target
138
+ and (task_name == "easy" or grouped_count >= 2)
139
+ and (task_name == "easy" or tagged_count >= 1)
140
+ )
141
 
142
  summary = [
143
  f"Step: {obs.step_number}/{obs.step_number + obs.steps_remaining}",
144
  f"Current Progress: {obs.current_score_estimate:.2f}",
145
+ f"Coverage: {flagged_count}/{total_visible} ({coverage:.2%}) | target {coverage_target:.0%}",
146
+ f"Sessions: grouped={grouped_count}, tagged={tagged_count}",
147
+ f"Submit Readiness: {'READY' if ready_to_submit else 'KEEP INVESTIGATING'}",
148
  f"Last Step Reward: {last_reward:.2f}" if isinstance(last_reward, (int, float)) else "Last Step Reward: n/a",
149
  f"Last Reward Feedback: {reward_feedback}",
150
+ f"ALREADY REVEALED: {', '.join(revealed_ids[-6:])} " + ("..." if len(revealed_ids) > 6 else ""),
151
  "\n### SESSIONS PENDING TAGGING:",
152
  ]
153
 
 
162
  summary.append(f"- {hint}")
163
 
164
  if untagged_sessions:
165
+ for s in untagged_sessions[:6]:
166
  summary.append(f"- {s} ({len(sessions[s])} packets)")
167
  else:
168
  summary.append("- [No pending sessions]")
169
 
170
  summary.append("\n### REVEALED INDICATORS:")
171
+ for p in revealed[-4:]:
172
  payload = (p.full_payload or "")[:150]
173
  if payload:
174
  summary.append(f"- {p.packet_id}: {payload}")
 
243
 
244
  def keyword_to_pattern(payload: str) -> str | None:
245
  text = payload.lower()
246
+ # --- DoS / DDoS variants ---
247
  if "slowloris" in text:
248
  return "dos_slowloris"
249
+ if "slowhttptest" in text or "slow http" in text:
250
  return "dos_slowhttptest"
251
+ if "goldeneye" in text or "golden eye" in text:
252
  return "dos_goldeneye"
253
  if "hulk" in text:
254
  return "dos_hulk"
255
+ if "heartbeat" in text or "heartbleed" in text or ("tls" in text and "ext" in text):
256
  return "heartbleed"
257
+ if "flood" in text or "burst" in text or "ddos" in text:
258
+ return "ddos"
259
+ # HTTP flood indicators (repeated GET/POST to same endpoint)
260
+ if text.startswith("get /") or text.startswith("post /") or text.startswith("get http"):
261
+ if "accept-encoding" in text or "connection" in text or "keep-alive" in text:
262
+ return "ddos"
263
+ # SYN flood / connection flood
264
+ if "syn" in text and "ack" not in text and len(text) < 30:
265
+ return "ddos"
266
+ # ICMP flood
267
+ if "icmp" in text and ("echo" in text or "request" in text or len(text) < 20):
268
+ return "ddos"
269
+ # --- Web attacks ---
270
+ if "xss" in text or "<script>" in text or "<scrip" in text or "/search?q=" in text or "onerror" in text or "onload" in text or "javascript:" in text or "alert(" in text or "%3cscript" in text:
271
  return "web_xss"
272
  if (
273
  "or 1=1" in text
274
  or "%20or" in text
275
  or "/items?id=" in text
276
  or "1=1" in text
277
+ or "' or " in text
278
+ or "'--" in text
279
+ or "union select" in text
280
+ or "union all select" in text
281
+ or "drop table" in text
282
+ or "select * from" in text
283
  or "sql" in text
284
+ or "%27" in text # URL-encoded single quote
285
+ or "' and " in text
286
+ or "admin'--" in text
287
  ):
288
  return "web_sql_injection"
289
+ if (
290
+ "login" in text
291
+ or "username=admin" in text
292
+ or "password=" in text
293
+ or "passwd=" in text
294
+ or "user=admin" in text
295
+ or "brute" in text
296
+ or "/login" in text
297
+ or "/signin" in text
298
+ or "/auth" in text
299
+ or "post /login" in text
300
+ or "post /sign" in text
301
+ ):
302
  return "web_bruteforce"
303
+ # --- C2 / exfil / scan / lateral ---
304
+ if "c2" in text or "command" in text or "shell" in text or "cmd" in text or "/bin/" in text or "reverse" in text:
305
+ return "c2"
306
+ if "exfil" in text or "exfiltrat" in text or "data_leak" in text or "dns_tunnel" in text:
307
+ return "exfiltration"
308
+ if "scan" in text or "nmap" in text or "port_scan" in text or "recon" in text:
309
+ return "scan"
310
+ if "lateral" in text or "pivot" in text or "spread" in text or "propagat" in text:
311
+ return "lateral"
312
  return None
313
 
314
 
 
323
  return (packet.src_ip, packet.dst_ip, packet.dst_port, pattern)
324
 
325
 
326
+ SUSPICIOUS_PORTS = {22, 23, 445, 1433, 3306, 5432, 4444, 5555, 6666, 6667, 7777, 8888, 9999, 31337}
327
+ SUSPICIOUS_PROTOCOLS = {"ICMP"}
328
+
329
+
330
+ def _infer_flow_pattern(packet: Any, flow_size: int) -> str | None:
331
+ """Heuristic pattern inference from flow characteristics when keyword matching fails."""
332
+ dst_port = packet.dst_port
333
+ protocol = packet.protocol
334
+ flags = getattr(packet, "flags", []) or []
335
+ # High-density flows to web ports → likely DDoS
336
+ if flow_size >= 5 and dst_port in {80, 8080, 443, 8443}:
337
+ return "ddos"
338
+ # SYN-only flood
339
+ if flow_size >= 5 and flags == ["SYN"]:
340
+ return "ddos"
341
+ # Suspicious ports → C2 or lateral
342
+ if dst_port in SUSPICIOUS_PORTS:
343
+ if dst_port in {4444, 5555, 6666, 7777, 31337}:
344
+ return "c2"
345
+ if dst_port in {445, 1433, 3306, 5432}:
346
+ return "lateral"
347
+ # ICMP flood
348
+ if protocol in SUSPICIOUS_PROTOCOLS and flow_size >= 3:
349
+ return "ddos"
350
+ # High-density flow to non-standard port
351
+ if flow_size >= 8 and dst_port not in {53, 80, 443, 8080}:
352
+ return "scan"
353
+ return None
354
+
355
+
356
  def session_candidates(obs: Any) -> list[tuple[tuple[str, str, int, str], list[Any]]]:
357
  grouped: dict[tuple[str, str, int, str], list[Any]] = {}
358
  attack_source_ports: dict[tuple[str, str, int, str], set[int]] = {}
359
+ # Phase 1: keyword-based grouping (high confidence)
360
  for packet in obs.visible_packets:
361
  pattern = keyword_to_pattern(packet_payload_text(packet))
362
  if pattern:
 
364
  grouped.setdefault(key, []).append(packet)
365
  attack_source_ports.setdefault(key, set()).add(packet.src_port)
366
 
367
+ # Add reverse-response packets to keyword-matched sessions
368
  for key, source_ports in attack_source_ports.items():
369
  src_ip, dst_ip, dst_port, _pattern = key
370
  for packet in obs.visible_packets:
 
377
  if is_reverse_response:
378
  grouped[key].append(packet)
379
 
380
+ # Phase 2: flow-based grouping for packets without keyword match
381
+ # Group unclaimed packets by (src_ip, dst_ip, dst_port) and infer pattern
382
+ claimed_ids: set[str] = set()
383
+ for items in grouped.values():
384
+ for p in items:
385
+ claimed_ids.add(p.packet_id)
386
+
387
+ flow_buckets: dict[tuple[str, str, int], list[Any]] = {}
388
+ for packet in obs.visible_packets:
389
+ if packet.packet_id in claimed_ids:
390
+ continue
391
+ flow_key = (packet.src_ip, packet.dst_ip, packet.dst_port)
392
+ flow_buckets.setdefault(flow_key, []).append(packet)
393
+
394
+ for flow_key, items in flow_buckets.items():
395
+ if len(items) < 2:
396
+ continue
397
+ pattern = _infer_flow_pattern(items[0], len(items))
398
+ if pattern:
399
+ session_key = (*flow_key, pattern)
400
+ grouped.setdefault(session_key, []).extend(items)
401
+ for p in items:
402
+ claimed_ids.add(p.packet_id)
403
+
404
  candidates = [
405
  (
406
  key,
 
421
  return 0
422
 
423
 
424
+ def select_inspect_packet(
425
+ obs: Any,
426
+ inspected_ids: set[str],
427
+ flagged_ids: set[str] | None = None,
428
+ ) -> str | None:
429
+ flagged_ids = flagged_ids or set()
430
+ unrevealed = [
431
+ p
432
+ for p in obs.visible_packets
433
+ if (not p.is_revealed)
434
+ and (p.packet_id not in inspected_ids)
435
+ and (p.packet_id not in flagged_ids)
436
+ ]
437
  if not unrevealed:
438
  return None
439
 
 
459
  def append_action_history(agent_state: dict[str, Any], action: NetworkForensicsAction) -> None:
460
  history = agent_state.setdefault("previous_actions", [])
461
  history.append(format_action(action))
462
+ if action.action_type == "inspect_packet" and action.packet_id:
463
+ inspected_ids = agent_state.setdefault("inspected_ids", set())
464
+ inspected_ids.add(action.packet_id)
465
  if len(history) > HISTORY_WINDOW:
466
  del history[:-HISTORY_WINDOW]
467
 
 
503
  candidate_packets, flagged_ids, visible_by_id
504
  )
505
  size = len(candidate_packets)
506
+ # Lowered thresholds for more aggressive grouping
507
  if task_name == "easy":
508
  min_flagged = 1 if size >= 2 else 0
509
  elif task_name == "medium":
510
+ min_flagged = 1 if size >= 2 else 0
511
  else:
512
+ min_flagged = 1 if size >= 3 else 0
513
+ if trusted_pattern and size >= 3:
514
  min_flagged = 1
515
  if flagged >= min_flagged:
516
  return True
517
  # Allow grouping with strong revealed malicious evidence.
518
+ if task_name == "easy" and (malicious_revealed >= 1 or revealed >= 1):
519
  return True
520
+ if task_name == "medium" and malicious_revealed >= 1 and revealed >= 1:
 
521
  return True
522
+ if malicious_revealed >= 1 and revealed >= min(2, size):
523
+ return True
524
+ # After a pattern has been confirmed by tagging, allow structure-first grouping.
525
+ if trusted_pattern and size >= 3:
526
  return True
527
+ # Large flows are very likely attack sessions - allow with minimal evidence
528
+ if size >= 6 and (flagged >= 1 or revealed >= 2 or malicious_revealed >= 1):
529
  return True
530
  return False
531
 
 
567
  )
568
 
569
  inspect_limit = {
570
+ "easy": 18,
571
+ "medium": 20,
572
+ "hard": 25,
573
+ }.get(agent_state.get("current_task_name", ""), 15)
574
  if len(previous_actions) >= inspect_limit and inspect_ratio >= INSPECT_SOFT_RATIO_THRESHOLD:
575
  hints.append(
576
  "You are over-inspecting. Shift to flagging, grouping, tagging, or report submission unless the next packet is clearly high-value."
 
578
  return hints
579
 
580
 
581
+ def should_submit_early(task_name: str, obs: Any, agent_state: dict[str, Any]) -> bool:
582
+ flagged_count = len(obs.flagged_packet_ids)
583
+ total_visible = max(1, len(obs.visible_packets))
584
+ coverage = flagged_count / total_visible
585
+ score = float(obs.current_score_estimate)
586
+ sessions = obs.grouped_sessions or {}
587
+ tags = obs.tagged_patterns or {}
588
+
589
+ score_target = TASK_SCORE_TARGETS.get(task_name, 0.65)
590
+ coverage_target = TASK_COVERAGE_TARGETS.get(task_name, 0.25)
591
+
592
+ if task_name == "easy":
593
+ return (
594
+ coverage >= max(coverage_target * 0.7, 0.20)
595
+ and flagged_count >= 6
596
+ and len(sessions) >= 1
597
+ )
598
+ if task_name == "medium":
599
+ return (
600
+ score >= score_target * 0.8
601
+ and coverage >= coverage_target * 0.7
602
+ and len(sessions) >= 1
603
+ and len(tags) >= 1
604
+ )
605
+ return (
606
+ score >= score_target * 0.8
607
+ and coverage >= coverage_target * 0.7
608
+ and len(sessions) >= 2
609
+ and len(tags) >= 1
610
+ and bool(agent_state.get("claimed_entry_point") or obs.claimed_entry_point)
611
+ )
612
+
613
+
614
  def build_fallback_action(
615
  task_name: str, obs: Any, agent_state: dict[str, Any]
616
  ) -> NetworkForensicsAction:
617
+ """Smart workflow engine: Flag aggressive -> Group -> Tag -> Entry Point -> Report."""
618
  inspected_ids = agent_state.setdefault("inspected_ids", set())
619
  flagged_ids = agent_state.setdefault("flagged_ids", set())
620
  session_map = agent_state.setdefault("sessions", {}) # key -> session_name
 
623
  visible_by_id = {p.packet_id: p for p in obs.visible_packets}
624
  trusted = trusted_patterns(session_map, tagged_sessions)
625
 
626
+ if obs.steps_remaining <= 1 or should_submit_early(task_name, obs, agent_state):
627
  summary = _build_report_summary(obs, agent_state)
628
  return NetworkForensicsAction(
629
  action_type="submit_report",
 
631
  claimed_entry_point=claimed_entry,
632
  )
633
 
634
+ # PHASE 1: Aggressive flag of ALL revealed malicious packets
635
+ # This maximizes recall by comprehensively flagging known-bad traffic
636
+ unflagged_malicious = []
637
  for packet in obs.visible_packets:
638
  if packet.is_revealed and packet.packet_id not in flagged_ids:
639
  payload = packet.full_payload or ""
640
  pattern = keyword_to_pattern(payload)
641
  if pattern:
642
+ unflagged_malicious.append(packet.packet_id)
643
+
644
+ if unflagged_malicious:
645
+ # Flag up to 5 per turn for aggressive recall buildup
646
+ target = min(5, len(unflagged_malicious))
647
+ for _ in range(target):
648
+ if unflagged_malicious:
649
+ pid = unflagged_malicious.pop(0)
650
+ flagged_ids.add(pid)
651
  return NetworkForensicsAction(
652
  action_type="flag_as_suspicious",
653
+ packet_id=pid,
654
  )
655
 
656
  # PHASE 2: Group flagged packets into sessions with evidence gate and backlog pacing.
657
+ min_flagged_before_group = 1 if task_name == "easy" else 2
658
  untagged_backlog = max(0, len(session_map) - len(tagged_sessions))
659
+ if len(flagged_ids) >= min_flagged_before_group and untagged_backlog <= UNTAGGED_BACKLOG_LIMIT:
660
  candidates = session_candidates(obs)
661
  for key, items in candidates:
662
  if key in session_map:
 
678
  packet_ids=packet_ids,
679
  )
680
 
681
+ # PHASE 2.5: Recall sweep - flag packets that are already part of grouped sessions.
682
+ # This boosts recall quickly without requiring more inspections.
683
+ grouped_packets = []
684
+ for packet_ids in (obs.grouped_sessions or {}).values():
685
+ grouped_packets.extend(packet_ids)
686
+ for pid in sorted(set(grouped_packets), key=packet_sort_key):
687
+ if pid in flagged_ids:
688
+ continue
689
+ if pid in visible_by_id:
690
+ flagged_ids.add(pid)
691
+ return NetworkForensicsAction(
692
+ action_type="flag_as_suspicious",
693
+ packet_id=pid,
694
+ )
695
+
696
+ # PHASE 3: Tag ALL untagged sessions aggressively (critical for medium/hard logic_score).
697
+ # Tagging helps LLM report score and logic_score for all difficulties.
698
+ for key, session_name in session_map.items():
699
+ if session_name in tagged_sessions:
700
+ continue
701
+ _src_ip, _dst_ip, _dst_port, pattern = key
702
+ tagged_sessions.add(session_name)
703
+ return NetworkForensicsAction(
704
+ action_type="tag_pattern",
705
+ session_name=session_name,
706
+ pattern_type=pattern,
707
+ )
708
+ # Also tag any observed sessions not yet in our session_map
709
+ for session_name, session_data in (obs.grouped_sessions or {}).items():
710
+ if session_name in tagged_sessions:
711
+ continue
712
+ if session_name in (obs.tagged_patterns or {}):
713
+ tagged_sessions.add(session_name)
714
+ continue
715
+ # Infer pattern from session packets
716
+ pattern = None
717
+ for pid in session_data:
718
+ pkt = visible_by_id.get(pid)
719
+ if pkt and pkt.is_revealed:
720
+ pattern = keyword_to_pattern(packet_payload_text(pkt))
721
+ if pattern:
722
+ break
723
+ if not pattern:
724
+ # Try flow-based inference
725
+ pkt = visible_by_id.get(session_data[0]) if session_data else None
726
+ if pkt:
727
+ pattern = _infer_flow_pattern(pkt, len(session_data))
728
+ if pattern:
729
  tagged_sessions.add(session_name)
730
  return NetworkForensicsAction(
731
  action_type="tag_pattern",
 
733
  pattern_type=pattern,
734
  )
735
 
736
+ # PHASE 4: Identify entry point - CRITICAL for hard mode (score=0 without it)
737
+ if not claimed_entry:
738
+ entry_candidate = None
739
+ # Strategy 1: earliest packet in any grouped session from observation
740
+ try:
741
+ grouped_packets = set()
742
+ for session_name in session_map.values():
743
+ if obs.grouped_sessions and session_name in obs.grouped_sessions:
744
+ grouped_packets.update(obs.grouped_sessions[session_name])
745
+ if grouped_packets:
746
+ entry_candidate = min(grouped_packets, key=lambda pid: packet_sort_key(pid))
747
+ except Exception:
748
+ pass
749
+ # Strategy 2: earliest flagged packet (often the first discovered attack)
750
+ if not entry_candidate and flagged_ids:
751
+ entry_candidate = min(flagged_ids, key=lambda pid: packet_sort_key(pid))
752
+ # Strategy 3: earliest revealed malicious packet
753
+ if not entry_candidate:
754
+ revealed_malicious = [
755
+ p for p in obs.visible_packets
756
+ if p.is_revealed and keyword_to_pattern(packet_payload_text(p))
757
+ ]
758
+ if revealed_malicious:
759
+ entry_candidate = min(
760
+ revealed_malicious, key=lambda p: packet_sort_key(p.packet_id)
761
+ ).packet_id
762
+ # Strategy 4: earliest packet in session_candidates
763
+ if not entry_candidate:
764
+ all_session_packets = []
765
+ for key, items in session_candidates(obs):
766
+ for p in items:
767
+ all_session_packets.append(p.packet_id)
768
+ if all_session_packets:
769
+ entry_candidate = min(all_session_packets, key=packet_sort_key)
770
+ # Strategy 5: earliest flagged packet from observation
771
+ if not entry_candidate and obs.flagged_packet_ids:
772
+ entry_candidate = min(obs.flagged_packet_ids, key=packet_sort_key)
773
+ if entry_candidate:
774
+ agent_state["claimed_entry_point"] = entry_candidate
775
+ return NetworkForensicsAction(
776
+ action_type="identify_entry_point",
777
+ claimed_entry_point=entry_candidate,
778
+ )
779
 
780
+ # PHASE 5: Inspect more unrevealed packets (to discover more malicious traffic)
781
+ inspect_id = select_inspect_packet(obs, inspected_ids, flagged_ids)
782
  if inspect_id is not None:
783
  return NetworkForensicsAction(action_type="inspect_packet", packet_id=inspect_id)
784
 
 
792
 
793
 
794
  def _build_report_summary(obs: Any, agent_state: dict[str, Any]) -> str:
795
+ """Generate a detailed incident summary for high LLM judge scores."""
796
  flagged = agent_state.get("flagged_ids", set())
797
  sessions = agent_state.get("sessions", {})
798
  tagged = agent_state.get("tagged_sessions", set())
799
+ entry_point = agent_state.get("claimed_entry_point") or getattr(obs, "claimed_entry_point", None)
800
+ patterns_by_session: dict[str, str] = {}
801
+ src_ips_by_pattern: dict[str, set[str]] = {}
802
+ dst_ips_by_pattern: dict[str, set[str]] = {}
803
+ for key, session_name in sessions.items():
804
  if len(key) >= 4:
805
+ pattern = key[3]
806
+ patterns_by_session[session_name] = pattern
807
+ src_ips_by_pattern.setdefault(pattern, set()).add(key[0])
808
+ dst_ips_by_pattern.setdefault(pattern, set()).add(key[1])
809
+
810
+ # Build detailed per-pattern section
811
+ pattern_details = []
812
+ for pattern in sorted(set(patterns_by_session.values())):
813
+ srcs = ", ".join(sorted(src_ips_by_pattern.get(pattern, set()))[:5])
814
+ dsts = ", ".join(sorted(dst_ips_by_pattern.get(pattern, set()))[:5])
815
+ session_names = [n for n, p in patterns_by_session.items() if p == pattern]
816
+ pattern_details.append(
817
+ f" - {pattern}: {len(session_names)} session(s) from {srcs} targeting {dsts}"
818
+ )
819
+ pattern_section = "\n".join(pattern_details) if pattern_details else " - No patterns classified"
820
+
821
+ # Tagged pattern summary
822
+ tagged_details = []
823
+ for session_name in sorted(tagged):
824
+ pattern = patterns_by_session.get(session_name, "unknown")
825
+ tagged_details.append(f"{session_name}={pattern}")
826
+ tagged_section = "; ".join(tagged_details) if tagged_details else "none"
827
+
828
+ entry_section = f"Entry point: {entry_point}" if entry_point else "Entry point: not identified"
829
+
830
  return (
831
+ f"INCIDENT REPORT\n\n"
832
+ f"Summary: Detected {len(flagged)} malicious packets across "
833
+ f"{len(sessions)} attack sessions.\n\n"
834
+ f"Attack Patterns:\n{pattern_section}\n\n"
835
+ f"Tagged Sessions: {tagged_section}\n\n"
836
+ f"{entry_section}\n\n"
837
+ f"Total flagged: {len(flagged)} | Total sessions: {len(sessions)} | "
838
+ f"Classified sessions: {len(tagged)}"
839
  )
840
 
841
 
 
856
  inspect_count = sum(1 for a in previous_actions if '"inspect_packet"' in a)
857
  revealed_count = sum(1 for p in obs.visible_packets if p.is_revealed)
858
  inspect_limit = {
859
+ "easy": 25,
860
+ "medium": 18,
861
+ "hard": 25,
862
+ }.get(task_name, 15)
863
 
864
  if action.action_type not in {
865
  "inspect_packet",
 
880
  return "Missing packet_id for inspect_packet"
881
  if action.packet_id not in {p.packet_id for p in obs.visible_packets}:
882
  return f"Invalid packet_id {action.packet_id} - not in visible_packets"
883
+ inspected_ids = agent_state.setdefault("inspected_ids", set())
884
+ if action.packet_id in inspected_ids:
885
+ return f"Packet {action.packet_id} was already inspected. Choose a different hidden packet."
886
  revealed_ids = {p.packet_id for p in obs.visible_packets if p.is_revealed}
887
  if action.packet_id in revealed_ids:
888
  return f"Packet {action.packet_id} is ALREADY revealed. Choose a HIDDEN packet."
889
+ if action.packet_id in set(obs.flagged_packet_ids):
890
+ return (
891
+ f"Packet {action.packet_id} is already flagged. Inspect a new hidden unflagged packet instead."
892
+ )
893
+ revealed_unflagged_malicious = [
894
+ p.packet_id
895
+ for p in obs.visible_packets
896
+ if p.is_revealed
897
+ and p.packet_id not in set(obs.flagged_packet_ids)
898
+ and keyword_to_pattern(packet_payload_text(p))
899
+ ]
900
+ if revealed_unflagged_malicious:
901
+ return (
902
+ "Recall-first policy: revealed malicious packets exist and must be flagged before new inspection."
903
+ )
904
+ grouped_unflagged = [
905
+ pid
906
+ for packet_ids in (obs.grouped_sessions or {}).values()
907
+ for pid in packet_ids
908
+ if pid not in set(obs.flagged_packet_ids)
909
+ ]
910
+ if grouped_unflagged:
911
  return (
912
+ "Recall-first policy: grouped session packets remain unflagged. Flag them before further inspection."
913
  )
914
+ if task_name == "easy" and len(flagged_ids) >= 4:
915
+ grouped_session_names = set((obs.grouped_sessions or {}).keys())
916
+ for key, items in session_candidates(obs):
917
+ if key in sessions:
918
+ continue
919
+ if len(items) >= 4:
920
+ return (
921
+ "Exploit mode: enough evidence exists. Group high-confidence attack flows before more inspection."
922
+ )
923
+ if inspect_count >= inspect_limit and (len(sessions) > 0 or len(flagged_ids) > 0 or revealed_count >= 4):
924
+ # Only block inspections for medium/hard modes; easy mode needs discovery
925
+ if task_name != "easy":
926
+ return (
927
+ f"Inspection budget reached for {task_name}. Shift to flagging, grouping, tagging, or report submission."
928
+ )
929
 
930
  if action.action_type == "flag_as_suspicious":
931
  if not action.packet_id:
 
945
  }
946
  if invalid_ids:
947
  return f"Invalid packet_ids in session: {invalid_ids}"
948
+ if action.session_name in sessions.values():
949
+ return f"Session name {action.session_name} is already used."
950
+ min_flagged_before_group = 1 if task_name == "easy" else 1
951
+ if len(flagged_ids) < min_flagged_before_group:
952
+ return (
953
+ f"Group blocked until enough evidence is flagged ({len(flagged_ids)}/{min_flagged_before_group}). "
954
+ "Inspect and flag suspicious packets first."
955
+ )
956
+ new_group_ids = set(action.packet_ids)
957
+ for existing_ids in (obs.grouped_sessions or {}).values():
958
+ existing_set = set(existing_ids)
959
+ if not existing_set:
960
+ continue
961
+ overlap = len(new_group_ids & existing_set) / max(1, len(new_group_ids))
962
+ if overlap >= 0.8:
963
+ return "This grouping heavily overlaps an existing session. Prioritize new evidence."
964
  untagged_backlog = max(0, len(sessions) - len(tagged_sessions))
965
  if untagged_backlog > UNTAGGED_BACKLOG_LIMIT:
966
  return (
 
986
 
987
  if action.action_type == "submit_report":
988
  untagged_backlog = max(0, len(sessions) - len(tagged_sessions))
989
+ total_visible = max(1, len(obs.visible_packets))
990
+ flagged_count = len(obs.flagged_packet_ids)
991
+ coverage = flagged_count / total_visible
992
+ min_cov = TASK_COVERAGE_TARGETS.get(task_name, 0.25) * 0.6
993
+ min_flags = 4 if task_name == "easy" else (3 if task_name == "medium" else 4)
994
+ min_groups = 1 if task_name == "easy" else (2 if task_name == "medium" else 2)
995
+ if (
996
+ obs.steps_remaining > 2
997
+ and obs.current_score_estimate < 0.40
998
+ and not should_submit_early(task_name, obs, agent_state)
999
+ ):
1000
  return (
1001
  "Premature report submission. Improve coverage and score estimate before submit_report."
1002
  )
1003
+ if obs.steps_remaining > 1 and (coverage < min_cov or flagged_count < min_flags):
1004
+ return (
1005
+ f"Premature report submission. Need stronger recall coverage before submit_report "
1006
+ f"(coverage {coverage:.0%}/{min_cov:.0%}, flags {flagged_count}/{min_flags})."
1007
+ )
1008
+ if obs.steps_remaining > 1 and len(sessions) < min_groups:
1009
+ return (
1010
+ f"Premature report submission. Need stronger session evidence before submit_report "
1011
+ f"(grouped {len(sessions)}/{min_groups})."
1012
+ )
1013
+ if task_name == "hard" and obs.steps_remaining > 3 and untagged_backlog > 0:
1014
  return "Premature report submission. Tag pending sessions before submitting report."
1015
+ # CRITICAL: Hard mode zero-out if no entry point identified
1016
+ if task_name == "hard" and not (agent_state.get("claimed_entry_point") or obs.claimed_entry_point):
1017
+ return (
1018
+ "FATAL: Hard mode requires identify_entry_point before submit_report. "
1019
+ "No entry point claimed yet — score will be 0.0 without it. "
1020
+ "Use identify_entry_point with the earliest malicious packet first."
1021
+ )
1022
+ # Medium mode: need entry point for good logic_score
1023
+ if task_name == "medium" and obs.steps_remaining > 5 and not (agent_state.get("claimed_entry_point") or obs.claimed_entry_point):
1024
+ return (
1025
+ "Missing entry point. Use identify_entry_point before submit_report for higher score."
1026
+ )
1027
+ # Require minimum tagging coverage for medium/hard
1028
+ min_tagged = 1 if task_name == "medium" else 2
1029
+ if task_name in {"medium", "hard"} and len(tagged_sessions) < min_tagged and obs.steps_remaining > 3:
1030
+ return (
1031
+ f"Premature report submission. Need at least {min_tagged} tagged session(s) before submit_report "
1032
+ f"(currently {len(tagged_sessions)})."
1033
+ )
1034
 
1035
  if action.action_type == "tag_pattern":
1036
  if not action.session_name:
1037
  return "Missing session_name for tag_pattern"
1038
  if not action.pattern_type:
1039
  return "Missing pattern_type for tag_pattern"
1040
+ if action.session_name in set((obs.tagged_patterns or {}).keys()):
1041
+ return f"Session {action.session_name} is already tagged."
1042
+ if task_name == "easy" and obs.steps_remaining > 8:
1043
+ return "For easy mode, prioritize recall actions (inspect/flag/group) before tagging."
1044
  valid_patterns = {
1045
  "ddos", "dos_slowloris", "dos_slowhttptest", "dos_goldeneye", "dos_hulk",
1046
  "heartbleed", "web_sql_injection", "web_xss", "web_bruteforce",
 
1052
  if action.action_type == "identify_entry_point":
1053
  if not action.claimed_entry_point:
1054
  return "Missing claimed_entry_point for identify_entry_point"
1055
+ # Lenient gating for easy mode
1056
+ min_flags_needed = 1 if task_name == "easy" else (2 if task_name == "medium" else 2)
1057
+ if obs.steps_remaining > 8 and len(flagged_ids) < min_flags_needed:
1058
  return (
1059
  "Premature entry-point claim. Gather and flag more evidence before identify_entry_point."
1060
  )
 
1071
  ) -> NetworkForensicsAction:
1072
  agent_state["current_task_name"] = task_name
1073
  agent_state["strategy_hints"] = derive_strategy_hints(obs, agent_state)
1074
+ if should_submit_early(task_name, obs, agent_state):
1075
+ action = NetworkForensicsAction(
1076
+ action_type="submit_report",
1077
+ incident_summary=_build_report_summary(obs, agent_state),
1078
+ claimed_entry_point=agent_state.get("claimed_entry_point") or obs.claimed_entry_point,
1079
+ )
1080
+ append_action_history(agent_state, action)
1081
+ return action
1082
  history = agent_state.get("previous_actions", [])[-HISTORY_WINDOW:]
1083
  history_str = "\n".join([f"Step {i+1}: {a}" for i, a in enumerate(history)])
1084
 
 
1093
  "Follow the JSON schema in the system prompt."
1094
  )
1095
 
1096
+ try:
1097
+ response = client.chat.completions.create(
1098
+ model=model_name or MODEL_NAME,
1099
+ temperature=0.1,
1100
+ timeout=LLM_TIMEOUT_S,
1101
+ messages=[
1102
+ {"role": "system", "content": SYSTEM_PROMPT},
1103
+ {
1104
+ "role": "user",
1105
+ "content": f"TASK: {task_name}{correction_text}\n\n### RECENT HISTORY:\n{history_str}\n\n### CURRENT OBSERVATION:\n{summarize_observation(obs, agent_state)}",
1106
+ },
1107
+ ],
1108
+ )
1109
+ except Exception as llm_exc:
1110
+ print(f"[WARN] LLM call failed/timed out: {llm_exc}")
1111
+ fallback = build_fallback_action(task_name, obs, agent_state)
1112
+ append_action_history(agent_state, fallback)
1113
+ return fallback
1114
  content = response.choices[0].message.content or ""
1115
  try:
1116
  action = sanitize_action(parse_action(content))
 
1272
  return result
1273
 
1274
 
1275
+ WS_RETRY_COUNT = 3
1276
+ WS_RETRY_DELAY_S = 2.0
1277
+ LLM_TIMEOUT_S = 45.0
1278
+
1279
+
1280
+ def step_env_with_retry(
1281
+ env: NetworkForensicsEnv,
1282
+ action: NetworkForensicsAction,
1283
+ task_name: str,
1284
+ agent_state: dict[str, Any],
1285
+ ) -> tuple[Any, NetworkForensicsEnv | None]:
1286
+ """Try step_env with retries on WebSocket timeout.
1287
+
1288
+ Returns (step_result, new_env_or_None).
1289
+ If the WebSocket connection drops, reconnects and retries.
1290
+ """
1291
+ last_exc = None
1292
+ for attempt in range(1, WS_RETRY_COUNT + 1):
1293
+ try:
1294
+ result = step_env(env, action)
1295
+ return result, None
1296
+ except Exception as exc:
1297
+ last_exc = exc
1298
+ exc_str = str(exc).lower()
1299
+ is_ws_timeout = any(
1300
+ kw in exc_str
1301
+ for kw in ("keepalive", "ping timeout", "1011", "websocket", "connection")
1302
+ )
1303
+ if not is_ws_timeout:
1304
+ raise
1305
+ print(
1306
+ f"[WARN] WebSocket timeout on attempt {attempt}/{WS_RETRY_COUNT}: {exc}"
1307
+ )
1308
+ if attempt < WS_RETRY_COUNT:
1309
+ time.sleep(WS_RETRY_DELAY_S * attempt)
1310
+ # Try reconnecting
1311
+ try:
1312
+ close_env(env)
1313
+ except Exception:
1314
+ pass
1315
+ try:
1316
+ env = create_env()
1317
+ reset_result = reset_env(env, task_name)
1318
+ obs = reset_result.observation
1319
+ sync_agent_state(obs, agent_state)
1320
+ print(f"[INFO] Reconnected to environment, resuming task={task_name}")
1321
+ except Exception as reconnect_exc:
1322
+ print(f"[WARN] Reconnect failed: {reconnect_exc}")
1323
+ continue
1324
+ raise last_exc # type: ignore[misc]
1325
+
1326
+
1327
  def close_env(env: NetworkForensicsEnv | None) -> None:
1328
  if env is None:
1329
  return
 
1356
  obs = reset_result.observation
1357
  sync_agent_state(obs, agent_state)
1358
  max_steps = obs.steps_remaining or 50
1359
+ soft_budget = min(max_steps, SOFT_STEP_BUDGETS.get(task_name, max_steps))
1360
+ hard_budget = min(max_steps, HARD_STEP_CAPS.get(task_name, max_steps))
1361
+ start_ts = time.monotonic()
1362
+ task_time_budget = min(MAX_TASK_SECONDS, TASK_TIME_BUDGET_SECONDS.get(task_name, MAX_TASK_SECONDS))
1363
 
1364
+ for _ in range(hard_budget):
1365
  if obs.done:
1366
  break
1367
 
1368
+ elapsed = time.monotonic() - start_ts
1369
+ total_visible = max(1, len(obs.visible_packets))
1370
+ current_coverage = len(obs.flagged_packet_ids) / total_visible
1371
+ min_cov = TASK_COVERAGE_TARGETS.get(task_name, 0.25)
1372
+ ready_for_budget_submit = (
1373
+ obs.step_number >= soft_budget
1374
+ and should_submit_early(task_name, obs, agent_state)
1375
+ )
1376
+ forced_at_hard_cap = (
1377
+ obs.step_number >= max(1, hard_budget - 1)
1378
+ and (should_submit_early(task_name, obs, agent_state) or task_name != "easy")
1379
+ )
1380
+ nearing_time_limit = elapsed >= max(20.0, task_time_budget - 12.0)
1381
+
1382
  error = None
1383
  try:
1384
+ if forced_at_hard_cap or nearing_time_limit or ready_for_budget_submit:
1385
+ action = NetworkForensicsAction(
1386
+ action_type="submit_report",
1387
+ incident_summary=_build_report_summary(obs, agent_state),
1388
+ claimed_entry_point=agent_state.get("claimed_entry_point") or obs.claimed_entry_point,
1389
+ )
1390
+ else:
1391
+ action = choose_action(client, task_name, obs, agent_state)
1392
  except Exception as exc:
1393
  error = str(exc).replace("\n", " ")
1394
  action = build_fallback_action(task_name, obs, agent_state)
1395
 
1396
+ try:
1397
+ step_result, new_env = step_env_with_retry(env, action, task_name, agent_state)
1398
+ if new_env is not None:
1399
+ env = new_env
1400
+ except Exception as exc:
1401
+ print(f"[WARN] step failure on task={task_name}: {exc}")
1402
+ break
1403
  obs = step_result.observation
1404
  sync_agent_state(obs, agent_state)
1405
  step_reward = float(step_result.reward or 0.0)
server/gradio_ui.py CHANGED
@@ -453,12 +453,6 @@ def create_demo() -> gr.Blocks:
453
 
454
  with gr.Blocks(
455
  title="NetForensics-RL · Analyst Console",
456
- theme=gr.themes.Base(
457
- primary_hue="blue",
458
- neutral_hue="slate",
459
- font=gr.themes.GoogleFont("Inter"),
460
- ),
461
- css=css,
462
  ) as demo:
463
  with gr.Column(elem_classes=["app-shell"]):
464
  gr.HTML(f"<style>{css}</style>")
 
453
 
454
  with gr.Blocks(
455
  title="NetForensics-RL · Analyst Console",
 
 
 
 
 
 
456
  ) as demo:
457
  with gr.Column(elem_classes=["app-shell"]):
458
  gr.HTML(f"<style>{css}</style>")