Building an ANN-based Customer Churn Prediction System: Deep Dive into Implementation

Try ANN-Classification webapp

ANN- Classification Webapp

Churn Prediction

Launch Application

Introduction

Customer churn prediction is a critical application of machine learning in today’s business environment. Companies across industries strive to identify customers who are likely to discontinue their services, enabling proactive retention strategies. This blog post provides a comprehensive breakdown of a customer churn prediction system built using Artificial Neural Networks (ANNs) and deployed as an interactive web application via Streamlit.

The complete project, available on GitHub, demonstrates the end-to-end pipeline from data preprocessing to model development and deployment. Let’s explore how this system works and the technical implementation details behind it.

Understanding the Churn Prediction Problem

Customer churn refers to when customers stop doing business with a company. In the context of this project, we’re predicting whether a telecom customer will leave the service provider (churn) based on various behavioral and demographic features. This is framed as a binary classification problem:

  • Class 0: Customer stays
  • Class 1: Customer churns (leaves)

Early identification of potential churners allows companies to implement targeted retention campaigns, which is typically more cost-effective than acquiring new customers.

Dataset Overview

The project uses the Telco Customer Churn dataset, which includes information about:

  • Customer demographics: Gender, age, partners, dependents
  • Account information: Tenure, contract type, payment method, billing preferences
  • Service subscriptions: Phone, internet, streaming, backup, protection services
  • Financial metrics: Monthly charges, total charges
  • Churn status: Whether the customer left the company (target variable)

Application Architecture

The Streamlit application provides a user-friendly interface for:

  1. Exploratory Data Analysis: Visualizing patterns and relationships in the churn data
  2. Model Exploration: Understanding the ANN architecture and performance metrics
  3. Real-time Prediction: Making churn predictions for individual customers
  4. Model Interpretation: Explaining predictions using SHAP values

Let’s dive into the implementation details, starting with the structure of the Streamlit application.

Streamlit Application Implementation

The Streamlit app is organized into multiple pages with distinct functionality. Here’s a breakdown of the app.py file which serves as the main entry point:

import streamlit as st
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow as tf
from tensorflow.keras.models import load_model
import shap
import pickle
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline

# Set page configuration
st.set_page_config(
    page_title="Churn Prediction App",
    page_icon="📊",
    layout="wide",
    initial_sidebar_state="expanded"
)

# Load the saved model
@st.cache_resource
def load_model_and_components():
    model = load_model('models/churn_prediction_model.h5')
    with open('models/preprocessor.pkl', 'rb') as f:
        preprocessor = pickle.load(f)
    return model, preprocessor

# Load the dataset
@st.cache_data
def load_data():
    df = pd.read_csv('data/telco_churn.csv')
    # Basic preprocessing
    df['TotalCharges'] = pd.to_numeric(df['TotalCharges'], errors='coerce')
    df['TotalCharges'].fillna(df['MonthlyCharges'], inplace=True)
    # Convert target to binary
    df['Churn'] = df['Churn'].map({'Yes': 1, 'No': 0})
    return df

# Load resources
model, preprocessor = load_model_and_components()
df = load_data()

# Sidebar navigation
st.sidebar.title("Navigation")
page = st.sidebar.radio(
    "Select a page:",
    ["Home", "Data Exploration", "Model Performance", "Prediction", "Model Explanation"]
)

# Home page
if page == "Home":
    st.title("Customer Churn Prediction System")
    st.image("images/churn_banner.jpg", use_column_width=True)
    
    st.markdown("""
    ## Welcome to the Churn Prediction App!
    
    This application uses Artificial Neural Networks to predict whether a customer
    will churn (leave) based on various features like demographics, services subscribed,
    and billing information.
    
    ### What can you do with this app?
    
    - **Explore the data**: Understand the patterns and relationships in the customer data
    - **View model performance**: See how well our neural network performs
    - **Make predictions**: Predict if a specific customer will churn
    - **Interpret predictions**: Understand the factors influencing churn predictions
    
    Navigate through the different sections using the sidebar on the left.
    """)
    
    # Display key metrics
    col1, col2, col3 = st.columns(3)
    with col1:
        churn_rate = df['Churn'].mean() * 100
        st.metric("Current Churn Rate", f"{churn_rate:.2f}%")
    
    with col2:
        avg_tenure = df['tenure'].mean()
        st.metric("Average Customer Tenure", f"{avg_tenure:.1f} months")
    
    with col3:
        avg_monthly = df['MonthlyCharges'].mean()
        st.metric("Average Monthly Charge", f"${avg_monthly:.2f}")

