import json import os import sys import asyncio import inspect import random import time from pathlib import Path from typing import Any from dotenv import load_dotenv from openai import OpenAI from openenv.core.containers.runtime.providers import LocalDockerProvider sys.path.insert(0, str(Path(__file__).parent)) from client import NetworkForensicsEnv from models import NetworkForensicsAction load_dotenv(Path(__file__).parent / ".env") API_BASE_URL = os.getenv("API_BASE_URL") MODEL_NAME = os.getenv("MODEL_NAME", "openai/gpt-oss-120b") API_KEY = os.getenv("OPENAI_API_KEY") or os.getenv("API_KEY") or os.getenv("HF_TOKEN") LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME", "network-forensics-env:latest") ENV_MODE = ( os.getenv("NETWORK_FORENSICS_ENV_MODE") or os.getenv("ENV_MODE") or "hf" ).lower() ENV_BASE_URL = os.getenv("ENV_BASE_URL", "http://localhost:8000") HF_SPACE_ID = ( os.getenv("HF_SPACE_ID") or os.getenv("SPACE_ID") or "WHOAM-EYE/network_forensics" ) HF_SPACE_URL = os.getenv("HF_SPACE_URL", "https://whoam-eye-network-forensics.hf.space") DOCKER_READY_TIMEOUT_S = float(os.getenv("DOCKER_READY_TIMEOUT_S", "120")) _ASYNC_LOOP: asyncio.AbstractEventLoop | None = None SYSTEM_PROMPT = """You are a senior Network Forensics Analyst. Your goal is to investigate malicious network traffic and achieve a 100% detection score. ### SCORING RULES: - You MUST identify and `flag_as_suspicious` EVERY malicious packet to maximize RECALL (very important!). - Only grouped packets or flagged packets contribute towards your score. - If RECALL is < 0.5, your score will be 0.0. DO NOT stop until you have flagged/grouped at least 60% of visible malicious packets. - Entry point must be the EARLIEST packet that initiated the attack (often in first group). - For HARD tasks: wrong entry point = score 0. Always identify_entry_point before submitting. ### WORKFLOW: 1. **Explore**: `inspect_packet` on suspicious samples. 2. **Flag**: `flag_as_suspicious` on ALL revealed malicious packets. 3. **Correlate**: `group_into_session` with descriptive names. 4. **Classify**: `tag_pattern` with a valid type. 5. **Root Cause**: `identify_entry_point` with the earliest malicious packet. 6. **Report**: `submit_report` ONLY when you have covered all visible malicious sessions. ### VALID PATTERN TYPES: ddos, dos_slowloris, dos_slowhttptest, dos_goldeneye, dos_hulk, heartbleed, web_sql_injection, web_xss, web_bruteforce, c2, exfiltration, scan, lateral ### JSON SCHEMA EXAMPLES (Use these exactly): - Inspect: {"action_type":"inspect_packet","packet_id":"pkt_0001"} - Flag: {"action_type":"flag_as_suspicious","packet_id":"pkt_0001"} - Group: {"action_type":"group_into_session","session_name":"DDoS_Burst_2","packet_ids":["pkt_0001","pkt_0002"]} - Tag: {"action_type":"tag_pattern","session_name":"DDoS_Burst_2","pattern_type":"ddos"} - Entry: {"action_type":"identify_entry_point","claimed_entry_point":"pkt_0001"} - Report: {"action_type":"submit_report","incident_summary":"Detailed incident summary here.","claimed_entry_point":"pkt_0001"}""" HISTORY_WINDOW = 20 REPEAT_ACTION_LIMIT = 3 CORRECTION_WINDOW = 5 UNTAGGED_BACKLOG_LIMIT = 6 INSPECT_SOFT_RATIO_THRESHOLD = 0.60 SOFT_STEP_BUDGETS = {"easy": 14, "medium": 28, "hard": 40} HARD_STEP_CAPS = {"easy": 30, "medium": 50, "hard": 65} TASK_SCORE_TARGETS = {"easy": 0.70, "medium": 0.68, "hard": 0.66} TASK_COVERAGE_TARGETS = {"easy": 0.32, "medium": 0.24, "hard": 0.20} MAX_TASK_SECONDS = float(os.getenv("MAX_TASK_SECONDS", "780")) TASK_TIME_BUDGET_SECONDS = { "easy": float(os.getenv("EASY_MAX_SECONDS", "150")), "medium": float(os.getenv("MEDIUM_MAX_SECONDS", "220")), "hard": float(os.getenv("HARD_MAX_SECONDS", "320")), } def build_client() -> OpenAI: return OpenAI(base_url=API_BASE_URL, api_key=API_KEY) def validate_config() -> None: missing = [] if not API_BASE_URL: missing.append("API_BASE_URL") if not API_KEY: missing.append("OPENAI_API_KEY/API_KEY/HF_TOKEN") if ENV_MODE == "hf" and not (HF_SPACE_URL or HF_SPACE_ID): missing.append("HF_SPACE_URL or HF_SPACE_ID/SPACE_ID") if missing: raise RuntimeError( f"Missing required environment variables: {', '.join(missing)}" ) if ENV_MODE not in {"server", "docker", "hf"}: raise RuntimeError( "NETWORK_FORENSICS_ENV_MODE must be one of: server, docker, hf" ) def format_action(action: NetworkForensicsAction) -> str: payload = action.model_dump(exclude_none=True, exclude_defaults=True) payload.pop("metadata", None) payload = { key: value for key, value in payload.items() if value not in ("", [], {}) } return json.dumps(payload, separators=(",", ":")) def summarize_observation(obs: Any, agent_state: dict[str, Any]) -> str: """Provide a compact structured summary for low-latency policy learning.""" packets = obs.visible_packets revealed = [p for p in packets if p.is_revealed] revealed_ids = [p.packet_id for p in revealed] sessions = obs.grouped_sessions or {} tags = obs.tagged_patterns or {} untagged_sessions = [s for s in sessions.keys() if s not in tags] last_reward = agent_state.get("last_step_reward") reward_feedback = agent_state.get("last_reward_feedback", "n/a") recent_corrections = agent_state.get("recent_corrections", [])[-CORRECTION_WINDOW:] strategy_hints = agent_state.get("strategy_hints", []) task_name = agent_state.get("current_task_name", "") flagged_count = len(obs.flagged_packet_ids) total_visible = max(1, len(obs.visible_packets)) coverage = flagged_count / total_visible coverage_target = TASK_COVERAGE_TARGETS.get(task_name, 0.25) score_target = TASK_SCORE_TARGETS.get(task_name, 0.65) grouped_count = len(sessions) tagged_count = len(tags) ready_to_submit = ( obs.current_score_estimate >= score_target and coverage >= coverage_target and (task_name == "easy" or grouped_count >= 2) and (task_name == "easy" or tagged_count >= 1) ) summary = [ f"Step: {obs.step_number}/{obs.step_number + obs.steps_remaining}", f"Current Progress: {obs.current_score_estimate:.2f}", f"Coverage: {flagged_count}/{total_visible} ({coverage:.2%}) | target {coverage_target:.0%}", f"Sessions: grouped={grouped_count}, tagged={tagged_count}", f"Submit Readiness: {'READY' if ready_to_submit else 'KEEP INVESTIGATING'}", f"Last Step Reward: {last_reward:.2f}" if isinstance(last_reward, (int, float)) else "Last Step Reward: n/a", f"Last Reward Feedback: {reward_feedback}", f"ALREADY REVEALED: {', '.join(revealed_ids[-6:])} " + ("..." if len(revealed_ids) > 6 else ""), "\n### SESSIONS PENDING TAGGING:", ] if recent_corrections: summary.append("\n### RECENT CORRECTIONS:") for reason in recent_corrections: summary.append(f"- {reason}") if strategy_hints: summary.append("\n### STRATEGY HINTS:") for hint in strategy_hints: summary.append(f"- {hint}") if untagged_sessions: for s in untagged_sessions[:6]: summary.append(f"- {s} ({len(sessions[s])} packets)") else: summary.append("- [No pending sessions]") summary.append("\n### REVEALED INDICATORS:") for p in revealed[-4:]: payload = (p.full_payload or "")[:150] if payload: summary.append(f"- {p.packet_id}: {payload}") summary.append("\n### UNKNOWN PACKETS (Must Inspect):") unknown = [p for p in packets if not p.is_revealed][:10] for p in unknown: summary.append(f"- {p.packet_id} | {p.src_ip} -> {p.dst_ip} | Proto: {p.protocol}") return "\n".join(summary) def parse_action(raw_text: str) -> NetworkForensicsAction: text = raw_text.strip() start = text.find("{") end = text.rfind("}") if start == -1 or end == -1: raise ValueError("model did not return JSON") data = json.loads(text[start : end + 1]) data.pop("metadata", None) for key in ("session_name", "pattern_type", "claimed_entry_point"): if data.get(key) == "": data.pop(key, None) if data.get("packet_ids") == []: data.pop("packet_ids", None) return NetworkForensicsAction(**data) def sanitize_action(action: NetworkForensicsAction) -> NetworkForensicsAction: payload = {"action_type": action.action_type} if ( action.action_type in {"inspect_packet", "flag_as_suspicious"} and action.packet_id ): payload["packet_id"] = action.packet_id elif action.action_type == "group_into_session": if action.session_name: payload["session_name"] = action.session_name if action.packet_ids: payload["packet_ids"] = action.packet_ids elif action.action_type == "tag_pattern": if action.session_name: payload["session_name"] = action.session_name if action.pattern_type: payload["pattern_type"] = action.pattern_type elif action.action_type == "identify_entry_point" and action.claimed_entry_point: payload["claimed_entry_point"] = action.claimed_entry_point if action.action_type == "submit_report": if action.incident_summary: payload["incident_summary"] = action.incident_summary if action.claimed_entry_point: payload["claimed_entry_point"] = action.claimed_entry_point return NetworkForensicsAction(**payload) def decode_payload_preview(payload_preview: str) -> str: preview = (payload_preview or "").strip() compact = "".join(preview.split()) if compact and len(compact) % 2 == 0: try: decoded = bytes.fromhex(compact).decode("utf-8", errors="ignore").strip() if decoded: return decoded except ValueError: pass return preview def packet_payload_text(packet: Any) -> str: return packet.full_payload or decode_payload_preview(packet.payload_preview) def keyword_to_pattern(payload: str) -> str | None: text = payload.lower() # --- DoS / DDoS variants --- if "slowloris" in text: return "dos_slowloris" if "slowhttptest" in text or "slow http" in text: return "dos_slowhttptest" if "goldeneye" in text or "golden eye" in text: return "dos_goldeneye" if "hulk" in text: return "dos_hulk" if "heartbeat" in text or "heartbleed" in text or ("tls" in text and "ext" in text): return "heartbleed" if "flood" in text or "burst" in text or "ddos" in text: return "ddos" # HTTP flood indicators (repeated GET/POST to same endpoint) if text.startswith("get /") or text.startswith("post /") or text.startswith("get http"): if "accept-encoding" in text or "connection" in text or "keep-alive" in text: return "ddos" # SYN flood / connection flood if "syn" in text and "ack" not in text and len(text) < 30: return "ddos" # ICMP flood if "icmp" in text and ("echo" in text or "request" in text or len(text) < 20): return "ddos" # --- Web attacks --- if "xss" in text or "