theniceboy/bin/ai-mem-usage
2026-03-27 21:04:47 -07:00

299 lines
9.4 KiB
Python
Executable file

#!/usr/bin/env python3
"""
Analyze memory usage of opencode and dart processes, mapping them to tmux windows.
"""
from __future__ import annotations
import subprocess
import re
from dataclasses import dataclass, field
from collections import defaultdict
from typing import Optional
@dataclass
class Process:
pid: int
mem_mb: float
compressed_mb: float
command: str
parent_pid: Optional[int] = None
parent_cmd: Optional[str] = None
@dataclass
class TmuxPane:
session: str
window_idx: int
window_name: str
pane_pid: int
def run_cmd(cmd: str) -> str:
result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
return result.stdout
def parse_mem(mem_str: str) -> float:
"""Parse memory string like '1234M' or '2G' to MB."""
mem_str = mem_str.strip()
if not mem_str or mem_str == '-':
return 0.0
match = re.match(r'([\d.]+)([KMGB]?)', mem_str, re.IGNORECASE)
if not match:
return 0.0
value = float(match.group(1))
unit = match.group(2).upper()
if unit == 'K':
return value / 1024
elif unit == 'G':
return value * 1024
return value
def fmt_mem(mb: float) -> str:
"""Format memory in human readable form (GB or MB)."""
if mb >= 1024:
return f"{mb / 1024:.1f}G"
return f"{mb:.0f}M"
def get_all_process_memory() -> dict[int, tuple[float, float]]:
"""Get MEM and CMPRS for all processes using a single top call."""
output = run_cmd('top -l 1 -o mem -n 500 -stats pid,mem,cmprs 2>/dev/null')
result = {}
for line in output.strip().split('\n'):
parts = line.split()
if len(parts) >= 3 and parts[0].isdigit():
pid = int(parts[0])
mem = parse_mem(parts[1])
cmprs = parse_mem(parts[2])
result[pid] = (mem, cmprs)
return result
def get_target_processes() -> list[Process]:
"""Get all opencode and dart processes."""
output = run_cmd("ps -eo pid,ppid,rss,comm | grep -E 'opencode|dart'")
processes = []
mem_data = get_all_process_memory()
for line in output.strip().split('\n'):
if not line.strip():
continue
parts = line.split()
if len(parts) >= 4:
try:
pid = int(parts[0])
ppid = int(parts[1])
comm = ' '.join(parts[3:])
if 'opencode' in comm or 'dart' in comm:
mem, cmprs = mem_data.get(pid, (0.0, 0.0))
processes.append(Process(
pid=pid,
mem_mb=mem,
compressed_mb=cmprs,
command=comm,
parent_pid=ppid
))
except ValueError:
continue
return processes
def get_tmux_panes() -> list[TmuxPane]:
"""Get all tmux panes with their shell PIDs."""
sessions = run_cmd("tmux list-sessions -F '#{session_name}' 2>/dev/null").strip().split('\n')
panes = []
for session in sessions:
if not session:
continue
output = run_cmd(f"tmux list-panes -s -t '{session}' -F '#{{window_index}}:#{{window_name}}:#{{pane_pid}}' 2>/dev/null")
for line in output.strip().split('\n'):
if not line or ':' not in line:
continue
parts = line.split(':')
if len(parts) >= 3:
try:
panes.append(TmuxPane(
session=session,
window_idx=int(parts[0]),
window_name=parts[1],
pane_pid=int(parts[2])
))
except ValueError:
continue
return panes
def get_descendants(pid: int, max_depth: int = 5) -> set[int]:
"""Get all descendant PIDs of a process."""
descendants = set()
to_check = [pid]
depth = 0
while to_check and depth < max_depth:
next_check = []
for p in to_check:
children = run_cmd(f"pgrep -P {p} 2>/dev/null").strip().split('\n')
for child in children:
if child.isdigit():
child_pid = int(child)
if child_pid not in descendants:
descendants.add(child_pid)
next_check.append(child_pid)
to_check = next_check
depth += 1
return descendants
def get_parent_chain(pid: int, max_depth: int = 10) -> list[tuple[int, str]]:
"""Get parent chain up to launchd."""
chain = []
current = pid
for _ in range(max_depth):
output = run_cmd(f"ps -p {current} -o ppid=,comm= 2>/dev/null").strip()
if not output:
break
parts = output.split(None, 1)
if len(parts) < 2:
break
try:
ppid = int(parts[0])
comm = parts[1]
chain.append((ppid, comm))
if ppid <= 1:
break
current = ppid
except ValueError:
break
return chain
def is_orphaned_opencode(proc: Process) -> bool:
"""Check if an opencode process is orphaned (parent is launchd or another opencode)."""
if 'opencode' not in proc.command:
return False
chain = get_parent_chain(proc.pid)
if not chain:
return True
parent_pid, parent_cmd = chain[0]
if parent_pid == 1:
return True
if 'node' in parent_cmd:
if len(chain) > 1 and chain[1][0] == 1:
return True
return False
def main():
processes = get_target_processes()
panes = get_tmux_panes()
pane_descendants: dict[int, set[int]] = {}
for pane in panes:
pane_descendants[pane.pane_pid] = get_descendants(pane.pane_pid)
tmux_mapping: dict[int, TmuxPane] = {}
for proc in processes:
for pane in panes:
if proc.pid in pane_descendants[pane.pane_pid]:
tmux_mapping[proc.pid] = pane
break
by_window: dict[str, list[Process]] = defaultdict(list)
orphans: list[Process] = []
orphaned_opencode_pids: set[int] = set()
for proc in processes:
if 'opencode' in proc.command and is_orphaned_opencode(proc):
orphaned_opencode_pids.add(proc.pid)
for proc in processes:
chain = get_parent_chain(proc.pid)
if chain:
proc.parent_pid = chain[0][0]
proc.parent_cmd = chain[0][1]
is_child_of_orphan = any(
p[0] in orphaned_opencode_pids for p in chain
)
if proc.pid in orphaned_opencode_pids or is_child_of_orphan:
orphans.append(proc)
elif proc.pid in tmux_mapping:
pane = tmux_mapping[proc.pid]
key = f"{pane.session}:{pane.window_idx}:{pane.window_name}"
by_window[key].append(proc)
else:
orphans.append(proc)
print("\n" + "=" * 70)
print("AI MEMORY USAGE (opencode + dart)")
print("=" * 70)
window_totals = []
for key in sorted(by_window.keys()):
procs = by_window[key]
parts = key.split(':')
session, window_idx, window_name = parts[0], parts[1], parts[2]
total_mem = sum(p.mem_mb for p in procs)
total_cmprs = sum(p.compressed_mb for p in procs)
window_totals.append((key, total_mem, total_cmprs, procs))
window_totals.sort(key=lambda x: x[1] + x[2], reverse=True)
for key, total_mem, total_cmprs, procs in window_totals:
parts = key.split(':')
session, window_idx, window_name = parts[0], parts[1], parts[2]
total_virt = total_mem + total_cmprs
print(f"\n[{session}] Window {window_idx}: {window_name} ({fmt_mem(total_virt)} total)")
procs.sort(key=lambda p: p.mem_mb + p.compressed_mb, reverse=True)
for proc in procs:
proc_type = "opencode" if "opencode" in proc.command else "dart"
virt = proc.mem_mb + proc.compressed_mb
print(f" {proc.pid:5} {fmt_mem(virt):>6} ({fmt_mem(proc.mem_mb):>5} + {fmt_mem(proc.compressed_mb):>5} cmprs) {proc_type}")
total_orphan_mem = sum(p.mem_mb for p in orphans)
total_orphan_cmprs = sum(p.compressed_mb for p in orphans)
if orphans:
total_orphan_virt = total_orphan_mem + total_orphan_cmprs
print("\n" + "=" * 70)
print(f"ORPHANS ({fmt_mem(total_orphan_virt)} total)")
print("=" * 70)
orphans.sort(key=lambda p: p.mem_mb + p.compressed_mb, reverse=True)
for proc in orphans:
proc_type = "opencode" if "opencode" in proc.command else "dart"
virt = proc.mem_mb + proc.compressed_mb
parent_info = proc.parent_cmd.split('/')[-1] if proc.parent_cmd else "?"
print(f" {proc.pid:5} {fmt_mem(virt):>6} ({fmt_mem(proc.mem_mb):>5} + {fmt_mem(proc.compressed_mb):>5} cmprs) {proc_type} <- {parent_info}")
print(f"\nkill {' '.join(str(p.pid) for p in orphans)}")
all_mem = sum(p.mem_mb for p in processes)
all_cmprs = sum(p.compressed_mb for p in processes)
all_virt = all_mem + all_cmprs
in_tmux_virt = (all_mem - total_orphan_mem) + (all_cmprs - total_orphan_cmprs)
print("\n" + "-" * 70)
print(f"Total: {fmt_mem(all_virt)} (tmux: {fmt_mem(in_tmux_virt)}", end="")
if orphans:
print(f", orphans: {fmt_mem(total_orphan_mem + total_orphan_cmprs)}", end="")
print(")")
if __name__ == '__main__':
main()