from typing import List, Dict from datetime import datetime from vnpy.app.portfolio_strategy import StrategyTemplate, StrategyEngine from vnpy.trader.utility import BarGenerator, ArrayManager from vnpy.trader.object import TickData, BarData from vnpy.trader.constant import Interval class DemoStrategy(StrategyTemplate): """""" author = "KeKe" price_add_percent = 0.05 # 超价5%下单 fixed_pos_value = 1000000 # 每个合约做10万 atr_window = 22 atr_ma_window = 10 rsi_window = 5 rsi_entry = 16 trailing_percent = 0.8 fixed_size = 1 price_add = 5 rsi_buy = 0 rsi_sell = 0 signal_ts = {} signal_total = {} last_tick_time: datetime = None last_bar_time: datetime = None trade_day = 0 targets_pos = {} show_pos = {} symbol_cap = {} window_bars = {} today = None parameters = [ "price_add_percent", "fixed_pos_value", "return_period", "holding_period", "shift_period" ] variables = [ "signal_ts", "signal_total", "trade_day", "targets_pos" ] def __init__( self, strategy_engine: StrategyEngine, strategy_name: str, vt_symbols: List[str], setting: dict ): """""" super().__init__(strategy_engine, strategy_name, vt_symbols, setting) self.bgs: Dict[str, BarGenerator] = {} self.ams: Dict[str, ArrayManager] = {} self.rsi_data: Dict[str, float] = {} self.atr_data: Dict[str, float] = {} self.atr_ma: Dict[str, float] = {} self.intra_trade_high: Dict[str, float] = {} self.intra_trade_low: Dict[str, float] = {} self.targets: Dict[str, int] = {} # Obtain contract info for vt_symbol in self.vt_symbols: self.bgs[vt_symbol] = BarGenerator( self.on_bar, 5, self.on_window_bar ) self.ams[vt_symbol] = ArrayManager() def on_init(self): """ Callback when strategy is inited. """ self.write_log("策略初始化") self.load_bars(days=20, interval=Interval.MINUTE) def on_start(self): """ Callback when strategy is started. """ self.write_log("策略启动") def on_stop(self): """ Callback when strategy is stopped. """ self.write_log("策略停止") def on_bar(bar: BarData): """""" pass def on_window_bar(self, bar): dt = bar.datetime.strftime("%Y%d%m, %H:%M") if ( self.last_bar_time and self.last_bar_time.minute != bar.datetime.minute ): self.on_window_bars(self.window_bars) self.window_bars = {} self.window_bars[bar.vt_symbol] = bar # print("@", dt, bar.vt_symbol, bar.close_price) else: # print("@@", dt, bar.vt_symbol, bar.close_price) self.window_bars[bar.vt_symbol] = bar self.last_bar_time = bar.datetime def on_window_bars(self, bars): # test # dts = {} # for vt_symbol, bar in bars.items(): # dt = bar.datetime.strftime("%H:%M") # dts[bar.vt_symbol] = dt # print("**", dts) # return # test self.cancel_all() # 更新K线计算RSI数值 for vt_symbol, bar in bars.items(): am: ArrayManager = self.ams[vt_symbol] am.update_bar(bar) if not am.inited: return atr_array = am.atr(self.atr_window, array=True) self.atr_data[vt_symbol] = atr_array[-1] self.atr_ma[vt_symbol] = atr_array[-self.atr_ma_window:].mean() self.rsi_data[vt_symbol] = am.rsi(self.rsi_window) current_pos = self.get_pos(vt_symbol) if current_pos == 0: self.intra_trade_high[vt_symbol] = bar.high_price self.intra_trade_low[vt_symbol] = bar.low_price if self.atr_data[vt_symbol] > self.atr_ma[vt_symbol]: if self.rsi_data[vt_symbol] > self.rsi_buy: self.targets[vt_symbol] = self.fixed_size elif self.rsi_data[vt_symbol] < self.rsi_sell: self.targets[vt_symbol] = -self.fixed_size else: self.targets[vt_symbol] = 0 elif current_pos > 0: self.intra_trade_high[vt_symbol] = max(self.intra_trade_high[vt_symbol], bar.high_price) self.intra_trade_low[vt_symbol] = bar.low_price long_stop = self.intra_trade_high[vt_symbol] * (1 - self.trailing_percent / 100) if bar.close_price <= long_stop: self.targets[vt_symbol] = 0 elif current_pos < 0: self.intra_trade_low[vt_symbol] = min(self.intra_trade_low[vt_symbol], bar.low_price) self.intra_trade_high[vt_symbol] = bar.high_price short_stop = self.intra_trade_low[vt_symbol] * (1 + self.trailing_percent / 100) if bar.close_price >= short_stop: self.targets[vt_symbol] = 0 for vt_symbol in self.vt_symbols: target_pos = self.targets.get(vt_symbol, None) if not target_pos: continue current_pos = self.get_pos(vt_symbol) pos_diff = target_pos - current_pos volume = abs(pos_diff) bar = bars[vt_symbol] if pos_diff > 0: price = bar.close_price + self.price_add if current_pos < 0: self.cover(vt_symbol, price, volume) else: self.buy(vt_symbol, price, volume) elif pos_diff < 0: price = bar.close_price - self.price_add if current_pos > 0: self.sell(vt_symbol, price, volume) else: self.short(vt_symbol, price, volume) self.put_event() def on_tick(self, tick: TickData): """ Callback of new tick data update. """ if ( self.last_tick_time and self.last_tick_time.minute != tick.datetime.minute ): bars = {} for vt_symbol, bg in self.bgs.items(): bars[vt_symbol] = bg.generate() self.on_bars(bars) bg: BarGenerator = self.bgs[tick.vt_symbol] bg.update_tick(tick) self.last_tick_time = tick.datetime def on_bars(self, bars: Dict[str, BarData]): """""" for vt_symbol, bg in self.bgs.items(): bar = bars[vt_symbol] bg.update_bar(bar)