量价分析面向对象框架

Posted by YU on June 6, 2020
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
#导入常规数据处理库
#from jqdata import *
import talib
import numpy as  np
import pandas as pd
import matplotlib.dates as mdates
import matplotlib.pyplot as plt
import seaborn as sns
import copy
plt.rcParams['axes.unicode_minus']=False
plt.rcParams['font.sans-serif']=['SimHei'] #指定默认字体 SimHei为黑体
import warnings
warnings.filterwarnings("ignore")
import datetime
import time

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
#特定基准的特定区间的收益率
def cal_benchmark_return(benchmark,start_date,end_date):
    First_date = get_trade_days(start_date = start_date, end_date=end_date, count=None)[0]
    tradingday = list(get_all_trade_days())
    shiftday_index = tradingday.index(First_date)-1
    pre_date = tradingday[shiftday_index]
    data = get_price(benchmark, start_date=pre_date, end_date=end_date,\
                     fields=['open','close'],frequency='daily').close
    return data/data.shift()-1

class Report(object):
    
    def __init__(self,benchmark):
        
        #初始化
        self.benchmark = benchmark
        #全A股的历史量价数据,收益,开盘价,收盘价,最高价,最低价,换手率,成交股票数,成交金额
        self.data_return = pd.read_csv('A_share_return.csv',index_col=0)
        self.data_close = pd.read_csv('A_share_close.csv',index_col=0)
        self.data_open = pd.read_csv('A_share_open.csv',index_col=0)
        self.data_high = pd.read_csv('A_share_high.csv',index_col=0)
        self.data_low = pd.read_csv('A_share_low.csv',index_col=0)
        self.data_money = pd.read_csv('A_share_money.csv',index_col=0)
        self.data_turnover = pd.read_csv('A_share_turnover_ratio.csv',index_col=0)
        self.data_volume = pd.read_csv('A_share_volume.csv',index_col=0)
        #定义基准在输入区间的收益率
        self.benchmark_return = cal_benchmark_return(benchmark,self.data_return.index[1],self.data_return.index[-1])
        
        
    #下面是为了实现特定功能的功能函数

    #获取特定周期的时间点列表
    def get_time_inverval(self,start_date,end_date,freq):
        if freq == 'M':
            interval_start = pd.to_datetime(pd.date_range(start_date,end_date,freq = 'MS'))
            interval_end = pd.to_datetime(pd.date_range(start_date,end_date,freq = 'BM'))
        elif freq == 'Y':
            interval_start = pd.to_datetime(pd.date_range(start_date,end_date,freq = 'AS'))
            interval_end = pd.to_datetime(pd.date_range(start_date,end_date,freq = 'A'))

        if len(interval_start) > len(interval_end) > 0:
            interval_start = interval_start[:-1]
        elif len(interval_end)  == 0:
            print('请输入完整的周期区间')
        return interval_start,interval_end
    
    #转datestampe格式为datetime格式
    def datetime_date(self,datestamp):
        return datetime.date(datestamp.year, datestamp.month, datestamp.day)
    
    #股票所属行业
    def get_stock_belong_industry(self,stocklist,date):
        industry = {}
        industry_all =  get_industry(security = stocklist, date=date)
        for i in stocklist:
            try:
                industry[i] = industry_all[i]['sw_l1']['industry_name']
            except:
                continue
        return pd.DataFrame([industry],index = ['industry']).T
    
    #过滤新股
    def filter_new_stock(self,stock_list,date):
        tmpList = []
        for stock in stock_list :
            days_public=(datetime_date(date) - get_security_info(stock).start_date).days
            if days_public >= 180:
                tmpList.append(stock)
        return tmpList


    #剔除ST股
    def delete_st(self,stocks,begin_date):
        st_data=get_extras('is_st',stocks, count = 1,end_date=begin_date)
        stockList = [stock for stock in stocks if not st_data[stock][0]]
        return stockList
    
    ###提取因子部分
    def get_MA(self,df_close,n):
        '''
        移动平均值是在一定范围内的价格平均值
        df_close:量价数据,dataframe或者series
        n: 回溯的天数,整数型
        '''
        ma = df_close.rolling(n).mean()
        ma = pd.DataFrame({'MA_' + str(n): ma}, index = ma.index)
        return ma

    ## 计算变化率ROC
    def get_ROC(self,df_close, n):
        '''
        ROC=(今天的收盘价-N日前的收盘价)/N日前的收盘价*100
        移动平均值是在一定范围内的价格平均值
        df_close:量价收盘数据,dataframe或者series
        n: 回溯的天数,整数型
        '''
        M = df_close
        N = df_close.shift(n)
        roc = pd.DataFrame({'ROC_' + str(n): (M-N) / N*100}, index = M.index)
        return roc

    ## 计算RSI
    def get_RSI(self,df_close,n):
        '''
        df_close:量价收盘数据,dataframe或者series
        n: 回溯的天数,整数型
        '''
        rsi = talib.RSI(df_close, timeperiod=n)
        return pd.DataFrame({'RSI_' + str(n): rsi}, index = df_close.index)

    ##计算OBV指标
    def get_OBV(self,df_close,df_volume):
        '''
        On Balance Volume 能量,通过统计成交量变动的趋势推测股价趋势
        df_close:量价收盘数据,dataframe或者series
        df_volume:J量价交易数数据,dataframe或者series
        '''
        obv = talib.OBV(df_close,df_volume)
        return pd.DataFrame({'OBV': obv}, index = df_close.index)

    #真实波幅
    def get_ATR(self,df_high,df_low,df_close,n):
        '''
        平均真实波幅,主要用来衡量价格的波动
        df_close:量价收盘数据,dataframe或者series
        df_high:量价最高价数据,dataframe或者series
        df_low:量价最低价数据,dataframe或者series
        n: 回溯的天数,整数型
        '''
        atr = talib.ATR(df_high,df_low,df_close, timeperiod=n)
        return pd.DataFrame({'ATR_' + str(n): atr}, index = df_close.index)
    #上升动向值
    def get_MOM(self,df_close,n):
        '''
        上升动向值,投资学中意思为续航,指股票(或经济指数)持续增长的能力。研究发现,赢家组合在牛市中存在着正的动量效应,输家组合在熊市中存在着负的动量效应。
        df_close:量价收盘数据,dataframe或者series
        n: 回溯的天数,整数型
        '''
        mom = talib.MOM(df_close, timeperiod=n)
        return pd.DataFrame({'MOM_' + str(n): mom}, index = df_close.index)
    
    #阿隆指标
    def get_AROON(self,df_high,df_low,n):
        aroondown, aroonup = talib.AROON(df_high,df_low, timeperiod=n)
        return pd.DataFrame({'Aroondown_' + str(n): aroondown}, index = df_high.index),pd.DataFrame({'Aroonup_' + str(n): aroonup}, index = df_high.index)
    
    
    ###因子处理合成部分
    
    #合并新特征的函数,是添加新指标或者新特征的重要函数
    def merge_raw_factors(self,df_close,df_open,df_high,df_low,df_volume,df_money):
        return pd.concat([self.get_MA(df_close,5),self.get_MA(df_close,60),self.get_AROON(df_high,df_low,14)[0],\
                          self.get_AROON(df_high,df_low,14)[1],self.get_ROC(df_close, 6),self.get_ROC(df_close, 12),self.get_RSI(df_close,6),\
        self.get_RSI(df_close,24),self.get_OBV(df_close,df_volume),self.get_ATR(df_high,df_low,df_close,14),self.get_MOM(df_close,10)],axis = 1)

    #数据去极值及标准化
    def winsorize_and_standarlize(self,data,qrange=[0.05,0.95],axis=0):
        '''
        input:
        data:Dataframe or series,输入数据
        qrange:list,list[0]下分位数,list[1],上分位数,极值用分位数代替
        '''
        if isinstance(data,pd.DataFrame):
            if axis == 0:
                q_down = data.quantile(qrange[0])
                q_up = data.quantile(qrange[1])
                index = data.index
                col = data.columns
                for n in col:
                    data[n][data[n] > q_up[n]] = q_up[n]
                    data[n][data[n] < q_down[n]] = q_down[n]
                data = (data - data.mean())/data.std()
                data = data.fillna(0)
            else:
                data = data.stack()
                data = data.unstack(0)
                q = data.quantile(qrange)
                index = data.index
                col = data.columns
                for n in col:
                    data[n][data[n] > q[n]] = q[n]
                data = (data - data.mean())/data.std()
                data = data.stack().unstack(0)
                data = data.fillna(0)

        elif isinstance(data,pd.Series):
            name = data.name
            q = data.quantile(qrange)
            data[data>q] = q
            data = (data - data.mean())/data.std()
        return data
    
    
    
    #可取一个标的的一日或者多日原始因子数据
    def merge_raw_factor_perdate_percode(self,stock_code,start_date,end_date):
        #获取指定区间的数据来计算因子
        date_interval = get_trade_days(start_date =start_date, end_date=end_date, count=None)
        date_interval = [date.strftime('%Y-%m-%d') for date in date_interval]
        
        if set(date_interval) <=  set(self.data_return.index) and len(self.data_return.loc[date_interval]) != 0:
            df_return = self.data_return[stock_code]
            df_close = self.data_close[stock_code]
            df_open = self.data_open[stock_code]
            df_high = self.data_high[stock_code]
            df_low = self.data_low[stock_code]
            df_money = self.data_money[stock_code]
            if stock_code != '000001.XSHG':
                df_turnover = self.data_turnover[stock_code]
            df_volume = self.data_volume[stock_code]
            merge_raw_data = self.merge_raw_factors(df_close,df_open,df_high,df_low,df_volume,df_money)
                
        else:
            print('你输入的时间区间超过了本地数据的时间区间或不含有交易日,请输入2005年1月5日到2020年2月28日内的起始时间')
        return merge_raw_data
    
    #可取一个标的的一日或者多日规整化因子数据
    def merge_regularfactor_multidate_percode(self,stock_code,start_date,end_date):
        #获取指定区间的数据来计算因子
        date_interval = get_trade_days(start_date =start_date, end_date=end_date, count=None)
        date_interval = [date.strftime('%Y-%m-%d') for date in date_interval]
        merge_raw_data = self.merge_raw_factor_perdate_percode(stock_code,start_date,end_date)
        data_pro = self.winsorize_and_standarlize(merge_raw_data.loc[date_interval])
        return data_pro
       

        
    #可取多个标的的一天的原始因子数据
    def merge_rawfactor_multicode_perday(self,stock_list,start_date,end_date):
        #获取指定区间的数据来计算因子
        date_interval = get_trade_days(start_date =start_date, end_date=end_date, count=None)
        date_interval = [date.strftime('%Y-%m-%d') for date in date_interval]
        factor_list = []
        for stock_code in stock_list:
            factor_list.append(self.merge_raw_factor_perdate_percode(stock_code,start_date,end_date).loc[date_interval])
        return pd.concat(factor_list)
    
    #可取多个标的的一天的规整化因子数据
    def merge_regularfactor_multicode_perday(self,stock_list,start_date,end_date):
        
        #获取指定区间的数据来计算因子
        date_interval = get_trade_days(start_date =start_date, end_date=end_date, count=None)
        date_interval = [date.strftime('%Y-%m-%d') for date in date_interval]
        data_raw = self.merge_rawfactor_multicode_perday(stock_list,start_date,end_date)
        data_regular = self.winsorize_and_standarlize(data_raw)
        data_regular.insert(0, 'code',stock_list)
        
        return data_regular
    
    
    #贴标签方法1,贴标签方法可以自定义
    def add_label_1(self,stock_code,start_date,end_date):
        
        profit = self.data_return[stock_code]
        profit[profit > 0] = 1
        profit[profit< 0] = 0
        profit = pd.DataFrame(profit)
        profit.columns = ['Label']
        
        return profit
    
    
    #一个标的的多天的规整化因子数据+标签(方法1)
    def merge_final_factor_label1(self,stock_code,start_date,end_date):
        
        #获取指定区间的数据来计算因子
        date_interval = get_trade_days(start_date =start_date, end_date=end_date, count=None)
        date_interval = [date.strftime('%Y-%m-%d') for date in date_interval]
        
        if set(date_interval) <=  set(self.data_return.index) and len(self.data_return.loc[date_interval]) != 0:
            data_pro = self.merge_regularfactor_multidate_percode(stock_code,start_date,end_date)
            profit = self.add_label_1(stock_code,start_date,end_date)
            data_final = pd.concat([profit.shift(),data_pro],axis = 1).loc[date_interval]
        else:
            print('你输入的时间区间超过了本地数据的时间区间或不含有交易日,请输入2005年1月5日到2020年2月28日内的起始时间')
        return data_final
    
    
    
    # 指数择时模型,参数:指数代码,开始时间,结束时间,机器学习模型(需要添加),训练好的模型的本地保存文件名
    def timing_model(self,code,start_date,end_date,out_start,out_end,model_name,file_name):
        print('开始获取合成特征和标签数据框...')
        #获取数据
        data_index = self.merge_final_factor_label1(code,start_date,end_date)
        # 特征数据和标签数据
        x_data,y_data = data_index.iloc[:,1:], data_index.iloc[:,0]
        print ('-' * 60)
        print('按照比例分割为训练集和测试集...')
        '''
        原始数据按照比例分割为“训练集”和“测试集”
        

        x_data:所要划分的样本特征集

        y_data:所要划分的样本标签

        test_size:样本占比,如果是整数的话就是样本的数量

        random_state:是随机数的种子。随机数种子:其实就是该组随机数的编号,在需要重复试验的时候,保证得到一组一样的随机数。比如你每次都填1,其他参数一样的情况下你得到的随机数组是一样的。但填0或不填,每次都会不一样。
        '''
        x_train,x_test,y_train,y_test = train_test_split(x_data,y_data,test_size=0.2)
        print ('-' * 60)

        # 创建预测数据框
        pred = pd.DataFrame(index=y_test.index)
        pred["Actual"] = y_test

        #可往里面添加算法
        if model_name == 'LR':
            #构建分类器,里面的默认参数可修改
            model = LogisticRegression(solver='liblinear')
        elif model_name == "LDA":
            model=QDA()
        else:
            print('不支持这个算法,请新添加这个算法')

        print('开始训练数据...')
        #训练数据
        model.fit(x_train, y_train)
        print('\n')
        print('训练结束')
        print ('-' * 60)
        print ("预测准确率:")

        pred[model_name] = model.predict(x_test)

        # 预测准确率
        score=accuracy_score(pred['Actual'], pred[model_name])
        print("%s模型: %.3f" % (model_name, score)) 
        print ('-' * 60)

        # 构建混淆矩阵
        cm = pd.crosstab(pred['Actual'], pred[model_name])
        print('输出混淆矩阵...')

        print(cm)
        # 绘制混淆矩阵图
        sns.heatmap(cm, annot = True, cmap = 'GnBu', fmt = 'd')
        print ('-' * 60)


        print('绘制曲线...')
        # 计算正例的预测概率,而非实际的预测值,用于生成ROC曲线的数据
        y_score = model.predict_proba(x_test)[:,1]
        #fpr表示1-Specificity,tpr表示Sensitivity
        fpr,tpr,threshold = roc_curve(y_test, y_score)
        # 计算AUC的值
        roc_auc = metrics.auc(fpr,tpr)
        # 绘制面积图
        plt.figure(figsize=(8,6))
        plt.stackplot(fpr, tpr, color='steelblue', alpha = 0.5, edgecolor = 'black')
        plt.plot(fpr, tpr, color='black', lw = 1)
        # 添加对角线
        plt.plot([0,1],[0,1], color = 'red', linestyle = '--')
        # 添加文本信息
        plt.text(0.5,0.3,'ROC曲线 (area = %0.2f)' % roc_auc)
        # 添加x轴与y轴标签
        plt.title('模型预测指数涨跌的ROC曲线',size=15)
        plt.xlabel('1-Specificity')
        plt.ylabel('Sensitivity')
        plt.show()
        print ('-' * 60)


        print('输出评估报告...')
        print('\n')

        print('模型的评估报告:\n',classification_report(pred['Actual'], pred[model_name]))
        print ('-' * 60)


        print('保存模型到本地...')
        joblib.dump(model,file_name+'.pkl')
        print('\n')
        print('加载本地训练好的模型...')
        model2 = joblib.load(file_name+'.pkl')
        print('加载完毕')
        print ('-' * 60)

        
        # 样本外预测结果
        print('样本外测试结果')
        #获取样本外数据, out_start,out_end 样本外开始时间,结束时间
        data_out_sample = self.merge_final_factor_label1(index_code,out_start,out_end)

        #划分特征集和标签集
        x_out_test,y_out_test = data_out_sample.iloc[:,1:], data_out_sample.iloc[:,0]
        y_out_pred = model2.predict(x_out_test)
        
        predict_proba = model2.predict_proba(x_out_test)[:,1] #此处test_X为特征集
        
        # 样本外准确率
        accuracy_out_sample = accuracy_score(y_out_test, y_out_pred)
        print('样本外准确率',accuracy_out_sample)
        # 样本外AUC值
        roc_out_sample = roc_auc_score(y_out_test, y_out_pred)
        print('样本外AUC值',roc_out_sample)

        #返回加载好的模型
        return model2

    #贴标签方法2,贴标签方法可以自定义.data为特征数据框,带有列名为return的收益序列,前30%记为1,后30%记为0,其余数据去掉
    def add_label_2(self,data):
        percent_select = [0.3,0.3]
        #做标签
        data['Label'] = np.nan
        #根据收益排序
        data = data.sort_values(by='return',ascending=False)
        #选一定比例的样本
        n_stock = data.shape[0]
        n_stock_select = np.multiply(percent_select,n_stock)
        n_stock_select = np.around(n_stock_select).astype(int)
        #给选中的样本打上标签1 or 0
        data.iloc[0:n_stock_select[0],-1] = 1
        data.iloc[-n_stock_select[1]:,-1] = 0
        #去掉其他没选上的股票
        data = data.dropna(axis=0)
        del data['return']
        return data

 
    #合成用贴标签方法2 的一期数据框,带有特征和标签
    def data_for_model_perperiod(self,start_date,end_date,index_code):
        #在相应区间还在上市的股票
        stock_list = get_index_stocks(index_code, date=end_date)
        stock_list_notpause = list(set(stock_list) & set(get_all_securities(date=end_date).index) & set(self.data_return.columns))
        
        #去掉上市时间不到六个月的公司
        #     stock_list_fillter = filter_new_stock(stock_list_notpause,date_interval[1])
        
        #去除st,*st股
        #stock_list= delete_st(stock_list_fillter, date1)
        date_interval = get_trade_days(start_date = start_date, end_date=end_date, count=None)
        data_per_period = self.merge_regularfactor_multicode_perday(stock_list_notpause,date_interval[0],date_interval[0])
        
        
        date_interval = [date.strftime('%Y-%m-%d') for date in date_interval]
        data_return_interval = self.data_return.loc[date_interval][stock_list_notpause]
        #统计相应区间固定周期收益率
        profit = data_return_interval.apply(lambda x:(1 + x).cumprod() - 1).iloc[-1]
            
        data_per_period = data_per_period.set_index('code')
        stock_profit = pd.DataFrame(profit,columns = ['return'])

        data_merge_label = self.add_label_2(pd.concat([data_per_period,stock_profit],axis = 1))
        return data_merge_label
    
    #合成用贴标签方法2 的多期数据框,带有特征和标签
    def data_for_model_multiperiod(self,start_date,end_date,index_code):
        interval_start,interval_end = self.get_time_inverval(start_date,end_date,'M')
        factor_list = []
        for date1,date2 in dict(zip(interval_start,interval_end)).items():

            data_merge_label = self.data_for_model_perperiod(date1,date2,index_code)
            factor_list.append(data_merge_label)
        return pd.concat(factor_list,axis = 0)
    
    #多因子模型
 
    def multifactor_model(self,index_code,start_date,end_date,out_start,out_end,model_name,file_name):
        data_regular = self.data_for_model_multiperiod(start_date,end_date,index_code)

        y_data = data_regular['Label']
        x_data = data_regular.iloc[:,:-1]

        x_train,x_test,y_train,y_test = train_test_split(x_data,y_data,test_size=0.2)
        #可往里面添加算法
        if model_name == 'LR':
            #构建分类器,里面的默认参数可修改
            model = LogisticRegression(solver='liblinear')
        elif model_name == "LDA":
            model=QDA()
        elif model_name == 'xgboost':
            model = XGBClassifier(max_depth=3,subsample=0.9,random_state=0)
        else:
            print('不支持这个算法,请新添加这个算法')




        print('开始训练数据...')
        #训练数据
        model.fit(x_train, y_train)
        print('\n')
        print('训练结束')
        print ('-' * 60)
        print ("预测准确率:")

        pred = model.predict(x_test)

        # 预测准确率
        score=accuracy_score(y_test, pred )
        print("%s模型: %.3f" % (model_name, score)) 
        print ('-' * 60)

        # 构建混淆矩阵
        cm = pd.crosstab(y_test, pred )
        print('输出混淆矩阵...')

        print(cm)
        # 绘制混淆矩阵图
        sns.heatmap(cm, annot = True, cmap = 'GnBu', fmt = 'd')
        print ('-' * 60)


        print('绘制曲线...')
        # 计算正例的预测概率,而非实际的预测值,用于生成ROC曲线的数据
        y_score = model.predict_proba(x_test)[:,1]
        #fpr表示1-Specificity,tpr表示Sensitivity
        fpr,tpr,threshold = roc_curve(y_test, y_score)
        # 计算AUC的值
        roc_auc = metrics.auc(fpr,tpr)
        # 绘制面积图
        plt.figure(figsize=(8,6))
        plt.stackplot(fpr, tpr, color='steelblue', alpha = 0.5, edgecolor = 'black')
        plt.plot(fpr, tpr, color='black', lw = 1)
        # 添加对角线
        plt.plot([0,1],[0,1], color = 'red', linestyle = '--')
        # 添加文本信息
        plt.text(0.5,0.3,'ROC曲线 (area = %0.2f)' % roc_auc)
        # 添加x轴与y轴标签
        plt.title('模型预测的ROC曲线',size=15)
        plt.xlabel('1-Specificity')
        plt.ylabel('Sensitivity')
        plt.show()
        print ('-' * 60)


        print('输出评估报告...')
        print('\n')

        print('模型的评估报告:\n',classification_report(y_test, pred ))
        print ('-' * 60)


        print('保存模型到本地...')
        joblib.dump(model,file_name+'.pkl')
        print('\n')
        print('加载本地训练好的模型...')
        model2 = joblib.load(file_name+'.pkl')
        print('加载完毕')
        print ('-' * 60)

        #样本外预测
        test_sample_predict={}
        test_sample_score=[]
        test_sample_accuracy=[]
        test_sample_roc_auc=[]
        test_sample_date=[]

        interval_start,interval_end = self.get_time_inverval(out_start,out_end,'M')
        # 样本外预测结果
        print('样本外测试结果...')
        for date1,date2 in dict(zip(interval_start,interval_end)).items():

            data_merge_label = self.data_for_model_perperiod(date1,date2,index_code)
            y_test=data_merge_label['Label']
            X_test=data_merge_label.iloc[:,:-1]


            # 输出预测值以及预测概率
            y_pred_tmp = model2.predict(X_test)
            y_pred = pd.DataFrame(y_pred_tmp, columns=['label_predict'])  # 获得预测标签
            y_pred_proba = pd.DataFrame(model.predict_proba(X_test), columns=['pro1', 'pro2'])  # 获得预测概率
            # 将预测标签、预测数据和原始数据X合并
            y_pred.set_index(X_test.index,inplace=True)
            y_pred_proba.set_index(X_test.index,inplace=True)
            predict_pd = pd.concat((X_test, y_pred, y_pred_proba), axis=1)
            print ('Predict date:')
            print (date1)    
            print ('AUC:')
            print (roc_auc_score(y_test,y_pred))  # 打印前2条结果
            print ('Accuracy:')
            print (accuracy_score(y_test, y_pred))  # 打印前2条结果    
            print ('-' * 60)       
            ## 后续统计画图用
            test_sample_date.append(date1)
            # 样本外预测结果
            test_sample_predict[date1]=y_pred_tmp
            # 样本外准确率
            test_sample_accuracy.append(accuracy_score(y_test, y_pred))   
            # 样本外AUC值
            test_sample_roc_auc.append(roc_auc_score(y_test,y_pred))

        print ('AUC mean info')
        print (np.mean(test_sample_roc_auc))
        print ('-' * 60)    
        print ('ACCURACY mean info')
        print (np.mean(test_sample_accuracy))
        print ('-' * 60)   
        
        f = plt.figure(figsize= (15,6))
        xs_date = test_sample_date
        ys_auc = test_sample_roc_auc
        # 配置横坐标
        plt.gca().xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
        plt.plot(xs_date, ys_auc,'r')
        # 自动旋转日期标记
        plt.gcf().autofmt_xdate() 
        # 横坐标标记
        plt.xlabel('date')
        # 纵坐标标记
        plt.ylabel("test AUC")
        plt.show()


        f = plt.figure(figsize= (15,6))
        xs_date = test_sample_date
        ys_score = test_sample_accuracy
        # 配置横坐标
        plt.gca().xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
        plt.plot(xs_date, ys_score,'r')
        # 自动旋转日期标记
        plt.gcf().autofmt_xdate() 
        # 横坐标标记
        plt.xlabel('date')
        # 纵坐标标记
        plt.ylabel("test accuracy")
        plt.show()



        f = plt.figure(figsize= (15,6))

        sns.set(style="whitegrid")
        data1 = pd.DataFrame(ys_auc, xs_date, columns={'AUC'})
        data2 = pd.DataFrame(ys_score, xs_date, columns={'accuracy'})
        data = pd.concat([data1,data2],sort=False)
        sns.lineplot(data=data, palette="tab10", linewidth=2.5)
        return model2

