#!/usr/bin/env python3
"""
V3 Parameter Optimizer — Grid search for optimal momentum strategy parameters.
Tests combinations of TOP_N, VIX threshold, stop-loss, rebalance frequency.
"""

import io
import json
import warnings
from datetime import datetime
from itertools import product

import numpy as np
import pandas as pd
import yfinance as yf

warnings.filterwarnings("ignore")

# ─── Fixed params ────────────────────────────────────────────────────────────
MOM_LOOKBACK = 126
MOM_SKIP = 21
MIN_PRICE = 10.0
SPY_SMA_PERIOD = 200
RISK_FREE = 0.04
ACCOUNT = 100000

# ─── Parameter grid ──────────────────────────────────────────────────────────
GRID = {
    "top_n":       [5, 8, 10, 15, 20],
    "vix_thresh":  [22, 25, 28, 30],
    "sl_pct":      [-0.07, -0.10, -0.12, -0.15],
    "rebal_freq":  [15, 21, 30],
}

def get_sp500_tickers():
    import urllib.request
    url = "https://en.wikipedia.org/wiki/List_of_S%26P_500_companies"
    try:
        req = urllib.request.Request(url, headers={"User-Agent": "Mozilla/5.0"})
        with urllib.request.urlopen(req, timeout=20) as resp:
            tables = pd.read_html(io.BytesIO(resp.read()))
        return [str(t).replace(".", "-") for t in tables[0]["Symbol"].tolist()]
    except:
        return []

def momentum_score(prices, lookback=126, skip=21):
    if len(prices) < lookback + skip:
        return np.nan
    start_p = prices[-(lookback + skip)]
    end_p = prices[-(skip + 1)]
    ret = (end_p / start_p) - 1
    daily_ret = np.diff(prices[-(lookback + skip):-(skip)]) / prices[-(lookback + skip):-skip-1]
    vol = np.std(daily_ret) * np.sqrt(252)
    return ret / vol if vol > 0 else np.nan

def rank_momentum(price_data, idx, tickers):
    results = []
    for t in tickers:
        if t not in price_data.columns:
            continue
        prices = price_data[t].iloc[:idx+1].dropna()
        if len(prices) < MOM_LOOKBACK + MOM_SKIP + 10:
            continue
        p = prices.values
        if p[-1] < MIN_PRICE:
            continue
        score = momentum_score(p, MOM_LOOKBACK, MOM_SKIP)
        if not np.isnan(score):
            results.append((t, score))
    results.sort(key=lambda x: x[1], reverse=True)
    return results

