1、创建agent
agent = DRLAgent(env = env_train)
SAC_PARAMS = {
"batch_size": 128,
"buffer_size": 1000000,
"learning_rate": 0.0001,
"learning_starts": 100,
"ent_coef": "auto_0.1",
"device": "cpu",
}
model_sac = agent.get_model("sac",model_kwargs = SAC_PARAMS)
if if_using_sac:
# set up logger
tmp_path = RESULTS_DIR + '/sac'
new_logger_sac = configure(tmp_path, ["stdout", "csv", "tensorboard"])
# Set new logger
model_sac.set_logger(new_logger_sac)
agent 来自 DRLAgent类,SAC_PARAMS为参数,model_sac 从DRLAgent get_model 函数获得
2、DRLAgent 类从stable_baselines3 模块 加载 相应的SAC模型
from stable_baselines3 import SAC
DRLAgent(env) 需要把环境变量env传入进去
3、get_model函数
def get_model(
self,
model_name,
policy="MlpPolicy",
policy_kwargs=None,
model_kwargs=None,
verbose=1,
seed=None,
tensorboard_log=None,
):
if model_name not in MODELS:
raise NotImplementedError("NotImplementedError")
if model_kwargs is None:
model_kwargs = MODEL_KWARGS[model_name]
if "action_noise" in model_kwargs:
n_actions = self.env.action_space.shape[-1]
model_kwargs["action_noise"] = NOISE[model_kwargs["action_noise"]](
mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions)
)
print(model_kwargs)
return MODELS[model_name](
policy=policy,
env=self.env,
tensorboard_log=tensorboard_log,
verbose=verbose,
policy_kwargs=policy_kwargs,
seed=seed,
**model_kwargs,
)
env 被传入到 sac 模型中
4、模型的训练
trained_sac = agent.train_model(model=model_sac,
tb_log_name='sac',
total_timesteps=30000) if if_using_sac else None
def train_model(self, model, tb_log_name, total_timesteps=5000):
model = model.learn(
total_timesteps=total_timesteps,
tb_log_name=tb_log_name,
callback=TensorboardCallback(),
)
return model
class SAC(OffPolicyAlgorithm):
def learn(
self: SACSelf,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 4,
tb_log_name: str = "SAC",
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> SACSelf:
return super().learn(
total_timesteps=total_timesteps,
callback=callback,
log_interval=log_interval,
tb_log_name=tb_log_name,
reset_num_timesteps=reset_num_timesteps,
progress_bar=progress_bar,
)
5、环境env
from finrl.meta.env_stock_trading.env_stocktrading import StockTradingEnv
env_kwargs = {
"hmax": 100,
"initial_amount": 1000000,
"buy_cost_pct": 0.001,
"sell_cost_pct": 0.001,
"state_space": state_space, #511 = 1 + 2 * 30 + 15 * 30
"stock_dim": stock_dimension, # 30
"tech_indicator_list": ratio_list, # 15列表
"action_space": stock_dimension, # 30
"reward_scaling": 1e-4
}
e_train_gym = StockTradingEnv(df = train_data, **env_kwargs)
env_train, _ = e_train_gym.get_sb_env()
# 这个模型通过扫描30只股票
# 前期处理数据 形成 15个技术指标类的因子
# state space = 1 + 2 * 30 + 15 * 30 =511 这是状态空间
# action space = 30
# 实现 训练model
# 通过action = anget.predict(state) 作出决策actions
# 通过api 同时交易 30只股票
class StockTradingEnv(gym.Env):
class StockTradingEnv(gym.Env):
def __init__()
''' 参数 '''
# initalize state 初始化state
self.state = self._initiate_state()
# initialize reward 初始化 reward
self.reward = 0
self.turbulence = 0
self.cost = 0
self.trades = 0
self.episode = 0
# memorize all the total balance change 一些记忆
self.asset_memory = [self.initial_amount]
self.rewards_memory = []
self.actions_memory = []
self.date_memory = [self._get_date()]
# self.reset() 重置
self._seed()
def _sell_stock(self, index, action)
def _buy_stock(self, index, action)
# 重点
def step(self, actions):
# if not terminal 先计算actions的值
actions = actions * self.hmax # actions initially is scaled between 0 to 1
actions = actions.astype(int)
# 更新 状态 state
self.state = self._update_state()
# 分别计算 begin_total_asset end_total_asset
# 存储一些信息
self.asset_memory.append(end_total_asset)
self.date_memory.append(self._get_date()))
self.reward = end_total_asset - begin_total_asset # 计算一下 当下step 的reward
self.rewards_memory.append(self.reward)
self.reward = self.reward * self.reward_scaling
return self.state, self.reward, self.terminal, {}
# 1、重点关注 输入 学习的state 是什么模样 的
# 2、sac模型的介绍
pandas sort_values
pandas DataFrame.sort_values(by,
axis=0,
ascending=True,
inplace=False,
kind='quicksort',
na_position='last', # last,first;默认是last
ignore_index=False,
key=None)
by:表示根据什么字段或者索引进行排序,可以是一个或多个
axis:排序是在横轴还是纵轴,默认是纵轴axis=0
ascending:排序结果是升序还是降序,默认是升序
inplace:表示排序的结果是直接在原数据上的就地修改还是生成新的DatFrame
kind:表示使用排序的算法,快排quicksort,,归并mergesort, 堆排序heapsort,稳定排序stable ,默认是 :快排quicksort
na_position:缺失值的位置处理,默认是最后,另一个选择是首位
ignore_index:新生成的数据帧的索引是否重排,默认False(采用原数据的索引)
key:排序之前使用的函数
# 尝试一下 标普期货30分钟数据作为训练集 怎么样