1
2
3
4
#创建实例,参数(基准),数据默认2005年1月5日到2020年2月28日的全A股数据
#基准:沪深300,可选
benchmark = '000300.XSHG'
report = Report(benchmark)

获取和获取数据:量价数据调用(前复权)

全A股的历史量价数据,涨跌,开盘价,收盘价,最高价,最低价,换手率,成交股票数,成交金额

要调用的数据的格式

储存的可调用的股票历史收益率数据框

1
2
#储存的可调用的股票历史收益率数据框
report.data_return.head()
000001.XSHG 399006.XSHE 399001.XSHE 399005.XSHE 000016.XSHG 000906.XSHG 600000.XSHG 600004.XSHG 600006.XSHG 600007.XSHG ... 300813.XSHE 300815.XSHE 300816.XSHE 300817.XSHE 300818.XSHE 300819.XSHE 300820.XSHE 300821.XSHE 300822.XSHE 300823.XSHE
2005-01-04 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
2005-01-05 0.007379 NaN 0.009070 NaN 0.005272 NaN -0.012048 0.015015 0.010638 0.018405 ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
2005-01-06 -0.009992 NaN -0.007904 NaN -0.010741 NaN -0.012195 0.002959 -0.015789 -0.009036 ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
2005-01-07 0.004292 NaN 0.002265 NaN 0.001362 NaN 0.000000 0.017699 0.005348 0.018237 ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
2005-01-10 0.006146 NaN 0.008941 NaN 0.011377 NaN 0.037037 0.005797 0.005319 0.017910 ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN

