122 lines
3.4 KiB
Go
122 lines
3.4 KiB
Go
package backtest
|
|
|
|
import (
|
|
"alpaca-bot/internal/model"
|
|
"alpaca-bot/internal/strategy"
|
|
"time"
|
|
|
|
talib "github.com/markcheno/go-talib"
|
|
)
|
|
|
|
type Engine struct {
|
|
Strategy strategy.Strategy
|
|
Cash float64
|
|
Position int
|
|
Trades []model.Trade
|
|
Symbol string
|
|
RiskPercent float64 // fraction of cash to risk per trade
|
|
ATRPeriod int
|
|
ATRMultiplier float64
|
|
TakeProfitMultiplier float64 // e.g., 2 for 2:1 reward/risk
|
|
highs, lows, closes []float64
|
|
currentStop float64
|
|
currentTakeProfit float64
|
|
entryPrice float64
|
|
}
|
|
|
|
func NewEngine(symbol string, strat strategy.Strategy, cash float64, risk float64, atrPeriod int, atrMult float64, tpMult float64) *Engine {
|
|
return &Engine{
|
|
Symbol: symbol,
|
|
Strategy: strat,
|
|
Cash: cash,
|
|
RiskPercent: risk,
|
|
ATRPeriod: atrPeriod,
|
|
ATRMultiplier: atrMult,
|
|
TakeProfitMultiplier: tpMult,
|
|
}
|
|
}
|
|
|
|
// Run executes the backtest with ATR stops and take profit
|
|
func (e *Engine) Run(bars []model.Bar) {
|
|
for _, bar := range bars {
|
|
// collect prices for ATR
|
|
e.highs = append(e.highs, bar.High)
|
|
e.lows = append(e.lows, bar.Low)
|
|
e.closes = append(e.closes, bar.Close)
|
|
|
|
// skip until enough bars for ATR
|
|
if len(e.closes) <= e.ATRPeriod {
|
|
continue
|
|
}
|
|
|
|
// calculate ATR safely
|
|
highsSlice := e.highs[len(e.highs)-e.ATRPeriod-1:]
|
|
lowsSlice := e.lows[len(e.lows)-e.ATRPeriod-1:]
|
|
closesSlice := e.closes[len(e.closes)-e.ATRPeriod-1:]
|
|
|
|
atrValues := talib.Atr(highsSlice, lowsSlice, closesSlice, e.ATRPeriod)
|
|
if len(atrValues) == 0 {
|
|
continue
|
|
}
|
|
atr := atrValues[len(atrValues)-1]
|
|
if atr <= 0 {
|
|
continue
|
|
}
|
|
|
|
stopDistance := atr * e.ATRMultiplier
|
|
|
|
// check existing position for stop loss / take profit
|
|
if e.Position > 0 {
|
|
// stop loss hit
|
|
if bar.Low <= e.currentStop {
|
|
e.Cash += float64(e.Position) * e.currentStop
|
|
e.record("SELL_STOP", e.currentStop, bar.Time, e.Position, e.currentStop, 0)
|
|
e.Position = 0
|
|
} else if bar.High >= e.currentTakeProfit {
|
|
e.Cash += float64(e.Position) * e.currentTakeProfit
|
|
e.record("SELL_TP", e.currentTakeProfit, bar.Time, e.Position, e.currentStop, e.currentTakeProfit)
|
|
e.Position = 0
|
|
}
|
|
}
|
|
|
|
// get new signal
|
|
signal := e.Strategy.OnBar(bar)
|
|
switch signal {
|
|
case model.Buy:
|
|
if e.Position == 0 { // only enter new position if flat
|
|
riskCash := e.Cash * e.RiskPercent
|
|
qty := int(riskCash / stopDistance)
|
|
if qty <= 0 || e.Cash < float64(qty)*bar.Close {
|
|
continue
|
|
}
|
|
e.Position = qty
|
|
e.Cash -= float64(qty) * bar.Close
|
|
e.entryPrice = bar.Close
|
|
e.currentStop = bar.Close - stopDistance
|
|
e.currentTakeProfit = bar.Close + stopDistance*e.TakeProfitMultiplier
|
|
e.record("BUY", bar.Close, bar.Time, qty, e.currentStop, e.currentTakeProfit)
|
|
}
|
|
case model.Sell:
|
|
if e.Position > 0 { // manual exit signal
|
|
e.Cash += float64(e.Position) * bar.Close
|
|
e.record("SELL_SIGNAL", bar.Close, bar.Time, e.Position, e.currentStop, e.currentTakeProfit)
|
|
e.Position = 0
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// record logs each trade
|
|
func (e *Engine) record(side string, price float64, t time.Time, qty int, stopPrice float64, takeProfit float64) {
|
|
e.Trades = append(e.Trades, model.Trade{
|
|
Time: t,
|
|
Symbol: e.Symbol,
|
|
Side: side,
|
|
Price: price,
|
|
Qty: qty,
|
|
Cash: e.Cash,
|
|
StopPrice: stopPrice,
|
|
TakeProfit: takeProfit,
|
|
})
|
|
}
|