1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
#导入常规数据处理库
#from jqdata import *
import talib
import numpy as np
import pandas as pd
import matplotlib.dates as mdates
import matplotlib.pyplot as plt
import seaborn as sns
import copy
plt.rcParams['axes.unicode_minus']=False
plt.rcParams['font.sans-serif']=['SimHei'] #指定默认字体 SimHei为黑体
import warnings
warnings.filterwarnings("ignore")
import datetime
import time
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
#特定基准的特定区间的收益率
def cal_benchmark_return(benchmark,start_date,end_date):
First_date = get_trade_days(start_date = start_date, end_date=end_date, count=None)[0]
tradingday = list(get_all_trade_days())
shiftday_index = tradingday.index(First_date)-1
pre_date = tradingday[shiftday_index]
data = get_price(benchmark, start_date=pre_date, end_date=end_date,\
fields=['open','close'],frequency='daily').close
return data/data.shift()-1
class Report(object):
def __init__(self,benchmark):
#初始化
self.benchmark = benchmark
#全A股的历史量价数据,收益,开盘价,收盘价,最高价,最低价,换手率,成交股票数,成交金额
self.data_return = pd.read_csv('A_share_return.csv',index_col=0)
self.data_close = pd.read_csv('A_share_close.csv',index_col=0)
self.data_open = pd.read_csv('A_share_open.csv',index_col=0)
self.data_high = pd.read_csv('A_share_high.csv',index_col=0)
self.data_low = pd.read_csv('A_share_low.csv',index_col=0)
self.data_money = pd.read_csv('A_share_money.csv',index_col=0)
self.data_turnover = pd.read_csv('A_share_turnover_ratio.csv',index_col=0)
self.data_volume = pd.read_csv('A_share_volume.csv',index_col=0)
#定义基准在输入区间的收益率
self.benchmark_return = cal_benchmark_return(benchmark,self.data_return.index[1],self.data_return.index[-1])
#下面是为了实现特定功能的功能函数
#获取特定周期的时间点列表
def get_time_inverval(self,start_date,end_date,freq):
if freq == 'M':
interval_start = pd.to_datetime(pd.date_range(start_date,end_date,freq = 'MS'))
interval_end = pd.to_datetime(pd.date_range(start_date,end_date,freq = 'BM'))
elif freq == 'Y':
interval_start = pd.to_datetime(pd.date_range(start_date,end_date,freq = 'AS'))
interval_end = pd.to_datetime(pd.date_range(start_date,end_date,freq = 'A'))
if len(interval_start) > len(interval_end) > 0:
interval_start = interval_start[:-1]
elif len(interval_end) == 0:
print('请输入完整的周期区间')
return interval_start,interval_end
#转datestampe格式为datetime格式
def datetime_date(self,datestamp):
return datetime.date(datestamp.year, datestamp.month, datestamp.day)
#股票所属行业
def get_stock_belong_industry(self,stocklist,date):
industry = {}
industry_all = get_industry(security = stocklist, date=date)
for i in stocklist:
try:
industry[i] = industry_all[i]['sw_l1']['industry_name']
except:
continue
return pd.DataFrame([industry],index = ['industry']).T
#过滤新股
def filter_new_stock(self,stock_list,date):
tmpList = []
for stock in stock_list :
days_public=(datetime_date(date) - get_security_info(stock).start_date).days
if days_public >= 180:
tmpList.append(stock)
return tmpList
#剔除ST股
def delete_st(self,stocks,begin_date):
st_data=get_extras('is_st',stocks, count = 1,end_date=begin_date)
stockList = [stock for stock in stocks if not st_data[stock][0]]
return stockList
###提取因子部分
def get_MA(self,df_close,n):
'''
移动平均值是在一定范围内的价格平均值
df_close:量价数据,dataframe或者series
n: 回溯的天数,整数型
'''
ma = df_close.rolling(n).mean()
ma = pd.DataFrame({'MA_' + str(n): ma}, index = ma.index)
return ma
## 计算变化率ROC
def get_ROC(self,df_close, n):
'''
ROC=(今天的收盘价-N日前的收盘价)/N日前的收盘价*100
移动平均值是在一定范围内的价格平均值
df_close:量价收盘数据,dataframe或者series
n: 回溯的天数,整数型
'''
M = df_close
N = df_close.shift(n)
roc = pd.DataFrame({'ROC_' + str(n): (M-N) / N*100}, index = M.index)
return roc
## 计算RSI
def get_RSI(self,df_close,n):
'''
df_close:量价收盘数据,dataframe或者series
n: 回溯的天数,整数型
'''
rsi = talib.RSI(df_close, timeperiod=n)
return pd.DataFrame({'RSI_' + str(n): rsi}, index = df_close.index)
##计算OBV指标
def get_OBV(self,df_close,df_volume):
'''
On Balance Volume 能量,通过统计成交量变动的趋势推测股价趋势
df_close:量价收盘数据,dataframe或者series
df_volume:J量价交易数数据,dataframe或者series
'''
obv = talib.OBV(df_close,df_volume)
return pd.DataFrame({'OBV': obv}, index = df_close.index)
#真实波幅
def get_ATR(self,df_high,df_low,df_close,n):
'''
平均真实波幅,主要用来衡量价格的波动
df_close:量价收盘数据,dataframe或者series
df_high:量价最高价数据,dataframe或者series
df_low:量价最低价数据,dataframe或者series
n: 回溯的天数,整数型
'''
atr = talib.ATR(df_high,df_low,df_close, timeperiod=n)
return pd.DataFrame({'ATR_' + str(n): atr}, index = df_close.index)
#上升动向值
def get_MOM(self,df_close,n):
'''
上升动向值,投资学中意思为续航,指股票(或经济指数)持续增长的能力。研究发现,赢家组合在牛市中存在着正的动量效应,输家组合在熊市中存在着负的动量效应。
df_close:量价收盘数据,dataframe或者series
n: 回溯的天数,整数型
'''
mom = talib.MOM(df_close, timeperiod=n)
return pd.DataFrame({'MOM_' + str(n): mom}, index = df_close.index)
#阿隆指标
def get_AROON(self,df_high,df_low,n):
aroondown, aroonup = talib.AROON(df_high,df_low, timeperiod=n)
return pd.DataFrame({'Aroondown_' + str(n): aroondown}, index = df_high.index),pd.DataFrame({'Aroonup_' + str(n): aroonup}, index = df_high.index)
###因子处理合成部分
#合并新特征的函数,是添加新指标或者新特征的重要函数
def merge_raw_factors(self,df_close,df_open,df_high,df_low,df_volume,df_money):
return pd.concat([self.get_MA(df_close,5),self.get_MA(df_close,60),self.get_AROON(df_high,df_low,14)[0],\
self.get_AROON(df_high,df_low,14)[1],self.get_ROC(df_close, 6),self.get_ROC(df_close, 12),self.get_RSI(df_close,6),\
self.get_RSI(df_close,24),self.get_OBV(df_close,df_volume),self.get_ATR(df_high,df_low,df_close,14),self.get_MOM(df_close,10)],axis = 1)
#数据去极值及标准化
def winsorize_and_standarlize(self,data,qrange=[0.05,0.95],axis=0):
'''
input:
data:Dataframe or series,输入数据
qrange:list,list[0]下分位数,list[1],上分位数,极值用分位数代替
'''
if isinstance(data,pd.DataFrame):
if axis == 0:
q_down = data.quantile(qrange[0])
q_up = data.quantile(qrange[1])
index = data.index
col = data.columns
for n in col:
data[n][data[n] > q_up[n]] = q_up[n]
data[n][data[n] < q_down[n]] = q_down[n]
data = (data - data.mean())/data.std()
data = data.fillna(0)
else:
data = data.stack()
data = data.unstack(0)
q = data.quantile(qrange)
index = data.index
col = data.columns
for n in col:
data[n][data[n] > q[n]] = q[n]
data = (data - data.mean())/data.std()
data = data.stack().unstack(0)
data = data.fillna(0)
elif isinstance(data,pd.Series):
name = data.name
q = data.quantile(qrange)
data[data>q] = q
data = (data - data.mean())/data.std()
return data
#可取一个标的的一日或者多日原始因子数据
def merge_raw_factor_perdate_percode(self,stock_code,start_date,end_date):
#获取指定区间的数据来计算因子
date_interval = get_trade_days(start_date =start_date, end_date=end_date, count=None)
date_interval = [date.strftime('%Y-%m-%d') for date in date_interval]
if set(date_interval) <= set(self.data_return.index) and len(self.data_return.loc[date_interval]) != 0:
df_return = self.data_return[stock_code]
df_close = self.data_close[stock_code]
df_open = self.data_open[stock_code]
df_high = self.data_high[stock_code]
df_low = self.data_low[stock_code]
df_money = self.data_money[stock_code]
if stock_code != '000001.XSHG':
df_turnover = self.data_turnover[stock_code]
df_volume = self.data_volume[stock_code]
merge_raw_data = self.merge_raw_factors(df_close,df_open,df_high,df_low,df_volume,df_money)
else:
print('你输入的时间区间超过了本地数据的时间区间或不含有交易日,请输入2005年1月5日到2020年2月28日内的起始时间')
return merge_raw_data
#可取一个标的的一日或者多日规整化因子数据
def merge_regularfactor_multidate_percode(self,stock_code,start_date,end_date):
#获取指定区间的数据来计算因子
date_interval = get_trade_days(start_date =start_date, end_date=end_date, count=None)
date_interval = [date.strftime('%Y-%m-%d') for date in date_interval]
merge_raw_data = self.merge_raw_factor_perdate_percode(stock_code,start_date,end_date)
data_pro = self.winsorize_and_standarlize(merge_raw_data.loc[date_interval])
return data_pro
#可取多个标的的一天的原始因子数据
def merge_rawfactor_multicode_perday(self,stock_list,start_date,end_date):
#获取指定区间的数据来计算因子
date_interval = get_trade_days(start_date =start_date, end_date=end_date, count=None)
date_interval = [date.strftime('%Y-%m-%d') for date in date_interval]
factor_list = []
for stock_code in stock_list:
factor_list.append(self.merge_raw_factor_perdate_percode(stock_code,start_date,end_date).loc[date_interval])
return pd.concat(factor_list)
#可取多个标的的一天的规整化因子数据
def merge_regularfactor_multicode_perday(self,stock_list,start_date,end_date):
#获取指定区间的数据来计算因子
date_interval = get_trade_days(start_date =start_date, end_date=end_date, count=None)
date_interval = [date.strftime('%Y-%m-%d') for date in date_interval]
data_raw = self.merge_rawfactor_multicode_perday(stock_list,start_date,end_date)
data_regular = self.winsorize_and_standarlize(data_raw)
data_regular.insert(0, 'code',stock_list)
return data_regular
#贴标签方法1,贴标签方法可以自定义
def add_label_1(self,stock_code,start_date,end_date):
profit = self.data_return[stock_code]
profit[profit > 0] = 1
profit[profit< 0] = 0
profit = pd.DataFrame(profit)
profit.columns = ['Label']
return profit
#一个标的的多天的规整化因子数据+标签(方法1)
def merge_final_factor_label1(self,stock_code,start_date,end_date):
#获取指定区间的数据来计算因子
date_interval = get_trade_days(start_date =start_date, end_date=end_date, count=None)
date_interval = [date.strftime('%Y-%m-%d') for date in date_interval]
if set(date_interval) <= set(self.data_return.index) and len(self.data_return.loc[date_interval]) != 0:
data_pro = self.merge_regularfactor_multidate_percode(stock_code,start_date,end_date)
profit = self.add_label_1(stock_code,start_date,end_date)
data_final = pd.concat([profit.shift(),data_pro],axis = 1).loc[date_interval]
else:
print('你输入的时间区间超过了本地数据的时间区间或不含有交易日,请输入2005年1月5日到2020年2月28日内的起始时间')
return data_final
# 指数择时模型,参数:指数代码,开始时间,结束时间,机器学习模型(需要添加),训练好的模型的本地保存文件名
def timing_model(self,code,start_date,end_date,out_start,out_end,model_name,file_name):
print('开始获取合成特征和标签数据框...')
#获取数据
data_index = self.merge_final_factor_label1(code,start_date,end_date)
# 特征数据和标签数据
x_data,y_data = data_index.iloc[:,1:], data_index.iloc[:,0]
print ('-' * 60)
print('按照比例分割为训练集和测试集...')
'''
原始数据按照比例分割为“训练集”和“测试集”
x_data:所要划分的样本特征集
y_data:所要划分的样本标签
test_size:样本占比,如果是整数的话就是样本的数量
random_state:是随机数的种子。随机数种子:其实就是该组随机数的编号,在需要重复试验的时候,保证得到一组一样的随机数。比如你每次都填1,其他参数一样的情况下你得到的随机数组是一样的。但填0或不填,每次都会不一样。
'''
x_train,x_test,y_train,y_test = train_test_split(x_data,y_data,test_size=0.2)
print ('-' * 60)
# 创建预测数据框
pred = pd.DataFrame(index=y_test.index)
pred["Actual"] = y_test
#可往里面添加算法
if model_name == 'LR':
#构建分类器,里面的默认参数可修改
model = LogisticRegression(solver='liblinear')
elif model_name == "LDA":
model=QDA()
else:
print('不支持这个算法,请新添加这个算法')
print('开始训练数据...')
#训练数据
model.fit(x_train, y_train)
print('\n')
print('训练结束')
print ('-' * 60)
print ("预测准确率:")
pred[model_name] = model.predict(x_test)
# 预测准确率
score=accuracy_score(pred['Actual'], pred[model_name])
print("%s模型: %.3f" % (model_name, score))
print ('-' * 60)
# 构建混淆矩阵
cm = pd.crosstab(pred['Actual'], pred[model_name])
print('输出混淆矩阵...')
print(cm)
# 绘制混淆矩阵图
sns.heatmap(cm, annot = True, cmap = 'GnBu', fmt = 'd')
print ('-' * 60)
print('绘制曲线...')
# 计算正例的预测概率,而非实际的预测值,用于生成ROC曲线的数据
y_score = model.predict_proba(x_test)[:,1]
#fpr表示1-Specificity,tpr表示Sensitivity
fpr,tpr,threshold = roc_curve(y_test, y_score)
# 计算AUC的值
roc_auc = metrics.auc(fpr,tpr)
# 绘制面积图
plt.figure(figsize=(8,6))
plt.stackplot(fpr, tpr, color='steelblue', alpha = 0.5, edgecolor = 'black')
plt.plot(fpr, tpr, color='black', lw = 1)
# 添加对角线
plt.plot([0,1],[0,1], color = 'red', linestyle = '--')
# 添加文本信息
plt.text(0.5,0.3,'ROC曲线 (area = %0.2f)' % roc_auc)
# 添加x轴与y轴标签
plt.title('模型预测指数涨跌的ROC曲线',size=15)
plt.xlabel('1-Specificity')
plt.ylabel('Sensitivity')
plt.show()
print ('-' * 60)
print('输出评估报告...')
print('\n')
print('模型的评估报告:\n',classification_report(pred['Actual'], pred[model_name]))
print ('-' * 60)
print('保存模型到本地...')
joblib.dump(model,file_name+'.pkl')
print('\n')
print('加载本地训练好的模型...')
model2 = joblib.load(file_name+'.pkl')
print('加载完毕')
print ('-' * 60)
# 样本外预测结果
print('样本外测试结果')
#获取样本外数据, out_start,out_end 样本外开始时间,结束时间
data_out_sample = self.merge_final_factor_label1(index_code,out_start,out_end)
#划分特征集和标签集
x_out_test,y_out_test = data_out_sample.iloc[:,1:], data_out_sample.iloc[:,0]
y_out_pred = model2.predict(x_out_test)
predict_proba = model2.predict_proba(x_out_test)[:,1] #此处test_X为特征集
# 样本外准确率
accuracy_out_sample = accuracy_score(y_out_test, y_out_pred)
print('样本外准确率',accuracy_out_sample)
# 样本外AUC值
roc_out_sample = roc_auc_score(y_out_test, y_out_pred)
print('样本外AUC值',roc_out_sample)
#返回加载好的模型
return model2
#贴标签方法2,贴标签方法可以自定义.data为特征数据框,带有列名为return的收益序列,前30%记为1,后30%记为0,其余数据去掉
def add_label_2(self,data):
percent_select = [0.3,0.3]
#做标签
data['Label'] = np.nan
#根据收益排序
data = data.sort_values(by='return',ascending=False)
#选一定比例的样本
n_stock = data.shape[0]
n_stock_select = np.multiply(percent_select,n_stock)
n_stock_select = np.around(n_stock_select).astype(int)
#给选中的样本打上标签1 or 0
data.iloc[0:n_stock_select[0],-1] = 1
data.iloc[-n_stock_select[1]:,-1] = 0
#去掉其他没选上的股票
data = data.dropna(axis=0)
del data['return']
return data
#合成用贴标签方法2 的一期数据框,带有特征和标签
def data_for_model_perperiod(self,start_date,end_date,index_code):
#在相应区间还在上市的股票
stock_list = get_index_stocks(index_code, date=end_date)
stock_list_notpause = list(set(stock_list) & set(get_all_securities(date=end_date).index) & set(self.data_return.columns))
#去掉上市时间不到六个月的公司
# stock_list_fillter = filter_new_stock(stock_list_notpause,date_interval[1])
#去除st,*st股
#stock_list= delete_st(stock_list_fillter, date1)
date_interval = get_trade_days(start_date = start_date, end_date=end_date, count=None)
data_per_period = self.merge_regularfactor_multicode_perday(stock_list_notpause,date_interval[0],date_interval[0])
date_interval = [date.strftime('%Y-%m-%d') for date in date_interval]
data_return_interval = self.data_return.loc[date_interval][stock_list_notpause]
#统计相应区间固定周期收益率
profit = data_return_interval.apply(lambda x:(1 + x).cumprod() - 1).iloc[-1]
data_per_period = data_per_period.set_index('code')
stock_profit = pd.DataFrame(profit,columns = ['return'])
data_merge_label = self.add_label_2(pd.concat([data_per_period,stock_profit],axis = 1))
return data_merge_label
#合成用贴标签方法2 的多期数据框,带有特征和标签
def data_for_model_multiperiod(self,start_date,end_date,index_code):
interval_start,interval_end = self.get_time_inverval(start_date,end_date,'M')
factor_list = []
for date1,date2 in dict(zip(interval_start,interval_end)).items():
data_merge_label = self.data_for_model_perperiod(date1,date2,index_code)
factor_list.append(data_merge_label)
return pd.concat(factor_list,axis = 0)
#多因子模型
def multifactor_model(self,index_code,start_date,end_date,out_start,out_end,model_name,file_name):
data_regular = self.data_for_model_multiperiod(start_date,end_date,index_code)
y_data = data_regular['Label']
x_data = data_regular.iloc[:,:-1]
x_train,x_test,y_train,y_test = train_test_split(x_data,y_data,test_size=0.2)
#可往里面添加算法
if model_name == 'LR':
#构建分类器,里面的默认参数可修改
model = LogisticRegression(solver='liblinear')
elif model_name == "LDA":
model=QDA()
elif model_name == 'xgboost':
model = XGBClassifier(max_depth=3,subsample=0.9,random_state=0)
else:
print('不支持这个算法,请新添加这个算法')
print('开始训练数据...')
#训练数据
model.fit(x_train, y_train)
print('\n')
print('训练结束')
print ('-' * 60)
print ("预测准确率:")
pred = model.predict(x_test)
# 预测准确率
score=accuracy_score(y_test, pred )
print("%s模型: %.3f" % (model_name, score))
print ('-' * 60)
# 构建混淆矩阵
cm = pd.crosstab(y_test, pred )
print('输出混淆矩阵...')
print(cm)
# 绘制混淆矩阵图
sns.heatmap(cm, annot = True, cmap = 'GnBu', fmt = 'd')
print ('-' * 60)
print('绘制曲线...')
# 计算正例的预测概率,而非实际的预测值,用于生成ROC曲线的数据
y_score = model.predict_proba(x_test)[:,1]
#fpr表示1-Specificity,tpr表示Sensitivity
fpr,tpr,threshold = roc_curve(y_test, y_score)
# 计算AUC的值
roc_auc = metrics.auc(fpr,tpr)
# 绘制面积图
plt.figure(figsize=(8,6))
plt.stackplot(fpr, tpr, color='steelblue', alpha = 0.5, edgecolor = 'black')
plt.plot(fpr, tpr, color='black', lw = 1)
# 添加对角线
plt.plot([0,1],[0,1], color = 'red', linestyle = '--')
# 添加文本信息
plt.text(0.5,0.3,'ROC曲线 (area = %0.2f)' % roc_auc)
# 添加x轴与y轴标签
plt.title('模型预测的ROC曲线',size=15)
plt.xlabel('1-Specificity')
plt.ylabel('Sensitivity')
plt.show()
print ('-' * 60)
print('输出评估报告...')
print('\n')
print('模型的评估报告:\n',classification_report(y_test, pred ))
print ('-' * 60)
print('保存模型到本地...')
joblib.dump(model,file_name+'.pkl')
print('\n')
print('加载本地训练好的模型...')
model2 = joblib.load(file_name+'.pkl')
print('加载完毕')
print ('-' * 60)
#样本外预测
test_sample_predict={}
test_sample_score=[]
test_sample_accuracy=[]
test_sample_roc_auc=[]
test_sample_date=[]
interval_start,interval_end = self.get_time_inverval(out_start,out_end,'M')
# 样本外预测结果
print('样本外测试结果...')
for date1,date2 in dict(zip(interval_start,interval_end)).items():
data_merge_label = self.data_for_model_perperiod(date1,date2,index_code)
y_test=data_merge_label['Label']
X_test=data_merge_label.iloc[:,:-1]
# 输出预测值以及预测概率
y_pred_tmp = model2.predict(X_test)
y_pred = pd.DataFrame(y_pred_tmp, columns=['label_predict']) # 获得预测标签
y_pred_proba = pd.DataFrame(model.predict_proba(X_test), columns=['pro1', 'pro2']) # 获得预测概率
# 将预测标签、预测数据和原始数据X合并
y_pred.set_index(X_test.index,inplace=True)
y_pred_proba.set_index(X_test.index,inplace=True)
predict_pd = pd.concat((X_test, y_pred, y_pred_proba), axis=1)
print ('Predict date:')
print (date1)
print ('AUC:')
print (roc_auc_score(y_test,y_pred)) # 打印前2条结果
print ('Accuracy:')
print (accuracy_score(y_test, y_pred)) # 打印前2条结果
print ('-' * 60)
## 后续统计画图用
test_sample_date.append(date1)
# 样本外预测结果
test_sample_predict[date1]=y_pred_tmp
# 样本外准确率
test_sample_accuracy.append(accuracy_score(y_test, y_pred))
# 样本外AUC值
test_sample_roc_auc.append(roc_auc_score(y_test,y_pred))
print ('AUC mean info')
print (np.mean(test_sample_roc_auc))
print ('-' * 60)
print ('ACCURACY mean info')
print (np.mean(test_sample_accuracy))
print ('-' * 60)
f = plt.figure(figsize= (15,6))
xs_date = test_sample_date
ys_auc = test_sample_roc_auc
# 配置横坐标
plt.gca().xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
plt.plot(xs_date, ys_auc,'r')
# 自动旋转日期标记
plt.gcf().autofmt_xdate()
# 横坐标标记
plt.xlabel('date')
# 纵坐标标记
plt.ylabel("test AUC")
plt.show()
f = plt.figure(figsize= (15,6))
xs_date = test_sample_date
ys_score = test_sample_accuracy
# 配置横坐标
plt.gca().xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
plt.plot(xs_date, ys_score,'r')
# 自动旋转日期标记
plt.gcf().autofmt_xdate()
# 横坐标标记
plt.xlabel('date')
# 纵坐标标记
plt.ylabel("test accuracy")
plt.show()
f = plt.figure(figsize= (15,6))
sns.set(style="whitegrid")
data1 = pd.DataFrame(ys_auc, xs_date, columns={'AUC'})
data2 = pd.DataFrame(ys_score, xs_date, columns={'accuracy'})
data = pd.concat([data1,data2],sort=False)
sns.lineplot(data=data, palette="tab10", linewidth=2.5)
return model2
1
2
3
4
#创建实例,参数(基准),数据默认2005年1月5日到2020年2月28日的全A股数据
#基准:沪深300,可选
benchmark = '000300.XSHG'
report = Report(benchmark)
获取和获取数据:量价数据调用(前复权)
全A股的历史量价数据,涨跌,开盘价,收盘价,最高价,最低价,换手率,成交股票数,成交金额
要调用的数据的格式
储存的可调用的股票历史收益率数据框
1
2
#储存的可调用的股票历史收益率数据框
report.data_return.head()
000001.XSHG | 399006.XSHE | 399001.XSHE | 399005.XSHE | 000016.XSHG | 000906.XSHG | 600000.XSHG | 600004.XSHG | 600006.XSHG | 600007.XSHG | ... | 300813.XSHE | 300815.XSHE | 300816.XSHE | 300817.XSHE | 300818.XSHE | 300819.XSHE | 300820.XSHE | 300821.XSHE | 300822.XSHE | 300823.XSHE | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
2005-01-04 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
2005-01-05 | 0.007379 | NaN | 0.009070 | NaN | 0.005272 | NaN | -0.012048 | 0.015015 | 0.010638 | 0.018405 | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
2005-01-06 | -0.009992 | NaN | -0.007904 | NaN | -0.010741 | NaN | -0.012195 | 0.002959 | -0.015789 | -0.009036 | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
2005-01-07 | 0.004292 | NaN | 0.002265 | NaN | 0.001362 | NaN | 0.000000 | 0.017699 | 0.005348 | 0.018237 | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
2005-01-10 | 0.006146 | NaN | 0.008941 | NaN | 0.011377 | NaN | 0.037037 | 0.005797 | 0.005319 | 0.017910 | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
5 rows × 3708 columns
储存的可调用的股票历史收盘价数据框
1
2
#储存的可调用的股票历史收盘价数据框
report.data_close.head()
000001.XSHG | 399006.XSHE | 399001.XSHE | 399005.XSHE | 000016.XSHG | 000906.XSHG | 600000.XSHG | 600004.XSHG | 600006.XSHG | 600007.XSHG | ... | 300813.XSHE | 300815.XSHE | 300816.XSHE | 300817.XSHE | 300818.XSHE | 300819.XSHE | 300820.XSHE | 300821.XSHE | 300822.XSHE | 300823.XSHE | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
2005-01-04 | 1242.77 | NaN | 3025.42 | NaN | 827.07 | NaN | 0.83 | 3.33 | 1.88 | 3.26 | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
2005-01-05 | 1251.94 | NaN | 3052.86 | NaN | 831.43 | NaN | 0.82 | 3.38 | 1.90 | 3.32 | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
2005-01-06 | 1239.43 | NaN | 3028.73 | NaN | 822.50 | NaN | 0.81 | 3.39 | 1.87 | 3.29 | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
2005-01-07 | 1244.75 | NaN | 3035.59 | NaN | 823.62 | NaN | 0.81 | 3.45 | 1.88 | 3.35 | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
2005-01-10 | 1252.40 | NaN | 3062.73 | NaN | 832.99 | NaN | 0.84 | 3.47 | 1.89 | 3.41 | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
5 rows × 3708 columns
储存的可调用的股票历史最低价数据框
1
2
#储存的可调用的股票历史最低价数据框
report.data_low.head()
000001.XSHG | 399006.XSHE | 399001.XSHE | 399005.XSHE | 000016.XSHG | 000906.XSHG | 600000.XSHG | 600004.XSHG | 600006.XSHG | 600007.XSHG | ... | 300813.XSHE | 300815.XSHE | 300816.XSHE | 300817.XSHE | 300818.XSHE | 300819.XSHE | 300820.XSHE | 300821.XSHE | 300822.XSHE | 300823.XSHE | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
2005-01-04 | 1238.18 | NaN | 3016.26 | NaN | 824.01 | NaN | 0.82 | 3.31 | 1.84 | 3.26 | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
2005-01-05 | 1235.75 | NaN | 3017.34 | NaN | 822.97 | NaN | 0.81 | 3.31 | 1.86 | 3.24 | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
2005-01-06 | 1234.24 | NaN | 3016.15 | NaN | 820.34 | NaN | 0.80 | 3.38 | 1.87 | 3.26 | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
2005-01-07 | 1235.51 | NaN | 3019.07 | NaN | 819.44 | NaN | 0.80 | 3.39 | 1.86 | 3.26 | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
2005-01-10 | 1236.09 | NaN | 3018.49 | NaN | 821.00 | NaN | 0.81 | 3.43 | 1.84 | 3.32 | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
5 rows × 3708 columns
储存的可调用的股票历史开盘价数据框
1
2
#储存的可调用的股票历史开盘价数据框
report.data_open.head()
000001.XSHG | 399006.XSHE | 399001.XSHE | 399005.XSHE | 000016.XSHG | 000906.XSHG | 600000.XSHG | 600004.XSHG | 600006.XSHG | 600007.XSHG | ... | 300813.XSHE | 300815.XSHE | 300816.XSHE | 300817.XSHE | 300818.XSHE | 300819.XSHE | 300820.XSHE | 300821.XSHE | 300822.XSHE | 300823.XSHE | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
2005-01-04 | 1260.78 | NaN | 3051.24 | NaN | 836.99 | NaN | 0.84 | 3.33 | 2.03 | 3.32 | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
2005-01-05 | 1241.68 | NaN | 3020.72 | NaN | 825.71 | NaN | 0.83 | 3.33 | 1.86 | 3.26 | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
2005-01-06 | 1252.49 | NaN | 3054.10 | NaN | 831.99 | NaN | 0.82 | 3.39 | 1.90 | 3.32 | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
2005-01-07 | 1239.32 | NaN | 3028.69 | NaN | 822.67 | NaN | 0.81 | 3.39 | 1.89 | 3.29 | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
2005-01-10 | 1243.58 | NaN | 3032.84 | NaN | 823.77 | NaN | 0.82 | 3.45 | 1.89 | 3.32 | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
5 rows × 3708 columns
储存的可调用的股票历史最高价数据框
1
2
#储存的可调用的股票历史最高价数据框
report.data_high.head()
000001.XSHG | 399006.XSHE | 399001.XSHE | 399005.XSHE | 000016.XSHG | 000906.XSHG | 600000.XSHG | 600004.XSHG | 600006.XSHG | 600007.XSHG | ... | 300813.XSHE | 300815.XSHE | 300816.XSHE | 300817.XSHE | 300818.XSHE | 300819.XSHE | 300820.XSHE | 300821.XSHE | 300822.XSHE | 300823.XSHE | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
2005-01-04 | 1260.78 | NaN | 3051.24 | NaN | 836.99 | NaN | 0.84 | 3.38 | 2.03 | 3.32 | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
2005-01-05 | 1258.58 | NaN | 3067.67 | NaN | 836.43 | NaN | 0.83 | 3.42 | 1.94 | 3.34 | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
2005-01-06 | 1252.73 | NaN | 3054.10 | NaN | 833.07 | NaN | 0.83 | 3.43 | 1.92 | 3.32 | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
2005-01-07 | 1256.31 | NaN | 3065.28 | NaN | 832.95 | NaN | 0.82 | 3.47 | 1.92 | 3.35 | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
2005-01-10 | 1252.72 | NaN | 3063.66 | NaN | 833.65 | NaN | 0.84 | 3.49 | 1.91 | 3.41 | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
5 rows × 3708 columns
储存的可调用的股票历史换手率数据框
1
2
#储存的可调用的股票历史换手率数据框
report.data_turnover.head()
000001.XSHE | 000002.XSHE | 000004.XSHE | 000005.XSHE | 000006.XSHE | 000007.XSHE | 000008.XSHE | 000009.XSHE | 000010.XSHE | 000011.XSHE | ... | 603989.XSHG | 603990.XSHG | 603991.XSHG | 603992.XSHG | 603993.XSHG | 603995.XSHG | 603996.XSHG | 603997.XSHG | 603998.XSHG | 603999.XSHG | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
2005-01-04 | 0.1249 | 0.6566 | 0.6400 | 0.2321 | 0.1816 | 0.3639 | 0.2005 | 0.1430 | 0.5186 | 0.5356 | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
2005-01-05 | 0.2286 | 1.1155 | 1.1179 | 0.2236 | 0.1860 | 0.4015 | 0.2400 | 0.1395 | 0.7805 | 0.3907 | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
2005-01-06 | 0.1892 | 1.1683 | 1.0751 | 0.2092 | 0.1633 | 0.3733 | 0.5760 | 0.1076 | 0.9594 | 0.1615 | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
2005-01-07 | 0.1338 | 1.1239 | 0.8072 | 0.8032 | 1.5223 | 0.5812 | 0.4816 | 0.2765 | 1.1124 | 0.1458 | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
2005-01-10 | 0.1868 | 0.4496 | 0.7163 | 1.9234 | 7.8899 | 1.0698 | 0.5288 | 0.6223 | 1.9922 | 0.1698 | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
5 rows × 3695 columns
储存的可调用的股票历史成交金额数据框
1
2
##储存的可调用的股票历史成交金额数据框
report.data_money.head()
000001.XSHG | 399006.XSHE | 399001.XSHE | 399005.XSHE | 000016.XSHG | 000906.XSHG | 600000.XSHG | 600004.XSHG | 600006.XSHG | 600007.XSHG | ... | 300813.XSHE | 300815.XSHE | 300816.XSHE | 300817.XSHE | 300818.XSHE | 300819.XSHE | 300820.XSHE | 300821.XSHE | 300822.XSHE | 300823.XSHE | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
2005-01-04 | 4.418452e+09 | NaN | 980468922.0 | NaN | 2.136409e+09 | NaN | 26134943.0 | 20728330.0 | 10491518.0 | 1555757.0 | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
2005-01-05 | 4.916589e+09 | NaN | 807720454.0 | NaN | 1.705649e+09 | NaN | 35366812.0 | 12969407.0 | 8801098.0 | 1330700.0 | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
2005-01-06 | 4.381370e+09 | NaN | 762259679.0 | NaN | 1.519687e+09 | NaN | 28758188.0 | 29118085.0 | 3727731.0 | 1328303.0 | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
2005-01-07 | 5.040042e+09 | NaN | 843160298.0 | NaN | 1.640665e+09 | NaN | 29239203.0 | 36353317.0 | 6311149.0 | 2460959.0 | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
2005-01-10 | 4.118292e+09 | NaN | 734534698.0 | NaN | 1.402314e+09 | NaN | 48798404.0 | 15893640.0 | 4484497.0 | 4907955.0 | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
5 rows × 3708 columns
储存的可调用的股票历史成交手数数据框
1
2
##储存的可调用的股票历史成交手数数据框
report.data_volume.head()
000001.XSHG | 399006.XSHE | 399001.XSHE | 399005.XSHE | 000016.XSHG | 000906.XSHG | 600000.XSHG | 600004.XSHG | 600006.XSHG | 600007.XSHG | ... | 300813.XSHE | 300815.XSHE | 300816.XSHE | 300817.XSHE | 300818.XSHE | 300819.XSHE | 300820.XSHE | 300821.XSHE | 300822.XSHE | 300823.XSHE | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
2005-01-04 | 816177000.0 | NaN | 138575300.0 | NaN | 403169700.0 | NaN | 31512001.0 | 6205383.0 | 5537612.0 | 473684.0 | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
2005-01-05 | 867865100.0 | NaN | 123303900.0 | NaN | 302086300.0 | NaN | 43229333.0 | 3835495.0 | 4655335.0 | 404622.0 | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
2005-01-06 | 792225400.0 | NaN | 108869300.0 | NaN | 275357400.0 | NaN | 35558905.0 | 8549882.0 | 1978640.0 | 406085.0 | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
2005-01-07 | 894087100.0 | NaN | 117752600.0 | NaN | 306608600.0 | NaN | 36094716.0 | 10584134.0 | 3332919.0 | 743356.0 | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
2005-01-10 | 723468300.0 | NaN | 94298700.0 | NaN | 247941100.0 | NaN | 58865757.0 | 4597541.0 | 2385315.0 | 1449998.0 | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
5 rows × 3708 columns
为了验证本地的数据是否是可靠的,随机查一下2020年1月15日平安银行的量价数据,与市面专业的商用软件通达信的数据进行对比
1
2
3
4
5
6
7
8
9
10
11
stock_choose = '000001.XSHE'
time_choose = '2020-01-15'
print(time_choose)
print('开盘',report.data_open.loc[time_choose,stock_choose])
print('最高',report.data_high.loc[time_choose,stock_choose])
print('最低',report.data_low.loc[time_choose,stock_choose])
print('收盘',report.data_close.loc[time_choose,stock_choose])
print('总量',report.data_volume.loc[time_choose,stock_choose])
print('换手',report.data_turnover.loc[time_choose,stock_choose])
print('总额',report.data_money.loc[time_choose,stock_choose])
print('涨跌',report.data_return.loc[time_choose,stock_choose])
1
2
3
4
5
6
7
8
9
2020-01-15
开盘 16.79
最高 16.86
最低 16.45
收盘 16.52
总量 85943912.0
换手 0.4429
总额 1424889228.07
涨跌 -0.014319809069212487