5 rows × 3708 columns

储存的可调用的股票历史收盘价数据框

1
2
#储存的可调用的股票历史收盘价数据框
report.data_close.head()
000001.XSHG 399006.XSHE 399001.XSHE 399005.XSHE 000016.XSHG 000906.XSHG 600000.XSHG 600004.XSHG 600006.XSHG 600007.XSHG ... 300813.XSHE 300815.XSHE 300816.XSHE 300817.XSHE 300818.XSHE 300819.XSHE 300820.XSHE 300821.XSHE 300822.XSHE 300823.XSHE
2005-01-04 1242.77 NaN 3025.42 NaN 827.07 NaN 0.83 3.33 1.88 3.26 ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
2005-01-05 1251.94 NaN 3052.86 NaN 831.43 NaN 0.82 3.38 1.90 3.32 ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
2005-01-06 1239.43 NaN 3028.73 NaN 822.50 NaN 0.81 3.39 1.87 3.29 ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
2005-01-07 1244.75 NaN 3035.59 NaN 823.62 NaN 0.81 3.45 1.88 3.35 ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
2005-01-10 1252.40 NaN 3062.73 NaN 832.99 NaN 0.84 3.47 1.89 3.41 ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN

5 rows × 3708 columns

储存的可调用的股票历史最低价数据框

1
2
#储存的可调用的股票历史最低价数据框
report.data_low.head()
000001.XSHG 399006.XSHE 399001.XSHE 399005.XSHE 000016.XSHG 000906.XSHG 600000.XSHG 600004.XSHG 600006.XSHG 600007.XSHG ... 300813.XSHE 300815.XSHE 300816.XSHE 300817.XSHE 300818.XSHE 300819.XSHE 300820.XSHE 300821.XSHE 300822.XSHE 300823.XSHE
2005-01-04 1238.18 NaN 3016.26 NaN 824.01 NaN 0.82 3.31 1.84 3.26 ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
2005-01-05 1235.75 NaN 3017.34 NaN 822.97 NaN 0.81 3.31 1.86 3.24 ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
2005-01-06 1234.24 NaN 3016.15 NaN 820.34 NaN 0.80 3.38 1.87 3.26 ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
2005-01-07 1235.51 NaN 3019.07 NaN 819.44 NaN 0.80 3.39 1.86 3.26 ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
2005-01-10 1236.09 NaN 3018.49 NaN 821.00 NaN 0.81 3.43 1.84 3.32 ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN

5 rows × 3708 columns

储存的可调用的股票历史开盘价数据框

1
2
#储存的可调用的股票历史开盘价数据框
report.data_open.head()
000001.XSHG 399006.XSHE 399001.XSHE 399005.XSHE 000016.XSHG 000906.XSHG 600000.XSHG 600004.XSHG 600006.XSHG 600007.XSHG ... 300813.XSHE 300815.XSHE 300816.XSHE 300817.XSHE 300818.XSHE 300819.XSHE 300820.XSHE 300821.XSHE 300822.XSHE 300823.XSHE
2005-01-04 1260.78 NaN 3051.24 NaN 836.99 NaN 0.84 3.33 2.03 3.32 ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
2005-01-05 1241.68 NaN 3020.72 NaN 825.71 NaN 0.83 3.33 1.86 3.26 ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
2005-01-06 1252.49 NaN 3054.10 NaN 831.99 NaN 0.82 3.39 1.90 3.32 ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
2005-01-07 1239.32 NaN 3028.69 NaN 822.67 NaN 0.81 3.39 1.89 3.29 ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
2005-01-10 1243.58 NaN 3032.84 NaN 823.77 NaN 0.82 3.45 1.89 3.32 ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN

5 rows × 3708 columns

储存的可调用的股票历史最高价数据框

1
2
#储存的可调用的股票历史最高价数据框
report.data_high.head()
000001.XSHG 399006.XSHE 399001.XSHE 399005.XSHE 000016.XSHG 000906.XSHG 600000.XSHG 600004.XSHG 600006.XSHG 600007.XSHG ... 300813.XSHE 300815.XSHE 300816.XSHE 300817.XSHE 300818.XSHE 300819.XSHE 300820.XSHE 300821.XSHE 300822.XSHE 300823.XSHE
2005-01-04 1260.78 NaN 3051.24 NaN 836.99 NaN 0.84 3.38 2.03 3.32 ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
2005-01-05 1258.58 NaN 3067.67 NaN 836.43 NaN 0.83 3.42 1.94 3.34 ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
2005-01-06 1252.73 NaN 3054.10 NaN 833.07 NaN 0.83 3.43 1.92 3.32 ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
2005-01-07 1256.31 NaN 3065.28 NaN 832.95 NaN 0.82 3.47 1.92 3.35 ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
2005-01-10 1252.72 NaN 3063.66 NaN 833.65 NaN 0.84 3.49 1.91 3.41 ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN

5 rows × 3708 columns

储存的可调用的股票历史换手率数据框

1
2
#储存的可调用的股票历史换手率数据框
report.data_turnover.head()
000001.XSHE 000002.XSHE 000004.XSHE 000005.XSHE 000006.XSHE 000007.XSHE 000008.XSHE 000009.XSHE 000010.XSHE 000011.XSHE ... 603989.XSHG 603990.XSHG 603991.XSHG 603992.XSHG 603993.XSHG 603995.XSHG 603996.XSHG 603997.XSHG 603998.XSHG 603999.XSHG
2005-01-04 0.1249 0.6566 0.6400 0.2321 0.1816 0.3639 0.2005 0.1430 0.5186 0.5356 ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
2005-01-05 0.2286 1.1155 1.1179 0.2236 0.1860 0.4015 0.2400 0.1395 0.7805 0.3907 ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
2005-01-06 0.1892 1.1683 1.0751 0.2092 0.1633 0.3733 0.5760 0.1076 0.9594 0.1615 ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
2005-01-07 0.1338 1.1239 0.8072 0.8032 1.5223 0.5812 0.4816 0.2765 1.1124 0.1458 ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
2005-01-10 0.1868 0.4496 0.7163 1.9234 7.8899 1.0698 0.5288 0.6223 1.9922 0.1698 ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN

5 rows × 3695 columns

储存的可调用的股票历史成交金额数据框

1
2
##储存的可调用的股票历史成交金额数据框
report.data_money.head()
000001.XSHG 399006.XSHE 399001.XSHE 399005.XSHE 000016.XSHG 000906.XSHG 600000.XSHG 600004.XSHG 600006.XSHG 600007.XSHG ... 300813.XSHE 300815.XSHE 300816.XSHE 300817.XSHE 300818.XSHE 300819.XSHE 300820.XSHE 300821.XSHE 300822.XSHE 300823.XSHE
2005-01-04 4.418452e+09 NaN 980468922.0 NaN 2.136409e+09 NaN 26134943.0 20728330.0 10491518.0 1555757.0 ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
2005-01-05 4.916589e+09 NaN 807720454.0 NaN 1.705649e+09 NaN 35366812.0 12969407.0 8801098.0 1330700.0 ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
2005-01-06 4.381370e+09 NaN 762259679.0 NaN 1.519687e+09 NaN 28758188.0 29118085.0 3727731.0 1328303.0 ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
2005-01-07 5.040042e+09 NaN 843160298.0 NaN 1.640665e+09 NaN 29239203.0 36353317.0 6311149.0 2460959.0 ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
2005-01-10 4.118292e+09 NaN 734534698.0 NaN 1.402314e+09 NaN 48798404.0 15893640.0 4484497.0 4907955.0 ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN

5 rows × 3708 columns

储存的可调用的股票历史成交手数数据框

1
2
##储存的可调用的股票历史成交手数数据框
report.data_volume.head()
000001.XSHG 399006.XSHE 399001.XSHE 399005.XSHE 000016.XSHG 000906.XSHG 600000.XSHG 600004.XSHG 600006.XSHG 600007.XSHG ... 300813.XSHE 300815.XSHE 300816.XSHE 300817.XSHE 300818.XSHE 300819.XSHE 300820.XSHE 300821.XSHE 300822.XSHE 300823.XSHE
2005-01-04 816177000.0 NaN 138575300.0 NaN 403169700.0 NaN 31512001.0 6205383.0 5537612.0 473684.0 ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
2005-01-05 867865100.0 NaN 123303900.0 NaN 302086300.0 NaN 43229333.0 3835495.0 4655335.0 404622.0 ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
2005-01-06 792225400.0 NaN 108869300.0 NaN 275357400.0 NaN 35558905.0 8549882.0 1978640.0 406085.0 ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
2005-01-07 894087100.0 NaN 117752600.0 NaN 306608600.0 NaN 36094716.0 10584134.0 3332919.0 743356.0 ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
2005-01-10 723468300.0 NaN 94298700.0 NaN 247941100.0 NaN 58865757.0 4597541.0 2385315.0 1449998.0 ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN

5 rows × 3708 columns

为了验证本地的数据是否是可靠的,随机查一下2020年1月15日平安银行的量价数据,与市面专业的商用软件通达信的数据进行对比

1
2
3
4
5
6
7
8
9
10
11
stock_choose = '000001.XSHE'
time_choose = '2020-01-15'
print(time_choose)
print('开盘',report.data_open.loc[time_choose,stock_choose])
print('最高',report.data_high.loc[time_choose,stock_choose])
print('最低',report.data_low.loc[time_choose,stock_choose])
print('收盘',report.data_close.loc[time_choose,stock_choose])
print('总量',report.data_volume.loc[time_choose,stock_choose])
print('换手',report.data_turnover.loc[time_choose,stock_choose])
print('总额',report.data_money.loc[time_choose,stock_choose])
print('涨跌',report.data_return.loc[time_choose,stock_choose])
1
2
3
4
5
6
7
8
9
2020-01-15
开盘 16.79
最高 16.86
最低 16.45
收盘 16.52
总量 85943912.0
换手 0.4429
总额 1424889228.07
涨跌 -0.014319809069212487

对比下,本地数据还是可信的

储存的可调用的历史基准收益

1
2
#储存的可调用的历史基准收益
report.benchmark_return.tail()
1
2
3
4
5
6
2020-02-24   -0.004013
2020-02-25   -0.002175
2020-02-26   -0.012326
2020-02-27    0.002912
2020-02-28   -0.035455
Name: close, dtype: float64

处理数据,合成特征和标签:手把手利用量价数据计算技术指标

1
2
3
4
5
6
7
8
9
#输入数据类型
close_pingan = report.data_close['000001.XSHE']
open_pingan = report.data_open['000001.XSHE']
high_pingan = report.data_high['000001.XSHE']
low_pingan = report.data_low['000001.XSHE']
volume_pingan = report.data_volume['000001.XSHE']
return_pingan = report.data_return['000001.XSHE']
money_pingan = report.data_return['000001.XSHE']
low_pingan.head()
1
2
3
4
5
6
2005-01-04    1.48
2005-01-05    1.46
2005-01-06    1.48
2005-01-07    1.48
2005-01-10    1.46
Name: 000001.XSHE, dtype: float64

计算均线

1
2
3
4
5
6
7
8
9
def get_MA(df_close,n):
    '''
    移动平均值是在一定范围内的价格平均值
    df_close:量价数据,dataframe或者series
    n: 回溯的天数,整数型
    '''
    ma = df_close.rolling(n).mean()
    ma = pd.DataFrame({'MA_' + str(n): ma}, index = ma.index)
    return ma
