Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 3 additions & 198 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@ Production-ready Python SDK for FAIM (Foundation AI Models) - a unified platform
- **🔄 Async Support**: Built-in async/await support for concurrent requests
- **📊 Rich Error Handling**: Machine-readable error codes with detailed diagnostics
- **🧪 Battle-Tested**: Production-ready with comprehensive error handling
- **📈 Evaluation Tools**: Built-in metrics (MSE, MASE, CRPS) and visualization utilities
- **🔎 Retrieval-Augmented Inference**: Optional RAI for improved accuracy on small datasets

## Installation

Expand Down Expand Up @@ -138,24 +136,6 @@ response.probabilities # Shape: (n_samples, n_classes) - classification only
- **Chronos2**: ✅ Supports multivariate forecasting (multiple features)
- **FlowState**: ⚠️ Univariate only - automatically transforms multivariate input
- **TiRex**: ⚠️ Univariate only - automatically transforms multivariate input
- **LimiX**: ✅ Supports multivariate tabular features (standard in tabular inference)

When you provide multivariate input (features > 1) to FlowState or TiRex, the SDK automatically:
1. Issues a warning
2. Forecasts each feature independently
3. Reshapes the output back to your original structure

```python
# Multivariate input to FlowState
data = np.random.randn(2, 100, 3) # 2 series, 3 features
request = FlowStateForecastRequest(x=data, horizon=24, prediction_type="mean")

# Warning: "FlowState model only supports univariate forecasting..."
response = client.forecast(request)

# Output is automatically reshaped
print(response.point.shape) # (2, 24, 3) - original structure preserved
```

## Available Models

Expand Down Expand Up @@ -355,177 +335,9 @@ print(response.metadata)
# {'model_name': 'chronos2', 'model_version': '1.0', 'inference_time_ms': 123}
```

## Evaluation & Metrics (Time-Series Forecasting)

The SDK includes a comprehensive evaluation toolkit (`faim_sdk.eval`) for measuring time-series forecast quality with standard metrics and visualizations.

**Note**: These metrics are designed for time-series forecasting evaluation. For tabular model evaluation (classification/regression), use standard scikit-learn metrics like `accuracy_score`, `mean_squared_error`, etc. (see tabular examples above).

### Installation

For visualization support, install with the viz extra:

```bash
pip install faim-sdk[viz]
```

### Available Metrics for Time-Series

#### Mean Squared Error (MSE)

Measures average squared difference between predictions and ground truth.

```python
from faim_sdk.eval import mse

# Evaluate point forecast
mse_score = mse(test_data, response.point, reduction='mean')
print(f"MSE: {mse_score:.4f}")

# Per-sample MSE
mse_per_sample = mse(test_data, response.point, reduction='none')
print(f"MSE per sample shape: {mse_per_sample.shape}") # (batch_size,)
```

#### Mean Absolute Scaled Error (MASE)

Scale-independent metric comparing forecast to naive baseline (better than MAPE for series with zeros).

```python
from faim_sdk.eval import mase

# MASE requires training data for baseline
mase_score = mase(test_data, response.point, train_data, reduction='mean')
print(f"MASE: {mase_score:.4f}")

# Interpretation:
# MASE < 1: Better than naive baseline
# MASE = 1: Equivalent to naive baseline
# MASE > 1: Worse than naive baseline
```

#### Continuous Ranked Probability Score (CRPS)

Proper scoring rule for probabilistic forecasts - generalizes MAE to distributions.

```python
from faim_sdk.eval import crps_from_quantiles

# Evaluate probabilistic forecast with quantiles
crps_score = crps_from_quantiles(
test_data,
response.quantiles,
quantile_levels=[0.1, 0.5, 0.9],
reduction='mean'
)
print(f"CRPS: {crps_score:.4f}")
```

### Visualization (Time-Series Only)

Plot time-series forecasts with training context and ground truth:

```python
from faim_sdk.eval import plot_forecast

# Plot single sample (remember to index batch dimension!)
fig, ax = plot_forecast(
train_data=train_data[0], # (seq_len, features) - 2D array
forecast=response.point[0], # (horizon, features) - 2D array
test_data=test_data[0], # (horizon, features) - optional
title="Time Series Forecast"
)

# Save to file
fig.savefig("forecast.png", dpi=300, bbox_inches="tight")
```

#### Multi-Feature Visualization

```python
# Option 1: All features on same plot
fig, ax = plot_forecast(
train_data[0],
response.point[0],
test_data[0],
features_on_same_plot=True,
feature_names=["Temperature", "Humidity", "Pressure"]
)

# Option 2: Separate subplots per feature
fig, axes = plot_forecast(
train_data[0],
response.point[0],
test_data[0],
features_on_same_plot=False,
feature_names=["Temperature", "Humidity", "Pressure"]
)
```

### Complete Evaluation Example

```python
import numpy as np
from faim_sdk import ForecastClient, Chronos2ForecastRequest
from faim_sdk.eval import mse, mase, crps_from_quantiles, plot_forecast

# Initialize client
client = ForecastClient()

# Prepare data splits
train_data = np.random.randn(32, 100, 1)
test_data = np.random.randn(32, 24, 1)

# Generate forecast
request = Chronos2ForecastRequest(
x=train_data,
horizon=24,
output_type="quantiles",
quantiles=[0.1, 0.5, 0.9]
)
response = client.forecast(request)

# Evaluate point forecast (use median)
point_pred = response.quantiles[:, :, 1:2] # Extract median, keep 3D shape
mse_score = mse(test_data, point_pred)
mase_score = mase(test_data, point_pred, train_data)

# Evaluate probabilistic forecast
crps_score = crps_from_quantiles(
test_data,
response.quantiles,
quantile_levels=[0.1, 0.5, 0.9]
)

print(f"MSE: {mse_score:.4f}")
print(f"MASE: {mase_score:.4f}")
print(f"CRPS: {crps_score:.4f}")

# Visualize best and worst predictions
mse_per_sample = mse(test_data, point_pred, reduction='none')
best_idx = np.argmin(mse_per_sample)
worst_idx = np.argmax(mse_per_sample)

fig1, ax1 = plot_forecast(
train_data[best_idx],
point_pred[best_idx],
test_data[best_idx],
title=f"Best Forecast (MSE: {mse_per_sample[best_idx]:.4f})"
)
fig1.savefig("best_forecast.png")

fig2, ax2 = plot_forecast(
train_data[worst_idx],
point_pred[worst_idx],
test_data[worst_idx],
title=f"Worst Forecast (MSE: {mse_per_sample[worst_idx]:.4f})"
)
fig2.savefig("worst_forecast.png")
```

## Error Handling

The SDK provides **machine-readable error codes** for robust error handling:
The SDK provides **error codes** for robust error handling:

```python
from faim_sdk import (
Expand Down Expand Up @@ -629,16 +441,9 @@ See the `examples/` directory for complete Jupyter notebook examples:

### Tabular Inference with LimiX
- **`limix_classification_example.ipynb`** - Binary classification on breast cancer dataset
- Standard approach with LimiX
- Retrieval-Augmented Inference (RAI) comparison
- Side-by-side metrics comparison (Accuracy, Precision, Recall, F1-Score)


- **`limix_regression_example.ipynb`** - Regression on California housing dataset
- Standard approach with LimiX
- Retrieval-Augmented Inference (RAI) comparison
- Comprehensive metrics comparison (MSE, RMSE, MAE, R²)
- Residual statistics analysis


## Requirements

- Python >= 3.10
Expand Down
Loading