#!/usr/bin/env python3
"""
Analyze memory usage of opencode and dart processes, mapping them to tmux windows.
"""
from __future__ import annotations

import subprocess
import re
from dataclasses import dataclass, field
from collections import defaultdict
from typing import Optional


@dataclass
class Process:
    pid: int
    mem_mb: float
    compressed_mb: float
    command: str
    parent_pid: Optional[int] = None
    parent_cmd: Optional[str] = None


@dataclass
class TmuxPane:
    session: str
    window_idx: int
    window_name: str
    pane_pid: int


def run_cmd(cmd: str) -> str:
    result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
    return result.stdout


def parse_mem(mem_str: str) -> float:
    """Parse memory string like '1234M' or '2G' to MB."""
    mem_str = mem_str.strip()
    if not mem_str or mem_str == '-':
        return 0.0
    match = re.match(r'([\d.]+)([KMGB]?)', mem_str, re.IGNORECASE)
    if not match:
        return 0.0
    value = float(match.group(1))
    unit = match.group(2).upper()
    if unit == 'K':
        return value / 1024
    elif unit == 'G':
        return value * 1024
    return value


def fmt_mem(mb: float) -> str:
    """Format memory in human readable form (GB or MB)."""
    if mb >= 1024:
        return f"{mb / 1024:.1f}G"
    return f"{mb:.0f}M"


def get_all_process_memory() -> dict[int, tuple[float, float]]:
    """Get MEM and CMPRS for all processes using a single top call."""
    output = run_cmd('top -l 1 -o mem -n 500 -stats pid,mem,cmprs 2>/dev/null')
    result = {}
    for line in output.strip().split('\n'):
        parts = line.split()
        if len(parts) >= 3 and parts[0].isdigit():
            pid = int(parts[0])
            mem = parse_mem(parts[1])
            cmprs = parse_mem(parts[2])
            result[pid] = (mem, cmprs)
    return result


def get_target_processes() -> list[Process]:
    """Get all opencode and dart processes."""
    output = run_cmd("ps -eo pid,ppid,rss,comm | grep -E 'opencode|dart'")
    processes = []
    
    mem_data = get_all_process_memory()
    
    for line in output.strip().split('\n'):
        if not line.strip():
            continue
        parts = line.split()
        if len(parts) >= 4:
            try:
                pid = int(parts[0])
                ppid = int(parts[1])
                comm = ' '.join(parts[3:])
                if 'opencode' in comm or 'dart' in comm:
                    mem, cmprs = mem_data.get(pid, (0.0, 0.0))
                    processes.append(Process(
                        pid=pid,
                        mem_mb=mem,
                        compressed_mb=cmprs,
                        command=comm,
                        parent_pid=ppid
                    ))
            except ValueError:
                continue
    
    return processes


def get_tmux_panes() -> list[TmuxPane]:
    """Get all tmux panes with their shell PIDs."""
    sessions = run_cmd("tmux list-sessions -F '#{session_name}' 2>/dev/null").strip().split('\n')
    panes = []
    
    for session in sessions:
        if not session:
            continue
        output = run_cmd(f"tmux list-panes -s -t '{session}' -F '#{{window_index}}:#{{window_name}}:#{{pane_pid}}' 2>/dev/null")
        for line in output.strip().split('\n'):
            if not line or ':' not in line:
                continue
            parts = line.split(':')
            if len(parts) >= 3:
                try:
                    panes.append(TmuxPane(
                        session=session,
                        window_idx=int(parts[0]),
                        window_name=parts[1],
                        pane_pid=int(parts[2])
                    ))
                except ValueError:
                    continue
    
    return panes


def get_descendants(pid: int, max_depth: int = 5) -> set[int]:
    """Get all descendant PIDs of a process."""
    descendants = set()
    to_check = [pid]
    depth = 0
    
    while to_check and depth < max_depth:
        next_check = []
        for p in to_check:
            children = run_cmd(f"pgrep -P {p} 2>/dev/null").strip().split('\n')
            for child in children:
                if child.isdigit():
                    child_pid = int(child)
                    if child_pid not in descendants:
                        descendants.add(child_pid)
                        next_check.append(child_pid)
        to_check = next_check
        depth += 1
    
    return descendants


def get_parent_chain(pid: int, max_depth: int = 10) -> list[tuple[int, str]]:
    """Get parent chain up to launchd."""
    chain = []
    current = pid
    
    for _ in range(max_depth):
        output = run_cmd(f"ps -p {current} -o ppid=,comm= 2>/dev/null").strip()
        if not output:
            break
        parts = output.split(None, 1)
        if len(parts) < 2:
            break
        try:
            ppid = int(parts[0])
            comm = parts[1]
            chain.append((ppid, comm))
            if ppid <= 1:
                break
            current = ppid
        except ValueError:
            break
    
    return chain