1
2
#计算平安银行的平均移动线
get_MA(close_pingan,20).tail()
MA_20
2020-02-24 15.1320
2020-02-25 15.0615
2020-02-26 15.0110
2020-02-27 14.9620
2020-02-28 14.9100

也可以用开源库talib来计算

1
2
#用talib库计算验证自己的计算结果
talib.MA(np.array(close_pingan),timeperiod=20, matype=0)
1
array([   nan,    nan,    nan, ..., 15.011, 14.962, 14.91 ])
1
2
3
4
5
6
7
8
9
10
11
12
## 计算变化率ROC
def get_ROC(df_close, n):
    '''
    ROC=(今天的收盘价-N日前的收盘价)/N日前的收盘价*100
    移动平均值是在一定范围内的价格平均值
    df_close:量价收盘数据,dataframe或者series
    n: 回溯的天数,整数型
    '''
    M = df_close
    N = df_close.shift(n)
    roc = pd.DataFrame({'ROC_' + str(n): (M-N) / N*100}, index = M.index)
    return roc
1
get_ROC(close_pingan, 12).tail()
ROC_12
2020-02-24 3.114421
2020-02-25 2.872777
2020-02-26 3.379310
2020-02-27 2.163624
2020-02-28 -1.828030
1
2
#用talib库计算验证自己的计算结果
talib.ROC(close_pingan,12).tail()
1
2
3
4
5
6
2020-02-24    3.114421
2020-02-25    2.872777
2020-02-26    3.379310
2020-02-27    2.163624
2020-02-28   -1.828030
dtype: float64
1
2
3
4
5
6
7
8
## 计算RSI
def get_RSI(df_close,n):
    '''
    df_close:量价收盘数据,dataframe或者series
    n: 回溯的天数,整数型
    '''
    rsi = talib.RSI(df_close, timeperiod=n)
    return pd.DataFrame({'RSI_' + str(n): rsi}, index = df_close.index)
1
df = get_price('000001.XSHE', start_date='2015-09-16', end_date='2020-3-22', frequency='daily')
1
get_RSI(close_pingan,6).tail()
RSI_6
2020-02-24 49.477234
2020-02-25 42.341524
2020-02-26 40.497130
2020-02-27 47.129830
2020-02-28 28.054166
1
2
3
4
5
6
7
8
9
##计算OBV指标
def get_OBV(df_close,df_volume):
    '''
    On Balance Volume 能量,通过统计成交量变动的趋势推测股价趋势
    df_close:量价收盘数据,dataframe或者series
    df_volume:J量价交易数数据,dataframe或者series
    '''
    obv = talib.OBV(df_close,df_volume)
    return pd.DataFrame({'OBV': obv}, index = df_close.index)
1
get_OBV(close_pingan,volume_pingan).tail()
OBV
2020-02-24 4.570885e+10
2020-02-25 4.559439e+10
2020-02-26 4.547673e+10
2020-02-27 4.557426e+10
2020-02-28 4.544420e+10
1
2
3
4
5
6
7
8
9
10
11
#真实波幅
def get_ATR(df_high,df_low,df_close,n):
    '''
    平均真实波幅,主要用来衡量价格的波动
    df_close:量价收盘数据,dataframe或者series
    df_high:量价最高价数据,dataframe或者series
    df_low:量价最低价数据,dataframe或者series
    n: 回溯的天数,整数型
    '''
    atr = talib.ATR(df_high,df_low,df_close, timeperiod=n)
    return pd.DataFrame({'ATR_' + str(n): atr}, index = df_close.index)
1
get_ATR(high_pingan,low_pingan,close_pingan,14).tail()
ATR_14
2020-02-24 0.422029
2020-02-25 0.424027
2020-02-26 0.434454
2020-02-27 0.421993
2020-02-28 0.438279
1
2
3
4
5
6
7
8
9
10
#上升动向值
def get_MOM(df_close,n):
    '''
    上升动向值,投资学中意思为续航,指股票(或经济指数)持续增长的能力。研究发现,赢家组合在牛市中存在着正的动量效应,输家组合在熊市中存在着负的动量效应。
    df_close:量价收盘数据,dataframe或者series
    n: 回溯的天数,整数型
    '''
    mom = talib.MOM(df_close, timeperiod=n)
    return pd.DataFrame({'MOM_' + str(n): mom}, index = df_close.index)

1
2
3
def merge_raw_factors(df_close,df_open,df_high,df_low,df_volume,df_money):
    return pd.concat([get_MA(df_close,5),get_MA(df_close,20),get_MA(df_close,60),get_ROC(df_close, 6),get_ROC(df_close, 12),get_RSI(df_close,6),\
    get_RSI(df_close,12),get_RSI(df_close,24),get_OBV(df_close,df_volume),get_ATR(df_high,df_low,df_close,14),get_MOM(df_close,10)],axis = 1)
1
2
3
#平安银行的每日指标
factors = merge_raw_factors(close_pingan,open_pingan,high_pingan,low_pingan,volume_pingan,money_pingan)
factors.tail()
MA_5 MA_20 MA_60 ROC_6 ROC_12 RSI_6 RSI_12 RSI_24 OBV ATR_14 MOM_10
2020-02-24 15.368 15.1320 15.831000 1.330672 3.114421 49.477234 47.099803 46.196881 4.570885e+10 0.422029 0.73
2020-02-25 15.336 15.0615 15.821833 -2.147040 2.872777 42.341524 43.943630 44.629065 4.559439e+10 0.424027 0.25
2020-02-26 15.286 15.0110 15.808333 -1.381579 3.379310 40.497130 43.114224 44.216995 4.547673e+10 0.434454 0.22
2020-02-27 15.190 14.9620 15.799833 -0.853018 2.163624 47.129830 45.792942 45.477725 4.557426e+10 0.421993 0.46
2020-02-28 14.974 14.9100 15.783667 -6.991661 -1.828030 28.054166 36.310977 40.609408 4.544420e+10 0.438279 -0.53
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
#数据去极值及标准化
def winsorize_and_standarlize(data,qrange=[0.05,0.95],axis=0):
    '''
    input:
    data:Dataframe or series,输入数据
    qrange:list,list[0]下分位数,list[1],上分位数,极值用分位数代替
    '''
    if isinstance(data,pd.DataFrame):
        if axis == 0:
            q_down = data.quantile(qrange[0])
            q_up = data.quantile(qrange[1])
            index = data.index
            col = data.columns
            for n in col:
                data[n][data[n] > q_up[n]] = q_up[n]
                data[n][data[n] < q_down[n]] = q_down[n]
            data = (data - data.mean())/data.std()
            data = data.fillna(0)
        else:
            data = data.stack()
            data = data.unstack(0)
            q = data.quantile(qrange)
            index = data.index
            col = data.columns
            for n in col:
                data[n][data[n] > q[n]] = q[n]
            data = (data - data.mean())/data.std()
            data = data.stack().unstack(0)
            data = data.fillna(0)
            
    elif isinstance(data,pd.Series):
        name = data.name
        q = data.quantile(qrange)
        data[data>q] = q
        data = (data - data.mean())/data.std()
    return data

数据查看

1
2
data_pro = winsorize_and_standarlize(factors)
data_pro.describe()
MA_5 MA_20 MA_60 ROC_6 ROC_12 RSI_6 RSI_12 RSI_24 OBV ATR_14 MOM_10
count 3.682000e+03 3.682000e+03 3.682000e+03 3.682000e+03 3.682000e+03 3.682000e+03 3.682000e+03 3.682000e+03 3.682000e+03 3.682000e+03 3.682000e+03
mean 1.700794e-14 2.515123e-14 -2.662883e-14 -7.065385e-16 -1.112786e-16 2.447526e-14 2.093756e-14 -1.427719e-14 2.011150e-14 -7.692863e-15 3.948800e-16
std 9.994565e-01 9.974158e-01 9.919535e-01 9.991847e-01 9.983687e-01 9.991847e-01 9.983687e-01 9.967347e-01 1.000000e+00 9.980965e-01 9.986407e-01
min -1.705255e+00 -1.728499e+00 -1.780926e+00 -1.824671e+00 -1.810217e+00 -1.702614e+00 -1.659394e+00 -1.596035e+00 -1.534646e+00 -1.351977e+00 -2.025731e+00
25% -5.698024e-01 -5.721285e-01 -5.615280e-01 -6.282979e-01 -6.495794e-01 -7.731055e-01 -7.740355e-01 -7.952646e-01 -8.534952e-01 -8.839974e-01 -5.336564e-01
50% -1.084304e-01 -8.470670e-02 -7.784826e-02 -9.829604e-02 -6.213428e-02 -3.013210e-02 -3.795124e-02 -8.486868e-02 -2.590756e-01 -9.136803e-02 -3.629837e-02
75% 6.278916e-01 6.541132e-01 6.663069e-01 5.347080e-01 5.678273e-01 7.716410e-01 7.247613e-01 7.206211e-01 1.097747e+00 6.497930e-01 4.989021e-01
max 1.994363e+00 1.940143e+00 1.881687e+00 2.213011e+00 2.117848e+00 1.767631e+00 1.851835e+00 1.935446e+00 1.523051e+00 1.985683e+00 2.114235e+00

平安银行的每日收益

1
2
3
#平安银行的每日收益
profit = copy.deepcopy(return_pingan)
profit.tail(10)
1
2
3
4
5
6
7
8
9
10
11
2020-02-17    0.022621
2020-02-18   -0.011061
2020-02-19    0.002632
2020-02-20    0.022966
2020-02-21   -0.000641
2020-02-24   -0.022465
2020-02-25   -0.012475
2020-02-26   -0.003324
2020-02-27    0.008005
2020-02-28   -0.040371
Name: 000001.XSHE, dtype: float64

收益率大于0标记为1,否则为0,记作标签

1
2
3
4
5
6
#收益率大于0标记为1,否则为0,记作标签
profit[profit > 0] = 1
profit[profit< 0] = 0
profit = pd.DataFrame(profit)
profit.columns = ['Label']
profit.tail()
Label
2020-02-24 0.0
2020-02-25 0.0
2020-02-26 0.0
2020-02-27 1.0
2020-02-28 0.0
1
2
#预测的是未来一天的涨跌,所以label比特征要滞后一天
profit.shift().tail()
Label
2020-02-24 0.0
2020-02-25 0.0
2020-02-26 0.0
2020-02-27 0.0
2020-02-28 1.0

合并后的数据框,第一列label为标签,其余列称为特征

1
2
data_final = pd.concat([profit.shift(),data_pro],axis = 1)
data_final.tail()
Label MA_5 MA_20 MA_60 ROC_6 ROC_12 RSI_6 RSI_12 RSI_24 OBV ATR_14 MOM_10
2020-02-24 0.0 1.994363 1.940143 1.881687 0.154400 0.287494 -0.101360 -0.362483 -0.602621 1.523051 1.427686 1.499024
2020-02-25 0.0 1.994363 1.940143 1.881687 -0.506021 0.254781 -0.510047 -0.616758 -0.775493 1.523051 1.442254 0.461060
2020-02-26 0.0 1.994363 1.940143 1.881687 -0.360660 0.323354 -0.615682 -0.683579 -0.820929 1.523051 1.518279 0.396187
2020-02-27 0.0 1.994363 1.940143 1.881687 -0.260285 0.158779 -0.235804 -0.467769 -0.681917 1.523051 1.427421 0.915169
2020-02-28 1.0 1.994363 1.940143 1.881687 -1.426021 -0.381598 -1.328334 -1.231680 -1.218712 1.523051 1.546171 -1.225633

功能集成,方便以后调用

将上述函数集成,添加到上面的类对象里,作为一个新的功能, 看懂了你也可以添加新的指标,或者改进代码

新功能:给定时间,给定股票或者指数,就能输出各种特征

1
2
#调用类中的新功能
report.merge_final_factor_label1('600000.XSHG','2020-01-21','2020-02-28')
Label MA_5 MA_60 Aroondown_14 Aroonup_14 ROC_6 ROC_12 RSI_6 RSI_24 OBV ATR_14 MOM_10
2020-01-21 1.0 2.310415 1.680919 1.385505 0.850139 -0.049595 0.151672 -0.627981 1.315566 0.974446 -1.829183 0.119921
2020-01-22 0.0 2.310415 1.680919 1.385505 0.642568 -0.508179 -0.039098 -1.254577 0.318997 -0.277086 -1.829183 -0.030231
2020-01-23 0.0 1.933406 1.563415 1.385505 0.434997 -0.860640 -0.534339 -1.550447 -0.808540 -1.503009 -1.113230 -0.573091
2020-02-03 0.0 1.066039 1.348623 1.385505 0.227426 -1.807349 -1.157115 -1.550447 -1.538096 -1.915072 1.662121 -1.413946
2020-02-04 0.0 0.282451 1.128566 1.203410 0.019855 -1.807349 -1.157115 -1.184782 -1.538096 -1.915072 1.662121 -1.413946
2020-02-05 1.0 -0.397643 0.913774 1.021315 -0.187716 -1.778663 -1.134754 -1.047883 -1.399087 -1.239307 1.462821 -1.393156
2020-02-06 1.0 -0.880609 0.686346 0.839220 -0.395288 -1.437094 -1.055505 -0.714797 -1.051062 -0.630247 1.287606 -1.081300
2020-02-07 1.0 -0.964882 0.469448 0.657125 -0.602859 -0.927177 -0.998269 -0.444543 -0.782738 -0.119326 1.216143 -0.942698
2020-02-10 1.0 -0.964882 0.259921 0.475030 -0.810430 -0.473645 -0.924015 -0.585924 -0.946487 -0.664448 0.876075 -1.081300
2020-02-11 0.0 -0.875681 0.079874 0.292935 -1.018001 1.049427 -0.775532 -0.186210 -0.595686 -0.106125 0.651534 -1.000449
2020-02-12 1.0 -0.796829 -0.070691 0.110840 -1.018001 0.735070 -0.805060 -0.186210 -0.595686 -0.106125 0.260558 -0.804095
2020-02-13 0.0 -0.811614 -0.226521 -0.071255 -1.018001 0.507486 -0.935054 -0.433283 -0.795829 -0.468618 -0.102491 -0.561541
2020-02-14 0.0 -0.811614 -0.370768 -0.253350 -1.018001 0.522711 -0.655954 0.118971 -0.388368 -0.110450 -0.378784 0.039070
2020-02-17 1.0 -0.609557 -0.471847 -0.435445 -1.018001 0.920266 0.031893 1.313898 0.788246 0.551781 0.003316 1.123634
2020-02-18 1.0 -0.510992 -0.586614 -0.617540 -1.018001 0.876334 0.364239 0.861870 0.507352 0.068639 -0.280534 1.067038
2020-02-19 0.0 -0.387787 -0.705592 -0.799635 -1.018001 0.808748 1.464518 1.032863 0.681504 0.470690 -0.361636 1.078588
2020-02-20 1.0 -0.156160 -0.810882 -0.981729 -1.018001 0.999922 1.432828 1.412282 1.093107 1.057847 -0.406533 1.113239
2020-02-21 1.0 0.060681 -0.888797 -1.163824 1.659666 1.049427 1.464518 1.412282 1.315566 1.502184 -0.539459 1.113239
2020-02-24 1.0 0.050825 -0.973029 -1.163824 1.659666 0.888404 1.173463 0.937318 0.961292 0.965116 -0.480417 1.055488
2020-02-25 0.0 0.065610 -1.067790 -1.163824 1.472852 0.271194 0.996147 0.633880 0.783536 0.191060 -0.699303 0.870685
2020-02-26 0.0 0.109964 -1.159393 -1.163824 1.265281 0.629473 1.249660 1.107259 1.170027 1.022488 -0.385546 0.997737
2020-02-27 1.0 0.100107 -1.239940 -1.163824 1.057710 0.566196 1.145570 1.149823 1.205165 1.502184 -0.550383 1.123634
2020-02-28 1.0 -0.121663 -1.239940 -1.163824 0.850139 -0.174968 0.697302 -0.213361 0.299315 0.748452 -0.125615 0.593480

指数择时

获取自定义区间的数据

1
2
3
#获取指标2010年到2018年的上证指数数据,作为数据集
data_index = report.merge_final_factor_label1('000001.XSHG','2010-01-01','2018-01-01')
data_index.head(20)
Label MA_5 MA_60 Aroondown_14 Aroonup_14 ROC_6 ROC_12 RSI_6 RSI_24 OBV ATR_14 MOM_10
2010-01-04 1.0 0.988304 0.868502 -0.147992 -1.385636 0.990031 -0.148243 0.441767 0.291502 -0.922039 0.602991 1.216550
2010-01-05 0.0 1.025477 0.881695 -0.342382 -1.385636 1.577618 0.758380 0.826958 0.486422 -0.922039 0.630772 1.505414
2010-01-06 1.0 1.042376 0.892517 -0.536772 -1.385636 0.694792 1.077039 0.346818 0.315013 -0.922039 0.577094 1.943435
2010-01-07 0.0 1.014588 0.900080 -0.731162 1.062305 -0.265111 0.504234 -0.466328 -0.042070 -0.922039 0.676621 1.109531
2010-01-08 0.0 0.982294 0.907437 -0.925552 0.874002 -0.791835 1.143227 -0.416499 -0.024108 -0.922039 0.641097 0.357246
2010-01-11 1.0 0.969952 0.915472 -1.119941 1.250608 -0.764055 1.080711 -0.138320 0.071259 -0.922039 0.792552 0.640891
2010-01-12 1.0 0.966685 0.923492 -1.314331 1.062305 0.287730 0.904679 0.654121 0.404724 -0.922039 0.885783 0.776659
2010-01-13 1.0 0.934224 0.926494 -1.314331 0.874002 -1.262313 0.191839 -0.591117 -0.177093 -0.922039 1.012879 -0.447023
2010-01-14 0.0 0.943287 0.931427 -1.314331 0.685699 -0.482049 0.152193 -0.089788 0.052163 -0.922039 0.964007 -0.525294
2010-01-15 1.0 0.954490 0.937305 -1.314331 0.497395 0.306325 0.037760 0.008444 0.097648 -0.916050 0.891891 -0.583775
2010-01-18 1.0 0.964181 0.941703 -1.314331 0.309092 0.416529 -0.256650 0.169099 0.167496 -0.906526 0.798950 -0.127639
2010-01-19 1.0 0.953396 0.946375 0.046398 0.120789 0.335193 -0.292484 0.298502 0.221135 -0.897364 0.700342 -0.409709
2010-01-20 1.0 0.945113 0.950812 1.407126 -0.067514 -1.405118 -0.774132 -0.872962 -0.323491 -0.908522 0.835592 -1.069940
2010-01-21 0.0 0.922551 0.955152 1.407126 -0.255817 -0.208462 -1.007122 -0.758828 -0.282305 -0.900460 0.786744 -0.396024
2010-01-22 1.0 0.884518 0.960873 1.407126 -0.444120 -1.032629 -1.033166 -1.051371 -0.447064 -0.910173 0.883784 -0.725746
2010-01-25 0.0 0.827728 0.964227 1.212737 -0.632424 -1.511881 -0.836571 -1.335185 -0.627804 -0.915699 0.815792 -1.227171
2010-01-26 0.0 0.737192 0.962278 1.407126 -0.820727 -1.923946 -1.453195 -1.699754 -0.994930 -0.922039 0.938067 -1.970330
2010-01-27 0.0 0.671427 0.957935 1.407126 -1.009030 -1.923946 -1.834742 -1.699754 -1.143798 -0.922039 0.902419 -1.893801
2010-01-28 0.0 0.605869 0.953362 1.407126 -1.197333 -1.867684 -1.873060 -1.699754 -1.092096 -0.922039 0.827277 -1.970330
2010-01-29 1.0 0.550428 0.947721 1.212737 -1.385636 -1.923946 -1.517184 -1.699754 -1.115215 -0.922039 0.800694 -1.970330

数据分割

1
2
3
4
5
6
7
8
9
10
11
# 将数据集分样本特征集和样本特征集

#x_data:所要划分的样本特征集

#y_data:所要划分的样本标签

#test_size:样本占比,如果是整数的话就是样本的数量

#random_state:是随机数的种子。随机数种子:其实就是该组随机数的编号,在需要重复试验的时候,保证得到一组一样的随机数。比如你每次都填1,其他参数一样的情况下你得到的随机数组是一样的。但填0或不填,每次都会不一样。

x_data,y_data = data_index.iloc[:,1:], data_index.iloc[:,0]
1
x_data.tail(20)
MA_5 MA_60 Aroondown_14 Aroonup_14 ROC_6 ROC_12 RSI_6 RSI_24 OBV ATR_14 MOM_10
2017-12-04 1.160203 1.302549 0.629567 -1.385636 -0.529114 -0.724529 -1.398063 -1.015399 1.729006 -0.574550 -0.877069
2017-12-05 1.148271 1.299805 1.407126 -0.444120 -0.253274 -0.650207 -1.503913 -1.094497 1.729006 -0.625791 -1.113752
2017-12-06 1.130799 1.296693 1.407126 -0.632424 -0.482986 -0.791368 -1.670960 -1.222999 1.729006 -0.570801 -1.405963
2017-12-07 1.112833 1.293513 1.212737 -0.820727 -0.766524 -1.083353 -1.699754 -1.496310 1.729006 -0.562992 -0.848419
2017-12-08 1.101837 1.290953 1.018347 -1.009030 -0.348283 -1.092247 -1.024953 -1.136449 1.729006 -0.544920 -0.690499
2017-12-11 1.106843 1.289108 0.823957 -1.197333 -0.000361 -0.283149 0.092151 -0.557712 1.729006 -0.540233 -0.062364
2017-12-12 1.097741 1.285750 0.629567 -1.385636 -0.366630 -0.608557 -0.829096 -1.079091 1.729006 -0.513177 -0.582397
2017-12-13 1.101355 1.282990 0.435177 -1.385636 -0.057532 -0.205210 -0.236514 -0.717925 1.729006 -0.521942 -0.404885
2017-12-14 1.109470 1.280302 0.240787 -1.385636 -0.067255 -0.371375 -0.463819 -0.846506 1.729006 -0.541387 -0.305742
2017-12-15 1.099978 1.277325 0.046398 -1.385636 -0.116092 -0.601435 -0.951066 -1.148089 1.729006 -0.540983 -0.568909
2017-12-18 1.078375 1.274094 1.407126 -1.197333 -0.294117 -0.434127 -0.890981 -1.118135 1.729006 -0.560886 -0.472621
2017-12-19 1.084635 1.272042 1.212737 -1.385636 -0.330999 -0.219780 -0.029386 -0.655026 1.729006 -0.565546 -0.132365
2017-12-20 1.078494 1.269375 1.018347 -1.385636 0.024776 -0.227254 -0.260446 -0.764224 1.729006 -0.590575 -0.124587
2017-12-21 1.081527 1.267410 0.823957 -1.385636 -0.083261 -0.087167 0.095794 -0.568856 1.729006 -0.560075 0.213700
2017-12-22 1.093833 1.265522 0.629567 -1.385636 0.000464 -0.035797 -0.003177 -0.607870 1.729006 -0.616412 0.007538
2017-12-25 1.098824 1.263443 0.435177 -0.632424 0.108730 0.005290 -0.520022 -0.821096 1.729006 -0.583946 -0.473015
2017-12-26 1.102637 1.262169 0.240787 -0.820727 0.374023 0.064116 0.288744 -0.413773 1.729006 -0.580745 0.187118
2017-12-27 1.097928 1.259804 0.046398 -1.009030 -0.279201 -0.411926 -0.531291 -0.793420 1.729006 -0.566376 -0.330454
2017-12-28 1.096464 1.258332 -0.147992 -1.197333 0.046382 0.060158 0.034153 -0.483688 1.729006 -0.542000 -0.023278
2017-12-29 1.100487 1.256911 -0.342382 -1.385636 0.027748 -0.027998 0.296710 -0.327705 1.729006 -0.594899 0.341887
1
y_data.tail(20)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
2017-12-04    1.0
2017-12-05    0.0
2017-12-06    0.0
2017-12-07    0.0
2017-12-08    0.0
2017-12-11    1.0
2017-12-12    1.0
2017-12-13    0.0
2017-12-14    1.0
2017-12-15    0.0
2017-12-18    0.0
2017-12-19    1.0
2017-12-20    1.0
2017-12-21    0.0
2017-12-22    1.0
2017-12-25    0.0
2017-12-26    0.0
2017-12-27    1.0
2017-12-28    0.0
2017-12-29    1.0
Name: Label, dtype: float64
1
2
# #原始数据按照比例分割为“训练集”和“测试集”
x_train,x_test,y_train,y_test = train_test_split(x_data,y_data,test_size=0.2)

训练模型

1
2
model=QDA()
model.fit(x_train, y_train)
1
2
QuadraticDiscriminantAnalysis(priors=None, reg_param=0.0,
               store_covariance=False, store_covariances=None, tol=0.0001)

输出训练结果和评价指标

混淆矩阵

1
2
3
4
5
6
7
8
9
10
# 模型在测试数据集上的预测
pred = model.predict(x_test)
# 构建混淆矩阵
cm = pd.crosstab(y_test,pred)
cm
# 绘制混淆矩阵图
sns.heatmap(cm, annot = True, cmap = 'GnBu', fmt = 'd')

print('模型的准确率为:\n',accuracy_score(y_test, pred))
print('模型的评估报告:\n',classification_report(y_test, pred))
1
2
3
4
5
6
7
8
9
10
11
模型的准确率为:
 0.6452442159383034
模型的评估报告:
               precision    recall  f1-score   support

         0.0       0.63      0.63      0.63       186
         1.0       0.66      0.66      0.66       203

   micro avg       0.65      0.65      0.65       389
   macro avg       0.64      0.64      0.64       389
weighted avg       0.65      0.65      0.65       389

