ADD: added first version
This commit is contained in:
121
internal/backtest/engine.go
Normal file
121
internal/backtest/engine.go
Normal file
@@ -0,0 +1,121 @@
|
||||
package backtest
|
||||
|
||||
import (
|
||||
"alpaca-bot/internal/model"
|
||||
"alpaca-bot/internal/strategy"
|
||||
"time"
|
||||
|
||||
talib "github.com/markcheno/go-talib"
|
||||
)
|
||||
|
||||
type Engine struct {
|
||||
Strategy strategy.Strategy
|
||||
Cash float64
|
||||
Position int
|
||||
Trades []model.Trade
|
||||
Symbol string
|
||||
RiskPercent float64 // fraction of cash to risk per trade
|
||||
ATRPeriod int
|
||||
ATRMultiplier float64
|
||||
TakeProfitMultiplier float64 // e.g., 2 for 2:1 reward/risk
|
||||
highs, lows, closes []float64
|
||||
currentStop float64
|
||||
currentTakeProfit float64
|
||||
entryPrice float64
|
||||
}
|
||||
|
||||
func NewEngine(symbol string, strat strategy.Strategy, cash float64, risk float64, atrPeriod int, atrMult float64, tpMult float64) *Engine {
|
||||
return &Engine{
|
||||
Symbol: symbol,
|
||||
Strategy: strat,
|
||||
Cash: cash,
|
||||
RiskPercent: risk,
|
||||
ATRPeriod: atrPeriod,
|
||||
ATRMultiplier: atrMult,
|
||||
TakeProfitMultiplier: tpMult,
|
||||
}
|
||||
}
|
||||
|
||||
// Run executes the backtest with ATR stops and take profit
|
||||
func (e *Engine) Run(bars []model.Bar) {
|
||||
for _, bar := range bars {
|
||||
// collect prices for ATR
|
||||
e.highs = append(e.highs, bar.High)
|
||||
e.lows = append(e.lows, bar.Low)
|
||||
e.closes = append(e.closes, bar.Close)
|
||||
|
||||
// skip until enough bars for ATR
|
||||
if len(e.closes) <= e.ATRPeriod {
|
||||
continue
|
||||
}
|
||||
|
||||
// calculate ATR safely
|
||||
highsSlice := e.highs[len(e.highs)-e.ATRPeriod-1:]
|
||||
lowsSlice := e.lows[len(e.lows)-e.ATRPeriod-1:]
|
||||
closesSlice := e.closes[len(e.closes)-e.ATRPeriod-1:]
|
||||
|
||||
atrValues := talib.Atr(highsSlice, lowsSlice, closesSlice, e.ATRPeriod)
|
||||
if len(atrValues) == 0 {
|
||||
continue
|
||||
}
|
||||
atr := atrValues[len(atrValues)-1]
|
||||
if atr <= 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
stopDistance := atr * e.ATRMultiplier
|
||||
|
||||
// check existing position for stop loss / take profit
|
||||
if e.Position > 0 {
|
||||
// stop loss hit
|
||||
if bar.Low <= e.currentStop {
|
||||
e.Cash += float64(e.Position) * e.currentStop
|
||||
e.record("SELL_STOP", e.currentStop, bar.Time, e.Position, e.currentStop, 0)
|
||||
e.Position = 0
|
||||
} else if bar.High >= e.currentTakeProfit {
|
||||
e.Cash += float64(e.Position) * e.currentTakeProfit
|
||||
e.record("SELL_TP", e.currentTakeProfit, bar.Time, e.Position, e.currentStop, e.currentTakeProfit)
|
||||
e.Position = 0
|
||||
}
|
||||
}
|
||||
|
||||
// get new signal
|
||||
signal := e.Strategy.OnBar(bar)
|
||||
switch signal {
|
||||
case model.Buy:
|
||||
if e.Position == 0 { // only enter new position if flat
|
||||
riskCash := e.Cash * e.RiskPercent
|
||||
qty := int(riskCash / stopDistance)
|
||||
if qty <= 0 || e.Cash < float64(qty)*bar.Close {
|
||||
continue
|
||||
}
|
||||
e.Position = qty
|
||||
e.Cash -= float64(qty) * bar.Close
|
||||
e.entryPrice = bar.Close
|
||||
e.currentStop = bar.Close - stopDistance
|
||||
e.currentTakeProfit = bar.Close + stopDistance*e.TakeProfitMultiplier
|
||||
e.record("BUY", bar.Close, bar.Time, qty, e.currentStop, e.currentTakeProfit)
|
||||
}
|
||||
case model.Sell:
|
||||
if e.Position > 0 { // manual exit signal
|
||||
e.Cash += float64(e.Position) * bar.Close
|
||||
e.record("SELL_SIGNAL", bar.Close, bar.Time, e.Position, e.currentStop, e.currentTakeProfit)
|
||||
e.Position = 0
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// record logs each trade
|
||||
func (e *Engine) record(side string, price float64, t time.Time, qty int, stopPrice float64, takeProfit float64) {
|
||||
e.Trades = append(e.Trades, model.Trade{
|
||||
Time: t,
|
||||
Symbol: e.Symbol,
|
||||
Side: side,
|
||||
Price: price,
|
||||
Qty: qty,
|
||||
Cash: e.Cash,
|
||||
StopPrice: stopPrice,
|
||||
TakeProfit: takeProfit,
|
||||
})
|
||||
}
|
||||
32
internal/backtest/result.go
Normal file
32
internal/backtest/result.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package backtest
|
||||
|
||||
import (
|
||||
"alpaca-bot/internal/model"
|
||||
"fmt"
|
||||
|
||||
"github.com/xuri/excelize/v2"
|
||||
)
|
||||
|
||||
func ExportToExcel(trades []model.Trade, path string) error {
|
||||
f := excelize.NewFile()
|
||||
sheet := "Backtest"
|
||||
f.NewSheet(sheet)
|
||||
|
||||
headers := []string{"Time", "Symbol", "Side", "Price", "Qty", "Cash"}
|
||||
for i, h := range headers {
|
||||
cell, _ := excelize.CoordinatesToCellName(i+1, 1)
|
||||
f.SetCellValue(sheet, cell, h)
|
||||
}
|
||||
|
||||
for i, t := range trades {
|
||||
row := i + 2
|
||||
f.SetCellValue(sheet, fmt.Sprintf("A%d", row), t.Time)
|
||||
f.SetCellValue(sheet, fmt.Sprintf("B%d", row), t.Symbol)
|
||||
f.SetCellValue(sheet, fmt.Sprintf("C%d", row), t.Side)
|
||||
f.SetCellValue(sheet, fmt.Sprintf("D%d", row), t.Price)
|
||||
f.SetCellValue(sheet, fmt.Sprintf("E%d", row), t.Qty)
|
||||
f.SetCellValue(sheet, fmt.Sprintf("F%d", row), t.Cash)
|
||||
}
|
||||
|
||||
return f.SaveAs(path)
|
||||
}
|
||||
Reference in New Issue
Block a user