# Data Exploration page
elif page == "Data Exploration":
    st.title("Data Exploration")
    
    # Dataset overview
    st.subheader("Dataset Overview")
    st.dataframe(df.head())
    
    # Basic statistics
    st.subheader("Basic Statistics")
    st.dataframe(df.describe())
    
    # Missing values
    st.subheader("Missing Values")
    missing_values = df.isnull().sum()
    st.write(missing_values[missing_values > 0] if any(missing_values > 0) else "No missing values")
    
    # Feature distributions
    st.subheader("Feature Distributions")
    
    # Select feature for visualization
    feature = st.selectbox(
        "Select a feature to visualize:",
        options=df.columns.tolist(),
        index=df.columns.get_loc("tenure")
    )
    
    # Plot based on feature type
    if df[feature].dtype == 'object' or df[feature].nunique() < 10:
        fig, ax = plt.subplots(figsize=(10, 6))
        counts = df[feature].value_counts()
        sns.barplot(x=counts.index, y=counts.values, ax=ax)
        plt.xticks(rotation=45)
        plt.title(f"Distribution of {feature}")
        plt.tight_layout()
        st.pyplot(fig)
    else:
        fig, ax = plt.subplots(figsize=(10, 6))
        sns.histplot(data=df, x=feature, hue="Churn", multiple="stack", bins=20)
        plt.title(f"Distribution of {feature} by Churn Status")
        plt.tight_layout()
        st.pyplot(fig)
    
    # Correlation heatmap
    st.subheader("Correlation Between Numerical Features")
    numerical_df = df.select_dtypes(include=['float64', 'int64'])
    corr = numerical_df.corr()
    
    fig, ax = plt.subplots(figsize=(10, 8))
    mask = np.triu(np.ones_like(corr, dtype=bool))
    cmap = sns.diverging_palette(230, 20, as_cmap=True)
    sns.heatmap(corr, mask=mask, cmap=cmap, vmax=.3, center=0,
                annot=True, square=True, linewidths=.5, cbar_kws={"shrink": .5})
    plt.title("Correlation Heatmap")
    st.pyplot(fig)
    
    # Churn analysis
    st.subheader("Churn Analysis")
    
    # Churn by categorical features
    categorical_cols = df.select_dtypes(include=['object']).columns.tolist()
    categorical_cols.append('Contract')  # Add contract even if it's not object type
    
    selected_cat = st.selectbox("Select categorical feature:", categorical_cols)
    
    fig, ax = plt.subplots(figsize=(10, 6))
    churn_by_cat = df.groupby([selected_cat, 'Churn']).size().unstack()
    churn_rate_by_cat = churn_by_cat[1] / (churn_by_cat[0] + churn_by_cat[1])
    
    sns.barplot(x=churn_rate_by_cat.index, y=churn_rate_by_cat.values)
    plt.title(f"Churn Rate by {selected_cat}")
    plt.xticks(rotation=45)
    plt.ylabel("Churn Rate")
    plt.tight_layout()
    st.pyplot(fig)

