This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| alpha_step_1 = alpha | |
| print(f"Amount of say for the first stump: {round(alpha_step_1,3)}") | |
| alpha_step_2 = alpha_step_2 | |
| print(f"Amount of say for the second stump: {round(alpha_step_2,3)}") | |
| ########################################################################################## | |
| # Make a prediction: | |
| # Suppose a person lives in the U.S., is 30 years old, and works about 42 hours per week. | |
| ########################################################################################## | |
| # the first stump uses the hours worked per week (>40 hours) as the root node |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| ################################################################################################################ | |
| # find the root node for the second stump | |
| ################################################################################################################ | |
| df_step_2 = new_data_set[["male", ">50 years", ">50k income"]] | |
| selected_root_node_attribute_2 = find_attribute_that_shows_the_smallest_gini_index(df_step_2) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #################################################################################### | |
| # define bins to select new instances | |
| #################################################################################### | |
| import random | |
| df_extended_2["cum_sum_upper"] = df_extended_2["sample_weight"].cumsum() | |
| df_extended_2["cum_sum_low"] = [0] + df_extended_2["cum_sum_upper"][0:9].to_list() | |
| #################################################################################### | |
| # create new dataset |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def update_sample_weights(df_extended_1): | |
| # calculate the new weights for the misclassified samples | |
| def calc_new_sample_weight(x, alpha): | |
| new_weight = plot_scale_of_weights(alpha, x["sample_weight"], x["chosen_stump_incorrect"]) | |
| return new_weight | |
| df_extended_1["new_sample_weight"] = df_extended_1.apply(lambda x: calc_new_sample_weight(x, alpha), axis=1) | |
| # define new extended data frame | |
| df_extended_2 = df_extended_1[["male", ">40 hours", ">50 years", ">50k income", "new_sample_weight"]] | |
| df_extended_2 = df_extended_2.rename(columns={"new_sample_weight": "sample_weight"}, errors="raise") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import math | |
| def plot_scale_of_weights(alpha, current_sample_weight, incorrect): | |
| alpha_list = [] | |
| new_weights = [] | |
| if incorrect == 1: | |
| # adjust the sample weights for instances which were misclassified | |
| new_weight = current_sample_weight * math.exp(alpha) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import matplotlib.pyplot as plt | |
| from datetime import datetime | |
| # calculate the amount of say using the weighted error rate of the weak classifier | |
| alpha = 1/2 * np.log((1-error)/error) | |
| print(f'Amount of say / Alpha = {round(alpha,3)}') | |
| helper_functions.plot_alpha(alpha, error) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import helper_functions | |
| def calculate_error_for_chosen_stump(df, selected_root_node_attribute): | |
| ''' | |
| Attributes: | |
| df: trainings data set | |
| selected_root_node_attribute: name of the column used for the root node of the stump | |
| Return: | |
| df_extended: df extended by the calculated weights and error | |
| error: calculated error for the stump - sum of the weights of all samples that were misclassified by the decision stub |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def calc_weighted_gini_index(attribute, df): | |
| ''' | |
| Args: | |
| df: the trainings dataset stored in a data frame | |
| attribute: the chosen attribute for the root node of the tree | |
| Return: | |
| Gini_attribute: the gini index for the chosen attribute | |
| ''' | |
| d_node = df[[attribute, '>50k income']] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import numpy as np | |
| # define input parameter | |
| df['male'] = df['sex'].apply(lambda x : 'Yes' if x.lstrip() == "Male" else "No") | |
| df['>40 hours'] = np.where(df['hours-per-week']>40, 'Yes', 'No') | |
| df['>50 years'] = np.where(df['age']>50, 'Yes', 'No') | |
| # target | |
| df['>50k income'] = df['income'].apply(lambda x : 'Yes' if x.lstrip() == '>50K' else "No") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import pandas as pd | |
| df = pd.read_csv("<https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data>", | |
| names = ["age", | |
| "workclass", | |
| "fnlwgt", | |
| "education", | |
| "education-num", | |
| "marital-status", | |
| "occupation", |
NewerOlder