from typing import List, Dict from datetime import datetime from vnpy.app.portfolio_strategy import StrategyTemplate, StrategyEngine from vnpy.trader.utility import BarGenerator, ArrayManager from vnpy.trader.object import TickData, BarData class DoubleMaStrategy(StrategyTemplate): """""" author = "KeKe" fast_window = 8 slow_window = 24 price_add = 5 today = "" daily_pos = {} daily_close = {} parameters = [ "fast_window", "slow_window", ] def __init__( self, strategy_engine: StrategyEngine, strategy_name: str, vt_symbols: List[str], setting: dict ): """""" super().__init__(strategy_engine, strategy_name, vt_symbols, setting) self.bgs: Dict[str, BarGenerator] = {} self.ams: Dict[str, ArrayManager] = {} self.last_tick_time: datetime = None # Obtain contract info for vt_symbol in self.vt_symbols: def on_bar(bar: BarData): """""" pass self.bgs[vt_symbol] = BarGenerator(on_bar) self.ams[vt_symbol] = ArrayManager() def on_init(self): """ Callback when strategy is inited. """ self.write_log("策略初始化") self.load_bars(100) def on_start(self): """ Callback when strategy is started. """ self.write_log("策略启动") def on_stop(self): """ Callback when strategy is stopped. """ self.write_log("策略停止") def on_tick(self, tick: TickData): """ Callback of new tick data update. """ if ( self.last_tick_time and self.last_tick_time.minute != tick.datetime.minute ): bars = {} for vt_symbol, bg in self.bgs.items(): bars[vt_symbol] = bg.generate() self.on_bars(bars) bg: BarGenerator = self.bgs[tick.vt_symbol] bg.update_tick(tick) self.last_tick_time = tick.datetime def on_bars(self, bars: Dict[str, BarData]): """""" # 更新K线计算 for vt_symbol, bar in bars.items(): dt_str = bar.datetime.strftime("%Y-%m-%d") self.today = dt_str am: ArrayManager = self.ams[vt_symbol] am.update_bar(bar) if not am.inited: continue fast_ma = am.sma(self.fast_window, array=True) fast_ma0 = fast_ma[-1] fast_ma1 = fast_ma[-2] slow_ma = am.sma(self.slow_window, array=True) slow_ma0 = slow_ma[-1] slow_ma1 = slow_ma[-2] cross_over = fast_ma0 > slow_ma0 and fast_ma1 < slow_ma1 cross_below = fast_ma0 < slow_ma0 and fast_ma1 > slow_ma1 pos = self.get_pos(vt_symbol) if cross_over: if pos == 0: self.buy(vt_symbol, bar.close_price, 1) elif pos < 0: self.cover(vt_symbol, bar.close_price, 1) self.buy(vt_symbol, bar.close_price, 1) elif cross_below: if pos == 0: self.short(vt_symbol, bar.close_price, 1) elif pos > 0: self.sell(vt_symbol, bar.close_price, 1) self.short(vt_symbol, bar.close_price, 1) self.record_current_pos(bars) self.put_event() def record_current_pos(self, bars): current_pos = {} current_close = {} for bar in bars.values(): pos = self.get_pos(bar.vt_symbol) current_pos[bar.vt_symbol] = pos current_close[bar.vt_symbol] = bar.close_price today = self.today self.daily_pos[today] = str(current_pos) self.daily_close[today] = current_close