对比下,本地数据还是可信的
储存的可调用的历史基准收益
1
2
#储存的可调用的历史基准收益
report.benchmark_return.tail()
1
2
3
4
5
6
2020-02-24 -0.004013
2020-02-25 -0.002175
2020-02-26 -0.012326
2020-02-27 0.002912
2020-02-28 -0.035455
Name: close, dtype: float64
处理数据,合成特征和标签:手把手利用量价数据计算技术指标
1
2
3
4
5
6
7
8
9
#输入数据类型
close_pingan = report.data_close['000001.XSHE']
open_pingan = report.data_open['000001.XSHE']
high_pingan = report.data_high['000001.XSHE']
low_pingan = report.data_low['000001.XSHE']
volume_pingan = report.data_volume['000001.XSHE']
return_pingan = report.data_return['000001.XSHE']
money_pingan = report.data_return['000001.XSHE']
low_pingan.head()
1
2
3
4
5
6
2005-01-04 1.48
2005-01-05 1.46
2005-01-06 1.48
2005-01-07 1.48
2005-01-10 1.46
Name: 000001.XSHE, dtype: float64
计算均线
1
2
3
4
5
6
7
8
9
def get_MA(df_close,n):
'''
移动平均值是在一定范围内的价格平均值
df_close:量价数据,dataframe或者series
n: 回溯的天数,整数型
'''
ma = df_close.rolling(n).mean()
ma = pd.DataFrame({'MA_' + str(n): ma}, index = ma.index)
return ma
1
2
#计算平安银行的平均移动线
get_MA(close_pingan,20).tail()
MA_20 | |
---|---|
2020-02-24 | 15.1320 |
2020-02-25 | 15.0615 |
2020-02-26 | 15.0110 |
2020-02-27 | 14.9620 |
2020-02-28 | 14.9100 |
也可以用开源库talib来计算
1
2
#用talib库计算验证自己的计算结果
talib.MA(np.array(close_pingan),timeperiod=20, matype=0)
1
array([ nan, nan, nan, ..., 15.011, 14.962, 14.91 ])
1
2
3
4
5
6
7
8
9
10
11
12
## 计算变化率ROC
def get_ROC(df_close, n):
'''
ROC=(今天的收盘价-N日前的收盘价)/N日前的收盘价*100
移动平均值是在一定范围内的价格平均值
df_close:量价收盘数据,dataframe或者series
n: 回溯的天数,整数型
'''
M = df_close
N = df_close.shift(n)
roc = pd.DataFrame({'ROC_' + str(n): (M-N) / N*100}, index = M.index)
return roc
1
get_ROC(close_pingan, 12).tail()
ROC_12 | |
---|---|
2020-02-24 | 3.114421 |
2020-02-25 | 2.872777 |
2020-02-26 | 3.379310 |
2020-02-27 | 2.163624 |
2020-02-28 | -1.828030 |
1
2
#用talib库计算验证自己的计算结果
talib.ROC(close_pingan,12).tail()
1
2
3
4
5
6
2020-02-24 3.114421
2020-02-25 2.872777
2020-02-26 3.379310
2020-02-27 2.163624
2020-02-28 -1.828030
dtype: float64
1
2
3
4
5
6
7
8
## 计算RSI
def get_RSI(df_close,n):
'''
df_close:量价收盘数据,dataframe或者series
n: 回溯的天数,整数型
'''
rsi = talib.RSI(df_close, timeperiod=n)
return pd.DataFrame({'RSI_' + str(n): rsi}, index = df_close.index)
1
df = get_price('000001.XSHE', start_date='2015-09-16', end_date='2020-3-22', frequency='daily')
1
get_RSI(close_pingan,6).tail()
RSI_6 | |
---|---|
2020-02-24 | 49.477234 |
2020-02-25 | 42.341524 |
2020-02-26 | 40.497130 |
2020-02-27 | 47.129830 |
2020-02-28 | 28.054166 |
1
2
3
4
5
6
7
8
9
##计算OBV指标
def get_OBV(df_close,df_volume):
'''
On Balance Volume 能量,通过统计成交量变动的趋势推测股价趋势
df_close:量价收盘数据,dataframe或者series
df_volume:J量价交易数数据,dataframe或者series
'''
obv = talib.OBV(df_close,df_volume)
return pd.DataFrame({'OBV': obv}, index = df_close.index)
1
get_OBV(close_pingan,volume_pingan).tail()
OBV | |
---|---|
2020-02-24 | 4.570885e+10 |
2020-02-25 | 4.559439e+10 |
2020-02-26 | 4.547673e+10 |
2020-02-27 | 4.557426e+10 |
2020-02-28 | 4.544420e+10 |
1
2
3
4
5
6
7
8
9
10
11
#真实波幅
def get_ATR(df_high,df_low,df_close,n):
'''
平均真实波幅,主要用来衡量价格的波动
df_close:量价收盘数据,dataframe或者series
df_high:量价最高价数据,dataframe或者series
df_low:量价最低价数据,dataframe或者series
n: 回溯的天数,整数型
'''
atr = talib.ATR(df_high,df_low,df_close, timeperiod=n)
return pd.DataFrame({'ATR_' + str(n): atr}, index = df_close.index)
1
get_ATR(high_pingan,low_pingan,close_pingan,14).tail()
ATR_14 | |
---|---|
2020-02-24 | 0.422029 |
2020-02-25 | 0.424027 |
2020-02-26 | 0.434454 |
2020-02-27 | 0.421993 |
2020-02-28 | 0.438279 |
1
2
3
4
5
6
7
8
9
10
#上升动向值
def get_MOM(df_close,n):
'''
上升动向值,投资学中意思为续航,指股票(或经济指数)持续增长的能力。研究发现,赢家组合在牛市中存在着正的动量效应,输家组合在熊市中存在着负的动量效应。
df_close:量价收盘数据,dataframe或者series
n: 回溯的天数,整数型
'''
mom = talib.MOM(df_close, timeperiod=n)
return pd.DataFrame({'MOM_' + str(n): mom}, index = df_close.index)
1
2
3
def merge_raw_factors(df_close,df_open,df_high,df_low,df_volume,df_money):
return pd.concat([get_MA(df_close,5),get_MA(df_close,20),get_MA(df_close,60),get_ROC(df_close, 6),get_ROC(df_close, 12),get_RSI(df_close,6),\
get_RSI(df_close,12),get_RSI(df_close,24),get_OBV(df_close,df_volume),get_ATR(df_high,df_low,df_close,14),get_MOM(df_close,10)],axis = 1)
1
2
3
#平安银行的每日指标
factors = merge_raw_factors(close_pingan,open_pingan,high_pingan,low_pingan,volume_pingan,money_pingan)
factors.tail()
MA_5 | MA_20 | MA_60 | ROC_6 | ROC_12 | RSI_6 | RSI_12 | RSI_24 | OBV | ATR_14 | MOM_10 | |
---|---|---|---|---|---|---|---|---|---|---|---|
2020-02-24 | 15.368 | 15.1320 | 15.831000 | 1.330672 | 3.114421 | 49.477234 | 47.099803 | 46.196881 | 4.570885e+10 | 0.422029 | 0.73 |
2020-02-25 | 15.336 | 15.0615 | 15.821833 | -2.147040 | 2.872777 | 42.341524 | 43.943630 | 44.629065 | 4.559439e+10 | 0.424027 | 0.25 |
2020-02-26 | 15.286 | 15.0110 | 15.808333 | -1.381579 | 3.379310 | 40.497130 | 43.114224 | 44.216995 | 4.547673e+10 | 0.434454 | 0.22 |
2020-02-27 | 15.190 | 14.9620 | 15.799833 | -0.853018 | 2.163624 | 47.129830 | 45.792942 | 45.477725 | 4.557426e+10 | 0.421993 | 0.46 |
2020-02-28 | 14.974 | 14.9100 | 15.783667 | -6.991661 | -1.828030 | 28.054166 | 36.310977 | 40.609408 | 4.544420e+10 | 0.438279 | -0.53 |
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
#数据去极值及标准化
def winsorize_and_standarlize(data,qrange=[0.05,0.95],axis=0):
'''
input:
data:Dataframe or series,输入数据
qrange:list,list[0]下分位数,list[1],上分位数,极值用分位数代替
'''
if isinstance(data,pd.DataFrame):
if axis == 0:
q_down = data.quantile(qrange[0])
q_up = data.quantile(qrange[1])
index = data.index
col = data.columns
for n in col:
data[n][data[n] > q_up[n]] = q_up[n]
data[n][data[n] < q_down[n]] = q_down[n]
data = (data - data.mean())/data.std()
data = data.fillna(0)
else:
data = data.stack()
data = data.unstack(0)
q = data.quantile(qrange)
index = data.index
col = data.columns
for n in col:
data[n][data[n] > q[n]] = q[n]
data = (data - data.mean())/data.std()
data = data.stack().unstack(0)
data = data.fillna(0)
elif isinstance(data,pd.Series):
name = data.name
q = data.quantile(qrange)
data[data>q] = q
data = (data - data.mean())/data.std()
return data
数据查看
1
2
data_pro = winsorize_and_standarlize(factors)
data_pro.describe()
MA_5 | MA_20 | MA_60 | ROC_6 | ROC_12 | RSI_6 | RSI_12 | RSI_24 | OBV | ATR_14 | MOM_10 | |
---|---|---|---|---|---|---|---|---|---|---|---|
count | 3.682000e+03 | 3.682000e+03 | 3.682000e+03 | 3.682000e+03 | 3.682000e+03 | 3.682000e+03 | 3.682000e+03 | 3.682000e+03 | 3.682000e+03 | 3.682000e+03 | 3.682000e+03 |
mean | 1.700794e-14 | 2.515123e-14 | -2.662883e-14 | -7.065385e-16 | -1.112786e-16 | 2.447526e-14 | 2.093756e-14 | -1.427719e-14 | 2.011150e-14 | -7.692863e-15 | 3.948800e-16 |
std | 9.994565e-01 | 9.974158e-01 | 9.919535e-01 | 9.991847e-01 | 9.983687e-01 | 9.991847e-01 | 9.983687e-01 | 9.967347e-01 | 1.000000e+00 | 9.980965e-01 | 9.986407e-01 |
min | -1.705255e+00 | -1.728499e+00 | -1.780926e+00 | -1.824671e+00 | -1.810217e+00 | -1.702614e+00 | -1.659394e+00 | -1.596035e+00 | -1.534646e+00 | -1.351977e+00 | -2.025731e+00 |
25% | -5.698024e-01 | -5.721285e-01 | -5.615280e-01 | -6.282979e-01 | -6.495794e-01 | -7.731055e-01 | -7.740355e-01 | -7.952646e-01 | -8.534952e-01 | -8.839974e-01 | -5.336564e-01 |
50% | -1.084304e-01 | -8.470670e-02 | -7.784826e-02 | -9.829604e-02 | -6.213428e-02 | -3.013210e-02 | -3.795124e-02 | -8.486868e-02 | -2.590756e-01 | -9.136803e-02 | -3.629837e-02 |
75% | 6.278916e-01 | 6.541132e-01 | 6.663069e-01 | 5.347080e-01 | 5.678273e-01 | 7.716410e-01 | 7.247613e-01 | 7.206211e-01 | 1.097747e+00 | 6.497930e-01 | 4.989021e-01 |
max | 1.994363e+00 | 1.940143e+00 | 1.881687e+00 | 2.213011e+00 | 2.117848e+00 | 1.767631e+00 | 1.851835e+00 | 1.935446e+00 | 1.523051e+00 | 1.985683e+00 | 2.114235e+00 |
平安银行的每日收益
1
2
3
#平安银行的每日收益
profit = copy.deepcopy(return_pingan)
profit.tail(10)
1
2
3
4
5
6
7
8
9
10
11
2020-02-17 0.022621
2020-02-18 -0.011061
2020-02-19 0.002632
2020-02-20 0.022966
2020-02-21 -0.000641
2020-02-24 -0.022465
2020-02-25 -0.012475
2020-02-26 -0.003324
2020-02-27 0.008005
2020-02-28 -0.040371
Name: 000001.XSHE, dtype: float64
收益率大于0标记为1,否则为0,记作标签
1
2
3
4
5
6
#收益率大于0标记为1,否则为0,记作标签
profit[profit > 0] = 1
profit[profit< 0] = 0
profit = pd.DataFrame(profit)
profit.columns = ['Label']
profit.tail()
Label | |
---|---|
2020-02-24 | 0.0 |
2020-02-25 | 0.0 |
2020-02-26 | 0.0 |
2020-02-27 | 1.0 |
2020-02-28 | 0.0 |
1
2
#预测的是未来一天的涨跌,所以label比特征要滞后一天
profit.shift().tail()
Label | |
---|---|
2020-02-24 | 0.0 |
2020-02-25 | 0.0 |
2020-02-26 | 0.0 |
2020-02-27 | 0.0 |
2020-02-28 | 1.0 |
合并后的数据框,第一列label为标签,其余列称为特征
1
2
data_final = pd.concat([profit.shift(),data_pro],axis = 1)
data_final.tail()
Label | MA_5 | MA_20 | MA_60 | ROC_6 | ROC_12 | RSI_6 | RSI_12 | RSI_24 | OBV | ATR_14 | MOM_10 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|
2020-02-24 | 0.0 | 1.994363 | 1.940143 | 1.881687 | 0.154400 | 0.287494 | -0.101360 | -0.362483 | -0.602621 | 1.523051 | 1.427686 | 1.499024 |
2020-02-25 | 0.0 | 1.994363 | 1.940143 | 1.881687 | -0.506021 | 0.254781 | -0.510047 | -0.616758 | -0.775493 | 1.523051 | 1.442254 | 0.461060 |
2020-02-26 | 0.0 | 1.994363 | 1.940143 | 1.881687 | -0.360660 | 0.323354 | -0.615682 | -0.683579 | -0.820929 | 1.523051 | 1.518279 | 0.396187 |
2020-02-27 | 0.0 | 1.994363 | 1.940143 | 1.881687 | -0.260285 | 0.158779 | -0.235804 | -0.467769 | -0.681917 | 1.523051 | 1.427421 | 0.915169 |
2020-02-28 | 1.0 | 1.994363 | 1.940143 | 1.881687 | -1.426021 | -0.381598 | -1.328334 | -1.231680 | -1.218712 | 1.523051 | 1.546171 | -1.225633 |
功能集成,方便以后调用
将上述函数集成,添加到上面的类对象里,作为一个新的功能, 看懂了你也可以添加新的指标,或者改进代码
新功能:给定时间,给定股票或者指数,就能输出各种特征
1
2
#调用类中的新功能
report.merge_final_factor_label1('600000.XSHG','2020-01-21','2020-02-28')
Label | MA_5 | MA_60 | Aroondown_14 | Aroonup_14 | ROC_6 | ROC_12 | RSI_6 | RSI_24 | OBV | ATR_14 | MOM_10 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|
2020-01-21 | 1.0 | 2.310415 | 1.680919 | 1.385505 | 0.850139 | -0.049595 | 0.151672 | -0.627981 | 1.315566 | 0.974446 | -1.829183 | 0.119921 |
2020-01-22 | 0.0 | 2.310415 | 1.680919 | 1.385505 | 0.642568 | -0.508179 | -0.039098 | -1.254577 | 0.318997 | -0.277086 | -1.829183 | -0.030231 |
2020-01-23 | 0.0 | 1.933406 | 1.563415 | 1.385505 | 0.434997 | -0.860640 | -0.534339 | -1.550447 | -0.808540 | -1.503009 | -1.113230 | -0.573091 |
2020-02-03 | 0.0 | 1.066039 | 1.348623 | 1.385505 | 0.227426 | -1.807349 | -1.157115 | -1.550447 | -1.538096 | -1.915072 | 1.662121 | -1.413946 |
2020-02-04 | 0.0 | 0.282451 | 1.128566 | 1.203410 | 0.019855 | -1.807349 | -1.157115 | -1.184782 | -1.538096 | -1.915072 | 1.662121 | -1.413946 |
2020-02-05 | 1.0 | -0.397643 | 0.913774 | 1.021315 | -0.187716 | -1.778663 | -1.134754 | -1.047883 | -1.399087 | -1.239307 | 1.462821 | -1.393156 |
2020-02-06 | 1.0 | -0.880609 | 0.686346 | 0.839220 | -0.395288 | -1.437094 | -1.055505 | -0.714797 | -1.051062 | -0.630247 | 1.287606 | -1.081300 |
2020-02-07 | 1.0 | -0.964882 | 0.469448 | 0.657125 | -0.602859 | -0.927177 | -0.998269 | -0.444543 | -0.782738 | -0.119326 | 1.216143 | -0.942698 |
2020-02-10 | 1.0 | -0.964882 | 0.259921 | 0.475030 | -0.810430 | -0.473645 | -0.924015 | -0.585924 | -0.946487 | -0.664448 | 0.876075 | -1.081300 |
2020-02-11 | 0.0 | -0.875681 | 0.079874 | 0.292935 | -1.018001 | 1.049427 | -0.775532 | -0.186210 | -0.595686 | -0.106125 | 0.651534 | -1.000449 |
2020-02-12 | 1.0 | -0.796829 | -0.070691 | 0.110840 | -1.018001 | 0.735070 | -0.805060 | -0.186210 | -0.595686 | -0.106125 | 0.260558 | -0.804095 |
2020-02-13 | 0.0 | -0.811614 | -0.226521 | -0.071255 | -1.018001 | 0.507486 | -0.935054 | -0.433283 | -0.795829 | -0.468618 | -0.102491 | -0.561541 |
2020-02-14 | 0.0 | -0.811614 | -0.370768 | -0.253350 | -1.018001 | 0.522711 | -0.655954 | 0.118971 | -0.388368 | -0.110450 | -0.378784 | 0.039070 |
2020-02-17 | 1.0 | -0.609557 | -0.471847 | -0.435445 | -1.018001 | 0.920266 | 0.031893 | 1.313898 | 0.788246 | 0.551781 | 0.003316 | 1.123634 |
2020-02-18 | 1.0 | -0.510992 | -0.586614 | -0.617540 | -1.018001 | 0.876334 | 0.364239 | 0.861870 | 0.507352 | 0.068639 | -0.280534 | 1.067038 |
2020-02-19 | 0.0 | -0.387787 | -0.705592 | -0.799635 | -1.018001 | 0.808748 | 1.464518 | 1.032863 | 0.681504 | 0.470690 | -0.361636 | 1.078588 |
2020-02-20 | 1.0 | -0.156160 | -0.810882 | -0.981729 | -1.018001 | 0.999922 | 1.432828 | 1.412282 | 1.093107 | 1.057847 | -0.406533 | 1.113239 |
2020-02-21 | 1.0 | 0.060681 | -0.888797 | -1.163824 | 1.659666 | 1.049427 | 1.464518 | 1.412282 | 1.315566 | 1.502184 | -0.539459 | 1.113239 |
2020-02-24 | 1.0 | 0.050825 | -0.973029 | -1.163824 | 1.659666 | 0.888404 | 1.173463 | 0.937318 | 0.961292 | 0.965116 | -0.480417 | 1.055488 |
2020-02-25 | 0.0 | 0.065610 | -1.067790 | -1.163824 | 1.472852 | 0.271194 | 0.996147 | 0.633880 | 0.783536 | 0.191060 | -0.699303 | 0.870685 |
2020-02-26 | 0.0 | 0.109964 | -1.159393 | -1.163824 | 1.265281 | 0.629473 | 1.249660 | 1.107259 | 1.170027 | 1.022488 | -0.385546 | 0.997737 |
2020-02-27 | 1.0 | 0.100107 | -1.239940 | -1.163824 | 1.057710 | 0.566196 | 1.145570 | 1.149823 | 1.205165 | 1.502184 | -0.550383 | 1.123634 |
2020-02-28 | 1.0 | -0.121663 | -1.239940 | -1.163824 | 0.850139 | -0.174968 | 0.697302 | -0.213361 | 0.299315 | 0.748452 | -0.125615 | 0.593480 |
指数择时
获取自定义区间的数据
1
2
3
#获取指标2010年到2018年的上证指数数据,作为数据集
data_index = report.merge_final_factor_label1('000001.XSHG','2010-01-01','2018-01-01')
data_index.head(20)
Label | MA_5 | MA_60 | Aroondown_14 | Aroonup_14 | ROC_6 | ROC_12 | RSI_6 | RSI_24 | OBV | ATR_14 | MOM_10 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|
2010-01-04 | 1.0 | 0.988304 | 0.868502 | -0.147992 | -1.385636 | 0.990031 | -0.148243 | 0.441767 | 0.291502 | -0.922039 | 0.602991 | 1.216550 |
2010-01-05 | 0.0 | 1.025477 | 0.881695 | -0.342382 | -1.385636 | 1.577618 | 0.758380 | 0.826958 | 0.486422 | -0.922039 | 0.630772 | 1.505414 |
2010-01-06 | 1.0 | 1.042376 | 0.892517 | -0.536772 | -1.385636 | 0.694792 | 1.077039 | 0.346818 | 0.315013 | -0.922039 | 0.577094 | 1.943435 |
2010-01-07 | 0.0 | 1.014588 | 0.900080 | -0.731162 | 1.062305 | -0.265111 | 0.504234 | -0.466328 | -0.042070 | -0.922039 | 0.676621 | 1.109531 |
2010-01-08 | 0.0 | 0.982294 | 0.907437 | -0.925552 | 0.874002 | -0.791835 | 1.143227 | -0.416499 | -0.024108 | -0.922039 | 0.641097 | 0.357246 |
2010-01-11 | 1.0 | 0.969952 | 0.915472 | -1.119941 | 1.250608 | -0.764055 | 1.080711 | -0.138320 | 0.071259 | -0.922039 | 0.792552 | 0.640891 |
2010-01-12 | 1.0 | 0.966685 | 0.923492 | -1.314331 | 1.062305 | 0.287730 | 0.904679 | 0.654121 | 0.404724 | -0.922039 | 0.885783 | 0.776659 |
2010-01-13 | 1.0 | 0.934224 | 0.926494 | -1.314331 | 0.874002 | -1.262313 | 0.191839 | -0.591117 | -0.177093 | -0.922039 | 1.012879 | -0.447023 |
2010-01-14 | 0.0 | 0.943287 | 0.931427 | -1.314331 | 0.685699 | -0.482049 | 0.152193 | -0.089788 | 0.052163 | -0.922039 | 0.964007 | -0.525294 |
2010-01-15 | 1.0 | 0.954490 | 0.937305 | -1.314331 | 0.497395 | 0.306325 | 0.037760 | 0.008444 | 0.097648 | -0.916050 | 0.891891 | -0.583775 |
2010-01-18 | 1.0 | 0.964181 | 0.941703 | -1.314331 | 0.309092 | 0.416529 | -0.256650 | 0.169099 | 0.167496 | -0.906526 | 0.798950 | -0.127639 |
2010-01-19 | 1.0 | 0.953396 | 0.946375 | 0.046398 | 0.120789 | 0.335193 | -0.292484 | 0.298502 | 0.221135 | -0.897364 | 0.700342 | -0.409709 |
2010-01-20 | 1.0 | 0.945113 | 0.950812 | 1.407126 | -0.067514 | -1.405118 | -0.774132 | -0.872962 | -0.323491 | -0.908522 | 0.835592 | -1.069940 |
2010-01-21 | 0.0 | 0.922551 | 0.955152 | 1.407126 | -0.255817 | -0.208462 | -1.007122 | -0.758828 | -0.282305 | -0.900460 | 0.786744 | -0.396024 |
2010-01-22 | 1.0 | 0.884518 | 0.960873 | 1.407126 | -0.444120 | -1.032629 | -1.033166 | -1.051371 | -0.447064 | -0.910173 | 0.883784 | -0.725746 |
2010-01-25 | 0.0 | 0.827728 | 0.964227 | 1.212737 | -0.632424 | -1.511881 | -0.836571 | -1.335185 | -0.627804 | -0.915699 | 0.815792 | -1.227171 |
2010-01-26 | 0.0 | 0.737192 | 0.962278 | 1.407126 | -0.820727 | -1.923946 | -1.453195 | -1.699754 | -0.994930 | -0.922039 | 0.938067 | -1.970330 |
2010-01-27 | 0.0 | 0.671427 | 0.957935 | 1.407126 | -1.009030 | -1.923946 | -1.834742 | -1.699754 | -1.143798 | -0.922039 | 0.902419 | -1.893801 |
2010-01-28 | 0.0 | 0.605869 | 0.953362 | 1.407126 | -1.197333 | -1.867684 | -1.873060 | -1.699754 | -1.092096 | -0.922039 | 0.827277 | -1.970330 |
2010-01-29 | 1.0 | 0.550428 | 0.947721 | 1.212737 | -1.385636 | -1.923946 | -1.517184 | -1.699754 | -1.115215 | -0.922039 | 0.800694 | -1.970330 |
数据分割
1
2
3
4
5
6
7
8
9
10
11
# 将数据集分样本特征集和样本特征集
#x_data:所要划分的样本特征集
#y_data:所要划分的样本标签
#test_size:样本占比,如果是整数的话就是样本的数量
#random_state:是随机数的种子。随机数种子:其实就是该组随机数的编号,在需要重复试验的时候,保证得到一组一样的随机数。比如你每次都填1,其他参数一样的情况下你得到的随机数组是一样的。但填0或不填,每次都会不一样。
x_data,y_data = data_index.iloc[:,1:], data_index.iloc[:,0]
1
x_data.tail(20)
MA_5 | MA_60 | Aroondown_14 | Aroonup_14 | ROC_6 | ROC_12 | RSI_6 | RSI_24 | OBV | ATR_14 | MOM_10 | |
---|---|---|---|---|---|---|---|---|---|---|---|
2017-12-04 | 1.160203 | 1.302549 | 0.629567 | -1.385636 | -0.529114 | -0.724529 | -1.398063 | -1.015399 | 1.729006 | -0.574550 | -0.877069 |
2017-12-05 | 1.148271 | 1.299805 | 1.407126 | -0.444120 | -0.253274 | -0.650207 | -1.503913 | -1.094497 | 1.729006 | -0.625791 | -1.113752 |
2017-12-06 | 1.130799 | 1.296693 | 1.407126 | -0.632424 | -0.482986 | -0.791368 | -1.670960 | -1.222999 | 1.729006 | -0.570801 | -1.405963 |
2017-12-07 | 1.112833 | 1.293513 | 1.212737 | -0.820727 | -0.766524 | -1.083353 | -1.699754 | -1.496310 | 1.729006 | -0.562992 | -0.848419 |
2017-12-08 | 1.101837 | 1.290953 | 1.018347 | -1.009030 | -0.348283 | -1.092247 | -1.024953 | -1.136449 | 1.729006 | -0.544920 | -0.690499 |
2017-12-11 | 1.106843 | 1.289108 | 0.823957 | -1.197333 | -0.000361 | -0.283149 | 0.092151 | -0.557712 | 1.729006 | -0.540233 | -0.062364 |
2017-12-12 | 1.097741 | 1.285750 | 0.629567 | -1.385636 | -0.366630 | -0.608557 | -0.829096 | -1.079091 | 1.729006 | -0.513177 | -0.582397 |
2017-12-13 | 1.101355 | 1.282990 | 0.435177 | -1.385636 | -0.057532 | -0.205210 | -0.236514 | -0.717925 | 1.729006 | -0.521942 | -0.404885 |
2017-12-14 | 1.109470 | 1.280302 | 0.240787 | -1.385636 | -0.067255 | -0.371375 | -0.463819 | -0.846506 | 1.729006 | -0.541387 | -0.305742 |
2017-12-15 | 1.099978 | 1.277325 | 0.046398 | -1.385636 | -0.116092 | -0.601435 | -0.951066 | -1.148089 | 1.729006 | -0.540983 | -0.568909 |
2017-12-18 | 1.078375 | 1.274094 | 1.407126 | -1.197333 | -0.294117 | -0.434127 | -0.890981 | -1.118135 | 1.729006 | -0.560886 | -0.472621 |
2017-12-19 | 1.084635 | 1.272042 | 1.212737 | -1.385636 | -0.330999 | -0.219780 | -0.029386 | -0.655026 | 1.729006 | -0.565546 | -0.132365 |
2017-12-20 | 1.078494 | 1.269375 | 1.018347 | -1.385636 | 0.024776 | -0.227254 | -0.260446 | -0.764224 | 1.729006 | -0.590575 | -0.124587 |
2017-12-21 | 1.081527 | 1.267410 | 0.823957 | -1.385636 | -0.083261 | -0.087167 | 0.095794 | -0.568856 | 1.729006 | -0.560075 | 0.213700 |
2017-12-22 | 1.093833 | 1.265522 | 0.629567 | -1.385636 | 0.000464 | -0.035797 | -0.003177 | -0.607870 | 1.729006 | -0.616412 | 0.007538 |
2017-12-25 | 1.098824 | 1.263443 | 0.435177 | -0.632424 | 0.108730 | 0.005290 | -0.520022 | -0.821096 | 1.729006 | -0.583946 | -0.473015 |
2017-12-26 | 1.102637 | 1.262169 | 0.240787 | -0.820727 | 0.374023 | 0.064116 | 0.288744 | -0.413773 | 1.729006 | -0.580745 | 0.187118 |
2017-12-27 | 1.097928 | 1.259804 | 0.046398 | -1.009030 | -0.279201 | -0.411926 | -0.531291 | -0.793420 | 1.729006 | -0.566376 | -0.330454 |
2017-12-28 | 1.096464 | 1.258332 | -0.147992 | -1.197333 | 0.046382 | 0.060158 | 0.034153 | -0.483688 | 1.729006 | -0.542000 | -0.023278 |
2017-12-29 | 1.100487 | 1.256911 | -0.342382 | -1.385636 | 0.027748 | -0.027998 | 0.296710 | -0.327705 | 1.729006 | -0.594899 | 0.341887 |
1
y_data.tail(20)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
2017-12-04 1.0
2017-12-05 0.0
2017-12-06 0.0
2017-12-07 0.0
2017-12-08 0.0
2017-12-11 1.0
2017-12-12 1.0
2017-12-13 0.0
2017-12-14 1.0
2017-12-15 0.0
2017-12-18 0.0
2017-12-19 1.0
2017-12-20 1.0
2017-12-21 0.0
2017-12-22 1.0
2017-12-25 0.0
2017-12-26 0.0
2017-12-27 1.0
2017-12-28 0.0
2017-12-29 1.0
Name: Label, dtype: float64
1
2
# #原始数据按照比例分割为“训练集”和“测试集”
x_train,x_test,y_train,y_test = train_test_split(x_data,y_data,test_size=0.2)
训练模型
1
2
model=QDA()
model.fit(x_train, y_train)
1
2
QuadraticDiscriminantAnalysis(priors=None, reg_param=0.0,
store_covariance=False, store_covariances=None, tol=0.0001)
输出训练结果和评价指标
混淆矩阵
1
2
3
4
5
6
7
8
9
10
# 模型在测试数据集上的预测
pred = model.predict(x_test)
# 构建混淆矩阵
cm = pd.crosstab(y_test,pred)
cm
# 绘制混淆矩阵图
sns.heatmap(cm, annot = True, cmap = 'GnBu', fmt = 'd')
print('模型的准确率为:\n',accuracy_score(y_test, pred))
print('模型的评估报告:\n',classification_report(y_test, pred))
1
2
3
4
5
6
7
8
9
10
11
模型的准确率为:
0.6452442159383034
模型的评估报告:
precision recall f1-score support
0.0 0.63 0.63 0.63 186
1.0 0.66 0.66 0.66 203
micro avg 0.65 0.65 0.65 389
macro avg 0.64 0.64 0.64 389
weighted avg 0.65 0.65 0.65 389

