楼主: lyqbnu
1207 0

[交易策略] 【聚宽本地数据JQData】【转载】强化学习入门:基于Q-learning算法的日内择时策略初窥 [推广有奖]

  • 0关注
  • 0粉丝

已卖:284份资源

硕士生

76%

还不是VIP/贵宾

-

威望
0
论坛币
5473 个
通用积分
4.1431
学术水平
11 点
热心指数
12 点
信用等级
11 点
经验
3010 点
帖子
172
精华
0
在线时间
209 小时
注册时间
2008-4-19
最后登录
2023-10-2

楼主
lyqbnu 发表于 2021-5-13 11:39:05 |AI写论文

+2 论坛币
k人 参与回答

经管之家送您一份

应届毕业生专属福利!

求职就业群
赵安豆老师微信:zhaoandou666

经管之家联合CDA

送您一个全额奖学金名额~ !

感谢您参与论坛问题回答

经管之家送您两个论坛币!

+2 论坛币

本篇文章所使用的数据,来源于JQData本地量化金融数据库。下面我将粗略的介绍一个强化学习在证券市场中应用的简单实例。
关于强化学习的算法理论及发展历史,我们不做过多的解释。我们可以很容易在互联网上找到强化学习的理论知识,虽然可能都是一些只言片语,但对于初学者来说基本也就够用了。到目前为止,还没有出现广受业内好评的中文教材,更多的参考资料还是英文版的。例如,Richard S.Sutton和Andrew G.Barto所著的《Reinforcement Learning: An Introduction》。这是比较好的强化学习教材,想要系统的、深入的学习强化学习,这本书值得一看。虽然国内学术界有很多关于强化学习的文章,但它们都看起来比较专业,我不建议初学者一上来就开始啃理论。最好的学习方式是你先入门,弄懂强化学习可以干什么?然后应用一些简单的算法搭建一个你当前正想解决的问题,再不断的去改进你的算法,并在这个过程中深入地学习。对于这篇文章而言,我们假设你已经有了一些强化学习的基础知识了,这里只是给出了一个十分简单的关于量化分析的应用demo而已。
作为量化分析领域的专业人员,我们可能对用强化学习解决玩游戏、找宝藏的Demo不感兴趣。我们更希望能够有一个简单的强化学习demo:当输入K线数据,就可以告诉我什么时候该买,什么时候该卖,即使给出的买卖点并不准确,但我们总算可以看看强化学习模型是怎么给出这个买卖点的。这篇文章就做了这样一个demo,主要是想介绍怎样在构建证券市场构建一个简单的强化学习模型。
强化学习相比于神经网络等常见的机器学习算法而言,强化学习更灵活多变。深度神经网络、卷积神经网络已经算是比较难的算法了,但对于应用人员来说,你只需要搞懂输入输出基本就能用了。但强化学习完全不行,必须要对特征的问题抽象建模,这往往是最难的。怎样从一堆证券数据中抽象各种各样的状态,以及这些状态是怎么转换的,怎么定义动作、回报等等。这些问题直接决定你的模型的质量。
在这篇文章中,我们需要解决的问题是:怎么利用一天内的48根5分钟的K线数据探索在每个5分钟结束的时候,我们是该买入(B),还是该卖出(S),或者是继续观望(W),并用一段时间内所有的5分钟数据训练这个模型,看哪个时间点最适合买入,哪个时间点最适合卖出。我们以时间点作为一个状态标识,则状态(S)转移就比较好定义了:935(早上9点35,这个时间点产生了第一根K线)->940->945->…->1455->1500,状态s->s’可以采取的动作(A)包含B、S、W。我们使用Q-learning算法来解决这个问题。因此,Q表应该是这样的:
关于Reward,我们是这样定义的:未来一段时间的收益率,比如未来3根K的涨跌幅。有了这些之后,我们基本就可以开始着手编写程序了。
首先创建一个环境类:

  1. times = [935, 940, 945, 950, 955, 1000, 1005, 1010, 1015,
  2.          1020, 1025, 1030, 1035, 1040, 1045, 1050, 1055,
  3.          1100, 1105, 1110, 1115, 1120, 1125, 1130, 1305,
  4.          1310, 1315, 1320, 1325, 1330, 1335, 1340, 1345,
  5.          1350, 1355, 1400, 1405, 1410, 1415, 1420, 1425,
  6.          1430, 1435, 1440, 1445, 1450, 1455, 1500]


  7. class Market:
  8.     def __init__(self, data):
  9.         self.action_space = ['B', 'S', 'W']  # 买进、卖出、观望
  10.         self.n_actions = len(self.action_space)
  11.         self.data = data  # 935 940 ... 1500 48根K线的数据
  12.         self.time = 935
  13.         pass

  14.     def step(self, action):
  15.         # 要知道当前在那个状态即时间点,用下一时间点的R(收益)作为
  16.         # 当前采取action的reward
  17.         tix = times.index(self.time)
  18.         nix = tix + 1
  19.         if self.time == 1500:
  20.             reward = 0
  21.             done = True
  22.             s_ = 'terminal'
  23.             # print('time is over.')
  24.         else:
  25.             reward = self.data.R.iloc[nix]
  26.             done = False
  27.             s_ = times[nix]
  28.         if action == 'B':
  29.             pass
  30.         elif action == 'S':
  31.             # 当R为-的时候,选择S,应该是正奖励
  32.             reward = reward * -1
  33.         else:
  34.             # 选择观望,既不亏损也不会盈利,但会损失机会成本
  35.             # 我们当前对观望的决策持客观态度,reward=0,这
  36.             # 可能需要在不同的大盘行情下适时调整
  37.             reward = 0
  38.             pass
  39.         self.time = s_
  40.         return s_, reward, done
  41.         pass

  42.     def reset(self):
  43.         self.time = 935
  44.         return self.time
  45.         pass
