增加交易策略、交易指标、量化库代码等文件夹
This commit is contained in:
318
5.课程代码/2.Option_spread_strategy/使用文档/20/template.py
Normal file
318
5.课程代码/2.Option_spread_strategy/使用文档/20/template.py
Normal file
@@ -0,0 +1,318 @@
|
||||
from abc import ABC
|
||||
from typing import Union, Type, TYPE_CHECKING
|
||||
from collections import defaultdict
|
||||
|
||||
from vnpy.trader.constant import Interval, Direction, Offset
|
||||
from vnpy.trader.object import BarData, TickData, OrderData, TradeData, ContractData
|
||||
from vnpy.trader.utility import virtual
|
||||
|
||||
from .object import PortfolioData
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .engine import StrategyEngine
|
||||
|
||||
|
||||
class StrategyTemplate(ABC):
|
||||
"""期权策略模板"""
|
||||
|
||||
author: str = ""
|
||||
|
||||
def __init__(self, engine: "StrategyEngine", name: str) -> None:
|
||||
"""构造函数"""
|
||||
self.engine: "StrategyEngine" = engine
|
||||
self.name: str = name
|
||||
|
||||
self.vt_symbols: set[str] = set()
|
||||
self.inited: bool = False
|
||||
self.trading: bool = False
|
||||
|
||||
# 初始化变量和参数列表
|
||||
if not hasattr(self, "parameters"):
|
||||
self.parameters: list[str] = []
|
||||
|
||||
if not hasattr(self, "variables"):
|
||||
self.variables: list[str] = []
|
||||
|
||||
self.variables = ["inited", "trading", "pos_data", "target_data"] + self.variables
|
||||
|
||||
# 委托缓存数据
|
||||
self.orders: dict[str, OrderData] = {}
|
||||
self.active_orderids: set[str] = set()
|
||||
|
||||
# 持仓和目标数据
|
||||
self.pos_data: dict[str, int] = defaultdict(int) # 实际持仓
|
||||
self.target_data: dict[str, int] = defaultdict(int) # 目标持仓
|
||||
|
||||
# 期权组合
|
||||
self.portfolios: dict[str, PortfolioData] = {}
|
||||
|
||||
def update_setting(self, setting: dict) -> None:
|
||||
"""更新策略参数"""
|
||||
for name in self.parameters:
|
||||
if name in setting:
|
||||
setattr(self, name, setting[name])
|
||||
|
||||
def get_parameters(self) -> dict:
|
||||
"""查询策略参数"""
|
||||
strategy_parameters: dict = {}
|
||||
for name in self.parameters:
|
||||
strategy_parameters[name] = getattr(self, name)
|
||||
return strategy_parameters
|
||||
|
||||
def get_variables(self) -> dict:
|
||||
"""查询策略变量"""
|
||||
strategy_variables: dict = {}
|
||||
for name in self.variables:
|
||||
strategy_variables[name] = getattr(self, name)
|
||||
return strategy_variables
|
||||
|
||||
def get_data(self) -> dict:
|
||||
"""查询策略状态数据"""
|
||||
strategy_data: dict = {
|
||||
"strategy_name": self.name,
|
||||
"vt_symbols": self.vt_symbols,
|
||||
"class_name": self.__class__.__name__,
|
||||
"author": self.author,
|
||||
"parameters": self.get_parameters(),
|
||||
"variables": self.get_variables(),
|
||||
}
|
||||
return strategy_data
|
||||
|
||||
@virtual
|
||||
def on_init(self) -> None:
|
||||
"""策略初始化"""
|
||||
pass
|
||||
|
||||
@virtual
|
||||
def on_start(self) -> None:
|
||||
"""策略启动"""
|
||||
pass
|
||||
|
||||
@virtual
|
||||
def on_stop(self) -> None:
|
||||
"""策略停止"""
|
||||
pass
|
||||
|
||||
@virtual
|
||||
def on_tick(self, tick: TickData) -> None:
|
||||
"""Tick推送"""
|
||||
pass
|
||||
|
||||
@virtual
|
||||
def on_bars(self, bars: dict[str, BarData]) -> None:
|
||||
"""K线推送"""
|
||||
pass
|
||||
|
||||
def update_trade(self, trade: TradeData) -> None:
|
||||
"""更新成交数据"""
|
||||
if trade.direction == Direction.LONG:
|
||||
self.pos_data[trade.vt_symbol] += trade.volume
|
||||
else:
|
||||
self.pos_data[trade.vt_symbol] -= trade.volume
|
||||
|
||||
def update_order(self, order: OrderData) -> None:
|
||||
"""更新委托数据"""
|
||||
self.orders[order.vt_orderid] = order
|
||||
|
||||
if not order.is_active() and order.vt_orderid in self.active_orderids:
|
||||
self.active_orderids.remove(order.vt_orderid)
|
||||
|
||||
def buy(self, vt_symbol: str, price: float, volume: float) -> list[str]:
|
||||
"""买入开仓"""
|
||||
return self.send_order(vt_symbol, Direction.LONG, Offset.OPEN, price, volume)
|
||||
|
||||
def sell(self, vt_symbol: str, price: float, volume: float) -> list[str]:
|
||||
"""卖出平仓"""
|
||||
return self.send_order(vt_symbol, Direction.SHORT, Offset.CLOSE, price, volume)
|
||||
|
||||
def short(self, vt_symbol: str, price: float, volume: float) -> list[str]:
|
||||
"""卖出开仓"""
|
||||
return self.send_order(vt_symbol, Direction.SHORT, Offset.OPEN, price, volume)
|
||||
|
||||
def cover(self, vt_symbol: str, price: float, volume: float) -> list[str]:
|
||||
"""买入平仓"""
|
||||
return self.send_order(vt_symbol, Direction.LONG, Offset.CLOSE, price, volume)
|
||||
|
||||
def send_order(
|
||||
self,
|
||||
vt_symbol: str,
|
||||
direction: Direction,
|
||||
offset: Offset,
|
||||
price: float,
|
||||
volume: float,
|
||||
) -> list[str]:
|
||||
"""委托下单"""
|
||||
if not self.trading:
|
||||
return []
|
||||
|
||||
vt_orderids = self.engine.send_order(
|
||||
self,
|
||||
vt_symbol,
|
||||
direction,
|
||||
offset,
|
||||
price,
|
||||
volume
|
||||
)
|
||||
self.active_orderids.update(vt_orderids)
|
||||
return vt_orderids
|
||||
|
||||
def cancel_order(self, vt_orderid: str) -> None:
|
||||
"""委托撤单"""
|
||||
if not self.trading:
|
||||
return
|
||||
|
||||
self.engine.cancel_order(self, vt_orderid)
|
||||
|
||||
def cancel_all(self) -> None:
|
||||
"""全撤委托"""
|
||||
for vt_orderid in list(self.active_orderids):
|
||||
self.cancel_order(vt_orderid)
|
||||
|
||||
def get_pos(self, vt_symbol: str) -> int:
|
||||
"""查询当前持仓"""
|
||||
return self.pos_data.get(vt_symbol, 0)
|
||||
|
||||
def get_target(self, vt_symbol: str) -> int:
|
||||
"""查询目标仓位"""
|
||||
return self.target_data[vt_symbol]
|
||||
|
||||
def set_target(self, vt_symbol: str, target: int) -> None:
|
||||
"""设置目标仓位"""
|
||||
self.target_data[vt_symbol] = target
|
||||
|
||||
def clear_targets(self) -> None:
|
||||
"""清空目标仓位"""
|
||||
self.target_data.clear()
|
||||
|
||||
def execute_trading(self, price_data: dict[str, float], percent_add: float) -> None:
|
||||
"""基于目标执行调仓交易"""
|
||||
self.cancel_all()
|
||||
|
||||
# 只发出当前K线切片有行情的合约的委托
|
||||
for vt_symbol, price in price_data.items():
|
||||
# 计算仓差
|
||||
target: int = self.get_target(vt_symbol)
|
||||
pos: int = self.get_pos(vt_symbol)
|
||||
diff: int = target - pos
|
||||
|
||||
# 多头
|
||||
if diff > 0:
|
||||
# 计算多头委托价
|
||||
order_price: float = price * (1 + percent_add)
|
||||
|
||||
# 计算买平和买开数量
|
||||
cover_volume: int = 0
|
||||
buy_volume: int = 0
|
||||
|
||||
if pos < 0:
|
||||
cover_volume = min(diff, abs(pos))
|
||||
buy_volume = diff - cover_volume
|
||||
else:
|
||||
buy_volume = diff
|
||||
|
||||
# 发出对应委托
|
||||
if cover_volume:
|
||||
self.cover(vt_symbol, order_price, cover_volume)
|
||||
|
||||
if buy_volume:
|
||||
self.buy(vt_symbol, order_price, buy_volume)
|
||||
# 空头
|
||||
elif diff < 0:
|
||||
# 计算空头委托价
|
||||
order_price: float = price * (1 - percent_add)
|
||||
|
||||
# 计算卖平和卖开数量
|
||||
sell_volume: int = 0
|
||||
short_volume: int = 0
|
||||
|
||||
if pos > 0:
|
||||
sell_volume = min(abs(diff), pos)
|
||||
short_volume = abs(diff) - sell_volume
|
||||
else:
|
||||
short_volume = abs(diff)
|
||||
|
||||
# 发出对应委托
|
||||
if sell_volume:
|
||||
self.sell(vt_symbol, order_price, sell_volume)
|
||||
|
||||
if short_volume:
|
||||
self.short(vt_symbol, order_price, short_volume)
|
||||
|
||||
def write_log(self, msg: str) -> None:
|
||||
"""输出日志"""
|
||||
self.engine.write_log(msg, self)
|
||||
|
||||
def get_portfolio(self, portfolio_name: str) -> PortfolioData:
|
||||
"""获取期权组合"""
|
||||
return self.portfolios[portfolio_name]
|
||||
|
||||
def load_bars(self, vt_symbol: str, days: int, interval: Interval) -> list[BarData]:
|
||||
"""加载历史K线数据"""
|
||||
return self.engine.load_bars(vt_symbol, days, interval)
|
||||
|
||||
def init_portfolio(self, portfolio_name: str) -> bool:
|
||||
"""查询期权合约"""
|
||||
# 避免重复订阅
|
||||
if portfolio_name in self.portfolios:
|
||||
return True
|
||||
|
||||
# 向引擎发起查询
|
||||
contracts: list[ContractData] = self.engine.init_portfolio(self, portfolio_name)
|
||||
|
||||
# 如果有返回合约,则创建PortfolioData对象
|
||||
if contracts:
|
||||
portfolio: PortfolioData = PortfolioData(contracts[0])
|
||||
self.portfolios[portfolio_name] = portfolio
|
||||
|
||||
for contract in contracts:
|
||||
self.vt_symbols.add(contract.vt_symbol)
|
||||
portfolio.add_contract(contract)
|
||||
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def subscribe_options(self, portfolio_name: str) -> bool:
|
||||
"""订阅期权行情"""
|
||||
# 避免重复订阅
|
||||
if portfolio_name in self.portfolios:
|
||||
return True
|
||||
|
||||
# 向引擎发起订阅
|
||||
contracts: list[ContractData] = self.engine.subscribe_options(self, portfolio_name)
|
||||
|
||||
# 如果有返回合约,则创建PortfolioData对象
|
||||
if contracts:
|
||||
portfolio: PortfolioData = PortfolioData(contracts[0])
|
||||
self.portfolios[portfolio_name] = portfolio
|
||||
|
||||
for contract in contracts:
|
||||
self.vt_symbols.add(contract.vt_symbol)
|
||||
portfolio.add_contract(contract)
|
||||
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def subscribe_data(self, vt_symbol: str) -> bool:
|
||||
"""订阅标的行情"""
|
||||
n: bool = self.engine.subscribe_data(self, vt_symbol)
|
||||
if n:
|
||||
self.vt_symbols.add(vt_symbol)
|
||||
|
||||
return n
|
||||
|
||||
def put_event(self) -> None:
|
||||
"""推送策略数据更新事件"""
|
||||
if self.inited:
|
||||
self.engine.put_strategy_event(self)
|
||||
|
||||
def send_email(self, msg: str) -> None:
|
||||
"""发送邮件信息"""
|
||||
if self.inited:
|
||||
self.engine.send_email(msg, self)
|
||||
|
||||
def sync_data(self):
|
||||
"""同步策略状态数据到文件"""
|
||||
if self.trading:
|
||||
self.engine.sync_strategy_data(self)
|
||||
Reference in New Issue
Block a user