ROC曲线
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# 计算正例的预测概率,而非实际的预测值,用于生成ROC曲线的数据
y_score = model.predict_proba(x_test)[:,1]
#fpr表示1-Specificity,tpr表示Sensitivity
fpr,tpr,threshold = roc_curve(y_test, y_score)
# 计算AUC的值
roc_auc = metrics.auc(fpr,tpr)
# 绘制面积图
plt.figure(figsize=(8,6))
plt.stackplot(fpr, tpr, color='steelblue', alpha = 0.5, edgecolor = 'black')
plt.plot(fpr, tpr, color='black', lw = 1)
# 添加对角线
plt.plot([0,1],[0,1], color = 'red', linestyle = '--')
# 添加文本信息
plt.text(0.5,0.3,'ROC曲线 (area = %0.2f)' % roc_auc)
# 添加x轴与y轴标签
plt.title('模型预测指数涨跌的ROC曲线',size=15)
plt.xlabel('1-Specificity')
plt.ylabel('Sensitivity')
plt.show()

保存模型
1
2
joblib.dump(model,'train_model.pkl')
model2 = joblib.load("train_model.pkl")
加载模型
1
model2 = joblib.load("train_model.pkl")
样本外测试
1
2
#获取样本外数据
data_out_sample = report.merge_final_factor_label1('000001.XSHG','2018-01-01','2020-01-01')
1
data_out_sample.tail(10)
Label | MA_5 | MA_60 | Aroondown_14 | Aroonup_14 | ROC_6 | ROC_12 | RSI_6 | RSI_24 | OBV | ATR_14 | MOM_10 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|
2019-12-18 | 1.0 | 0.234736 | -0.061065 | -0.783525 | 1.202429 | 1.361947 | 1.427895 | 1.862737 | 1.513285 | 1.435901 | -1.419598 | 1.522202 |
2019-12-19 | 0.0 | 0.326176 | -0.059776 | -0.976240 | 1.008527 | 1.264220 | 1.337858 | 1.862737 | 1.513665 | 1.435901 | -1.419598 | 1.297170 |
2019-12-20 | 1.0 | 0.359787 | -0.059885 | -1.168955 | 0.814626 | 1.222215 | 1.286913 | 1.208598 | 1.265640 | 1.435901 | -1.419598 | 1.036778 |
2019-12-23 | 0.0 | 0.340266 | -0.060923 | -1.361671 | 0.620724 | -0.038880 | 0.665595 | -0.266910 | 0.490695 | 1.435901 | -1.419598 | 0.565392 |
2019-12-24 | 0.0 | 0.304419 | -0.061116 | -1.361671 | 0.426822 | 0.003651 | 0.733870 | 0.252256 | 0.777277 | 1.435901 | -1.419598 | 0.745777 |
2019-12-25 | 1.0 | 0.272703 | -0.059200 | -1.361671 | 0.232921 | -0.498166 | 0.701949 | 0.227258 | 0.763167 | 1.435901 | -1.419598 | 0.662392 |
2019-12-26 | 0.0 | 0.263935 | -0.053531 | -1.361671 | 0.039019 | -0.099477 | 0.918042 | 0.826302 | 1.116750 | 1.435901 | -1.419598 | 1.023268 |
2019-12-27 | 1.0 | 0.264025 | -0.048252 | -1.168955 | -0.154882 | -0.129788 | 0.826051 | 0.734788 | 1.073468 | 1.435901 | -1.419598 | 0.450237 |
2019-12-30 | 0.0 | 0.333726 | -0.038485 | -1.361671 | 1.396330 | 0.482292 | 1.247091 | 1.361105 | 1.524115 | 1.435901 | -1.355460 | 0.643077 |
2019-12-31 | 1.0 | 0.394560 | -0.028593 | -1.168955 | 1.396330 | 1.178552 | 0.831957 | 1.494480 | 1.644490 | 1.435901 | -1.412389 | 0.348276 |
1
2
#划分特征集和标签集
x_out_test,y_out_test = data_out_sample.iloc[:,1:], data_out_sample.iloc[:,0]
加载模型后可直接使用
输出预测标签
1
2
y_out_pred = model2.predict(x_out_test)
y_out_pred
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1.,
0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 1., 1., 1., 1., 0.,
0., 0., 1., 0., 0., 1., 0., 0., 1., 0., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 0., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1.,
1., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 1., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 1., 1., 1., 0., 0., 0.,
0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 1., 0., 0., 0., 0.,
0., 0., 1., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 0., 1., 1., 0., 0., 0., 0.,
0., 1., 1., 1., 1., 1., 1., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 0., 1., 1., 0.,
1., 1., 0., 0., 0., 1., 1., 1., 0., 0., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 1., 0., 0., 0.,
0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 1., 1., 1., 0., 1., 1., 1.,
1., 1., 1., 0., 0., 0., 1., 1., 1., 1., 1., 0., 0., 1., 1., 1., 1.,
1., 1., 1., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0.,
0., 1., 0., 0., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
1., 1., 0., 0., 0., 1., 1., 0., 1., 1., 0., 1., 1., 1., 0., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 1., 0., 1., 0., 0., 0.,
0., 0., 0., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 1., 1., 0.,
0., 0., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 0., 0., 0., 1., 1., 1., 1.])
获得预测概率
1
2
predict_proba = model2.predict_proba(x_out_test)[:,1] #此处test_X为特征集
predict_proba
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
array([0.92768637, 0.87372613, 0.93611771, 0.92504257, 0.93830129,
0.97350163, 0.98439474, 0.98494626, 0.99302094, 0.71066415,
0.9346553 , 0.9571639 , 0.98347517, 0.98133935, 0.97773941,
0.8420986 , 0.89008306, 0.92629049, 0.9572436 , 0.5039008 ,
0.15433895, 0.19618445, 0.13506135, 0.17494294, 0.54099699,
0.03775966, 0.01478691, 0.01300564, 0.01183642, 0.03102973,
0.06048433, 0.10077356, 0.19561205, 0.16183885, 0.18157097,
0.40267378, 0.57780427, 0.76351945, 0.98894525, 0.96757754,
0.94039104, 0.69392342, 0.45948081, 0.53882556, 0.54942774,
0.55376286, 0.78704139, 0.6902328 , 0.67840564, 0.50551527,
0.60997526, 0.46440017, 0.26347114, 0.0656017 , 0.05779543,
0.21262375, 0.25199363, 0.50075914, 0.61394749, 0.27114971,
0.10048141, 0.13073217, 0.13687481, 0.8528971 , 0.90548906,
0.56790395, 0.50395654, 0.13714471, 0.08843812, 0.29217134,
0.64426371, 0.42854564, 0.32085174, 0.5137408 , 0.35244305,
0.30040986, 0.50240616, 0.457433 , 0.67157372, 0.69082967,
0.77687727, 0.79582527, 0.78517435, 0.74283735, 0.68615354,
0.80688554, 0.90175701, 0.64304015, 0.42421477, 0.75869859,
0.81017815, 0.8514314 , 0.38881122, 0.25902116, 0.16211049,
0.16614164, 0.12334865, 0.13226115, 0.47903855, 0.36390768,
0.5174554 , 0.73624912, 0.75098208, 0.40788233, 0.3967154 ,
0.2439561 , 0.54479096, 0.53777812, 0.40403812, 0.25577642,
0.24904761, 0.09929233, 0.13741176, 0.1678991 , 0.18963611,
0.1943378 , 0.14487027, 0.18659786, 0.42940769, 0.25579801,
0.52254904, 0.33436851, 0.23813746, 0.37284747, 0.74879912,
0.79164491, 0.66651398, 0.88469944, 0.74616232, 0.73853718,
0.63936874, 0.6081675 , 0.55622609, 0.75321664, 0.7815924 ,
0.86126268, 0.85880723, 0.6411158 , 0.65249363, 0.52455551,
0.54540161, 0.34515664, 0.16483909, 0.09813257, 0.07875876,
0.39561187, 0.29154285, 0.57326271, 0.59871615, 0.54026514,
0.49279335, 0.30241696, 0.42908667, 0.26428967, 0.48671753,
0.62122644, 0.56028551, 0.56033722, 0.62665293, 0.81940654,
0.7387033 , 0.66592612, 0.52672122, 0.37637344, 0.2850368 ,
0.53597262, 0.24010379, 0.18681959, 0.22768771, 0.26363987,
0.24933921, 0.19862666, 0.51994383, 0.45421865, 0.30938515,
0.74792445, 0.78869373, 0.73751946, 0.95793289, 0.87027078,
0.90282423, 0.77251597, 0.79809766, 0.37687865, 0.25473459,
0.30440401, 0.06507915, 0.09396445, 0.09478336, 0.10337532,
0.1668443 , 0.11178433, 0.39416537, 0.87374959, 0.65201754,
0.83738234, 0.88783166, 0.95112084, 0.82290048, 0.55409254,
0.83058832, 0.88418062, 0.97701667, 0.96498092, 0.96815346,
0.73549159, 0.57495442, 0.55501549, 0.45825476, 0.58237228,
0.60024485, 0.69880294, 0.77742039, 0.88159983, 0.34161263,
0.51951705, 0.5136713 , 0.14993804, 0.14164131, 0.13503646,
0.44318011, 0.29003659, 0.54459399, 0.93611822, 0.88347707,
0.79203113, 0.59555529, 0.54438551, 0.38859177, 0.25961676,
0.38008381, 0.52648356, 0.27392474, 0.52292447, 0.3474074 ,
0.28097051, 0.17816129, 0.15292792, 0.25848705, 0.21666034,
0.15186081, 0.20692974, 0.2616456 , 0.1740437 , 0.18470274,
0.6455302 , 0.74751271, 0.67978954, 0.81101168, 0.72088468,
0.86716137, 0.49636418, 0.73066438, 0.60749457, 0.48322452,
0.86254143, 0.90508905, 0.43830084, 0.4528298 , 0.48121765,
0.59174571, 0.54601153, 0.62680834, 0.19902157, 0.32335333,
0.80321108, 0.93053582, 0.90023112, 0.80395445, 0.93857671,
0.74428307, 0.83853564, 0.80653039, 0.74841075, 0.54240611,
0.76270491, 0.78043358, 0.79909342, 0.69770529, 0.72841322,
0.60916924, 0.46919568, 0.23561556, 0.27600535, 0.14867822,
0.5547519 , 0.34777691, 0.16786792, 0.03765598, 0.01467337,
0.02226528, 0.86197099, 0.36702956, 0.22010419, 0.29279984,
0.40326685, 0.24354445, 0.34090129, 0.27896042, 0.02739086,
0.11002659, 0.37231314, 0.60113585, 0.28467928, 0.30574615,
0.24217869, 0.27286212, 0.0713619 , 0.23335294, 0.20911186,
0.18794769, 0.18560383, 0.28359213, 0.18806428, 0.38684336,
0.23436443, 0.2736422 , 0.21226031, 0.06377722, 0.07374984,
0.01122629, 0.03179844, 0.00436962, 0.01504322, 0.02710046,
0.04742455, 0.21986613, 0.15414805, 0.28812503, 0.48076086,
0.61525733, 0.29366504, 0.69023641, 0.72010134, 0.73192722,
0.71350155, 0.43443765, 0.83757989, 0.91196111, 0.80597116,
0.71725584, 0.72846096, 0.68139639, 0.4055958 , 0.39676744,
0.42733194, 0.5501376 , 0.8451244 , 0.73108955, 0.71831443,
0.51176214, 0.47541834, 0.47238053, 0.86942811, 0.93746309,
0.90782211, 0.88088337, 0.69681599, 0.53907133, 0.57675902,
0.49121053, 0.73568145, 0.75968351, 0.49019147, 0.43524361,
0.48736784, 0.11410632, 0.15359888, 0.15212952, 0.18340806,
0.3275416 , 0.58926939, 0.53889026, 0.41103409, 0.20666117,
0.51170725, 0.29108907, 0.42503409, 0.58790145, 0.67938655,
0.69460852, 0.67718365, 0.81981997, 0.44427158, 0.17228037,
0.08049974, 0.05629146, 0.10224809, 0.16675326, 0.37406816,
0.34571201, 0.71925828, 0.5392739 , 0.46973142, 0.42379824,
0.4823324 , 0.57120789, 0.60050637, 0.4901272 , 0.68224158,
0.73347621, 0.27967437, 0.68220636, 0.5950859 , 0.52300998,
0.45049518, 0.86898773, 0.83804669, 0.91548209, 0.89367349,
0.85914564, 0.85690102, 0.90204988, 0.71132515, 0.79163732,
0.812289 , 0.20139774, 0.28918965, 0.49553488, 0.61543553,
0.33575905, 0.54168394, 0.15658974, 0.13426621, 0.22442951,
0.16946512, 0.1953657 , 0.42157698, 0.74126061, 0.87274383,
0.95759955, 0.83947617, 0.55318222, 0.43796698, 0.1473707 ,
0.14126038, 0.36556707, 0.24001217, 0.25518879, 0.58937928,
0.80839603, 0.36331598, 0.38230496, 0.28113962, 0.55396343,
0.72834572, 0.87771931, 0.64322465, 0.61406436, 0.3262896 ,
0.09367672, 0.11010181, 0.10681042, 0.13281207, 0.09630662,
0.41926085, 0.76901883, 0.46624134, 0.37215163, 0.23810044,
0.38668379, 0.42983361, 0.40827589, 0.21950464, 0.09326784,
0.10705637, 0.35907413, 0.22196696, 0.73719639, 0.79230004,
0.81567024, 0.85233467, 0.96740518, 0.73752961, 0.96878856,
0.96848615, 0.88184152, 0.86665458, 0.89069756, 0.57529245,
0.13776376, 0.3218207 , 0.49661417, 0.67958821, 0.79617313,
0.90001448, 0.93731303])
1
2
3
# 样本外准确率
accuracy_out_sample = accuracy_score(y_out_test, y_out_pred)
accuracy_out_sample
1
0.6201232032854209
1
2
3
# 样本外AUC值
roc_out_sample = roc_auc_score(y_out_test, y_out_pred)
roc_out_sample
1
0.6200961376286052
练习:集成上述训练模型功能,加入上面类对象里,或者创建新类,集成旧类的功能,在新添功能
1
1
答案:翻看类的对应功能函数,查看相应的参数说明
1
model_timing = report.timing_model('000001.XSHG','2010-01-01','2018-01-01','2018-01-01','2020-01-01','LR','LR_model')
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
开始获取合成特征和标签数据框...
------------------------------------------------------------
按照比例分割为训练集和测试集...
------------------------------------------------------------
开始训练数据...
训练结束
------------------------------------------------------------
预测准确率:
LR模型: 0.638
------------------------------------------------------------
输出混淆矩阵...
LR 0.0 1.0
Actual
0.0 113 86
1.0 55 135
------------------------------------------------------------
绘制曲线...


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
------------------------------------------------------------
输出评估报告...
模型的评估报告:
precision recall f1-score support
0.0 0.67 0.57 0.62 199
1.0 0.61 0.71 0.66 190
micro avg 0.64 0.64 0.64 389
macro avg 0.64 0.64 0.64 389
weighted avg 0.64 0.64 0.64 389
------------------------------------------------------------
保存模型到本地...
加载本地训练好的模型...
加载完毕
------------------------------------------------------------
样本外测试结果
样本外准确率 0.6673511293634496
样本外AUC值 0.666959015010963
多因子模型
数据获取和预处理
股票池设定
沪深300成分股,剔除ST股票,剔除每个截面期下一交易日停牌的股票,剔除上市6个月内的股票,每只股票视作一个样本。
时间区间
2014年1月1日-2019年12月31日的5年区间。其中前4年区间(48个月)作为训练集,后1年区间(12个月)作为测试集。
1
2
3
4
5
#样本内时间段
start_date = '2014-01-01'
end_date = '2014-03-28'
#指数
index_code = '000300.XSHG'
特征和标签提取
每个自然月的第一个交易日,计算因子暴露度,作为样本的原始特征;计算下期收益率,作为样本的标签
1
2
data_regular = report.data_for_model_multiperiod(start_date,end_date,index_code)
data_regular.tail(20)
1
data_regular.shape
1
(352, 12)
1
2
y_train = data_regular['Label'] # 分割
x_train = data_regular.iloc[:,:-1]
1
x_train.tail()
MA_5 | MA_60 | Aroondown_14 | Aroonup_14 | ROC_6 | ROC_12 | RSI_6 | RSI_24 | OBV | ATR_14 | MOM_10 | |
---|---|---|---|---|---|---|---|---|---|---|---|
2012-05-25 | -0.768642 | -0.730667 | 1.407126 | -1.385636 | -0.742667 | -0.845272 | -1.066789 | -0.593849 | -0.788446 | -0.653850 | -0.666870 |
2013-12-03 | -1.046679 | -1.144522 | -1.119941 | 0.685699 | 0.556855 | 0.965889 | 0.907214 | 0.662227 | -0.641008 | -0.618633 | 0.228764 |
2011-10-13 | -0.712226 | -0.354024 | 1.212737 | -0.820727 | 0.658966 | -0.151954 | 0.580085 | -0.727583 | -0.822100 | 0.093216 | -0.104109 |
2016-04-15 | 0.630381 | 0.313270 | -0.925552 | 0.874002 | 0.790899 | 1.307709 | 0.898470 | 0.523861 | 1.297079 | 0.323287 | 0.668458 |
2010-02-05 | 0.443279 | 0.911932 | 1.018347 | -1.197333 | -0.714451 | -1.759498 | -1.198944 | -1.187684 | -0.922039 | 0.934016 | -1.924716 |
1
y_train.tail()
1
2
3
4
5
6
2012-05-25 0.0
2013-12-03 0.0
2011-10-13 1.0
2016-04-15 1.0
2010-02-05 0.0
Name: Label, dtype: float64
模型构建和样本内训练
1
2
3
4
#特征
x_train = data_regular.iloc[:,:-1]
#标签
y_train = data_regular['Label']
通过Pipeline方法,将特征选择和模型构建结合起来,形成model_pipe对象,然后针对该对象做交叉验证并得到不同参数下的检验结果,辅助于最终模型的参数设置。
特征选择
用SelectPercentile(f_classif, percentile)来做特征选择,其中f_classif用来确定特征选择的得分标准,percentile用来确定特征选择的比例。
1
2
3
4
5
transform = SelectPercentile(f_classif) # 使用f_classif方法选择特征最明显的?%数量的特征
#定义训练器
model = XGBClassifier()
model_pipe = Pipeline(steps=[('ANOVA', transform), ('model', model)]) # 建立由特征选择和分类模型构成的“管道”对象
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# 选择最佳特征比例
# #############################################################################
# Plot the cross-validation score as a function of percentile of features
score_means = list()
score_stds = list()
percentiles = (10, 20, 30, 40, 50, 60, 70, 80, 90, 100)
for percentile in percentiles:
model_pipe.set_params(ANOVA__percentile=percentile)
this_scores = cross_val_score(model_pipe, x_train, y_train, cv=5, n_jobs=-1)
score_means.append(this_scores.mean())
score_stds.append(this_scores.std())
plt.errorbar(percentiles, score_means, np.array(score_stds))
plt.title('Performance of the model-Anova varying the percentile of features selected')
plt.xlabel('Percentile')
plt.ylabel('Prediction rate')
plt.axis('tight')
plt.show()

