225 lines
6.7 KiB
Python
225 lines
6.7 KiB
Python
from typing import List, Dict
|
|
from datetime import datetime
|
|
|
|
from vnpy.app.portfolio_strategy import StrategyTemplate, StrategyEngine
|
|
from vnpy.trader.utility import BarGenerator, ArrayManager
|
|
from vnpy.trader.object import TickData, BarData
|
|
|
|
from vnpy.trader.constant import Interval
|
|
|
|
|
|
class DemoStrategy(StrategyTemplate):
|
|
""""""
|
|
|
|
author = "KeKe"
|
|
|
|
price_add_percent = 0.05 # 超价5%下单
|
|
fixed_pos_value = 1000000 # 每个合约做10万
|
|
atr_window = 22
|
|
atr_ma_window = 10
|
|
rsi_window = 5
|
|
rsi_entry = 16
|
|
trailing_percent = 0.8
|
|
fixed_size = 1
|
|
price_add = 5
|
|
|
|
rsi_buy = 0
|
|
rsi_sell = 0
|
|
|
|
signal_ts = {}
|
|
signal_total = {}
|
|
last_tick_time: datetime = None
|
|
last_bar_time: datetime = None
|
|
trade_day = 0
|
|
targets_pos = {}
|
|
show_pos = {}
|
|
symbol_cap = {}
|
|
window_bars = {}
|
|
today = None
|
|
|
|
parameters = [
|
|
"price_add_percent", "fixed_pos_value",
|
|
"return_period", "holding_period", "shift_period"
|
|
]
|
|
variables = [
|
|
"signal_ts", "signal_total",
|
|
"trade_day", "targets_pos"
|
|
]
|
|
|
|
def __init__(
|
|
self,
|
|
strategy_engine: StrategyEngine,
|
|
strategy_name: str,
|
|
vt_symbols: List[str],
|
|
setting: dict
|
|
):
|
|
""""""
|
|
super().__init__(strategy_engine, strategy_name, vt_symbols, setting)
|
|
|
|
self.bgs: Dict[str, BarGenerator] = {}
|
|
self.ams: Dict[str, ArrayManager] = {}
|
|
|
|
self.rsi_data: Dict[str, float] = {}
|
|
self.atr_data: Dict[str, float] = {}
|
|
self.atr_ma: Dict[str, float] = {}
|
|
self.intra_trade_high: Dict[str, float] = {}
|
|
self.intra_trade_low: Dict[str, float] = {}
|
|
|
|
self.targets: Dict[str, int] = {}
|
|
|
|
# Obtain contract info
|
|
for vt_symbol in self.vt_symbols:
|
|
self.bgs[vt_symbol] = BarGenerator(
|
|
self.on_bar,
|
|
5,
|
|
self.on_window_bar
|
|
)
|
|
self.ams[vt_symbol] = ArrayManager()
|
|
|
|
def on_init(self):
|
|
"""
|
|
Callback when strategy is inited.
|
|
"""
|
|
self.write_log("策略初始化")
|
|
|
|
self.load_bars(days=20, interval=Interval.MINUTE)
|
|
|
|
def on_start(self):
|
|
"""
|
|
Callback when strategy is started.
|
|
"""
|
|
self.write_log("策略启动")
|
|
|
|
def on_stop(self):
|
|
"""
|
|
Callback when strategy is stopped.
|
|
"""
|
|
self.write_log("策略停止")
|
|
|
|
def on_bar(bar: BarData):
|
|
""""""
|
|
pass
|
|
|
|
def on_window_bar(self, bar):
|
|
dt = bar.datetime.strftime("%Y%d%m, %H:%M")
|
|
|
|
if (
|
|
self.last_bar_time
|
|
and self.last_bar_time.minute != bar.datetime.minute
|
|
):
|
|
self.on_window_bars(self.window_bars)
|
|
|
|
self.window_bars = {}
|
|
self.window_bars[bar.vt_symbol] = bar
|
|
# print("@", dt, bar.vt_symbol, bar.close_price)
|
|
else:
|
|
# print("@@", dt, bar.vt_symbol, bar.close_price)
|
|
self.window_bars[bar.vt_symbol] = bar
|
|
|
|
self.last_bar_time = bar.datetime
|
|
|
|
def on_window_bars(self, bars):
|
|
# test
|
|
# dts = {}
|
|
# for vt_symbol, bar in bars.items():
|
|
# dt = bar.datetime.strftime("%H:%M")
|
|
# dts[bar.vt_symbol] = dt
|
|
# print("**", dts)
|
|
|
|
# return
|
|
# test
|
|
self.cancel_all()
|
|
|
|
# 更新K线计算RSI数值
|
|
for vt_symbol, bar in bars.items():
|
|
am: ArrayManager = self.ams[vt_symbol]
|
|
am.update_bar(bar)
|
|
if not am.inited:
|
|
return
|
|
atr_array = am.atr(self.atr_window, array=True)
|
|
self.atr_data[vt_symbol] = atr_array[-1]
|
|
self.atr_ma[vt_symbol] = atr_array[-self.atr_ma_window:].mean()
|
|
self.rsi_data[vt_symbol] = am.rsi(self.rsi_window)
|
|
|
|
current_pos = self.get_pos(vt_symbol)
|
|
if current_pos == 0:
|
|
self.intra_trade_high[vt_symbol] = bar.high_price
|
|
self.intra_trade_low[vt_symbol] = bar.low_price
|
|
|
|
if self.atr_data[vt_symbol] > self.atr_ma[vt_symbol]:
|
|
if self.rsi_data[vt_symbol] > self.rsi_buy:
|
|
self.targets[vt_symbol] = self.fixed_size
|
|
elif self.rsi_data[vt_symbol] < self.rsi_sell:
|
|
self.targets[vt_symbol] = -self.fixed_size
|
|
else:
|
|
self.targets[vt_symbol] = 0
|
|
|
|
elif current_pos > 0:
|
|
self.intra_trade_high[vt_symbol] = max(self.intra_trade_high[vt_symbol], bar.high_price)
|
|
self.intra_trade_low[vt_symbol] = bar.low_price
|
|
|
|
long_stop = self.intra_trade_high[vt_symbol] * (1 - self.trailing_percent / 100)
|
|
|
|
if bar.close_price <= long_stop:
|
|
self.targets[vt_symbol] = 0
|
|
|
|
elif current_pos < 0:
|
|
self.intra_trade_low[vt_symbol] = min(self.intra_trade_low[vt_symbol], bar.low_price)
|
|
self.intra_trade_high[vt_symbol] = bar.high_price
|
|
|
|
short_stop = self.intra_trade_low[vt_symbol] * (1 + self.trailing_percent / 100)
|
|
|
|
if bar.close_price >= short_stop:
|
|
self.targets[vt_symbol] = 0
|
|
|
|
for vt_symbol in self.vt_symbols:
|
|
target_pos = self.targets.get(vt_symbol, None)
|
|
if not target_pos:
|
|
continue
|
|
current_pos = self.get_pos(vt_symbol)
|
|
|
|
pos_diff = target_pos - current_pos
|
|
volume = abs(pos_diff)
|
|
bar = bars[vt_symbol]
|
|
|
|
if pos_diff > 0:
|
|
price = bar.close_price + self.price_add
|
|
|
|
if current_pos < 0:
|
|
self.cover(vt_symbol, price, volume)
|
|
else:
|
|
self.buy(vt_symbol, price, volume)
|
|
elif pos_diff < 0:
|
|
price = bar.close_price - self.price_add
|
|
|
|
if current_pos > 0:
|
|
self.sell(vt_symbol, price, volume)
|
|
else:
|
|
self.short(vt_symbol, price, volume)
|
|
|
|
self.put_event()
|
|
|
|
def on_tick(self, tick: TickData):
|
|
"""
|
|
Callback of new tick data update.
|
|
"""
|
|
if (
|
|
self.last_tick_time
|
|
and self.last_tick_time.minute != tick.datetime.minute
|
|
):
|
|
bars = {}
|
|
for vt_symbol, bg in self.bgs.items():
|
|
bars[vt_symbol] = bg.generate()
|
|
self.on_bars(bars)
|
|
|
|
bg: BarGenerator = self.bgs[tick.vt_symbol]
|
|
bg.update_tick(tick)
|
|
|
|
self.last_tick_time = tick.datetime
|
|
|
|
def on_bars(self, bars: Dict[str, BarData]):
|
|
""""""
|
|
for vt_symbol, bg in self.bgs.items():
|
|
bar = bars[vt_symbol]
|
|
bg.update_bar(bar)
|