from abc import ABC from typing import Union, Type, TYPE_CHECKING from collections import defaultdict from vnpy.trader.constant import Interval, Direction, Offset from vnpy.trader.object import BarData, TickData, OrderData, TradeData, ContractData from vnpy.trader.utility import virtual from .object import PortfolioData if TYPE_CHECKING: from .engine import StrategyEngine class StrategyTemplate(ABC): """期权策略模板""" author: str = "" def __init__(self, engine: "StrategyEngine", name: str) -> None: """构造函数""" self.engine: "StrategyEngine" = engine self.name: str = name self.vt_symbols: set[str] = set() self.inited: bool = False self.trading: bool = False # 初始化变量和参数列表 if not hasattr(self, "parameters"): self.parameters: list[str] = [] if not hasattr(self, "variables"): self.variables: list[str] = [] self.variables = ["inited", "trading", "pos_data", "target_data"] + self.variables # 委托缓存数据 self.orders: dict[str, OrderData] = {} self.active_orderids: set[str] = set() # 持仓和目标数据 self.pos_data: dict[str, int] = defaultdict(int) # 实际持仓 self.target_data: dict[str, int] = defaultdict(int) # 目标持仓 # 期权组合 self.portfolios: dict[str, PortfolioData] = {} def update_setting(self, setting: dict) -> None: """更新策略参数""" for name in self.parameters: if name in setting: setattr(self, name, setting[name]) def get_parameters(self) -> dict: """查询策略参数""" strategy_parameters: dict = {} for name in self.parameters: strategy_parameters[name] = getattr(self, name) return strategy_parameters def get_variables(self) -> dict: """查询策略变量""" strategy_variables: dict = {} for name in self.variables: strategy_variables[name] = getattr(self, name) return strategy_variables def get_data(self) -> dict: """查询策略状态数据""" strategy_data: dict = { "strategy_name": self.name, "vt_symbols": self.vt_symbols, "class_name": self.__class__.__name__, "author": self.author, "parameters": self.get_parameters(), "variables": self.get_variables(), } return strategy_data @virtual def on_init(self) -> None: """策略初始化""" pass @virtual def on_start(self) -> None: """策略启动""" pass @virtual def on_stop(self) -> None: """策略停止""" pass @virtual def on_tick(self, tick: TickData) -> None: """Tick推送""" pass @virtual def on_bars(self, bars: dict[str, BarData]) -> None: """K线推送""" pass def update_trade(self, trade: TradeData) -> None: """更新成交数据""" if trade.direction == Direction.LONG: self.pos_data[trade.vt_symbol] += trade.volume else: self.pos_data[trade.vt_symbol] -= trade.volume def update_order(self, order: OrderData) -> None: """更新委托数据""" self.orders[order.vt_orderid] = order if not order.is_active() and order.vt_orderid in self.active_orderids: self.active_orderids.remove(order.vt_orderid) def buy(self, vt_symbol: str, price: float, volume: float) -> list[str]: """买入开仓""" return self.send_order(vt_symbol, Direction.LONG, Offset.OPEN, price, volume) def sell(self, vt_symbol: str, price: float, volume: float) -> list[str]: """卖出平仓""" return self.send_order(vt_symbol, Direction.SHORT, Offset.CLOSE, price, volume) def short(self, vt_symbol: str, price: float, volume: float) -> list[str]: """卖出开仓""" return self.send_order(vt_symbol, Direction.SHORT, Offset.OPEN, price, volume) def cover(self, vt_symbol: str, price: float, volume: float) -> list[str]: """买入平仓""" return self.send_order(vt_symbol, Direction.LONG, Offset.CLOSE, price, volume) def send_order( self, vt_symbol: str, direction: Direction, offset: Offset, price: float, volume: float, ) -> list[str]: """委托下单""" if not self.trading: return [] vt_orderids = self.engine.send_order( self, vt_symbol, direction, offset, price, volume ) self.active_orderids.update(vt_orderids) return vt_orderids def cancel_order(self, vt_orderid: str) -> None: """委托撤单""" if not self.trading: return self.engine.cancel_order(self, vt_orderid) def cancel_all(self) -> None: """全撤委托""" for vt_orderid in list(self.active_orderids): self.cancel_order(vt_orderid) def get_pos(self, vt_symbol: str) -> int: """查询当前持仓""" return self.pos_data.get(vt_symbol, 0) def get_target(self, vt_symbol: str) -> int: """查询目标仓位""" return self.target_data[vt_symbol] def set_target(self, vt_symbol: str, target: int) -> None: """设置目标仓位""" self.target_data[vt_symbol] = target def clear_targets(self) -> None: """清空目标仓位""" self.target_data.clear() def execute_trading(self, price_data: dict[str, float], percent_add: float) -> None: """基于目标执行调仓交易""" self.cancel_all() # 只发出当前K线切片有行情的合约的委托 for vt_symbol, price in price_data.items(): # 计算仓差 target: int = self.get_target(vt_symbol) pos: int = self.get_pos(vt_symbol) diff: int = target - pos # 多头 if diff > 0: # 计算多头委托价 order_price: float = price * (1 + percent_add) # 计算买平和买开数量 cover_volume: int = 0 buy_volume: int = 0 if pos < 0: cover_volume = min(diff, abs(pos)) buy_volume = diff - cover_volume else: buy_volume = diff # 发出对应委托 if cover_volume: self.cover(vt_symbol, order_price, cover_volume) if buy_volume: self.buy(vt_symbol, order_price, buy_volume) # 空头 elif diff < 0: # 计算空头委托价 order_price: float = price * (1 - percent_add) # 计算卖平和卖开数量 sell_volume: int = 0 short_volume: int = 0 if pos > 0: sell_volume = min(abs(diff), pos) short_volume = abs(diff) - sell_volume else: short_volume = abs(diff) # 发出对应委托 if sell_volume: self.sell(vt_symbol, order_price, sell_volume) if short_volume: self.short(vt_symbol, order_price, short_volume) def write_log(self, msg: str) -> None: """输出日志""" self.engine.write_log(msg, self) def get_portfolio(self, portfolio_name: str) -> PortfolioData: """获取期权组合""" return self.portfolios[portfolio_name] def load_bars(self, vt_symbol: str, days: int, interval: Interval) -> list[BarData]: """加载历史K线数据""" return self.engine.load_bars(vt_symbol, days, interval) def init_portfolio(self, portfolio_name: str) -> bool: """查询期权合约""" # 避免重复订阅 if portfolio_name in self.portfolios: return True # 向引擎发起查询 contracts: list[ContractData] = self.engine.init_portfolio(self, portfolio_name) # 如果有返回合约,则创建PortfolioData对象 if contracts: portfolio: PortfolioData = PortfolioData(contracts[0]) self.portfolios[portfolio_name] = portfolio for contract in contracts: self.vt_symbols.add(contract.vt_symbol) portfolio.add_contract(contract) return True else: return False def subscribe_options(self, portfolio_name: str) -> bool: """订阅期权行情""" # 避免重复订阅 if portfolio_name in self.portfolios: return True # 向引擎发起订阅 contracts: list[ContractData] = self.engine.subscribe_options(self, portfolio_name) # 如果有返回合约,则创建PortfolioData对象 if contracts: portfolio: PortfolioData = PortfolioData(contracts[0]) self.portfolios[portfolio_name] = portfolio for contract in contracts: self.vt_symbols.add(contract.vt_symbol) portfolio.add_contract(contract) return True else: return False def subscribe_data(self, vt_symbol: str) -> bool: """订阅标的行情""" n: bool = self.engine.subscribe_data(self, vt_symbol) if n: self.vt_symbols.add(vt_symbol) return n def put_event(self) -> None: """推送策略数据更新事件""" if self.inited: self.engine.put_strategy_event(self) def send_email(self, msg: str) -> None: """发送邮件信息""" if self.inited: self.engine.send_email(msg, self) def sync_data(self): """同步策略状态数据到文件""" if self.trading: self.engine.sync_strategy_data(self)