# Model Performance page
elif page == "Model Performance":
    st.title("Model Performance")
    
    # Load performance metrics
    with open('models/performance_metrics.pkl', 'rb') as f:
        performance = pickle.load(f)
    
    # Display metrics
    st.subheader("Model Metrics")
    col1, col2, col3, col4 = st.columns(4)
    with col1:
        st.metric("Accuracy", f"{performance['accuracy']:.4f}")
    with col2:
        st.metric("Precision", f"{performance['precision']:.4f}")
    with col3:
        st.metric("Recall", f"{performance['recall']:.4f}")
    with col4:
        st.metric("F1 Score", f"{performance['f1']:.4f}")
    
    # ROC Curve
    st.subheader("ROC Curve")
    fig, ax = plt.subplots(figsize=(10, 6))
    plt.plot(performance['fpr'], performance['tpr'], label=f"AUC = {performance['auc']:.4f}")
    plt.plot([0, 1], [0, 1], 'k--')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('ROC Curve')
    plt.legend()
    st.pyplot(fig)
    
    # Confusion Matrix
    st.subheader("Confusion Matrix")
    fig, ax = plt.subplots(figsize=(8, 6))
    sns.heatmap(performance['confusion_matrix'], annot=True, fmt='d', cmap='Blues', cbar=False)
    plt.title('Confusion Matrix')
    plt.xlabel('Predicted')
    plt.ylabel('Actual')
    st.pyplot(fig)
    
    # Learning Curves
    st.subheader("Learning Curves")
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    
    # Accuracy
    ax1.plot(performance['history']['accuracy'], label='Training Accuracy')
    ax1.plot(performance['history']['val_accuracy'], label='Validation Accuracy')
    ax1.set_title('Model Accuracy')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Accuracy')
    ax1.legend()
    
    # Loss
    ax2.plot(performance['history']['loss'], label='Training Loss')
    ax2.plot(performance['history']['val_loss'], label='Validation Loss')
    ax2.set_title('Model Loss')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Loss')
    ax2.legend()
    
    plt.tight_layout()
    st.pyplot(fig)
    
    # Feature Importance
    st.subheader("Feature Importance")
    
    if 'feature_importance' in performance:
        fig, ax = plt.subplots(figsize=(12, 8))
        sorted_idx = np.argsort(performance['feature_importance'])
        plt.barh(range(len(sorted_idx)), performance['feature_importance'][sorted_idx])
        plt.yticks(range(len(sorted_idx)), np.array(performance['feature_names'])[sorted_idx])
        plt.title('Feature Importance')
        plt.tight_layout()
        st.pyplot(fig)
    else:
        st.info("Feature importance information is not available for this model.")