ROC曲线

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# 计算正例的预测概率,而非实际的预测值,用于生成ROC曲线的数据
y_score = model.predict_proba(x_test)[:,1]
#fpr表示1-Specificity,tpr表示Sensitivity
fpr,tpr,threshold = roc_curve(y_test, y_score)
# 计算AUC的值
roc_auc = metrics.auc(fpr,tpr)
# 绘制面积图
plt.figure(figsize=(8,6))
plt.stackplot(fpr, tpr, color='steelblue', alpha = 0.5, edgecolor = 'black')
plt.plot(fpr, tpr, color='black', lw = 1)
# 添加对角线
plt.plot([0,1],[0,1], color = 'red', linestyle = '--')
# 添加文本信息
plt.text(0.5,0.3,'ROC曲线 (area = %0.2f)' % roc_auc)
# 添加x轴与y轴标签
plt.title('模型预测指数涨跌的ROC曲线',size=15)
plt.xlabel('1-Specificity')
plt.ylabel('Sensitivity')
plt.show()

保存模型

1
2
joblib.dump(model,'train_model.pkl')
model2 = joblib.load("train_model.pkl")

加载模型

1
model2 = joblib.load("train_model.pkl")

样本外测试

1
2
#获取样本外数据
data_out_sample = report.merge_final_factor_label1('000001.XSHG','2018-01-01','2020-01-01')
1
data_out_sample.tail(10)
Label MA_5 MA_60 Aroondown_14 Aroonup_14 ROC_6 ROC_12 RSI_6 RSI_24 OBV ATR_14 MOM_10
2019-12-18 1.0 0.234736 -0.061065 -0.783525 1.202429 1.361947 1.427895 1.862737 1.513285 1.435901 -1.419598 1.522202
2019-12-19 0.0 0.326176 -0.059776 -0.976240 1.008527 1.264220 1.337858 1.862737 1.513665 1.435901 -1.419598 1.297170
2019-12-20 1.0 0.359787 -0.059885 -1.168955 0.814626 1.222215 1.286913 1.208598 1.265640 1.435901 -1.419598 1.036778
2019-12-23 0.0 0.340266 -0.060923 -1.361671 0.620724 -0.038880 0.665595 -0.266910 0.490695 1.435901 -1.419598 0.565392
2019-12-24 0.0 0.304419 -0.061116 -1.361671 0.426822 0.003651 0.733870 0.252256 0.777277 1.435901 -1.419598 0.745777
2019-12-25 1.0 0.272703 -0.059200 -1.361671 0.232921 -0.498166 0.701949 0.227258 0.763167 1.435901 -1.419598 0.662392
2019-12-26 0.0 0.263935 -0.053531 -1.361671 0.039019 -0.099477 0.918042 0.826302 1.116750 1.435901 -1.419598 1.023268
2019-12-27 1.0 0.264025 -0.048252 -1.168955 -0.154882 -0.129788 0.826051 0.734788 1.073468 1.435901 -1.419598 0.450237
2019-12-30 0.0 0.333726 -0.038485 -1.361671 1.396330 0.482292 1.247091 1.361105 1.524115 1.435901 -1.355460 0.643077
2019-12-31 1.0 0.394560 -0.028593 -1.168955 1.396330 1.178552 0.831957 1.494480 1.644490 1.435901 -1.412389 0.348276
1
2
#划分特征集和标签集
x_out_test,y_out_test = data_out_sample.iloc[:,1:], data_out_sample.iloc[:,0]

加载模型后可直接使用

输出预测标签

1
2
y_out_pred = model2.predict(x_out_test)
y_out_pred
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1.,
       0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 1., 1., 1., 1., 0.,
       0., 0., 1., 0., 0., 1., 0., 0., 1., 0., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 0., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1.,
       1., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 1., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 1., 1., 1., 0., 0., 0.,
       0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 1., 0., 0., 0., 0.,
       0., 0., 1., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 0., 1., 1., 0., 0., 0., 0.,
       0., 1., 1., 1., 1., 1., 1., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 0., 1., 1., 0.,
       1., 1., 0., 0., 0., 1., 1., 1., 0., 0., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 1., 0., 0., 0.,
       0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 1., 1., 1., 0., 1., 1., 1.,
       1., 1., 1., 0., 0., 0., 1., 1., 1., 1., 1., 0., 0., 1., 1., 1., 1.,
       1., 1., 1., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0.,
       0., 1., 0., 0., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
       1., 1., 0., 0., 0., 1., 1., 0., 1., 1., 0., 1., 1., 1., 0., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 1., 0., 1., 0., 0., 0.,
       0., 0., 0., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 1., 1., 0.,
       0., 0., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 0., 0., 0., 1., 1., 1., 1.])

获得预测概率

1
2
predict_proba = model2.predict_proba(x_out_test)[:,1] #此处test_X为特征集
predict_proba
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
array([0.92768637, 0.87372613, 0.93611771, 0.92504257, 0.93830129,
       0.97350163, 0.98439474, 0.98494626, 0.99302094, 0.71066415,
       0.9346553 , 0.9571639 , 0.98347517, 0.98133935, 0.97773941,
       0.8420986 , 0.89008306, 0.92629049, 0.9572436 , 0.5039008 ,
       0.15433895, 0.19618445, 0.13506135, 0.17494294, 0.54099699,
       0.03775966, 0.01478691, 0.01300564, 0.01183642, 0.03102973,
       0.06048433, 0.10077356, 0.19561205, 0.16183885, 0.18157097,
       0.40267378, 0.57780427, 0.76351945, 0.98894525, 0.96757754,
       0.94039104, 0.69392342, 0.45948081, 0.53882556, 0.54942774,
       0.55376286, 0.78704139, 0.6902328 , 0.67840564, 0.50551527,
       0.60997526, 0.46440017, 0.26347114, 0.0656017 , 0.05779543,
       0.21262375, 0.25199363, 0.50075914, 0.61394749, 0.27114971,
       0.10048141, 0.13073217, 0.13687481, 0.8528971 , 0.90548906,
       0.56790395, 0.50395654, 0.13714471, 0.08843812, 0.29217134,
       0.64426371, 0.42854564, 0.32085174, 0.5137408 , 0.35244305,
       0.30040986, 0.50240616, 0.457433  , 0.67157372, 0.69082967,
       0.77687727, 0.79582527, 0.78517435, 0.74283735, 0.68615354,
       0.80688554, 0.90175701, 0.64304015, 0.42421477, 0.75869859,
       0.81017815, 0.8514314 , 0.38881122, 0.25902116, 0.16211049,
       0.16614164, 0.12334865, 0.13226115, 0.47903855, 0.36390768,
       0.5174554 , 0.73624912, 0.75098208, 0.40788233, 0.3967154 ,
       0.2439561 , 0.54479096, 0.53777812, 0.40403812, 0.25577642,
       0.24904761, 0.09929233, 0.13741176, 0.1678991 , 0.18963611,
       0.1943378 , 0.14487027, 0.18659786, 0.42940769, 0.25579801,
       0.52254904, 0.33436851, 0.23813746, 0.37284747, 0.74879912,
       0.79164491, 0.66651398, 0.88469944, 0.74616232, 0.73853718,
       0.63936874, 0.6081675 , 0.55622609, 0.75321664, 0.7815924 ,
       0.86126268, 0.85880723, 0.6411158 , 0.65249363, 0.52455551,
       0.54540161, 0.34515664, 0.16483909, 0.09813257, 0.07875876,
       0.39561187, 0.29154285, 0.57326271, 0.59871615, 0.54026514,
       0.49279335, 0.30241696, 0.42908667, 0.26428967, 0.48671753,
       0.62122644, 0.56028551, 0.56033722, 0.62665293, 0.81940654,
       0.7387033 , 0.66592612, 0.52672122, 0.37637344, 0.2850368 ,
       0.53597262, 0.24010379, 0.18681959, 0.22768771, 0.26363987,
       0.24933921, 0.19862666, 0.51994383, 0.45421865, 0.30938515,
       0.74792445, 0.78869373, 0.73751946, 0.95793289, 0.87027078,
       0.90282423, 0.77251597, 0.79809766, 0.37687865, 0.25473459,
       0.30440401, 0.06507915, 0.09396445, 0.09478336, 0.10337532,
       0.1668443 , 0.11178433, 0.39416537, 0.87374959, 0.65201754,
       0.83738234, 0.88783166, 0.95112084, 0.82290048, 0.55409254,
       0.83058832, 0.88418062, 0.97701667, 0.96498092, 0.96815346,
       0.73549159, 0.57495442, 0.55501549, 0.45825476, 0.58237228,
       0.60024485, 0.69880294, 0.77742039, 0.88159983, 0.34161263,
       0.51951705, 0.5136713 , 0.14993804, 0.14164131, 0.13503646,
       0.44318011, 0.29003659, 0.54459399, 0.93611822, 0.88347707,
       0.79203113, 0.59555529, 0.54438551, 0.38859177, 0.25961676,
       0.38008381, 0.52648356, 0.27392474, 0.52292447, 0.3474074 ,
       0.28097051, 0.17816129, 0.15292792, 0.25848705, 0.21666034,
       0.15186081, 0.20692974, 0.2616456 , 0.1740437 , 0.18470274,
       0.6455302 , 0.74751271, 0.67978954, 0.81101168, 0.72088468,
       0.86716137, 0.49636418, 0.73066438, 0.60749457, 0.48322452,
       0.86254143, 0.90508905, 0.43830084, 0.4528298 , 0.48121765,
       0.59174571, 0.54601153, 0.62680834, 0.19902157, 0.32335333,
       0.80321108, 0.93053582, 0.90023112, 0.80395445, 0.93857671,
       0.74428307, 0.83853564, 0.80653039, 0.74841075, 0.54240611,
       0.76270491, 0.78043358, 0.79909342, 0.69770529, 0.72841322,
       0.60916924, 0.46919568, 0.23561556, 0.27600535, 0.14867822,
       0.5547519 , 0.34777691, 0.16786792, 0.03765598, 0.01467337,
       0.02226528, 0.86197099, 0.36702956, 0.22010419, 0.29279984,
       0.40326685, 0.24354445, 0.34090129, 0.27896042, 0.02739086,
       0.11002659, 0.37231314, 0.60113585, 0.28467928, 0.30574615,
       0.24217869, 0.27286212, 0.0713619 , 0.23335294, 0.20911186,
       0.18794769, 0.18560383, 0.28359213, 0.18806428, 0.38684336,
       0.23436443, 0.2736422 , 0.21226031, 0.06377722, 0.07374984,
       0.01122629, 0.03179844, 0.00436962, 0.01504322, 0.02710046,
       0.04742455, 0.21986613, 0.15414805, 0.28812503, 0.48076086,
       0.61525733, 0.29366504, 0.69023641, 0.72010134, 0.73192722,
       0.71350155, 0.43443765, 0.83757989, 0.91196111, 0.80597116,
       0.71725584, 0.72846096, 0.68139639, 0.4055958 , 0.39676744,
       0.42733194, 0.5501376 , 0.8451244 , 0.73108955, 0.71831443,
       0.51176214, 0.47541834, 0.47238053, 0.86942811, 0.93746309,
       0.90782211, 0.88088337, 0.69681599, 0.53907133, 0.57675902,
       0.49121053, 0.73568145, 0.75968351, 0.49019147, 0.43524361,
       0.48736784, 0.11410632, 0.15359888, 0.15212952, 0.18340806,
       0.3275416 , 0.58926939, 0.53889026, 0.41103409, 0.20666117,
       0.51170725, 0.29108907, 0.42503409, 0.58790145, 0.67938655,
       0.69460852, 0.67718365, 0.81981997, 0.44427158, 0.17228037,
       0.08049974, 0.05629146, 0.10224809, 0.16675326, 0.37406816,
       0.34571201, 0.71925828, 0.5392739 , 0.46973142, 0.42379824,
       0.4823324 , 0.57120789, 0.60050637, 0.4901272 , 0.68224158,
       0.73347621, 0.27967437, 0.68220636, 0.5950859 , 0.52300998,
       0.45049518, 0.86898773, 0.83804669, 0.91548209, 0.89367349,
       0.85914564, 0.85690102, 0.90204988, 0.71132515, 0.79163732,
       0.812289  , 0.20139774, 0.28918965, 0.49553488, 0.61543553,
       0.33575905, 0.54168394, 0.15658974, 0.13426621, 0.22442951,
       0.16946512, 0.1953657 , 0.42157698, 0.74126061, 0.87274383,
       0.95759955, 0.83947617, 0.55318222, 0.43796698, 0.1473707 ,
       0.14126038, 0.36556707, 0.24001217, 0.25518879, 0.58937928,
       0.80839603, 0.36331598, 0.38230496, 0.28113962, 0.55396343,
       0.72834572, 0.87771931, 0.64322465, 0.61406436, 0.3262896 ,
       0.09367672, 0.11010181, 0.10681042, 0.13281207, 0.09630662,
       0.41926085, 0.76901883, 0.46624134, 0.37215163, 0.23810044,
       0.38668379, 0.42983361, 0.40827589, 0.21950464, 0.09326784,
       0.10705637, 0.35907413, 0.22196696, 0.73719639, 0.79230004,
       0.81567024, 0.85233467, 0.96740518, 0.73752961, 0.96878856,
       0.96848615, 0.88184152, 0.86665458, 0.89069756, 0.57529245,
       0.13776376, 0.3218207 , 0.49661417, 0.67958821, 0.79617313,
       0.90001448, 0.93731303])
1
2
3
# 样本外准确率
accuracy_out_sample = accuracy_score(y_out_test, y_out_pred)
accuracy_out_sample 
1
0.6201232032854209
1
2
3
# 样本外AUC值
roc_out_sample = roc_auc_score(y_out_test, y_out_pred)
roc_out_sample
1
0.6200961376286052

练习:集成上述训练模型功能,加入上面类对象里,或者创建新类,集成旧类的功能,在新添功能

1
1

答案:翻看类的对应功能函数,查看相应的参数说明

1
model_timing = report.timing_model('000001.XSHG','2010-01-01','2018-01-01','2018-01-01','2020-01-01','LR','LR_model')
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
开始获取合成特征和标签数据框...
------------------------------------------------------------
按照比例分割为训练集和测试集...
------------------------------------------------------------
开始训练数据...


训练结束
------------------------------------------------------------
预测准确率:
LR模型: 0.638
------------------------------------------------------------
输出混淆矩阵...
LR      0.0  1.0
Actual          
0.0     113   86
1.0      55  135
------------------------------------------------------------
绘制曲线...
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
------------------------------------------------------------
输出评估报告...


模型的评估报告:
               precision    recall  f1-score   support

         0.0       0.67      0.57      0.62       199
         1.0       0.61      0.71      0.66       190

   micro avg       0.64      0.64      0.64       389
   macro avg       0.64      0.64      0.64       389
weighted avg       0.64      0.64      0.64       389

------------------------------------------------------------
保存模型到本地...


加载本地训练好的模型...
加载完毕
------------------------------------------------------------
样本外测试结果
样本外准确率 0.6673511293634496
样本外AUC值 0.666959015010963

多因子模型

数据获取和预处理

股票池设定

沪深300成分股,剔除ST股票,剔除每个截面期下一交易日停牌的股票,剔除上市6个月内的股票,每只股票视作一个样本。

时间区间

2014年1月1日-2019年12月31日的5年区间。其中前4年区间(48个月)作为训练集,后1年区间(12个月)作为测试集。

1
2
3
4
5
#样本内时间段
start_date = '2014-01-01'
end_date = '2014-03-28'
#指数
index_code = '000300.XSHG'

特征和标签提取

每个自然月的第一个交易日,计算因子暴露度,作为样本的原始特征;计算下期收益率,作为样本的标签

1
2
data_regular = report.data_for_model_multiperiod(start_date,end_date,index_code)
data_regular.tail(20)
1
data_regular.shape
1
(352, 12)
1
2
y_train = data_regular['Label']  # 分割
x_train = data_regular.iloc[:,:-1]
1
x_train.tail()
MA_5 MA_60 Aroondown_14 Aroonup_14 ROC_6 ROC_12 RSI_6 RSI_24 OBV ATR_14 MOM_10
2012-05-25 -0.768642 -0.730667 1.407126 -1.385636 -0.742667 -0.845272 -1.066789 -0.593849 -0.788446 -0.653850 -0.666870
2013-12-03 -1.046679 -1.144522 -1.119941 0.685699 0.556855 0.965889 0.907214 0.662227 -0.641008 -0.618633 0.228764
2011-10-13 -0.712226 -0.354024 1.212737 -0.820727 0.658966 -0.151954 0.580085 -0.727583 -0.822100 0.093216 -0.104109
2016-04-15 0.630381 0.313270 -0.925552 0.874002 0.790899 1.307709 0.898470 0.523861 1.297079 0.323287 0.668458
2010-02-05 0.443279 0.911932 1.018347 -1.197333 -0.714451 -1.759498 -1.198944 -1.187684 -0.922039 0.934016 -1.924716
1
y_train.tail()
1
2
3
4
5
6
2012-05-25    0.0
2013-12-03    0.0
2011-10-13    1.0
2016-04-15    1.0
2010-02-05    0.0
Name: Label, dtype: float64

模型构建和样本内训练

1
2
3
4
#特征
x_train = data_regular.iloc[:,:-1]
#标签
y_train = data_regular['Label']  

通过Pipeline方法,将特征选择和模型构建结合起来,形成model_pipe对象,然后针对该对象做交叉验证并得到不同参数下的检验结果,辅助于最终模型的参数设置。

特征选择

用SelectPercentile(f_classif, percentile)来做特征选择,其中f_classif用来确定特征选择的得分标准,percentile用来确定特征选择的比例。

1
2
3
4
5
transform = SelectPercentile(f_classif)  # 使用f_classif方法选择特征最明显的?%数量的特征

#定义训练器
model = XGBClassifier()
model_pipe = Pipeline(steps=[('ANOVA', transform), ('model', model)])  # 建立由特征选择和分类模型构成的“管道”对象
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# 选择最佳特征比例
# #############################################################################
# Plot the cross-validation score as a function of percentile of features
score_means = list()
score_stds = list()
percentiles = (10, 20, 30, 40, 50, 60, 70, 80, 90, 100)
for percentile in percentiles:
    model_pipe.set_params(ANOVA__percentile=percentile)
    this_scores = cross_val_score(model_pipe, x_train, y_train, cv=5, n_jobs=-1)
    score_means.append(this_scores.mean())
    score_stds.append(this_scores.std())
plt.errorbar(percentiles, score_means, np.array(score_stds))
plt.title('Performance of the model-Anova varying the percentile of features selected')
plt.xlabel('Percentile')
plt.ylabel('Prediction rate')
plt.axis('tight')
plt.show()

交叉验证调参

特征(比例)选择完成后,根据不同的参数(n_estimators,max_depth),对模型进行交叉验证。采用StratifiedKFold来将训练集分成训练集和验证集。StratifiedKFold能够有效结合分类样本标签做数据集分割,而不是完全的随机选择和分割。完成交叉验证后,选取交叉验证集AUC(或f1-score)最高的一组参数作为模型的最优参数。

1
2
3
4
transform = SelectPercentile(f_classif,percentile=100)  # 使用f_classif方法选择特征最明显的?%数量的特征

model = XGBClassifier()
model_pipe = Pipeline(steps=[('ANOVA', transform), ('model', model)])  # 建立由特征选择和分类模型构成的“管道”对象
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
cv = StratifiedKFold(5)  # 设置交叉检验次数

#XGB max_depth
parameters = [3,4,5,6,7,8]
#XGB subsample
#parameters = [0.6,0.7,0.8,0.9,1]
#score_methods = ['roc_auc','accuracy', 'precision', 'recall', 'f1']  # 设置交叉检验指标
score_methods = ['roc_auc', 'f1']  # 设置交叉检验指标
#mean_list = list()  # 建立空列表用于存放不同参数方法、交叉检验评估指标的均值列表
#std_list = list()  # 建立空列表用于存放不同参数方法、交叉检验评估指标的标准差列表
for parameter in parameters:  # 循环读出每个参数值
    t1 = time.time()  # 记录训练开始的时间
    score_list = list()  # 建立空列表用于存放不同交叉检验下各个评估指标的详细数据
    print ('set parameters: %s' % parameter)  # 打印当前模型使用的参数
    for score_method in score_methods:  # 循环读出每个交叉检验指标
        #model_pipe.set_params(model__n_estimators=parameter)  # 通过“管道”设置分类模型参数
        model_pipe.set_params(model__max_depth=parameter)  # 通过“管道”设置分类模型参数            
        #model_pipe.set_params(model__subsample=parameter)  # 通过“管道”设置分类模型参数            
        score_tmp = cross_val_score(model_pipe, x_train, y_train, scoring=score_method, cv=cv, n_jobs=-1)  # 使用交叉检验计算指定指标的得分
        score_list.append(score_tmp)  # 将交叉检验得分存储到列表
    score_matrix = pd.DataFrame(np.array(score_list), index=score_methods)  # 将交叉检验详细数据转换为矩阵
    score_mean = score_matrix.mean(axis=1).rename('mean')  # 计算每个评估指标的均值
    score_std = score_matrix.std(axis=1).rename('std')  # 计算每个评估指标的标准差
    score_pd = pd.concat([score_matrix, score_mean, score_std], axis=1)  # 将原始详细数据和均值、标准差合并
    #mean_list.append(score_mean)  # 将每个参数得到的各指标均值追加到列表
    #std_list.append(score_std)  # 将每个参数得到的各指标标准差追加到列表
    print (score_pd.round(4))  # 打印每个参数得到的交叉检验指标数据,只保留4位小数
    print ('-' * 60)
    t2 = time.time()  # 计算每个参数下算法用时
    tt = t2 - t1  # 计算时间间隔
    print ('算法用时time: %s' % str(tt))  # 打印时间间隔
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
set parameters: 3
              0       1       2       3       4    mean     std
roc_auc  0.6929  0.6707  0.6939  0.7505  0.7167  0.7050  0.0302
f1       0.6588  0.6429  0.6667  0.7052  0.7086  0.6764  0.0291
------------------------------------------------------------
算法用时time: 0.6454043388366699
set parameters: 4
              0       1       2       3       4    mean     std
roc_auc  0.6820  0.6653  0.7065  0.7323  0.6920  0.6956  0.0254
f1       0.6647  0.6551  0.6964  0.6866  0.6964  0.6798  0.0189
------------------------------------------------------------
算法用时time: 0.7695553302764893
set parameters: 5
              0       1       2       3       4    mean     std
roc_auc  0.6813  0.6473  0.6968  0.7254  0.6814  0.6865  0.0283
f1       0.6667  0.6331  0.6725  0.7038  0.6891  0.6730  0.0266
------------------------------------------------------------
算法用时time: 0.9359838962554932
set parameters: 6
              0       1       2       3       4    mean     std
roc_auc  0.6896  0.6316  0.6802  0.7292  0.6718  0.6805  0.0351
f1       0.6786  0.6243  0.6531  0.6967  0.6629  0.6631  0.0273
------------------------------------------------------------
算法用时time: 1.0583992004394531
set parameters: 7
              0       1       2       3       4    mean     std
roc_auc  0.6583  0.6381  0.6676  0.7175  0.6581  0.6679  0.0297
f1       0.6509  0.6126  0.6395  0.7076  0.6552  0.6532  0.0347
------------------------------------------------------------
算法用时time: 1.185107946395874
set parameters: 8
              0       1       2       3       4    mean     std
roc_auc  0.6712  0.6266  0.6603  0.7188  0.6675  0.6689  0.0330
f1       0.6527  0.6163  0.6388  0.6962  0.6648  0.6538  0.0298
------------------------------------------------------------
算法用时time: 1.3339207172393799

模型构建

根据上述交叉验证的最优模型,使用XGBoosting集成学习模型对训练集进行训练。

1
2
transform.fit(x_train, y_train)  # 应用特征选择对象选择要参与建模的特征变量
X_train_final = transform.transform(x_train)  # 获得具有显著性特征的特征变量

XGBoost

(1)subsample subsample是训练集参与模型训练的比例,取值在0-1之间,可有效地防止过拟合。subsample参数的性能评价参考上面执行结果所示。随着subsample的上升,f1-score呈下降趋势,模型训练速度加快,综合训练时间和效果提升考量,选取subsample=0.9。

(2)max_depth max_depth参数的性能评价参考表所示。随着max_depth的上升,AUC和f1-score呈下降趋势,模型训练时间变慢。选取max_depth=3。

1
2
model = XGBClassifier(max_depth=3,subsample=0.9,random_state=0)
model.fit(X_train_final, y_train)  # 训练模型
1
2
3
4
5
6
7
XGBClassifier(base_score=0.5, booster='gbtree', colsample_bylevel=1,
       colsample_bynode=1, colsample_bytree=1, gamma=0, learning_rate=0.1,
       max_delta_step=0, max_depth=3, min_child_weight=1, missing=None,
       n_estimators=100, n_jobs=1, nthread=None,
       objective='binary:logistic', random_state=0, reg_alpha=0,
       reg_lambda=1, scale_pos_weight=1, seed=None, silent=None,
       subsample=0.9, verbosity=1)

样本外测试

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
#样本外测试
out_start = '2018-01-01'
out_end = '2019-12-31'

test_sample_predict={}
test_sample_score=[]
test_sample_accuracy=[]
test_sample_roc_auc=[]
test_sample_date=[]

interval_start,interval_end = report.get_time_inverval(out_start,out_end,'M')

