Linear regression is one of the supervised Machine learning algorithms in Python that observes continuous features and predicts an outcome. Depending on whether it runs on a single variable or on many features, we can call it simple linear regression or multiple linear regression.

This is one of the most popular Python ML algorithms and often under-appreciated. It assigns optimal weights to variables to create a line ax+b to predict the output. We often use linear regression to estimate real values like a number of calls and costs of houses based on continuous variables. The regression line is the best line that fits Y=a*X+b to denote a relationship between independent and dependent variables.
Let’s plot this for the diabetes dataset.
>>> import matplotlib.pyplot as plt
>>> import numpy as np
>>> from sklearn import datasets,linear_model
>>> from sklearn.metrics import mean_squared_error,r2_score
>>> diabetes=datasets.load_diabetes()
>>> diabetes_X=diabetes.data[:,np.newaxis,2]
>>> diabetes_X_train=diabetes_X[:-30] #splitting data into training and test sets
>>> diabetes_X_test=diabetes_X[-30:]
>>> diabetes_y_train=diabetes.target[:-30] #splitting targets into training and test sets
>>> diabetes_y_test=diabetes.target[-30:]
>>> regr=linear_model.LinearRegression() #Linear regression object
>>> regr.fit(diabetes_X_train,diabetes_y_train) #Use training sets to train the model
LinearRegression(copy_X=True, fit_intercept=True, n_jobs=1, normalize=False)
>>> diabetes_y_pred=regr.predict(diabetes_X_test) #Make predictions
>>> regr.coef_
array([941.43097333])
>>> mean_squared_error(diabetes_y_test,diabetes_y_pred)
3035.0601152912695
>>> r2_score(diabetes_y_test,diabetes_y_pred) #Variance score
0.410920728135835
>>> plt.scatter(diabetes_X_test,diabetes_y_test,color =’lavender’)
<matplotlib.collections.PathCollection object at 0x0584FF70>
>>> plt.plot(diabetes_X_test,diabetes_y_pred,color=’pink’,linewidth=3)
[<matplotlib.lines.Line2D object at 0x0584FF30>]
>>> plt.xticks(())
([], <a list of 0 Text xticklabel objects>)
>>> plt.yticks(())
([], <a list of 0 Text yticklabel objects>)
>>> plt.show()