# Prediction page
elif page == "Prediction":
    st.title("Customer Churn Prediction")
    
    st.markdown("""
    ## Make predictions for individual customers
    
    Fill in the customer information below to predict their likelihood of churning.
    """)
    
    # Create input form with columns for better layout
    col1, col2 = st.columns(2)
    
    with col1:
        gender = st.selectbox('Gender', ['Male', 'Female'])
        senior_citizen = st.selectbox('Senior Citizen', ['No', 'Yes'])
        partner = st.selectbox('Partner', ['No', 'Yes'])
        dependents = st.selectbox('Dependents', ['No', 'Yes'])
        tenure = st.slider('Tenure (months)', 0, 72, 24)
        phone_service = st.selectbox('Phone Service', ['No', 'Yes'])
        
        if phone_service == 'Yes':
            multiple_lines = st.selectbox('Multiple Lines', ['No', 'Yes'])
        else:
            multiple_lines = 'No phone service'
    
    with col2:
        internet_service = st.selectbox('Internet Service', ['DSL', 'Fiber optic', 'No'])
        
        if internet_service != 'No':
            online_security = st.selectbox('Online Security', ['No', 'Yes'])
            online_backup = st.selectbox('Online Backup', ['No', 'Yes'])
            device_protection = st.selectbox('Device Protection', ['No', 'Yes'])
            tech_support = st.selectbox('Tech Support', ['No', 'Yes'])
            streaming_tv = st.selectbox('Streaming TV', ['No', 'Yes'])
            streaming_movies = st.selectbox('Streaming Movies', ['No', 'Yes'])
        else:
            online_security = 'No internet service'
            online_backup = 'No internet service'
            device_protection = 'No internet service'
            tech_support = 'No internet service'
            streaming_tv = 'No internet service'
            streaming_movies = 'No internet service'
    
    col3, col4 = st.columns(2)
    
    with col3:
        contract = st.selectbox('Contract', ['Month-to-month', 'One year', 'Two year'])
        paperless_billing = st.selectbox('Paperless Billing', ['No', 'Yes'])
    
    with col4:
        payment_method = st.selectbox('Payment Method', [
            'Electronic check', 
            'Mailed check', 
            'Bank transfer (automatic)', 
            'Credit card (automatic)'
        ])
        monthly_charges = st.slider('Monthly Charges ($)', 0, 150, 70)
        total_charges = st.slider('Total Charges ($)', 0, 10000, monthly_charges * tenure)
    
    # Create a dictionary with the input values
    input_data = {
        'gender': gender,
        'SeniorCitizen': 1 if senior_citizen == 'Yes' else 0,
        'Partner': partner,
        'Dependents': dependents,
        'tenure': tenure,
        'PhoneService': phone_service,
        'MultipleLines': multiple_lines,
        'InternetService': internet_service,
        'OnlineSecurity': online_security,
        'OnlineBackup': online_backup,
        'DeviceProtection': device_protection,
        'TechSupport': tech_support,
        'StreamingTV': streaming_tv,
        'StreamingMovies': streaming_movies,
        'Contract': contract,
        'PaperlessBilling': paperless_billing,
        'PaymentMethod': payment_method,
        'MonthlyCharges': monthly_charges,
        'TotalCharges': total_charges
    }
    
    # Create DataFrame from input
    input_df = pd.DataFrame([input_data])
    
    # Prediction button
    if st.button('Predict Churn Probability'):
        # Preprocess the input data
        X_processed = preprocessor.transform(input_df)
        
        # Make prediction
        prediction = model.predict(X_processed)[0][0]
        
        # Display prediction
        st.subheader("Prediction Result")
        
        # Create gauge chart for probability
        fig = plt.figure(figsize=(8, 4))
        ax = fig.add_subplot(111)
        
        # Create gauge
        pos = ax.barh([0], [prediction], left=0, height=0.5, color='red')
        neg = ax.barh([0], [1-prediction], left=prediction, height=0.5, color='green')
        
        # Remove axis
        ax.set_yticks([])
        ax.set_xlim(0, 1)
        ax.spines['right'].set_visible(False)
        ax.spines['top'].set_visible(False)
        ax.spines['left'].set_visible(False)
        ax.spines['bottom'].set_visible(False)
        
        # Add text
        ax.text(0.5, -0.5, f"Churn Probability: {prediction:.2%}", ha='center', va='center', fontsize=15)
        plt.tight_layout()
        
        st.pyplot(fig)
        
        # Show interpretation
        if prediction > 0.5:
            st.error(f"⚠️ High Risk of Churn: This customer has a {prediction:.2%} probability of churning.")
            st.markdown("""
            ### Recommended Actions:
            1. Reach out to the customer with a retention offer
            2. Address any service issues they may be experiencing
            3. Consider offering a contract upgrade with better terms
            """)
        else:
            st.success(f"✅ Low Risk of Churn: This customer has a {prediction:.2%} probability of churning.")
            st.markdown("""
            ### Recommended Actions:
            1. Maintain regular engagement to ensure continued satisfaction
            2. Consider cross-selling or upselling opportunities
            3. Encourage referrals from this stable customer
            """)