def run_backtest(price_data, spy_prices, vix_prices, tickers, 
                 top_n, vix_thresh, sl_pct, rebal_freq):
    """Run one backtest with given parameters. Returns metrics dict."""
    account = ACCOUNT
    bt_start_idx = max(price_data.index.searchsorted(pd.Timestamp("2020-01-01")),
                       MOM_LOOKBACK + MOM_SKIP + 50)
    
    equity = account
    cash = account
    holdings = {}
    equity_curve = []
    last_rebalance = 0
    days_in_cash = 0
    days_invested = 0
    circuit_breaker_acts = 0
    stopped_out = 0
    wins = 0
    losses = 0
    win_pnls = []
    loss_pnls = []
    
    for i in range(bt_start_idx, len(price_data)):
        # Portfolio value
        portfolio_value = cash
        for t, pos in holdings.items():
            if t in price_data.columns:
                cp = price_data[t].iloc[i]
                if not np.isnan(cp):
                    portfolio_value += pos["shares"] * cp
        equity = portfolio_value
        equity_curve.append(equity)
        
        # Stop-losses
        for t in list(holdings.keys()):
            if t in price_data.columns:
                cp = price_data[t].iloc[i]
                if not np.isnan(cp):
                    pnl = (cp / holdings[t]["entry"]) - 1
                    if pnl <= sl_pct:
                        cash += holdings[t]["shares"] * cp
                        stopped_out += 1
                        losses += 1
                        loss_pnls.append(pnl * 100)
                        del holdings[t]
        
        # Macro check
        spy_arr = spy_prices.iloc[:i+1].values
        sma200 = np.mean(spy_arr[-SPY_SMA_PERIOD:]) if len(spy_arr) >= SPY_SMA_PERIOD else spy_arr[-1]
        spy_below = spy_arr[-1] < sma200
        
        vix_arr = vix_prices.iloc[:i+1].dropna().values
        vix_val = vix_arr[-1] if len(vix_arr) > 0 else 20
        vix_high = vix_val > vix_thresh
        
        is_risk_off = spy_below and vix_high
        
        # Rebalance?
        if (i - last_rebalance) < rebal_freq:
            if holdings:
                days_invested += 1
            else:
                days_in_cash += 1
            continue
        
        last_rebalance = i
        
        # Circuit breaker
        if is_risk_off:
            circuit_breaker_acts += 1
            for t, pos in holdings.items():
                cp = price_data[t].iloc[i]
                if not np.isnan(cp):
                    pnl = (cp / pos["entry"]) - 1
                    cash += pos["shares"] * cp
                    if pnl > 0:
                        wins += 1; win_pnls.append(pnl * 100)
                    else:
                        losses += 1; loss_pnls.append(pnl * 100)
            holdings = {}
            days_in_cash += 1
            continue
        
        days_invested += 1
        
        # Rank & rotate
        rankings = rank_momentum(price_data, i, tickers)
        if not rankings:
            continue
        target = [r[0] for r in rankings[:top_n]]
        
        # Sell non-targets
        for t in list(holdings.keys()):
            if t not in target:
                cp = price_data[t].iloc[i]
                if not np.isnan(cp):
                    pnl = (cp / holdings[t]["entry"]) - 1
                    cash += holdings[t]["shares"] * cp
                    if pnl > 0:
                        wins += 1; win_pnls.append(pnl * 100)
                    else:
                        losses += 1; loss_pnls.append(pnl * 100)
                del holdings[t]
        
        # Buy new
        current_eq = cash + sum(
            holdings[t]["shares"] * price_data[t].iloc[i]
            for t in holdings if t in price_data.columns and not np.isnan(price_data[t].iloc[i])
        )
        target_per = current_eq / top_n
        
        for t in target:
            if t not in holdings and t in price_data.columns:
                cp = price_data[t].iloc[i]
                if np.isnan(cp) or cp <= 0:
                    continue
                shares = int(min(target_per, cash) / cp)
                if shares <= 0:
                    continue
                cash -= shares * cp
                holdings[t] = {"shares": shares, "entry": cp}
    
    # Final equity
    final = cash
    for t, pos in holdings.items():
        if t in price_data.columns:
            cp = price_data[t].iloc[-1]
            if not np.isnan(cp):
                final += pos["shares"] * cp
    
    # Metrics
    eq = pd.Series(equity_curve)
    years = max((price_data.index[-1] - price_data.index[bt_start_idx]).days / 365.25, 0.1)
    cagr = ((final / ACCOUNT) ** (1/years) - 1) * 100
    
    peak = eq.expanding().max()
    dd = (eq - peak) / peak * 100
    max_dd = abs(dd.min())
    
    daily_ret = eq.pct_change().dropna()
    if len(daily_ret) > 10:
        ann_vol = daily_ret.std() * np.sqrt(252)
        sharpe = (daily_ret.mean() * 252 - RISK_FREE) / ann_vol if ann_vol > 0 else 0
    else:
        sharpe = 0
    
    total_trades = wins + losses
    win_rate = wins / total_trades * 100 if total_trades > 0 else 0
    avg_win = np.mean(win_pnls) if win_pnls else 0
    avg_loss = np.mean(loss_pnls) if loss_pnls else 0
    
    spy_start = spy_prices.iloc[bt_start_idx]
    spy_end = spy_prices.iloc[-1]
    bh_cagr = ((spy_end / spy_start) ** (1/years) - 1) * 100
    alpha = cagr - bh_cagr
    
    # Calmar
    calmar = cagr / max_dd if max_dd > 0 else 999
    
    # Sortino
    downside = daily_ret[daily_ret < 0]
    down_vol = downside.std() * np.sqrt(252) if len(downside) > 0 else 1
    sortino = (daily_ret.mean() * 252 - RISK_FREE) / down_vol if down_vol > 0 else 0
    
    return {
        "top_n": top_n, "vix_thresh": vix_thresh,
        "sl_pct": sl_pct, "rebal_freq": rebal_freq,
        "cagr": round(cagr, 2), "alpha": round(alpha, 2),
        "max_dd": round(max_dd, 2), "sharpe": round(sharpe, 2),
        "sortino": round(sortino, 2), "calmar": round(calmar, 2),
        "win_rate": round(win_rate, 1), "total_trades": total_trades,
        "final_equity": round(final, 0),
        "circuit_breaker": circuit_breaker_acts,
        "stopped_out": stopped_out,
        "avg_win": round(avg_win, 1), "avg_loss": round(avg_loss, 1),
    }


