Loading W Code...
5
Topics
sklearn
Implementation
Indian Context
Telecom Churn (Jio, Airtel), Credit Scoring, Fraud Detection for UPI
Concept Level: Beginner
# ENSEMBLE LEARNING: THE POWER OF COMBINING MODELS
import numpy as np
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import make_classification
from sklearn.model_selection import cross_val_score
import matplotlib.pyplot as plt
# Create dataset
X, y = make_classification(
n_samples=1000, n_features=20,
n_informative=15, n_redundant=5,
random_state=42
)
# Compare single tree vs ensemble
print("="*60)
print("SINGLE TREE vs ENSEMBLE COMPARISON")
print("="*60)
# Single Decision Tree
single_tree = DecisionTreeClassifier(max_depth=10, random_state=42)
single_scores = cross_val_score(single_tree, X, y, cv=10)
# Random Forest (100 trees)
forest = RandomForestClassifier(n_estimators=100, max_depth=10, random_state=42)
forest_scores = cross_val_score(forest, X, y, cv=10)
print(f"\nSingle Decision Tree:")
print(f" Mean Accuracy: {single_scores.mean():.4f}")
print(f" Std Deviation: {single_scores.std():.4f}")
print(f"\nRandom Forest (100 trees):")
print(f" Mean Accuracy: {forest_scores.mean():.4f}")
print(f" Std Deviation: {forest_scores.std():.4f}")
print(f"\nā Improvement: {(forest_scores.mean() - single_scores.mean())*100:.2f}%")
print(f"ā More stable: {single_scores.std()/forest_scores.std():.1f}x lower variance")
# Visualize the effect of number of trees
n_trees_range = [1, 5, 10, 25, 50, 100, 200]
accuracies = []
for n in n_trees_range:
rf = RandomForestClassifier(n_estimators=n, max_depth=10, random_state=42)
scores = cross_val_score(rf, X, y, cv=5)
accuracies.append(scores.mean())
plt.figure(figsize=(10, 6))
plt.plot(n_trees_range, accuracies, 'b-', marker='o', linewidth=2)
plt.axhline(y=single_scores.mean(), color='r', linestyle='--',
label=f'Single Tree: {single_scores.mean():.4f}')
plt.xlabel('Number of Trees')
plt.ylabel('Cross-Validation Accuracy')
plt.title('More Trees = Better Performance (up to a point)')
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
print("\nš” Insight: Performance improves with more trees, then plateaus!")