# Model Explanation page
elif page == "Model Explanation":
    st.title("Model Explanation")
    
    st.markdown("""
    ## Understanding Model Predictions
    
    This page uses SHAP (SHapley Additive exPlanations) values to explain how different features
    contribute to the model's predictions.
    """)
    
    # Load precomputed SHAP values or calculate them (resource intensive)
    @st.cache_resource
    def get_shap_values():
        # Create a small sample for SHAP analysis (for demonstration purposes)
        sample_df = df.sample(100, random_state=42)
        X_sample = preprocessor.transform(sample_df.drop('Churn', axis=1))
        
        # Create explainer
        explainer = shap.DeepExplainer(model, X_sample[:10])
        
        # Calculate SHAP values
        shap_values = explainer.shap_values(X_sample)
        
        # Get feature names after preprocessing
        feature_names = []
        for name, transformer, cols in preprocessor.transformers_:
            if name != 'remainder':
                if hasattr(transformer, 'get_feature_names_out'):
                    feature_names.extend(transformer.get_feature_names_out(cols))
                else:
                    feature_names.extend(cols)
        
        return shap_values, X_sample, feature_names, explainer
    
    with st.spinner("Loading SHAP values... This may take a moment."):
        try:
            shap_values, X_sample, feature_names, explainer = get_shap_values()
            
            # SHAP Summary Plot
            st.subheader("Feature Importance (SHAP Summary Plot)")
            
            fig, ax = plt.subplots(figsize=(10, 8))
            shap.summary_plot(shap_values[0], X_sample, feature_names=feature_names, show=False)
            plt.tight_layout()
            st.pyplot(fig)
            
            # Individual SHAP Explanation
            st.subheader("Individual Prediction Explanation")
            
            # Let user select a sample
            sample_idx = st.slider("Select a sample to explain:", 0, len(X_sample)-1, 0)
            
            # Force plot for selected sample
            st.write("SHAP Force Plot (showing how each feature contributes to the prediction):")
            
            fig, ax = plt.subplots(figsize=(12, 3))
            shap.force_plot(
                explainer.expected_value[0], 
                shap_values[0][sample_idx], 
                X_sample[sample_idx],
                feature_names=feature_names, 
                matplotlib=True,
                show=False
            )
            plt.tight_layout()
            st.pyplot(fig)
            
            # Decision Plot
            st.subheader("Decision Plot")
            fig, ax = plt.subplots(figsize=(10, 8))
            shap.decision_plot(
                explainer.expected_value[0], 
                shap_values[0][sample_idx], 
                feature_names=feature_names,
                show=False
            )
            plt.tight_layout()
            st.pyplot(fig)
            
        except Exception as e:
            st.error(f"Error generating SHAP explanations: {e}")
            st.info("SHAP analysis requires significant computational resources. Try with a smaller sample or check the model configuration.")

Model Training Architecture

The project implements a neural network using TensorFlow/Keras. Here’s the architecture of the ANN model:

def build_churn_model(input_shape):
    model = tf.keras.Sequential([
        tf.keras.layers.Dense(64, activation='relu', input_shape=(input_shape,)),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.Dropout(0.3),
        
        tf.keras.layers.Dense(32, activation='relu'),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.Dropout(0.2),
        
        tf.keras.layers.Dense(16, activation='relu'),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.Dropout(0.1),
        
        tf.keras.layers.Dense(1, activation='sigmoid')
    ])
    
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
        loss='binary_crossentropy',
        metrics=['accuracy', tf.keras.metrics.AUC()]
    )
    
    return model

This architecture includes:

  1. Input Layer: Accepts customer features after preprocessing
  2. Hidden Layers: Three dense layers with decreasing neuron counts (64 → 32 → 16)
  3. Regularization: Each layer has BatchNormalization and Dropout to prevent overfitting
  4. Output Layer: A single neuron with sigmoid activation for binary classification

The model is compiled with:

  • Optimizer: Adam with learning rate of 0.001
  • Loss Function: Binary cross-entropy (standard for binary classification)
  • Metrics: Accuracy and AUC (Area Under the ROC Curve)

Data Preprocessing Pipeline

The preprocessing pipeline handles both numerical and categorical features:

def create_preprocessor(X):
    # Identify categorical and numerical columns
    categorical_cols = X.select_dtypes(include=['object']).columns.tolist()
    numerical_cols = X.select_dtypes(include=['int64', 'float64']).columns.tolist()
    
    # Create preprocessing pipelines
    numerical_transformer = Pipeline(steps=[
        ('scaler', StandardScaler())
    ])
    
    categorical_transformer = Pipeline(steps=[
        ('onehot', OneHotEncoder(handle_unknown='ignore'))
    ])
    
    # Combine preprocessing steps
    preprocessor = ColumnTransformer(
        transformers=[
            ('num', numerical_transformer, numerical_cols),
            ('cat', categorical_transformer, categorical_cols)
        ])
    
    return preprocessor

Key preprocessing steps include:

  • Numerical features: Standardized using StandardScaler
  • Categorical features: Transformed using OneHotEncoder
  • Missing values: Handled appropriately (e.g., TotalCharges nulls filled with MonthlyCharges)