def main():
    print("=" * 70)
    print("  V3 PARAMETER OPTIMIZER")
    print("  Grid search for optimal momentum strategy parameters")
    print("=" * 70)
    
    print("\n  Loading data...", flush=True)
    tickers = get_sp500_tickers()
    if not tickers:
        print("  ❌ Could not get tickers")
        return
    print(f"  {len(tickers)} tickers", flush=True)
    
    print("  Downloading prices (~30s)...", flush=True)
    raw = yf.download(tickers + ["SPY"], start="2019-01-01", 
                      end=datetime.now().strftime("%Y-%m-%d"),
                      interval="1d", progress=False, auto_adjust=True,
                      group_by="ticker", threads=True)
    
    price_data = pd.DataFrame()
    for t in tickers + ["SPY"]:
        try:
            if isinstance(raw.columns, pd.MultiIndex):
                if t in raw.columns.get_level_values(0):
                    price_data[t] = raw[t]["Close"]
        except:
            continue
    price_data.dropna(how="all", inplace=True)
    
    print("  Downloading VIX...", flush=True)
    vix_raw = yf.download("^VIX", start="2019-01-01", 
                          end=datetime.now().strftime("%Y-%m-%d"),
                          interval="1d", progress=False, auto_adjust=True)
    if isinstance(vix_raw.columns, pd.MultiIndex):
        vix_raw.columns = vix_raw.columns.get_level_values(0)
    vix_prices = vix_raw["Close"].reindex(price_data.index, method="ffill")
    spy_prices = price_data["SPY"]
    
    # Generate grid
    combos = list(product(GRID["top_n"], GRID["vix_thresh"], 
                          GRID["sl_pct"], GRID["rebal_freq"]))
    print(f"\n  Running {len(combos)} parameter combinations...\n", flush=True)
    
    results = []
    for idx, (tn, vt, sl, rf) in enumerate(combos):
        if (idx + 1) % 20 == 0:
            print(f"  [{idx+1}/{len(combos)}] ...", flush=True)
        try:
            r = run_backtest(price_data, spy_prices, vix_prices, tickers, tn, vt, sl, rf)
            results.append(r)
        except Exception as e:
            pass
    
    print(f"\n  Completed {len(results)} backtests.\n")
    
    # Sort by composite score: alpha + sharpe - (max_dd penalty)
    for r in results:
        # Composite: prioritize alpha, sharpe, and low drawdown
        dd_penalty = max(0, r["max_dd"] - 20) * 2  # Penalize DD > 20%
        r["composite"] = r["alpha"] + r["sharpe"] * 5 - dd_penalty + r["sortino"]
    
    results.sort(key=lambda x: x["composite"], reverse=True)
    
    # Show top 15
    print("  ══ TOP 15 PARAMETER COMBINATIONS ═════════════════════════════════")
    print(f"  {'#':<3} {'TopN':>4} {'VIX':>4} {'SL%':>5} {'Reb':>4} │ {'CAGR':>7} {'Alpha':>7} {'MaxDD':>7} {'Sharpe':>7} {'Sortino':>8} {'WinR%':>6} {'$Final':>10}")
    print(f"  {'-'*3} {'-'*4} {'-'*4} {'-'*5} {'-'*4} ┼ {'-'*7} {'-'*7} {'-'*7} {'-'*7} {'-'*8} {'-'*6} {'-'*10}")
    
    for i, r in enumerate(results[:15], 1):
        alpha_icon = "✅" if r["alpha"] > 0 else "  "
        dd_icon = "✅" if r["max_dd"] <= 20 else "  "
        sh_icon = "✅" if r["sharpe"] >= 1.0 else "  "
        print(f"  {i:<3} {r['top_n']:>4} {r['vix_thresh']:>4} {r['sl_pct']*100:>4.0f}% {r['rebal_freq']:>4} │ "
              f"{r['cagr']:>+6.1f}% {r['alpha']:>+6.1f}%{alpha_icon}{r['max_dd']:>6.1f}%{dd_icon}"
              f"{r['sharpe']:>6.2f}{sh_icon}{r['sortino']:>7.2f} {r['win_rate']:>5.1f}% ${r['final_equity']:>9,.0f}")
    
    # Show best that meets ALL targets
    print(f"\n  ══ BEST THAT MEETS ALL TARGETS (CAGR>12, Alpha>0, DD<20, Sharpe>1) ═══")
    ideal = [r for r in results if r["cagr"] > 12 and r["alpha"] > 0 
             and r["max_dd"] <= 20 and r["sharpe"] >= 1.0]
    
    if ideal:
        ideal.sort(key=lambda x: x["alpha"], reverse=True)
        for i, r in enumerate(ideal[:5], 1):
            print(f"  {i}. TopN={r['top_n']} VIX={r['vix_thresh']} SL={r['sl_pct']*100:.0f}% Reb={r['rebal_freq']}d")
            print(f"     CAGR: {r['cagr']:+.1f}% | Alpha: {r['alpha']:+.1f}% | MaxDD: {r['max_dd']:.1f}% | Sharpe: {r['sharpe']:.2f} | Sortino: {r['sortino']:.2f}")
            print(f"     WinRate: {r['win_rate']:.1f}% | Trades: {r['total_trades']} | Final: ${r['final_equity']:,.0f}")
            print()
    else:
        print("  No combination meets ALL four targets simultaneously.")
        print("  Showing best near-misses:\n")
        near = [r for r in results if r["alpha"] > 0 and r["max_dd"] <= 22]
        near.sort(key=lambda x: x["composite"], reverse=True)
        for i, r in enumerate(near[:5], 1):
            print(f"  {i}. TopN={r['top_n']} VIX={r['vix_thresh']} SL={r['sl_pct']*100:.0f}% Reb={r['rebal_freq']}d")
            print(f"     CAGR: {r['cagr']:+.1f}% | Alpha: {r['alpha']:+.1f}% | MaxDD: {r['max_dd']:.1f}% | Sharpe: {r['sharpe']:.2f} | Sortino: {r['sortino']:.2f}")
            print(f"     WinRate: {r['win_rate']:.1f}% | Trades: {r['total_trades']} | Final: ${r['final_equity']:,.0f}")
            print()
    
    # Save all results
    with open("/tmp/v3_optimizer_results.json", "w") as f:
        json.dump(results, f, indent=2, default=str)
    print(f"  Full results saved to /tmp/v3_optimizer_results.json")
    print("=" * 70)


if __name__ == "__main__":
    main()
