package backtest import ( "alpaca-bot/internal/model" "alpaca-bot/internal/strategy" "time" talib "github.com/markcheno/go-talib" ) type Engine struct { Strategy strategy.Strategy Cash float64 Position int Trades []model.Trade Symbol string RiskPercent float64 // fraction of cash to risk per trade ATRPeriod int ATRMultiplier float64 TakeProfitMultiplier float64 // e.g., 2 for 2:1 reward/risk highs, lows, closes []float64 currentStop float64 currentTakeProfit float64 entryPrice float64 } func NewEngine(symbol string, strat strategy.Strategy, cash float64, risk float64, atrPeriod int, atrMult float64, tpMult float64) *Engine { return &Engine{ Symbol: symbol, Strategy: strat, Cash: cash, RiskPercent: risk, ATRPeriod: atrPeriod, ATRMultiplier: atrMult, TakeProfitMultiplier: tpMult, } } // Run executes the backtest with ATR stops and take profit func (e *Engine) Run(bars []model.Bar) { for _, bar := range bars { // collect prices for ATR e.highs = append(e.highs, bar.High) e.lows = append(e.lows, bar.Low) e.closes = append(e.closes, bar.Close) // skip until enough bars for ATR if len(e.closes) <= e.ATRPeriod { continue } // calculate ATR safely highsSlice := e.highs[len(e.highs)-e.ATRPeriod-1:] lowsSlice := e.lows[len(e.lows)-e.ATRPeriod-1:] closesSlice := e.closes[len(e.closes)-e.ATRPeriod-1:] atrValues := talib.Atr(highsSlice, lowsSlice, closesSlice, e.ATRPeriod) if len(atrValues) == 0 { continue } atr := atrValues[len(atrValues)-1] if atr <= 0 { continue } stopDistance := atr * e.ATRMultiplier // check existing position for stop loss / take profit if e.Position > 0 { // stop loss hit if bar.Low <= e.currentStop { e.Cash += float64(e.Position) * e.currentStop e.record("SELL_STOP", e.currentStop, bar.Time, e.Position, e.currentStop, 0) e.Position = 0 } else if bar.High >= e.currentTakeProfit { e.Cash += float64(e.Position) * e.currentTakeProfit e.record("SELL_TP", e.currentTakeProfit, bar.Time, e.Position, e.currentStop, e.currentTakeProfit) e.Position = 0 } } // get new signal signal := e.Strategy.OnBar(bar) switch signal { case model.Buy: if e.Position == 0 { // only enter new position if flat riskCash := e.Cash * e.RiskPercent qty := int(riskCash / stopDistance) if qty <= 0 || e.Cash < float64(qty)*bar.Close { continue } e.Position = qty e.Cash -= float64(qty) * bar.Close e.entryPrice = bar.Close e.currentStop = bar.Close - stopDistance e.currentTakeProfit = bar.Close + stopDistance*e.TakeProfitMultiplier e.record("BUY", bar.Close, bar.Time, qty, e.currentStop, e.currentTakeProfit) } case model.Sell: if e.Position > 0 { // manual exit signal e.Cash += float64(e.Position) * bar.Close e.record("SELL_SIGNAL", bar.Close, bar.Time, e.Position, e.currentStop, e.currentTakeProfit) e.Position = 0 } } } } // record logs each trade func (e *Engine) record(side string, price float64, t time.Time, qty int, stopPrice float64, takeProfit float64) { e.Trades = append(e.Trades, model.Trade{ Time: t, Symbol: e.Symbol, Side: side, Price: price, Qty: qty, Cash: e.Cash, StopPrice: stopPrice, TakeProfit: takeProfit, }) }