Model Training and Evaluation

Model training incorporates several best practices:

# Split data into train and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Preprocess the data
preprocessor = create_preprocessor(X)
X_train_processed = preprocessor.fit_transform(X_train)
X_test_processed = preprocessor.transform(X_test)

# Create and train the model
model = build_churn_model(X_train_processed.shape[1])

# Define callbacks
early_stopping = tf.keras.callbacks.EarlyStopping(
    monitor='val_loss', 
    patience=10, 
    restore_best_weights=True
)

reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(
    monitor='val_loss', 
    factor=0.2, 
    patience=5, 
    min_lr=1e-6
)

# Train the model
history = model.fit(
    X_train_processed, y_train,
    epochs=100,
    batch_size=32,
    validation_split=0.2,
    callbacks=[early_stopping, reduce_lr],
    verbose=1
)

Training includes:

  • Data splitting: 80% training, 20% testing
  • Validation: Additional 20% of training data used for validation
  • Callbacks:
    • Early stopping to prevent overfitting
    • Learning rate reduction when performance plateaus
  • Batch size: 32 (balances computing efficiency and gradient accuracy)
  • Epochs: Up to 100, with early stopping

Model Training with TensorBoard Integration

One of the project’s highlights is the integration of TensorBoard for visualization and monitoring:

python# Set up TensorBoard callback
import datetime
log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(
    log_dir=log_dir,
    histogram_freq=1,
    write_graph=True
)

# Early stopping to prevent overfitting
early_stopping = tf.keras.callbacks.EarlyStopping(
    monitor='val_loss',
    patience=10,
    restore_best_weights=True
)

# Model checkpoint to save the best model
model_checkpoint = tf.keras.callbacks.ModelCheckpoint(
    'best_model.h5',
    monitor='val_accuracy',
    save_best_only=True,
    mode='max'
)

# Train the model
history = model.fit(
    X_train, y_train,
    epochs=100,
    batch_size=32,
    validation_split=0.2,
    callbacks=[tensorboard_callback, early_stopping, model_checkpoint]
)

This section includes:

  1. TensorBoard Configuration: Sets up logging for real-time training visualization
  2. Early Stopping: Prevents overfitting by monitoring validation loss and stopping when it no longer improves
  3. Model Checkpoint: Saves the model with the highest validation accuracy
  4. Training Process: Runs for up to 100 epochs with a batch size of 32, reserving 20% of the training data for validation

TensorBoard Features

TensorBoard provides several valuable visualizations:

  1. Scalars: Tracks metrics like loss and accuracy over time
  2. Distributions: Shows how weights and biases evolve during training
  3. Graphs: Visualizes the model’s computational graph
  4. Histograms: Displays weight distributions across layers

Model Evaluation

The project includes comprehensive evaluation metrics:

python# Evaluate the model on the test set
test_loss, test_accuracy = model.evaluate(X_test, y_test)
print(f"Test Loss: {test_loss:.4f}")
print(f"Test Accuracy: {test_accuracy:.4f}")

# Make predictions
y_pred_prob = model.predict(X_test)
y_pred = (y_pred_prob > 0.5).astype(int)

# Calculate evaluation metrics
from sklearn.metrics import classification_report, confusion_matrix, roc_curve, auc

# Confusion matrix
cm = confusion_matrix(y_test, y_pred)
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.title('Confusion Matrix')
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.show()

# Classification report
print(classification_report(y_test, y_pred))

# ROC curve
fpr, tpr, thresholds = roc_curve(y_test, y_pred_prob)
roc_auc = auc(fpr, tpr)