交叉验证调参
特征(比例)选择完成后,根据不同的参数(n_estimators,max_depth),对模型进行交叉验证。采用StratifiedKFold来将训练集分成训练集和验证集。StratifiedKFold能够有效结合分类样本标签做数据集分割,而不是完全的随机选择和分割。完成交叉验证后,选取交叉验证集AUC(或f1-score)最高的一组参数作为模型的最优参数。
1
2
3
4
transform = SelectPercentile(f_classif,percentile=100) # 使用f_classif方法选择特征最明显的?%数量的特征
model = XGBClassifier()
model_pipe = Pipeline(steps=[('ANOVA', transform), ('model', model)]) # 建立由特征选择和分类模型构成的“管道”对象
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
cv = StratifiedKFold(5) # 设置交叉检验次数
#XGB max_depth
parameters = [3,4,5,6,7,8]
#XGB subsample
#parameters = [0.6,0.7,0.8,0.9,1]
#score_methods = ['roc_auc','accuracy', 'precision', 'recall', 'f1'] # 设置交叉检验指标
score_methods = ['roc_auc', 'f1'] # 设置交叉检验指标
#mean_list = list() # 建立空列表用于存放不同参数方法、交叉检验评估指标的均值列表
#std_list = list() # 建立空列表用于存放不同参数方法、交叉检验评估指标的标准差列表
for parameter in parameters: # 循环读出每个参数值
t1 = time.time() # 记录训练开始的时间
score_list = list() # 建立空列表用于存放不同交叉检验下各个评估指标的详细数据
print ('set parameters: %s' % parameter) # 打印当前模型使用的参数
for score_method in score_methods: # 循环读出每个交叉检验指标
#model_pipe.set_params(model__n_estimators=parameter) # 通过“管道”设置分类模型参数
model_pipe.set_params(model__max_depth=parameter) # 通过“管道”设置分类模型参数
#model_pipe.set_params(model__subsample=parameter) # 通过“管道”设置分类模型参数
score_tmp = cross_val_score(model_pipe, x_train, y_train, scoring=score_method, cv=cv, n_jobs=-1) # 使用交叉检验计算指定指标的得分
score_list.append(score_tmp) # 将交叉检验得分存储到列表
score_matrix = pd.DataFrame(np.array(score_list), index=score_methods) # 将交叉检验详细数据转换为矩阵
score_mean = score_matrix.mean(axis=1).rename('mean') # 计算每个评估指标的均值
score_std = score_matrix.std(axis=1).rename('std') # 计算每个评估指标的标准差
score_pd = pd.concat([score_matrix, score_mean, score_std], axis=1) # 将原始详细数据和均值、标准差合并
#mean_list.append(score_mean) # 将每个参数得到的各指标均值追加到列表
#std_list.append(score_std) # 将每个参数得到的各指标标准差追加到列表
print (score_pd.round(4)) # 打印每个参数得到的交叉检验指标数据,只保留4位小数
print ('-' * 60)
t2 = time.time() # 计算每个参数下算法用时
tt = t2 - t1 # 计算时间间隔
print ('算法用时time: %s' % str(tt)) # 打印时间间隔
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
set parameters: 3
0 1 2 3 4 mean std
roc_auc 0.6929 0.6707 0.6939 0.7505 0.7167 0.7050 0.0302
f1 0.6588 0.6429 0.6667 0.7052 0.7086 0.6764 0.0291
------------------------------------------------------------
算法用时time: 0.6454043388366699
set parameters: 4
0 1 2 3 4 mean std
roc_auc 0.6820 0.6653 0.7065 0.7323 0.6920 0.6956 0.0254
f1 0.6647 0.6551 0.6964 0.6866 0.6964 0.6798 0.0189
------------------------------------------------------------
算法用时time: 0.7695553302764893
set parameters: 5
0 1 2 3 4 mean std
roc_auc 0.6813 0.6473 0.6968 0.7254 0.6814 0.6865 0.0283
f1 0.6667 0.6331 0.6725 0.7038 0.6891 0.6730 0.0266
------------------------------------------------------------
算法用时time: 0.9359838962554932
set parameters: 6
0 1 2 3 4 mean std
roc_auc 0.6896 0.6316 0.6802 0.7292 0.6718 0.6805 0.0351
f1 0.6786 0.6243 0.6531 0.6967 0.6629 0.6631 0.0273
------------------------------------------------------------
算法用时time: 1.0583992004394531
set parameters: 7
0 1 2 3 4 mean std
roc_auc 0.6583 0.6381 0.6676 0.7175 0.6581 0.6679 0.0297
f1 0.6509 0.6126 0.6395 0.7076 0.6552 0.6532 0.0347
------------------------------------------------------------
算法用时time: 1.185107946395874
set parameters: 8
0 1 2 3 4 mean std
roc_auc 0.6712 0.6266 0.6603 0.7188 0.6675 0.6689 0.0330
f1 0.6527 0.6163 0.6388 0.6962 0.6648 0.6538 0.0298
------------------------------------------------------------
算法用时time: 1.3339207172393799
模型构建
根据上述交叉验证的最优模型,使用XGBoosting集成学习模型对训练集进行训练。
1
2
transform.fit(x_train, y_train) # 应用特征选择对象选择要参与建模的特征变量
X_train_final = transform.transform(x_train) # 获得具有显著性特征的特征变量
XGBoost
(1)subsample subsample是训练集参与模型训练的比例,取值在0-1之间,可有效地防止过拟合。subsample参数的性能评价参考上面执行结果所示。随着subsample的上升,f1-score呈下降趋势,模型训练速度加快,综合训练时间和效果提升考量,选取subsample=0.9。
(2)max_depth max_depth参数的性能评价参考表所示。随着max_depth的上升,AUC和f1-score呈下降趋势,模型训练时间变慢。选取max_depth=3。
1
2
model = XGBClassifier(max_depth=3,subsample=0.9,random_state=0)
model.fit(X_train_final, y_train) # 训练模型
1
2
3
4
5
6
7
XGBClassifier(base_score=0.5, booster='gbtree', colsample_bylevel=1,
colsample_bynode=1, colsample_bytree=1, gamma=0, learning_rate=0.1,
max_delta_step=0, max_depth=3, min_child_weight=1, missing=None,
n_estimators=100, n_jobs=1, nthread=None,
objective='binary:logistic', random_state=0, reg_alpha=0,
reg_lambda=1, scale_pos_weight=1, seed=None, silent=None,
subsample=0.9, verbosity=1)
样本外测试
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
#样本外测试
out_start = '2018-01-01'
out_end = '2019-12-31'
test_sample_predict={}
test_sample_score=[]
test_sample_accuracy=[]
test_sample_roc_auc=[]
test_sample_date=[]
interval_start,interval_end = report.get_time_inverval(out_start,out_end,'M')
for date1,date2 in dict(zip(interval_start,interval_end)).items():
data_merge_label = report.data_for_model_perperiod(date1,date2,index_code)
y_test=data_merge_label['Label']
X_test=data_merge_label.iloc[:,:-1]
# 新数据集做预测
# 输出预测值以及预测概率
y_pred_tmp = model.predict(X_test)
y_pred = pd.DataFrame(y_pred_tmp, columns=['label_predict']) # 获得预测标签
y_pred_proba = pd.DataFrame(model.predict_proba(X_test), columns=['pro1', 'pro2']) # 获得预测概率
# 将预测标签、预测数据和原始数据X合并
y_pred.set_index(X_test.index,inplace=True)
y_pred_proba.set_index(X_test.index,inplace=True)
predict_pd = pd.concat((X_test, y_pred, y_pred_proba), axis=1)
print ('Predict date:')
print (date1)
print ('AUC:')
print (roc_auc_score(y_test,y_pred)) # 打印前2条结果
print ('Accuracy:')
print (accuracy_score(y_test, y_pred)) # 打印前2条结果
print ('-' * 60)
## 后续统计画图用
test_sample_date.append(date1)
# 样本外预测结果
test_sample_predict[date1]=y_pred_tmp
# 样本外准确率
test_sample_accuracy.append(accuracy_score(y_test, y_pred))
# 样本外AUC值
test_sample_roc_auc.append(roc_auc_score(y_test,y_pred))
print ('AUC mean info')
print (np.mean(test_sample_roc_auc))
print ('-' * 60)
print ('ACCURACY mean info')
print (np.mean(test_sample_accuracy))
print ('-' * 60)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
Predict date:
2018-01-01 00:00:00
AUC:
0.6055555555555556
Accuracy:
0.6055555555555555
------------------------------------------------------------
Predict date:
2018-02-01 00:00:00
AUC:
0.4444444444444445
Accuracy:
0.4444444444444444
------------------------------------------------------------
Predict date:
2018-03-01 00:00:00
AUC:
0.5388888888888889
Accuracy:
0.5388888888888889
------------------------------------------------------------
Predict date:
2018-04-01 00:00:00
AUC:
0.55
Accuracy:
0.55
------------------------------------------------------------
Predict date:
2018-05-01 00:00:00
AUC:
0.5277777777777778
Accuracy:
0.5277777777777778
------------------------------------------------------------
Predict date:
2018-06-01 00:00:00
AUC:
0.4444444444444444
Accuracy:
0.4444444444444444
------------------------------------------------------------
Predict date:
2018-07-01 00:00:00
AUC:
0.6166666666666667
Accuracy:
0.6166666666666667
------------------------------------------------------------
Predict date:
2018-08-01 00:00:00
AUC:
0.3611111111111111
Accuracy:
0.3611111111111111
------------------------------------------------------------
Predict date:
2018-09-01 00:00:00
AUC:
0.5166666666666667
Accuracy:
0.5166666666666667
------------------------------------------------------------
Predict date:
2018-10-01 00:00:00
AUC:
0.4555555555555556
Accuracy:
0.45555555555555555
------------------------------------------------------------
Predict date:
2018-11-01 00:00:00
AUC:
0.75
Accuracy:
0.75
------------------------------------------------------------
Predict date:
2018-12-01 00:00:00
AUC:
0.4666666666666666
Accuracy:
0.4666666666666667
------------------------------------------------------------
Predict date:
2019-01-01 00:00:00
AUC:
0.4722222222222222
Accuracy:
0.4722222222222222
------------------------------------------------------------
Predict date:
2019-02-01 00:00:00
AUC:
0.6777777777777778
Accuracy:
0.6777777777777778
------------------------------------------------------------
Predict date:
2019-03-01 00:00:00
AUC:
0.5722222222222222
Accuracy:
0.5722222222222222
------------------------------------------------------------
Predict date:
2019-04-01 00:00:00
AUC:
0.40555555555555556
Accuracy:
0.40555555555555556
------------------------------------------------------------
Predict date:
2019-05-01 00:00:00
AUC:
0.5111111111111111
Accuracy:
0.5111111111111111
------------------------------------------------------------
Predict date:
2019-06-01 00:00:00
AUC:
0.46111111111111114
Accuracy:
0.46111111111111114
------------------------------------------------------------
Predict date:
2019-07-01 00:00:00
AUC:
0.5611111111111111
Accuracy:
0.5611111111111111
------------------------------------------------------------
Predict date:
2019-08-01 00:00:00
AUC:
0.4833333333333333
Accuracy:
0.48333333333333334
------------------------------------------------------------
Predict date:
2019-09-01 00:00:00
AUC:
0.4444444444444444
Accuracy:
0.4444444444444444
------------------------------------------------------------
Predict date:
2019-10-01 00:00:00
AUC:
0.5277777777777777
Accuracy:
0.5277777777777778
------------------------------------------------------------
Predict date:
2019-11-01 00:00:00
AUC:
0.5722222222222223
Accuracy:
0.5722222222222222
------------------------------------------------------------
Predict date:
2019-12-01 00:00:00
AUC:
0.5277777777777778
Accuracy:
0.5277777777777778
------------------------------------------------------------
AUC mean info
0.5206018518518518
------------------------------------------------------------
ACCURACY mean info
0.5206018518518518
------------------------------------------------------------
样本外每期AUC
1
test_sample_roc_auc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
[0.6055555555555556,
0.4444444444444445,
0.5388888888888889,
0.55,
0.5277777777777778,
0.4444444444444444,
0.6166666666666667,
0.3611111111111111,
0.5166666666666667,
0.4555555555555556,
0.75,
0.4666666666666666,
0.4722222222222222,
0.6777777777777778,
0.5722222222222222,
0.40555555555555556,
0.5111111111111111,
0.46111111111111114,
0.5611111111111111,
0.4833333333333333,
0.4444444444444444,
0.5277777777777777,
0.5722222222222223,
0.5277777777777778]
预测能力
1
2
3
4
5
6
7
8
9
10
11
12
xs_date = test_sample_date
ys_auc = test_sample_roc_auc
# 配置横坐标
plt.gca().xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
plt.plot(xs_date, ys_auc,'r')
# 自动旋转日期标记
plt.gcf().autofmt_xdate()
# 横坐标标记
plt.xlabel('date')
# 纵坐标标记
plt.ylabel("test AUC")
plt.show()

