ChurnShield is a complete, end-to-end AI-powered Customer Churn Prediction & Retention Intelligence platform. It is designed to help e-commerce businesses identify at-risk customers, understand why they are leaving, and generate actionable, personalized retention strategies.
The ChurnShield ecosystem is divided into three distinct layers that communicate seamlessly:
- Machine Learning Pipeline (
src/): Generates data, trains the model, and exposes Python inference functions. - FastAPI Backend (
backend/): Acts as the middleman, taking HTTP requests from the frontend and executing ML inferences or data aggregations. - React Frontend (
client/): The user-facing dashboard for business stakeholders.
Here is the detailed workflow of how data moves through the system:
The heart of ChurnShield is an XGBoost Classifier.
XGBoost (Extreme Gradient Boosting) is a highly efficient, scalable machine learning library based on the gradient boosting framework. By definition, it works by building an ensemble of sequential decision trees, where each subsequent tree rigorously attempts to correct the errors (residuals) made by the combination of all previous trees.
We chose XGBoost over other algorithms (like Logistic Regression, SVM, Random Forest, or Neural Networks) for several key reasons:
- State-of-the-Art Performance on Tabular Data: E-commerce customer data is structured (tabular features like
Tenure,OrderCount,SatisfactionScore). XGBoost consistently crushes both traditional models and deep learning approaches on structured, spreadsheet-like data. - Complex, Non-Linear Interactions: Variables in e-commerce interact in complicated ways. For example, high application usage usually indicates loyalty, but if it is paired with recent support complaints, it flags severe churn risk. XGBoost's deep tree mapping natively captures these non-linear interactions without needing manual feature engineering.
- Imbalance Handling: Churn datasets are inherently imbalanced (the vast majority of people stay; only a fraction churn). XGBoost features native hyperparameters (like
scale_pos_weight) to heavily penalize missing a churner, automatically handling this imbalance. - Explainability: Unlike Neural Networks, which are a "black box," XGBoost is highly interpretable. We can extract exact feature importance (Gain/Weight) and integrate seamlessly with explainability frameworks like SHAP or Anchor.
In the context of this project, XGBoost maps historical customer behaviors and demographics to a future risk state.
- Training: It analyzes the patterns in
ecommerce_churn_enhanced.csvto find hidden threshold rules that cleanly separate loyal customers from those who abandon the platform. - Inference: During live prediction via the FastAPI backend, the model evaluates new customer data through its ensemble of trees to compute a continuous probability (e.g.,
0.85) of churn. - Explainable Insights: We extract internal tree split metrics to populate the global "Top churn drivers" on the dashboard. Furthermore, we evaluate its pathways locally to provide human-readable explanations detailing exactly why a specific user is flagged as high-risk.
Because we are predicting something as deeply human as customer loyalty, providing just a percentage (e.g., "85% risk") is not actionable for support agents. To solve this, our project uses two different Explainable AI frameworks:
1. SHAP (SHapley Additive exPlanations)
- What it is: SHAP is a framework based on cooperative game theory. It breaks down a machine learning prediction by treating each feature (e.g.,
Tenure,OrderCount) as a "player" cooperating to reach the final probability. - How we use it: During a single customer prediction (
/predict), we run the XGBoostTreeExplainerto calculate the exact Shapley values. This allows the API to return the Top 3 Risk Factors isolated specifically for that specific customer (e.g., "This customer has +15% risk specifically because their Complain score is 1").
2. Anchor Explanations & Calibration
- What it is: Anchor rules provide local, high-precision "if-then" conditions. An anchor says "If A and B are true, the prediction is practically guaranteed."
- How we use it: While SHAP provides numerical attribution, we use Anchor (
src/anchor_calibrate.py) sequentially after XGBoost to calibrate probabilities. If a customer hits deeply entrenched negative conditions that we know from business logic represent extreme risk, the Anchor heuristic "anchors" or recalibrates the raw XGBoost probability to reflect that absolute business truth, improving domain relevance.
- Data Generation: Instead of using raw, leaky Kaggle data, we use a synthetic data generator (
generate_churn_dataset.py). This script creates a realistic dataset of 5,630 e-commerce customers with 28 features (e.g.,Tenure,SatisfactionScore,Complain,LastLoginDaysAgo). - Causality: The target label (
Churn) is generated after the feature values, ensuring strict, realistic causal relationships without data leakage. The dataset is saved todata/ecommerce_churn_enhanced.csv.
When you train the model, the following pipeline executes:
- Load & Clean: Loads the CSV, drops irrelevant identifiers (
CustomerID), and splits the data into Training and Testing sets before doing anything else to prevent leakage. - Encode: Categorical features (like
PreferredLoginDevice) are converted into numeric formats usingLabelEncoder. Crucially, encoders are fitted only on the training set to prevent "look-ahead" bias. - Hyperparameter Tuning: It uses
RandomizedSearchCVwith 5-fold cross-validation to test multiple combinations ofmax_depth,learning_rate,n_estimators, and L1/L2 regularization to find the optimal model architecture. - Class Imbalance: Churn is naturally imbalanced (~83% stay, ~17% churn). The pipeline automatically computes
scale_pos_weightto heavily penalize the model when it misses a churner. - Sanity Checks: Evaluates the gap between Train AUC and Test AUC to ensure the model isn't overfitting.
The model is evaluated and the threshold for predicting "Churn" is optimized to maximize the F1-Score. It outputs outputs/metrics.json (which contains accuracy, AUC-ROC, recall, and a naive baseline comparison) and saves visualizations (ROC curves, Confusion Matrices) to outputs/plots/.
Because customer churn inherently deals with imbalanced data (the vast majority stay, a small percentage leave), standard metrics can be misleading. Here is how we measure success in ChurnShield:
- Threshold: XGBoost outputs a continuous probability (0.0 to 1.0). The threshold is the numerical cutoff where we formally classify a user as "At Risk of Churn". Instead of a blind
0.50default,evaluate.pymathematically sweeps all possibilities to find the exact threshold that maximizes the F1-Score. - Accuracy: The overall percentage of correct predictions. In this project: This is recorded but taken with a grain of salt. If 83% of customers naturally stay, a broken "dumb" model that just predicts "Stay" for everyone achieves 83% accuracy but offers zero business value.
- Recall (Sensitivity): Out of all the customers who actually churned, what percentage did the model successfully catch? In this project: High recall is crucial so no at-risk customer slips away without a retention intervention (minimizing False Negatives).
- Precision: Out of everyone the model predicted would churn, how many actually did? In this project: High precision ensures we don't waste expensive retention campaigns or discounts on customers who were perfectly happy and going to stay anyway (minimizing False Positives).
- F1-Score: The harmonic mean of Precision and Recall. In this project: This is our "North Star" metric. It forces the model to strike the perfect balance between catching as many churners as possible without firing off too many false alarms.
- ROC-AUC (Receiver Operating Characteristic - Area Under Curve): Measures the model's overall ability to segregate the "Churn" class from the "Stay" class across all possible thresholds. In this project: An AUC of 0.5 is a random coin toss; 1.0 is perfect. This metric operates independent of the threshold, proving that our XGBoost feature map fundamentally works.
The backend is built with FastAPI (Python) and serves as the bridge between the heavy ML scripts and the lightweight frontend.
We chose FastAPI over traditional frameworks for three critical reasons tailored to Machine Learning applications:
- Pydantic Validation: Machine learning models crash spectacularly if fed bad data types (e.g., passing a string instead of an int for
Tenure). FastAPI uses Pydantic to strictly validate the JSON payload before it reaches the ML functions. - Speed & Async: FastAPI is built on Starlette and supports asynchronous endpoints (
async def). In bulk-prediction scenarios, this non-blocking nature allows the server to handle concurrent frontend requests much faster than standard Flask. - Overhead & Auto-Docs: Django is a "batteries-included" monolith utilizing an ORM, which is unnecessary here since we are serving an ML model, not a relational database app. FastAPI is micro-framework lightweight but natively generates interactive Swagger UI documentation (
/docs), making API-Frontend integration a breeze.
/predict(Single Inference): The frontend sends a JSON payload of a single customer's features. The backend callssrc.predict.predict(). The ML script runs the data through the saved Encoders, feeds it to the XGBoost model, and uses SHAP (SHapley Additive exPlanations) to calculate the top 3 specific features that drove this customer's risk score. It returns the churn probability, risk level (High/Medium/Low), and the SHAP factors./predict/bulk(Batch Inference): Accepts a CSV file upload. The backend parses it usingpandas, runs the entire batch through the model, computes risk levels for everyone, assigns recommended business actions, and returns a formatted Excel spreadsheet for the user to download./analytics&/analytics/trends: Instead of hitting a database, the backend caches theecommerce_churn_enhanced.csvin memory on startup. It performs high-speedpandasgroup-by operations to serve live JSON statistics (e.g., "Churn by City Tier", "Rolling Churn Rates") to the frontend dashboard./metrics/revenue-impact: Takes a list of customer churn probabilities and their average order values to calculate "Revenue at Risk". It also computes the Return on Investment (ROI) of a theoretical retention campaign based on expected retention rates and campaign costs.
The frontend is a Single Page Application (SPA) built with React, TypeScript, and Vite. It utilizes TailwindCSS for rapid styling and shadcn/ui for accessible, unstyled UI primitives.
- Vite: Provides lightning-fast Hot Module Replacement (HMR) and optimized production builds.
- TypeScript: Ensures type safety across the complex JSON responses coming from the FastAPI backend, preventing runtime UI crashes.
- TailwindCSS: Allows for beautiful, responsive design systems directly inside the JSX.
- Overview Dashboard: Fetches data from
/analyticsand displays overarching business KPIs (Total Customers, Overall Churn Rate) and demographic pie/bar charts. - Predictor Form: A complex React Hook Form that allows customer service agents to manually input customer data. It validates the inputs and hits the
/predictbackend endpoint, displaying a dynamic Risk Gauge and the SHAP-derived "Top Risk Factors" banner. - Model Performance: Reads directly from
outputs/metrics.jsonto prove the model's worth to stakeholders. It displays the F1-Score and compares the XGBoost model against a "Naive Baseline" to mathematically prove that the AI is adding business value. - Revenue Calculator: A sandbox tool where marketers can adjust sliders (Campaign Cost, Expected Retention) to instantly recalculate the ROI of saving at-risk customers, using the
/metrics/revenue-impactbackend logic.
You will need two terminals to run the full application—one for the backend, and one for the frontend.
- Python 3.9+ (For the ML/Backend)
- Node.js 18+ (For the Frontend)
Open your first terminal and navigate to the project root.
(Note: It is crucial to start the server from the backend/ directory or ensure paths resolve to the project root).
# Activate your virtual environment (Windows PowerShell)
.\.venv\Scripts\Activate.ps1
# Run the backend server
python backend/main.pyThe API will be available at: http://localhost:8000
Open a second terminal and navigate to the client directory:
cd client
# Install dependencies (if you haven't already)
npm install
# Start the development server
npm run devThe UI will be available at: http://localhost:8080 (or the port specified by Vite in the console).
If you want to tweak hyperparameters or retrain the model from scratch, follow these steps from the project root:
- (Optional) Regenerate Dataset:
python generate_churn_dataset.py --output data/ecommerce_churn_enhanced.csv --rows 5630
- Clear Old Models: Delete the old model files in the
models/directory. - Train:
python -m src.train --tune
- Evaluate:
This will update the
python -m src.evaluate
outputs/metrics.jsonfile, which the frontend automatically reads to populate the Model Performance page.
ChurnShield/
│
├── backend/ # FastAPI Application
│ ├── main.py # API Endpoints
│ ├── analytics.py # Dashboard aggregations
│ └── revenue.py # ROI calculations
│
├── client/ # React Frontend (Vite/TypeScript)
│ ├── src/
│ │ ├── components/ # Reusable UI components (shadcn/ui)
│ │ ├── pages/ # Main dashboard views (Analytics, Predict, Performance)
│ │ └── services/ # API integration logic
│ ├── package.json # Node dependencies
│ └── tsconfig.json # TypeScript configuration
│
├── data/ # Datasets
│ └── ecommerce_churn_enhanced.csv # The ground-truth training data
│
├── models/ # Saved ML Models (DO NOT EDIT MANUALLY)
│ ├── xgboost_churn.pkl # Trained XGBoost weights
│ └── encoders.pkl # Fitted categorical encoders
│
├── outputs/ # Model Evaluation Outputs
│ ├── metrics.json # Live metrics read by the frontend
│ └── plots/ # SHAP, ROC, and Confusion Matrix charts
│
├── src/ # Machine Learning Source Code
│ ├── train.py # XGBoost training & tuning logic
│ ├── evaluate.py # Metrics calculation & threshold optimization
│ ├── predict.py # Inference & SHAP explanation logic
│ └── preprocess.py # Data cleaning & train/test splitting
│
├── config.yaml # ML Configuration & Hyperparameters
└── generate_churn_dataset.py # Synthetic dataset generator