def is_orphaned_opencode(proc: Process) -> bool:
    """Check if an opencode process is orphaned (parent is launchd or another opencode)."""
    if 'opencode' not in proc.command:
        return False
    chain = get_parent_chain(proc.pid)
    if not chain:
        return True
    parent_pid, parent_cmd = chain[0]
    if parent_pid == 1:
        return True
    if 'node' in parent_cmd:
        if len(chain) > 1 and chain[1][0] == 1:
            return True
    return False


def main():
    processes = get_target_processes()
    panes = get_tmux_panes()
    
    pane_descendants: dict[int, set[int]] = {}
    for pane in panes:
        pane_descendants[pane.pane_pid] = get_descendants(pane.pane_pid)
    
    tmux_mapping: dict[int, TmuxPane] = {}
    for proc in processes:
        for pane in panes:
            if proc.pid in pane_descendants[pane.pane_pid]:
                tmux_mapping[proc.pid] = pane
                break
    
    by_window: dict[str, list[Process]] = defaultdict(list)
    orphans: list[Process] = []
    orphaned_opencode_pids: set[int] = set()
    
    for proc in processes:
        if 'opencode' in proc.command and is_orphaned_opencode(proc):
            orphaned_opencode_pids.add(proc.pid)
    
    for proc in processes:
        chain = get_parent_chain(proc.pid)
        if chain:
            proc.parent_pid = chain[0][0]
            proc.parent_cmd = chain[0][1]
        
        is_child_of_orphan = any(
            p[0] in orphaned_opencode_pids for p in chain
        )
        
        if proc.pid in orphaned_opencode_pids or is_child_of_orphan:
            orphans.append(proc)
        elif proc.pid in tmux_mapping:
            pane = tmux_mapping[proc.pid]
            key = f"{pane.session}:{pane.window_idx}:{pane.window_name}"
            by_window[key].append(proc)
        else:
            orphans.append(proc)
    
    print("\n" + "=" * 70)
    print("AI MEMORY USAGE (opencode + dart)")
    print("=" * 70)
    
    window_totals = []
    for key in sorted(by_window.keys()):
        procs = by_window[key]
        parts = key.split(':')
        session, window_idx, window_name = parts[0], parts[1], parts[2]
        
        total_mem = sum(p.mem_mb for p in procs)
        total_cmprs = sum(p.compressed_mb for p in procs)
        window_totals.append((key, total_mem, total_cmprs, procs))
    
    window_totals.sort(key=lambda x: x[1] + x[2], reverse=True)
    
    for key, total_mem, total_cmprs, procs in window_totals:
        parts = key.split(':')
        session, window_idx, window_name = parts[0], parts[1], parts[2]
        
        total_virt = total_mem + total_cmprs
        print(f"\n[{session}] Window {window_idx}: {window_name}  ({fmt_mem(total_virt)} total)")
        
        procs.sort(key=lambda p: p.mem_mb + p.compressed_mb, reverse=True)
        for proc in procs:
            proc_type = "opencode" if "opencode" in proc.command else "dart"
            virt = proc.mem_mb + proc.compressed_mb
            print(f"  {proc.pid:5}  {fmt_mem(virt):>6}  ({fmt_mem(proc.mem_mb):>5} + {fmt_mem(proc.compressed_mb):>5} cmprs)  {proc_type}")
    
    total_orphan_mem = sum(p.mem_mb for p in orphans)
    total_orphan_cmprs = sum(p.compressed_mb for p in orphans)
    
    if orphans:
        total_orphan_virt = total_orphan_mem + total_orphan_cmprs
        print("\n" + "=" * 70)
        print(f"ORPHANS ({fmt_mem(total_orphan_virt)} total)")
        print("=" * 70)
        
        orphans.sort(key=lambda p: p.mem_mb + p.compressed_mb, reverse=True)
        
        for proc in orphans:
            proc_type = "opencode" if "opencode" in proc.command else "dart"
            virt = proc.mem_mb + proc.compressed_mb
            parent_info = proc.parent_cmd.split('/')[-1] if proc.parent_cmd else "?"
            print(f"  {proc.pid:5}  {fmt_mem(virt):>6}  ({fmt_mem(proc.mem_mb):>5} + {fmt_mem(proc.compressed_mb):>5} cmprs)  {proc_type}  <- {parent_info}")
        
        print(f"\nkill {' '.join(str(p.pid) for p in orphans)}")
    
    all_mem = sum(p.mem_mb for p in processes)
    all_cmprs = sum(p.compressed_mb for p in processes)
    all_virt = all_mem + all_cmprs
    in_tmux_virt = (all_mem - total_orphan_mem) + (all_cmprs - total_orphan_cmprs)
    
    print("\n" + "-" * 70)
    print(f"Total: {fmt_mem(all_virt)}  (tmux: {fmt_mem(in_tmux_virt)}", end="")
    if orphans:
        print(f", orphans: {fmt_mem(total_orphan_mem + total_orphan_cmprs)}", end="")
    print(")")


if __name__ == '__main__':
    main()