复制代码

然后创建Q-learning算法类(或者称这个类为一个Agent):

  1. class QLearning:

  2.     #Agent


  3.     def __init__(self, actions, q_table=None, learning_rate=0.01,
  4.                  discount_factor=0.9, e_greedy=0.1):
  5.         self.actions = actions  # action 列表
  6.         self.lr = learning_rate  # 学习速率
  7.         self.gamma = discount_factor  # 折扣因子
  8.         self.epsilon = e_greedy  # 贪婪度
  9.         # 列是action。
  10.         if q_table is None:
  11.             self.q_table = pd.DataFrame(columns=self.actions, dtype=np.float32)  # Q 表
  12.         else:
  13.             self.q_table = q_table

  14.     # 检测 q_table 中有没有这个 state
  15.     # 如果还没有当前 state, 那我们就插入一组全 0 数据, 作为这个 state 的所有 action 的初始值
  16.     def check_state_exist(self, state):
  17.         # state对应每一行,如果不在Q表中。
  18.         if state not in self.q_table.index:
  19.             # 插入一组全 0 数据,给每个action赋值为0
  20.             self.q_table = self.q_table.append(
  21.                 pd.Series(
  22.                     [0] * len(self.actions),
  23.                     index=self.q_table.columns,
  24.                     name=state,
  25.                 )
  26.             )

  27.     # 根据 state 来选择 action
  28.     def choose_action(self, state):
  29.         self.check_state_exist(state)  # 检测此 state 是否在 q_table 中存在
  30.         # 选行为,用 Epsilon Greedy 贪婪方法
  31.         if np.random.uniform() < self.epsilon:
  32.             # 随机选择 action
  33.             action = np.random.choice(self.actions)
  34.         else:  # 选择 Q 值最高的 action
  35.             state_action = self.q_table.loc[state, :]
  36.             # 同一个 state, 可能会有多个相同的 Q action 值, 所以我们乱序一下
  37.             state_action = state_action.reindex(np.random.permutation(state_action.index))
  38.             # 每一行中取到Q值最大的那个
  39.             action = state_action.idxmax()
  40.         return action

  41.     # 学习。更新 Q 表中的值
  42.     def learn(self, s, a, r, s_):
  43.         # s_是下一个状态
  44.         self.check_state_exist(s_)  # 检测 q_table 中是否存在 s_

  45.         # Q(S,A) <- Q(S,A)+a*[R+v*max(Q(S',a))-Q(S,A)]

  46.         q_predict = self.q_table.loc[s, a]  # 根据 Q 表得到的 估计(predict)值

  47.         # q_target 是现实值
  48.         if s_ != 'terminal':  # 下个 state 不是 终止符
  49.             q_target = r + self.gamma * self.q_table.loc[s_, :].max()
  50.         else:
  51.             q_target = r  # 下个 state 是 终止符

  52.         # 更新 Q 表中 state-action 的值
  53.         self.q_table.loc[s, a] += self.lr * (q_target - q_predict)