plt.figure(figsize=(8, 6))
plt.plot(fpr, tpr, color='blue', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
plt.plot([0, 1], [0, 1], color='gray', lw=2, linestyle='--')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic (ROC) Curve')
plt.legend(loc='lower right')
plt.show()

The evaluation includes:

  1. Test Loss and Accuracy: Basic performance metrics on the test set
  2. Confusion Matrix: Visualizes true positives, false positives, true negatives, and false negatives
  3. Classification Report: Detailed metrics including precision, recall, and F1-score
  4. ROC Curve: Plots the true positive rate against the false positive rate at different thresholds

Model Interpretation with SHAP

One of the most valuable aspects of this project is its use of SHAP (SHapley Additive exPlanations) values to interpret model predictions:

def generate_shap_explanations(model, X_processed, feature_names):
    # Create a background dataset for SHAP
    background = X_processed[:100]  # Use first 100 instances as background
    
    # Create explainer
    explainer = shap.DeepExplainer(model, background)
    
    # Calculate SHAP values
    shap_values = explainer.shap_values(X_processed)
    
    return explainer, shap_values

SHAP analysis provides:

  1. Global interpretability: Which features most affect churn predictions overall
  2. Local interpretability: How each feature contributes to individual predictions
  3. Force plots: Visual representation of feature impacts on specific predictions
  4. Decision plots: Showing how the model progresses from baseline to final prediction

Key Performance Features of the Streamlit App

The Streamlit application stands out with several innovative features:

1. Interactive EDA

  • Dynamic feature selection for visualization
  • Comparative analysis of churned vs. non-churned customers
  • Analysis of churn rates across different categorical variables

2. Comprehensive Performance Metrics

  • Confusion matrix visualization
  • ROC curve and AUC score
  • Learning curves showing model convergence
  • Feature importance rankings

3. Real-time Prediction Interface

  • Intuitive form for entering customer details
  • Visual probability gauge for prediction results
  • Tailored retention recommendations based on churn risk

4. Advanced Model Interpretation

  • SHAP summary plots showing global feature importance
  • Individual prediction explanations using force plots
  • Decision plots showing prediction pathways

Deployment Considerations

The Streamlit app is deployed on Streamlit’s cloud platform, making it accessible to users without requiring local setup. Key deployment considerations include:

  1. Model Serialization: The trained model and preprocessor are saved and loaded efficiently
  2. Caching: Resource-intensive operations use Streamlit’s caching mechanisms
  3. Error Handling: Robust error handling for SHAP calculations and predictions
  4. Performance Optimization: Sample-based SHAP calculations to manage computational load
  5. Responsive Design: Column layouts for better user experience across devices

Business Impact and Applications

This churn prediction system offers several business benefits:

  1. Proactive Retention: Identifying high-risk customers before they leave
  2. Resource Optimization: Focusing retention efforts on customers most likely to churn
  3. Root Cause Analysis: Understanding the key drivers of churn
  4. Strategy Development: Informing long-term product and service improvements
  5. ROI Measurement: Quantifying the impact of retention initiatives

Technical Insights and Learnings

From this implementation, we can derive several technical insights:

  1. Architecture Choices: The moderate-sized neural network with dropout and batch normalization balances complexity and generalization
  2. Feature Importance: Contract type, tenure, and monthly charges typically emerge as the most influential features
  3. Model Training: Early stopping and learning rate reduction help find optimal parameters
  4. Hyperparameter Sensitivity: The model’s performance is particularly sensitive to learning rate and dropout rates
  5. Explainability: SHAP values provide crucial transparency for a traditionally “black box” neural network

Conclusion

The ANN-based Customer Churn Prediction System demonstrates a complete machine learning pipeline from data preparation to deployment and interpretation. By combining the predictive power of neural networks with the explainability of SHAP analysis, it delivers both accurate predictions and actionable insights.

The Streamlit interface makes these sophisticated techniques accessible to business users without requiring technical expertise, bridging the gap between advanced AI and practical business applications.

For organizations looking to reduce customer attrition, this system provides a powerful tool for identifying at-risk customers, understanding churn drivers, and implementing targeted retention strategies.

Next Steps and Future Enhancements

Potential improvements to the system could include:

  1. Model Ensemble: Combining the ANN with other algorithms like XGBoost
  2. Time Series Analysis: Incorporating temporal patterns in customer behavior
  3. Automatic Monitoring: Detecting concept drift and model performance degradation
  4. Recommendation Engine: Suggesting specific retention offers based on customer profiles

Leave a Reply

Your email address will not be published. Required fields are marked *