319 lines
9.9 KiB
Python
319 lines
9.9 KiB
Python
from abc import ABC
|
|
from typing import Union, Type, TYPE_CHECKING
|
|
from collections import defaultdict
|
|
|
|
from vnpy.trader.constant import Interval, Direction, Offset
|
|
from vnpy.trader.object import BarData, TickData, OrderData, TradeData, ContractData
|
|
from vnpy.trader.utility import virtual
|
|
|
|
from .object import PortfolioData
|
|
|
|
if TYPE_CHECKING:
|
|
from .engine import StrategyEngine
|
|
|
|
|
|
class StrategyTemplate(ABC):
|
|
"""期权策略模板"""
|
|
|
|
author: str = ""
|
|
|
|
def __init__(self, engine: "StrategyEngine", name: str) -> None:
|
|
"""构造函数"""
|
|
self.engine: "StrategyEngine" = engine
|
|
self.name: str = name
|
|
|
|
self.vt_symbols: set[str] = set()
|
|
self.inited: bool = False
|
|
self.trading: bool = False
|
|
|
|
# 初始化变量和参数列表
|
|
if not hasattr(self, "parameters"):
|
|
self.parameters: list[str] = []
|
|
|
|
if not hasattr(self, "variables"):
|
|
self.variables: list[str] = []
|
|
|
|
self.variables = ["inited", "trading", "pos_data", "target_data"] + self.variables
|
|
|
|
# 委托缓存数据
|
|
self.orders: dict[str, OrderData] = {}
|
|
self.active_orderids: set[str] = set()
|
|
|
|
# 持仓和目标数据
|
|
self.pos_data: dict[str, int] = defaultdict(int) # 实际持仓
|
|
self.target_data: dict[str, int] = defaultdict(int) # 目标持仓
|
|
|
|
# 期权组合
|
|
self.portfolios: dict[str, PortfolioData] = {}
|
|
|
|
def update_setting(self, setting: dict) -> None:
|
|
"""更新策略参数"""
|
|
for name in self.parameters:
|
|
if name in setting:
|
|
setattr(self, name, setting[name])
|
|
|
|
def get_parameters(self) -> dict:
|
|
"""查询策略参数"""
|
|
strategy_parameters: dict = {}
|
|
for name in self.parameters:
|
|
strategy_parameters[name] = getattr(self, name)
|
|
return strategy_parameters
|
|
|
|
def get_variables(self) -> dict:
|
|
"""查询策略变量"""
|
|
strategy_variables: dict = {}
|
|
for name in self.variables:
|
|
strategy_variables[name] = getattr(self, name)
|
|
return strategy_variables
|
|
|
|
def get_data(self) -> dict:
|
|
"""查询策略状态数据"""
|
|
strategy_data: dict = {
|
|
"strategy_name": self.name,
|
|
"vt_symbols": self.vt_symbols,
|
|
"class_name": self.__class__.__name__,
|
|
"author": self.author,
|
|
"parameters": self.get_parameters(),
|
|
"variables": self.get_variables(),
|
|
}
|
|
return strategy_data
|
|
|
|
@virtual
|
|
def on_init(self) -> None:
|
|
"""策略初始化"""
|
|
pass
|
|
|
|
@virtual
|
|
def on_start(self) -> None:
|
|
"""策略启动"""
|
|
pass
|
|
|
|
@virtual
|
|
def on_stop(self) -> None:
|
|
"""策略停止"""
|
|
pass
|
|
|
|
@virtual
|
|
def on_tick(self, tick: TickData) -> None:
|
|
"""Tick推送"""
|
|
pass
|
|
|
|
@virtual
|
|
def on_bars(self, bars: dict[str, BarData]) -> None:
|
|
"""K线推送"""
|
|
pass
|
|
|
|
def update_trade(self, trade: TradeData) -> None:
|
|
"""更新成交数据"""
|
|
if trade.direction == Direction.LONG:
|
|
self.pos_data[trade.vt_symbol] += trade.volume
|
|
else:
|
|
self.pos_data[trade.vt_symbol] -= trade.volume
|
|
|
|
def update_order(self, order: OrderData) -> None:
|
|
"""更新委托数据"""
|
|
self.orders[order.vt_orderid] = order
|
|
|
|
if not order.is_active() and order.vt_orderid in self.active_orderids:
|
|
self.active_orderids.remove(order.vt_orderid)
|
|
|
|
def buy(self, vt_symbol: str, price: float, volume: float) -> list[str]:
|
|
"""买入开仓"""
|
|
return self.send_order(vt_symbol, Direction.LONG, Offset.OPEN, price, volume)
|
|
|
|
def sell(self, vt_symbol: str, price: float, volume: float) -> list[str]:
|
|
"""卖出平仓"""
|
|
return self.send_order(vt_symbol, Direction.SHORT, Offset.CLOSE, price, volume)
|
|
|
|
def short(self, vt_symbol: str, price: float, volume: float) -> list[str]:
|
|
"""卖出开仓"""
|
|
return self.send_order(vt_symbol, Direction.SHORT, Offset.OPEN, price, volume)
|
|
|
|
def cover(self, vt_symbol: str, price: float, volume: float) -> list[str]:
|
|
"""买入平仓"""
|
|
return self.send_order(vt_symbol, Direction.LONG, Offset.CLOSE, price, volume)
|
|
|
|
def send_order(
|
|
self,
|
|
vt_symbol: str,
|
|
direction: Direction,
|
|
offset: Offset,
|
|
price: float,
|
|
volume: float,
|
|
) -> list[str]:
|
|
"""委托下单"""
|
|
if not self.trading:
|
|
return []
|
|
|
|
vt_orderids = self.engine.send_order(
|
|
self,
|
|
vt_symbol,
|
|
direction,
|
|
offset,
|
|
price,
|
|
volume
|
|
)
|
|
self.active_orderids.update(vt_orderids)
|
|
return vt_orderids
|
|
|
|
def cancel_order(self, vt_orderid: str) -> None:
|
|
"""委托撤单"""
|
|
if not self.trading:
|
|
return
|
|
|
|
self.engine.cancel_order(self, vt_orderid)
|
|
|
|
def cancel_all(self) -> None:
|
|
"""全撤委托"""
|
|
for vt_orderid in list(self.active_orderids):
|
|
self.cancel_order(vt_orderid)
|
|
|
|
def get_pos(self, vt_symbol: str) -> int:
|
|
"""查询当前持仓"""
|
|
return self.pos_data.get(vt_symbol, 0)
|
|
|
|
def get_target(self, vt_symbol: str) -> int:
|
|
"""查询目标仓位"""
|
|
return self.target_data[vt_symbol]
|
|
|
|
def set_target(self, vt_symbol: str, target: int) -> None:
|
|
"""设置目标仓位"""
|
|
self.target_data[vt_symbol] = target
|
|
|
|
def clear_targets(self) -> None:
|
|
"""清空目标仓位"""
|
|
self.target_data.clear()
|
|
|
|
def execute_trading(self, price_data: dict[str, float], percent_add: float) -> None:
|
|
"""基于目标执行调仓交易"""
|
|
self.cancel_all()
|
|
|
|
# 只发出当前K线切片有行情的合约的委托
|
|
for vt_symbol, price in price_data.items():
|
|
# 计算仓差
|
|
target: int = self.get_target(vt_symbol)
|
|
pos: int = self.get_pos(vt_symbol)
|
|
diff: int = target - pos
|
|
|
|
# 多头
|
|
if diff > 0:
|
|
# 计算多头委托价
|
|
order_price: float = price * (1 + percent_add)
|
|
|
|
# 计算买平和买开数量
|
|
cover_volume: int = 0
|
|
buy_volume: int = 0
|
|
|
|
if pos < 0:
|
|
cover_volume = min(diff, abs(pos))
|
|
buy_volume = diff - cover_volume
|
|
else:
|
|
buy_volume = diff
|
|
|
|
# 发出对应委托
|
|
if cover_volume:
|
|
self.cover(vt_symbol, order_price, cover_volume)
|
|
|
|
if buy_volume:
|
|
self.buy(vt_symbol, order_price, buy_volume)
|
|
# 空头
|
|
elif diff < 0:
|
|
# 计算空头委托价
|
|
order_price: float = price * (1 - percent_add)
|
|
|
|
# 计算卖平和卖开数量
|
|
sell_volume: int = 0
|
|
short_volume: int = 0
|
|
|
|
if pos > 0:
|
|
sell_volume = min(abs(diff), pos)
|
|
short_volume = abs(diff) - sell_volume
|
|
else:
|
|
short_volume = abs(diff)
|
|
|
|
# 发出对应委托
|
|
if sell_volume:
|
|
self.sell(vt_symbol, order_price, sell_volume)
|
|
|
|
if short_volume:
|
|
self.short(vt_symbol, order_price, short_volume)
|
|
|
|
def write_log(self, msg: str) -> None:
|
|
"""输出日志"""
|
|
self.engine.write_log(msg, self)
|
|
|
|
def get_portfolio(self, portfolio_name: str) -> PortfolioData:
|
|
"""获取期权组合"""
|
|
return self.portfolios[portfolio_name]
|
|
|
|
def load_bars(self, vt_symbol: str, days: int, interval: Interval) -> list[BarData]:
|
|
"""加载历史K线数据"""
|
|
return self.engine.load_bars(vt_symbol, days, interval)
|
|
|
|
def init_portfolio(self, portfolio_name: str) -> bool:
|
|
"""查询期权合约"""
|
|
# 避免重复订阅
|
|
if portfolio_name in self.portfolios:
|
|
return True
|
|
|
|
# 向引擎发起查询
|
|
contracts: list[ContractData] = self.engine.init_portfolio(self, portfolio_name)
|
|
|
|
# 如果有返回合约,则创建PortfolioData对象
|
|
if contracts:
|
|
portfolio: PortfolioData = PortfolioData(contracts[0])
|
|
self.portfolios[portfolio_name] = portfolio
|
|
|
|
for contract in contracts:
|
|
self.vt_symbols.add(contract.vt_symbol)
|
|
portfolio.add_contract(contract)
|
|
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
def subscribe_options(self, portfolio_name: str) -> bool:
|
|
"""订阅期权行情"""
|
|
# 避免重复订阅
|
|
if portfolio_name in self.portfolios:
|
|
return True
|
|
|
|
# 向引擎发起订阅
|
|
contracts: list[ContractData] = self.engine.subscribe_options(self, portfolio_name)
|
|
|
|
# 如果有返回合约,则创建PortfolioData对象
|
|
if contracts:
|
|
portfolio: PortfolioData = PortfolioData(contracts[0])
|
|
self.portfolios[portfolio_name] = portfolio
|
|
|
|
for contract in contracts:
|
|
self.vt_symbols.add(contract.vt_symbol)
|
|
portfolio.add_contract(contract)
|
|
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
def subscribe_data(self, vt_symbol: str) -> bool:
|
|
"""订阅标的行情"""
|
|
n: bool = self.engine.subscribe_data(self, vt_symbol)
|
|
if n:
|
|
self.vt_symbols.add(vt_symbol)
|
|
|
|
return n
|
|
|
|
def put_event(self) -> None:
|
|
"""推送策略数据更新事件"""
|
|
if self.inited:
|
|
self.engine.put_strategy_event(self)
|
|
|
|
def send_email(self, msg: str) -> None:
|
|
"""发送邮件信息"""
|
|
if self.inited:
|
|
self.engine.send_email(msg, self)
|
|
|
|
def sync_data(self):
|
|
"""同步策略状态数据到文件"""
|
|
if self.trading:
|
|
self.engine.sync_strategy_data(self)
|