你的位置:首页 > 信息动态 > 新闻中心
信息动态
联系我们

Sklearn初学(线性回归)

2021/12/25 8:04:35

data.csv文件

from  sklearn.linear_model import LinearRegression
import numpy as np
from matplotlib import pyplot as plt

# 用numpy读取文件,加载数据
data = np.genfromtxt('data.csv',delimiter=',')
x_data = data[:,0]
y_data = data[:,1]
# 画出散点图
plt.scatter(x_data,y_data)
plt.show()
# 改变数组维度
x_data = data[:,0,np.newaxis]
y_data = data[:,1,np.newaxis]
# 创建并拟合模型
model = LinearRegression()
# 训练模型
model.fit(x_data,y_data)
# 绘制折线图
plt.plot(x_data,y_data,'.')
plt.plot(x_data,model.predict(x_data),'r')
plt.show()

运行结果如下: