Skip to content

Instantly share code, notes, and snippets.

@nipuntalukdar
Created September 11, 2024 02:04
Show Gist options
  • Select an option

  • Save nipuntalukdar/ebbd6246bb5766e588d9fe03b803d138 to your computer and use it in GitHub Desktop.

Select an option

Save nipuntalukdar/ebbd6246bb5766e588d9fe03b803d138 to your computer and use it in GitHub Desktop.
Get the number of trees in an XGBoost model
import sys
import json
import xgboost as xgb
if len(sys.argv) < 2:
print(f'Usage: {sys.argv[0]} <model-file>')
exit(1)
loaded_model = xgb.Booster()
loaded_model.load_model(sys.argv[1])
loaded_model.save_model('/tmp/a_model.json')
with open('/tmp/a_model.json', 'r') as fp:
jsonrepr = json.load(fp)
print(jsonrepr['learner']['gradient_booster']['model']['gbtree_model_param']['num_trees'])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment