Created
February 17, 2021 21:31
-
-
Save sohiniroych/62f644d76994c01464cae3b6f160e7b5 to your computer and use it in GitHub Desktop.
Model Fitting and visualization
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #To generate a best fit model | |
| X_range=np.zeros((50,3)) | |
| y_range=np.zeros((50,)) | |
| for i in range(3): | |
| Xi=X[:,i] | |
| vals=plt.hist(Xi,49) | |
| plt.xlabel("Feature") | |
| plt.ylabel("Frequency") | |
| X_range[:,i]=np.transpose(vals[1]) | |
| y_range=model.predict(X_range) | |
| # Plot the results | |
| plt.figure() | |
| plt.scatter(X[:,0], y, s=20, edgecolor="black", c="darkorange", label="train data") | |
| plt.scatter(x_test[:,0], model.predict(x_test), s=30, color="yellowgreen", label="test data", linewidth=2) | |
| plt.plot(X_range[:,0], y_range, color="cornflowerblue", | |
| label="Regression_model", linewidth=2) | |
| plt.xlabel("R&D Cost") | |
| plt.ylabel("Profit") | |
| plt.title("Decision Tree Regression") | |
| plt.legend() | |
| plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment