基于Lasso回归的涨跌预测

Posted by 高庆东 on November 24, 2018

数据

tushare提供了全面的金融数据,可以直接调用api来获取数据,需要在python环境下安装anaconda

直接pip install tushare就可以了,然后调用函数既可以获得数据,该数据集下的数据主要有股票

的变化情况和参数,

目标

用每天的数据来预测三天后的涨跌幅度

用r2来检测算法的效果,在sklearning是实践经验证r2值完全不到0.1%,

python实现

数据读取

import numpy as np
import tushare as ts
import pandas as pd
from sklearn.model_selection import train_test_split


def dataset_read(df):
    #pro = ts.pro_api()
    #df = pro.daily(ts_code='000001.SZ', start_data='20181109', end_data='20181110')
    #df = sorted(df, key=lambda x: x.trade_date)
    df=np.array(df)
    #print("读取数据")
    #print(df.shape)
    return df

def dataset_proc(df):
    df_lable=df[:,8]
    df_lable=df_lable.reshape(-1,1)
    df_lable=np.delete(df_lable,[len(df_lable)-1,len(df_lable)-2,len(df_lable)-3],0)
    df_dates=np.delete(df,[0,1,8],1)
    df_dates=np.delete(df_dates,[0,1,2],0)
    x_train, x_test, y_train, y_test = train_test_split(df_dates, df_lable, test_size=0.2, random_state=0)
    # print("数据分类")
    # print(x_train.shape,y_train.shape,x_test.shape,y_test.shape)
    return x_train,x_test,y_train,y_test

建模

import tushare as ts
import numpy as np
from sklearn.linear_model import Lasso,LarsCV,LassoLarsCV
from sklearn.metrics import r2_score
import data_test_1 as dt

def train_modl_(x,df):

    df=dt.dataset_read(df)

    x_train,x_test,y_train,y_test=dt.dataset_proc(df)

    model = Lasso(alpha=x)

    model.fit(x_train, y_train)

    predicted = model.predict(x_test)

    score = r2_score(y_test, predicted)

    return score

调参

import train_modl_1 as tm
import numpy as np
import matplotlib.pyplot as plt
import tushare as ts

pro = ts.pro_api()
df = pro.daily(ts_code='000001.SZ', start_data='20181109', end_data='20181110')
#df = pro.query('daily', ts_code='000001.SZ', start_date='20180701', end_date='20180718')#可以正常使用的

print(np.array(df).shape)
a=np.arange(0.0005,0.5,0.0005)

soc=[]
for i in range(0,999):
    x=tm.train_modl_(a[i],df)
    soc.append(x)
    if i%10==0:
        print(i)
soc=np.array(soc)


plt.plot(a,soc,color='red')
plt.xlabel('alpha')
plt.ylabel('r2')
plt.show()