1
2
3
4
5
6
7
8
9
10
11
12
13
xs_date = test_sample_date
ys_score = test_sample_accuracy
# 配置横坐标
plt.gca().xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
plt.plot(xs_date, ys_score,'r')
# 自动旋转日期标记
plt.gcf().autofmt_xdate()
# 横坐标标记
plt.xlabel('date')
# 纵坐标标记
plt.ylabel("test accuracy")
plt.show()

1
2
3
4
5
6
7
8
f = plt.figure(figsize= (15,6))
sns.set(style="whitegrid")
data1 = pd.DataFrame(ys_auc, xs_date, columns={'AUC'})
data2 = pd.DataFrame(ys_score, xs_date, columns={'accuracy'})
data = pd.concat([data1,data2],sort=False)
sns.lineplot(data=data, palette="tab10", linewidth=2.5)
1
<matplotlib.axes._subplots.AxesSubplot at 0x264b0823c88>

下面特征重要度是这类算法的独有分析部分,不兼容其类机器学习算法
特征重要度
1
2
model = XGBClassifier(max_depth=3,subsample=0.9,random_state=0)
model.fit(x_train, y_train) # 训练模型
1
2
3
4
5
6
7
XGBClassifier(base_score=0.5, booster='gbtree', colsample_bylevel=1,
colsample_bynode=1, colsample_bytree=1, gamma=0, learning_rate=0.1,
max_delta_step=0, max_depth=3, min_child_weight=1, missing=None,
n_estimators=100, n_jobs=1, nthread=None,
objective='binary:logistic', random_state=0, reg_alpha=0,
reg_lambda=1, scale_pos_weight=1, seed=None, silent=None,
subsample=0.9, verbosity=1)
1
2
3
4
5
6
7
8
#%matplotlib inline
fig = plt.figure(figsize= (15,6))
n_features = x_train.shape[1]
plt.barh(range(n_features),model.feature_importances_,align='center')
plt.yticks(np.arange(n_features),x_train.columns)
plt.xlabel("Feature importance")
plt.ylabel("Feature")
1
Text(0, 0.5, 'Feature')
集成,添加,调用,前面的单元格代码都可删掉
1
2
3
4
5
6
7
8
9
10
11
12
#样本内时间段
start_date = '2014-01-01'
end_date = '2018-12-31'
#指数
index_code = '000300.XSHG'
#样本外时段
out_start = '2018-10-01'
out_end = '2019-12-31'
model_name = 'xgboost'
file_name = 'xgboost_model'
调用类中的多因子训练和测试模型
前面的单元格代码都可删掉,写入样本内和样本外的开始时间和结束时间,模型名,保存文件的名字(自己创建)
1
report.multifactor_model(index_code,start_date,end_date,out_start,out_end,model_name,file_name)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
开始训练数据...
训练结束
------------------------------------------------------------
预测准确率:
xgboost模型: 0.667
------------------------------------------------------------
输出混淆矩阵...
col_0 0.0 1.0
Label
0.0 25 14
1.0 10 23
------------------------------------------------------------
绘制曲线...


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
------------------------------------------------------------
输出评估报告...
模型的评估报告:
precision recall f1-score support
0.0 0.71 0.64 0.68 39
1.0 0.62 0.70 0.66 33
micro avg 0.67 0.67 0.67 72
macro avg 0.67 0.67 0.67 72
weighted avg 0.67 0.67 0.67 72
------------------------------------------------------------
保存模型到本地...
加载本地训练好的模型...
加载完毕
------------------------------------------------------------
样本外测试结果...
Predict date:
2018-10-01 00:00:00
AUC:
0.4833333333333334
Accuracy:
0.48333333333333334
------------------------------------------------------------
Predict date:
2018-11-01 00:00:00
AUC:
0.6722222222222222
Accuracy:
0.6722222222222223
------------------------------------------------------------
Predict date:
2018-12-01 00:00:00
AUC:
0.49444444444444446
Accuracy:
0.49444444444444446
------------------------------------------------------------
Predict date:
2019-01-01 00:00:00
AUC:
0.5611111111111111
Accuracy:
0.5611111111111111
------------------------------------------------------------
Predict date:
2019-02-01 00:00:00
AUC:
0.6166666666666666
Accuracy:
0.6166666666666667
------------------------------------------------------------
Predict date:
2019-03-01 00:00:00
AUC:
0.5277777777777778
Accuracy:
0.5277777777777778
------------------------------------------------------------
Predict date:
2019-04-01 00:00:00
AUC:
0.46111111111111114
Accuracy:
0.46111111111111114
------------------------------------------------------------
Predict date:
2019-05-01 00:00:00
AUC:
0.5166666666666667
Accuracy:
0.5166666666666667
------------------------------------------------------------
Predict date:
2019-06-01 00:00:00
AUC:
0.4555555555555556
Accuracy:
0.45555555555555555
------------------------------------------------------------
Predict date:
2019-07-01 00:00:00
AUC:
0.5611111111111111
Accuracy:
0.5611111111111111
------------------------------------------------------------
Predict date:
2019-08-01 00:00:00
AUC:
0.5111111111111111
Accuracy:
0.5111111111111111
------------------------------------------------------------
Predict date:
2019-09-01 00:00:00
AUC:
0.4555555555555556
Accuracy:
0.45555555555555555
------------------------------------------------------------
Predict date:
2019-10-01 00:00:00
AUC:
0.4222222222222222
Accuracy:
0.4222222222222222
------------------------------------------------------------
Predict date:
2019-11-01 00:00:00
AUC:
0.5444444444444444
Accuracy:
0.5444444444444444
------------------------------------------------------------
Predict date:
2019-12-01 00:00:00
AUC:
0.5666666666666667
Accuracy:
0.5666666666666667
------------------------------------------------------------
AUC mean info
0.5233333333333333
------------------------------------------------------------
ACCURACY mean info
0.5233333333333333
------------------------------------------------------------


1
2
3
4
5
6
7
XGBClassifier(base_score=0.5, booster='gbtree', colsample_bylevel=1,
colsample_bynode=1, colsample_bytree=1, gamma=0, learning_rate=0.1,
max_delta_step=0, max_depth=3, min_child_weight=1, missing=nan,
n_estimators=100, n_jobs=1, nthread=None,
objective='binary:logistic', random_state=0, reg_alpha=0,
reg_lambda=1, scale_pos_weight=1, seed=None, silent=None,
subsample=0.9, verbosity=1)
1
report.multifactor_model(index_code,start_date,end_date,out_start,out_end,model_name,file_name)
1
本文采用知识共享署名-非商业性使用-禁止演绎 4.0 国际许可协议(CC BY-NC-ND 4.0)进行许可,转载请注明出处,请勿用于任何商业用途采用。