for date1,date2 in dict(zip(interval_start,interval_end)).items():
    
    data_merge_label = report.data_for_model_perperiod(date1,date2,index_code)
    y_test=data_merge_label['Label']
    X_test=data_merge_label.iloc[:,:-1]

    # 新数据集做预测
    
    # 输出预测值以及预测概率
    y_pred_tmp = model.predict(X_test)
    y_pred = pd.DataFrame(y_pred_tmp, columns=['label_predict'])  # 获得预测标签
    y_pred_proba = pd.DataFrame(model.predict_proba(X_test), columns=['pro1', 'pro2'])  # 获得预测概率
    # 将预测标签、预测数据和原始数据X合并
    y_pred.set_index(X_test.index,inplace=True)
    y_pred_proba.set_index(X_test.index,inplace=True)
    predict_pd = pd.concat((X_test, y_pred, y_pred_proba), axis=1)
    print ('Predict date:')
    print (date1)    
    print ('AUC:')
    print (roc_auc_score(y_test,y_pred))  # 打印前2条结果
    print ('Accuracy:')
    print (accuracy_score(y_test, y_pred))  # 打印前2条结果    
    print ('-' * 60)       
    ## 后续统计画图用
    test_sample_date.append(date1)
    # 样本外预测结果
    test_sample_predict[date1]=y_pred_tmp
    # 样本外准确率
    test_sample_accuracy.append(accuracy_score(y_test, y_pred))   
    # 样本外AUC值
    test_sample_roc_auc.append(roc_auc_score(y_test,y_pred))

print ('AUC mean info')
print (np.mean(test_sample_roc_auc))
print ('-' * 60)    
print ('ACCURACY mean info')
print (np.mean(test_sample_accuracy))
print ('-' * 60)    
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
Predict date:
2018-01-01 00:00:00
AUC:
0.6055555555555556
Accuracy:
0.6055555555555555
------------------------------------------------------------
Predict date:
2018-02-01 00:00:00
AUC:
0.4444444444444445
Accuracy:
0.4444444444444444
------------------------------------------------------------
Predict date:
2018-03-01 00:00:00
AUC:
0.5388888888888889
Accuracy:
0.5388888888888889
------------------------------------------------------------
Predict date:
2018-04-01 00:00:00
AUC:
0.55
Accuracy:
0.55
------------------------------------------------------------
Predict date:
2018-05-01 00:00:00
AUC:
0.5277777777777778
Accuracy:
0.5277777777777778
------------------------------------------------------------
Predict date:
2018-06-01 00:00:00
AUC:
0.4444444444444444
Accuracy:
0.4444444444444444
------------------------------------------------------------
Predict date:
2018-07-01 00:00:00
AUC:
0.6166666666666667
Accuracy:
0.6166666666666667
------------------------------------------------------------
Predict date:
2018-08-01 00:00:00
AUC:
0.3611111111111111
Accuracy:
0.3611111111111111
------------------------------------------------------------
Predict date:
2018-09-01 00:00:00
AUC:
0.5166666666666667
Accuracy:
0.5166666666666667
------------------------------------------------------------
Predict date:
2018-10-01 00:00:00
AUC:
0.4555555555555556
Accuracy:
0.45555555555555555
------------------------------------------------------------
Predict date:
2018-11-01 00:00:00
AUC:
0.75
Accuracy:
0.75
------------------------------------------------------------
Predict date:
2018-12-01 00:00:00
AUC:
0.4666666666666666
Accuracy:
0.4666666666666667
------------------------------------------------------------
Predict date:
2019-01-01 00:00:00
AUC:
0.4722222222222222
Accuracy:
0.4722222222222222
------------------------------------------------------------
Predict date:
2019-02-01 00:00:00
AUC:
0.6777777777777778
Accuracy:
0.6777777777777778
------------------------------------------------------------
Predict date:
2019-03-01 00:00:00
AUC:
0.5722222222222222
Accuracy:
0.5722222222222222
------------------------------------------------------------
Predict date:
2019-04-01 00:00:00
AUC:
0.40555555555555556
Accuracy:
0.40555555555555556
------------------------------------------------------------
Predict date:
2019-05-01 00:00:00
AUC:
0.5111111111111111
Accuracy:
0.5111111111111111
------------------------------------------------------------
Predict date:
2019-06-01 00:00:00
AUC:
0.46111111111111114
Accuracy:
0.46111111111111114
------------------------------------------------------------
Predict date:
2019-07-01 00:00:00
AUC:
0.5611111111111111
Accuracy:
0.5611111111111111
------------------------------------------------------------
Predict date:
2019-08-01 00:00:00
AUC:
0.4833333333333333
Accuracy:
0.48333333333333334
------------------------------------------------------------
Predict date:
2019-09-01 00:00:00
AUC:
0.4444444444444444
Accuracy:
0.4444444444444444
------------------------------------------------------------
Predict date:
2019-10-01 00:00:00
AUC:
0.5277777777777777
Accuracy:
0.5277777777777778
------------------------------------------------------------
Predict date:
2019-11-01 00:00:00
AUC:
0.5722222222222223
Accuracy:
0.5722222222222222
------------------------------------------------------------
Predict date:
2019-12-01 00:00:00
AUC:
0.5277777777777778
Accuracy:
0.5277777777777778
------------------------------------------------------------
AUC mean info
0.5206018518518518
------------------------------------------------------------
ACCURACY mean info
0.5206018518518518
------------------------------------------------------------

样本外每期AUC

1
test_sample_roc_auc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
[0.6055555555555556,
 0.4444444444444445,
 0.5388888888888889,
 0.55,
 0.5277777777777778,
 0.4444444444444444,
 0.6166666666666667,
 0.3611111111111111,
 0.5166666666666667,
 0.4555555555555556,
 0.75,
 0.4666666666666666,
 0.4722222222222222,
 0.6777777777777778,
 0.5722222222222222,
 0.40555555555555556,
 0.5111111111111111,
 0.46111111111111114,
 0.5611111111111111,
 0.4833333333333333,
 0.4444444444444444,
 0.5277777777777777,
 0.5722222222222223,
 0.5277777777777778]

预测能力

1
2
3
4
5
6
7
8
9
10
11
12
xs_date = test_sample_date
ys_auc = test_sample_roc_auc
# 配置横坐标
plt.gca().xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
plt.plot(xs_date, ys_auc,'r')
# 自动旋转日期标记
plt.gcf().autofmt_xdate() 
# 横坐标标记
plt.xlabel('date')
# 纵坐标标记
plt.ylabel("test AUC")
plt.show()
1
2
3
4
5
6
7
8
9
10
11
12
13
xs_date = test_sample_date
ys_score = test_sample_accuracy
# 配置横坐标
plt.gca().xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
plt.plot(xs_date, ys_score,'r')
# 自动旋转日期标记
plt.gcf().autofmt_xdate() 
# 横坐标标记
plt.xlabel('date')
# 纵坐标标记
plt.ylabel("test accuracy")
plt.show()
1
2
3
4
5
6
7
8
f = plt.figure(figsize= (15,6))

sns.set(style="whitegrid")
data1 = pd.DataFrame(ys_auc, xs_date, columns={'AUC'})
data2 = pd.DataFrame(ys_score, xs_date, columns={'accuracy'})
data = pd.concat([data1,data2],sort=False)
sns.lineplot(data=data, palette="tab10", linewidth=2.5)
1
<matplotlib.axes._subplots.AxesSubplot at 0x264b0823c88>

下面特征重要度是这类算法的独有分析部分,不兼容其类机器学习算法

特征重要度

1
2
model = XGBClassifier(max_depth=3,subsample=0.9,random_state=0)
model.fit(x_train, y_train)  # 训练模型
1
2
3
4
5
6
7
XGBClassifier(base_score=0.5, booster='gbtree', colsample_bylevel=1,
       colsample_bynode=1, colsample_bytree=1, gamma=0, learning_rate=0.1,
       max_delta_step=0, max_depth=3, min_child_weight=1, missing=None,
       n_estimators=100, n_jobs=1, nthread=None,
       objective='binary:logistic', random_state=0, reg_alpha=0,
       reg_lambda=1, scale_pos_weight=1, seed=None, silent=None,
       subsample=0.9, verbosity=1)
1
2
3
4
5
6
7
8
#%matplotlib inline
fig = plt.figure(figsize= (15,6))

n_features = x_train.shape[1]
plt.barh(range(n_features),model.feature_importances_,align='center')
plt.yticks(np.arange(n_features),x_train.columns)
plt.xlabel("Feature importance")
plt.ylabel("Feature")
1
Text(0, 0.5, 'Feature')

集成,添加,调用,前面的单元格代码都可删掉

1
2
3
4
5
6
7
8
9
10
11
12
#样本内时间段
start_date = '2014-01-01'
end_date = '2018-12-31'
#指数
index_code = '000300.XSHG'

#样本外时段
out_start = '2018-10-01'
out_end = '2019-12-31'

model_name = 'xgboost'
file_name = 'xgboost_model'

调用类中的多因子训练和测试模型

前面的单元格代码都可删掉,写入样本内和样本外的开始时间和结束时间,模型名,保存文件的名字(自己创建)

1
report.multifactor_model(index_code,start_date,end_date,out_start,out_end,model_name,file_name)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
开始训练数据...


训练结束
------------------------------------------------------------
预测准确率:
xgboost模型: 0.667
------------------------------------------------------------
输出混淆矩阵...
col_0  0.0  1.0
Label          
0.0     25   14
1.0     10   23
------------------------------------------------------------
绘制曲线...
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
------------------------------------------------------------
输出评估报告...


模型的评估报告:
               precision    recall  f1-score   support

         0.0       0.71      0.64      0.68        39
         1.0       0.62      0.70      0.66        33

   micro avg       0.67      0.67      0.67        72
   macro avg       0.67      0.67      0.67        72
weighted avg       0.67      0.67      0.67        72

------------------------------------------------------------
保存模型到本地...


加载本地训练好的模型...
加载完毕
------------------------------------------------------------
样本外测试结果...
Predict date:
2018-10-01 00:00:00
AUC:
0.4833333333333334
Accuracy:
0.48333333333333334
------------------------------------------------------------
Predict date:
2018-11-01 00:00:00
AUC:
0.6722222222222222
Accuracy:
0.6722222222222223
------------------------------------------------------------
Predict date:
2018-12-01 00:00:00
AUC:
0.49444444444444446
Accuracy:
0.49444444444444446
------------------------------------------------------------
Predict date:
2019-01-01 00:00:00
AUC:
0.5611111111111111
Accuracy:
0.5611111111111111
------------------------------------------------------------
Predict date:
2019-02-01 00:00:00
AUC:
0.6166666666666666
Accuracy:
0.6166666666666667
------------------------------------------------------------
Predict date:
2019-03-01 00:00:00
AUC:
0.5277777777777778
Accuracy:
0.5277777777777778
------------------------------------------------------------
Predict date:
2019-04-01 00:00:00
AUC:
0.46111111111111114
Accuracy:
0.46111111111111114
------------------------------------------------------------
Predict date:
2019-05-01 00:00:00
AUC:
0.5166666666666667
Accuracy:
0.5166666666666667
------------------------------------------------------------
Predict date:
2019-06-01 00:00:00
AUC:
0.4555555555555556
Accuracy:
0.45555555555555555
------------------------------------------------------------
Predict date:
2019-07-01 00:00:00
AUC:
0.5611111111111111
Accuracy:
0.5611111111111111
------------------------------------------------------------
Predict date:
2019-08-01 00:00:00
AUC:
0.5111111111111111
Accuracy:
0.5111111111111111
------------------------------------------------------------
Predict date:
2019-09-01 00:00:00
AUC:
0.4555555555555556
Accuracy:
0.45555555555555555
------------------------------------------------------------
Predict date:
2019-10-01 00:00:00
AUC:
0.4222222222222222
Accuracy:
0.4222222222222222
------------------------------------------------------------
Predict date:
2019-11-01 00:00:00
AUC:
0.5444444444444444
Accuracy:
0.5444444444444444
------------------------------------------------------------
Predict date:
2019-12-01 00:00:00
AUC:
0.5666666666666667
Accuracy:
0.5666666666666667
------------------------------------------------------------
AUC mean info
0.5233333333333333
------------------------------------------------------------
ACCURACY mean info
0.5233333333333333
------------------------------------------------------------
1
2
3
4
5
6
7
XGBClassifier(base_score=0.5, booster='gbtree', colsample_bylevel=1,
       colsample_bynode=1, colsample_bytree=1, gamma=0, learning_rate=0.1,
       max_delta_step=0, max_depth=3, min_child_weight=1, missing=nan,
       n_estimators=100, n_jobs=1, nthread=None,
       objective='binary:logistic', random_state=0, reg_alpha=0,
       reg_lambda=1, scale_pos_weight=1, seed=None, silent=None,
       subsample=0.9, verbosity=1)

1
report.multifactor_model(index_code,start_date,end_date,out_start,out_end,model_name,file_name)
1

本文采用知识共享署名-非商业性使用-禁止演绎 4.0 国际许可协议(CC BY-NC-ND 4.0)进行许可,转载请注明出处,请勿用于任何商业用途采用。

☛决定关注我了吗☚