复制代码

最后就是创建一个文件来协调上面两个类开始工作:

  1. def update(data, q_table=None):
  2.     env = Market(data)
  3.     RL = QLearning(actions=env.action_space, q_table=q_table)

  4.     for episode in range(100):
  5.         # 初始化 state(状态)
  6.         state = env.reset()

  7.         step_count = 0  # 记录走过的步数

  8.         while True:
  9.             # 更新可视化环境
  10.             # env.render()
  11.             # RL 大脑根据 state 挑选 action
  12.             action = RL.choose_action(str(state))
  13.             # 探索者在环境中实施这个 action, 并得到环境返回的下一个 state, reward 和 done (是否到了1500)
  14.             state_, reward, done = env.step(action)
  15.             step_count += 1  # 增加步数
  16.             # 机器人大脑从这个过渡(transition) (state, action, reward, state_) 中学习
  17.             RL.learn(str(state), action, reward, str(state_))
  18.             # 机器人移动到下一个 state
  19.             state = state_
  20.             # 如果时间到了1500, 这回合就结束了,或者是某个止损条件达到了
  21.             if done:
  22.                 # print("回合 {} 结束. 总步数 : {}\n".format(episode + 1, step_count))
  23.                 break

  24.     # print('模拟交易结束了。')
  25.     # print('\nQ 表:')
  26.     # print(RL.q_table)
  27.     return RL.q_table


  28. def train():
  29.     code = '000001'  # 上证指数
  30.     sd = dt.datetime(2018, 10, 1)
  31.     ed = dt.datetime(2018, 11, 1)
  32.     # 我已经把从jqdata读取到了数据存在了本地,这里只是读取出来
  33.     data = md().read_data('index_min5', stock_code=code,
  34.                           date={'gte': sd, 'lt': ed},
  35.                           field={'_id': 0, 'time': 1, 'close': 1, 'date': 1})
  36.     data = data.sort_values(['date', 'time'], ascending=False)
  37.     # 计算每根K线收盘时未来三根K线的涨跌幅
  38.     data['R'] = (data.close.shift(3) / data.close - 1) * 100
  39.     data.fillna(0, inplace=True)
  40.     data = data.round({'R': 3})
  41.     data = data.sort_values(['date', 'time'], ascending=True)
  42.     qtb = None
  43.     for k, g in data.groupby(['date']):
  44.         print('train to:', k)
  45.         try:
  46.             # 开始一天一天的训练
  47.             qtb = update(g, qtb)
  48.         except Exception as e:
  49.             ExceptionInfo(e)
  50.         print('\nQ 表:')
  51.         print(qtb)
  52.     qtb['time'] = qtb.index
  53.     qtb.to_csv(path_or_buf='E:\wv\ReinfL\model_param\qtb({})_{}.csv'.
  54.                format(code, sd.strftime('%Y_%m_%d')), index=False)
  55.     pass


  56. train()
复制代码

二维码

扫码加我 拉你入群

请注明:姓名-公司-职位

以便审核进群资格,未注明则拒绝

关键词:Learning earning Learn Earn ning

已有 1 人评分论坛币 学术水平 热心指数 信用等级 收起 理由
ithjesuxf + 5 + 5 + 5 + 5 优秀

总评分: 论坛币 + 5  学术水平 + 5  热心指数 + 5  信用等级 + 5   查看全部评分

学习

您需要登录后才可以回帖 登录 | 我要注册

本版微信群
加好友,备注jr
拉您进交流群
GMT+8, 2026-1-27 13:27