diff --git a/doc/syntax/index.qmd b/doc/syntax/index.qmd index 42cc62e7..8b2f8837 100644 --- a/doc/syntax/index.qmd +++ b/doc/syntax/index.qmd @@ -33,6 +33,7 @@ There are many different layers to choose from when visualising your data. Some - [`histogram`](layer/type/histogram.qmd) bins the data along the x axis and produces a bar for each bin showing the number of records in it. - [`boxplot`](layer/type/boxplot.qmd) displays continuous variables as 5-number summaries. - [`errorbar`](layer/type/errorbar.qmd) a line segment with hinges at the endpoints. +- [`smooth`](layer/type/smooth.qmd) a trendline that follows the data shape. ### Position adjustments - [`stack`](layer/position/stack.qmd) places objects with a shared baseline on top of each other. diff --git a/doc/syntax/layer/type/smooth.qmd b/doc/syntax/layer/type/smooth.qmd new file mode 100644 index 00000000..b6e7b331 --- /dev/null +++ b/doc/syntax/layer/type/smooth.qmd @@ -0,0 +1,153 @@ +--- +title: "Smooth" +--- + +> Layers are declared with the [`DRAW` clause](../clause/draw.qmd). Read the documentation for this clause for a thorough description of how to use it. + +Smooth layers are used to display a trendline among a series of observations. + +## Aesthetics + +### Required +* Primary axis (e.g. `x`): Position along the primary axis. +* Secondary axis (e.g. `y`): Position along the secondary axis. + +### Optional +* `colour`/`stroke`: The colour of the line +* `opacity`: The opacity of the line +* `linewidth`: The width of the line +* `linetype`: The type of line, i.e. the dashing pattern + +## Settings + +* `method`: Choice of the method for generating the trendline. One of the following: + * `'nw'` or `'nadaraya-watson'` estimates the trendline using the Nadaraya-Watson kernel regression method (default). + * `'ols'` estimates a straight trendline using ordinary least squares method. + * `'tls'` estimates a straight trendline using total least squares method. + +The settings below only apply when `method => 'nw'` and are ignored when using other methods. +* `bandwidth`: A numerical value setting the smoothing bandwidth to use. If absent (default), the bandwidth will be computed using Silverman's rule of thumb. +* `adjust`: A numerical value as multiplier for the `bandwidth` setting, with 1 as default. +* `kernel`: Determines the smoothing kernel shape. Can be one of the following: + * `'gaussian'` (default) + * `'epanechnikov'` + * `'triangular'` + * `'rectangular'` or `'uniform'` + * `'biweight'` or `'quartic'` + * `'cosine'` + +## Data transformation + +### Nadaraya-Watson kernel regression + +The default `method => 'nw'` computes a locally weighted average of $y$. + +$$ +y(x) = \frac{\sum_{i=1}^nW(x)y_i}{\sum_{i=1}^nW(x)} +$$ + +Where: + +* $W(x)$ is kernel intensity $w_iK(\frac{x - x_i}{h})$ where + * $K$ is the kernel function + * $h$ is the bandwidth + * $w_i$ is the weight of observation $i$ + +Please note the similarity of $W(x)$ to the [kernel density estimation formula](density.qmd#data-transformation). + +### Ordinary least squares + +The `method => 'ols'` setting uses ordinary least squares to compute the intercept $a$ and slope $b$ of a straight line. +The method minimizes the 1-dimensional distance between a point and the vertical projection of that point on the line. +Only considering the vertical distances implies having measurement error in $y$, but not $x$. + +$$ +y = a + bx +$$ + +Wherein: + +$$ +a = E[Y] - bE[X] +$$ + +and + +$$ +b = \frac{\text{cov}(X, Y)}{\text{var}(X)} = \frac{E[XY] - E[X]E[Y]}{E[X^2]-(E[X])^2} +$$ + +### Total least squares + +The `method => 'tls'` setting uses total least squares to compute the intercept $a$ and slope $b$ of a straight line. +The method minimizes the 2-dimensiontal distance between a point and the perpendicular projection of that point on the line. +Minimising the perpendicular distances (rather than just the vertical distances) makes sense if there is uncertainty or measurement error in not just $y$, but in $x$ as well. +In such case, it is a more accurate depiction of the relationship between $x$ and $y$, but it isn't the best predictor of $y$ given $x$. + +$$ +y = a + bx +$$ + +Wherein: + +$$ +a = E[Y] - bE[X] +$$ + +and + +$$ +b = \frac{\text{var}(Y) - \text{var}(X) + \sqrt{(\text{var}(Y) - \text{var}(X))^2 + 4\text{cov}(X, Y)^2}}{2\text{cov}(X, Y)} +$$ + +### Properties + +* `weight` is available when using `method => 'nw'`, where when mapped, it sets the relative contribution of an observation $w_i$ to the average. + +### Calculated statistics + +* `intensity` corresponds to $y$ in the formulas described in the [data transformation](#data-transformation) section. + +### Default remappings + +* `intensity AS y`: By default the smooth layer will display the $y$ in the formulas along the y-axis. + +## Examples + +The default `method => 'nw'` might be too coarse for timeseries. + + + +```{ggsql} +SELECT *, EPOCH(Date) AS numdate FROM ggsql:airquality +VISUALISE numdate AS x, Temp AS y + DRAW point + DRAW smooth +``` + +You can make the fit more granular by reducing the bandwidth, for example using `adjust`. + +```{ggsql} +SELECT *, EPOCH(Date) AS numdate FROM ggsql:airquality +VISUALISE numdate AS x, Temp AS y + DRAW point + DRAW smooth SETTING adjust => 0.2 +``` + +There is a subtle difference between the ordinary and total least squares method. + +```{ggsql} +VISUALISE bill_len AS x, bill_dep AS y FROM ggsql:penguins + DRAW point + DRAW smooth MAPPING 'Ordinary' AS colour SETTING method => 'ols' + DRAW smooth MAPPING 'Total' AS colour SETTING method => 'tls' +``` + +Simpson's Paradox is a case where a trend of combined groups is reversed when groups are considered separately. + +```{ggsql} +VISUALISE bill_len AS x, bill_dep AS y, species AS stroke FROM ggsql:penguins + DRAW point SETTING opacity => 0 + DRAW smooth SETTING method => 'ols' + DRAW smooth MAPPING 'All' AS stroke SETTING method => 'ols' +``` \ No newline at end of file diff --git a/doc/syntax/layer/type/violin.qmd b/doc/syntax/layer/type/violin.qmd index 0d63d8e7..acc55eb8 100644 --- a/doc/syntax/layer/type/violin.qmd +++ b/doc/syntax/layer/type/violin.qmd @@ -34,6 +34,9 @@ The following aesthetics are recognised by the violin layer. * `'biweight'` or `'quartic'` * `'cosine'` * `width`: Relative width of the violins. Defaults to `0.9`. +* `tails`: Expansion rule for drawing the tails. One of the following: + * A number setting a multiple of adjusted bandwidths to expand each group's range. Defaults to 3. + * `null` to use the whole data range rather than group ranges. ## Data transformation A violin layer uses the same computation as a density layer. See the [density data transformation](density.qmd#data-transformation) section for details. @@ -71,6 +74,13 @@ VISUALISE species AS x, bill_dep AS y FROM ggsql:penguins DRAW violin SETTING adjust => 0.1 ``` +The `tails` setting controls the display beyond the data range. You can set it to `0` to use the exact group's data range. + +```{ggsql} +VISUALISE species AS x, bill_dep AS y FROM ggsql:penguins + DRAW violin SETTING tails => 0 +``` + To more clearly indicate differences in group sizes, you can use the `intensity` computed variable. Note that we have fewer (n=68) Chinstrap penguins than Adelie (n=152) or Gentoo (n=124) penguins. diff --git a/src/plot/layer/geom/density.rs b/src/plot/layer/geom/density.rs index c83d2ee4..92d8cc75 100644 --- a/src/plot/layer/geom/density.rs +++ b/src/plot/layer/geom/density.rs @@ -88,17 +88,12 @@ impl GeomTrait for Density { aesthetics: &Mappings, group_by: &[String], parameters: &std::collections::HashMap, - execute_query: &dyn Fn(&str) -> crate::Result, + _execute_query: &dyn Fn(&str) -> crate::Result, dialect: &dyn SqlDialect, ) -> crate::Result { + // Density geom: no tails limit (don't set tails parameter, defaults to None) stat_density( - query, - aesthetics, - "pos1", - group_by, - parameters, - execute_query, - dialect, + query, aesthetics, "pos1", None, group_by, parameters, dialect, ) } } @@ -131,9 +126,9 @@ pub(crate) fn stat_density( query: &str, aesthetics: &Mappings, value_aesthetic: &str, + smooth_aesthetic: Option<&str>, group_by: &[String], parameters: &HashMap, - execute: &dyn Fn(&str) -> crate::Result, dialect: &dyn SqlDialect, ) -> Result { let value = get_column_name(aesthetics, value_aesthetic).ok_or_else(|| { @@ -142,13 +137,26 @@ pub(crate) fn stat_density( value_aesthetic )) })?; + let smooth = smooth_aesthetic.and_then(|smth| get_column_name(aesthetics, smth)); let weight = get_column_name(aesthetics, "weight"); - let (min, max) = compute_range_sql(&value, query, execute)?; + // Get tails parameter (None = unlimited) + let tails = match parameters.get("tails") { + Some(ParameterValue::Number(n)) => Some(*n), + _ => None, + }; + let bw_cte = density_sql_bandwidth(query, group_by, &value, parameters, dialect); - let data_cte = build_data_cte(&value, weight.as_deref(), query, group_by); - let grid_cte = build_grid_cte(group_by, query, min, max, 512, dialect); - let kernel = choose_kde_kernel(parameters)?; + let data_cte = build_data_cte( + &value, + smooth.as_deref(), + weight.as_deref(), + query, + group_by, + ); + let grid_cte = build_grid_cte(group_by, 512, tails, dialect); + let kernel = choose_kde_kernel(parameters, smooth)?; + let density_query = compute_density( value_aesthetic, group_by, @@ -177,54 +185,6 @@ pub(crate) fn stat_density( }) } -fn compute_range_sql( - value: &str, - from: &str, - execute: &dyn Fn(&str) -> crate::Result, -) -> Result<(f64, f64)> { - let query = format!( - "SELECT - MIN({value}) AS min, - MAX({value}) AS max - FROM ({from}) - WHERE {value} IS NOT NULL", - value = value, - from = from - ); - let result = execute(&query)?; - let min = result - .column("min") - .and_then(|col| col.get(0)) - .and_then(|v| v.try_extract::()); - - let max = result - .column("max") - .and_then(|col| col.get(0)) - .and_then(|v| v.try_extract::()); - - if let (Ok(start), Ok(end)) = (min, max) { - if !start.is_finite() || !end.is_finite() { - return Err(GgsqlError::ValidationError(format!( - "Density layer needs finite numbers in '{}' column.", - value - ))); - } - if (end - start).abs() < 1e-8 { - // We need to be able to compute variance for density. Having zero - // range is guaranteed to also have zero variance. - return Err(GgsqlError::ValidationError(format!( - "Density layer needs non-zero range data in '{}' column.", - value - ))); - } - return Ok((start, end)); - } - Err(GgsqlError::ReaderError(format!( - "Density layer failed to compute range for '{}' column.", - value - ))) -} - fn density_sql_bandwidth( from: &str, groups: &[String], @@ -232,56 +192,44 @@ fn density_sql_bandwidth( parameters: &HashMap, dialect: &dyn SqlDialect, ) -> String { - let mut group_by = String::new(); - let mut comma = String::new(); - let groups_str = groups.join(", "); - - if !groups_str.is_empty() { - group_by = format!("GROUP BY {}", groups_str); - comma = ",".to_string() - } - let adjust = match parameters.get("adjust") { Some(ParameterValue::Number(adj)) => *adj, _ => 1.0, }; - if let Some(ParameterValue::Number(mut num)) = parameters.get("bandwidth") { - // When we have a user-supplied bandwidth, we don't have to compute the - // bandwidth from the data. Instead, we just make sure the query has - // the right shape. - num *= adjust; - let cte = if groups_str.is_empty() { - format!( - "WITH RECURSIVE bandwidth AS (SELECT {num} AS bw)", - num = num - ) - } else { - format!( - "WITH RECURSIVE bandwidth AS (SELECT {num} AS bw, {groups_str} FROM ({from}) {group_by})", - num = num, - groups_str = groups_str, - group_by = group_by - ) - }; - return cte; - } + // Preformat the bandwidth expression (either explicit or computed via Silverman's rule) + let bw_expr = if let Some(ParameterValue::Number(num)) = parameters.get("bandwidth") { + format!("{}", num * adjust) + } else { + silverman_rule(adjust, value, from, groups, dialect) + }; + + // Preformat groups and GROUP BY clause together + let (groups_select, group_by) = if groups.is_empty() { + (String::new(), String::new()) + } else { + let groups_str = groups.join(", "); + ( + format!("\n {},", groups_str), + format!("\n GROUP BY {}", groups_str), + ) + }; + format!( "WITH RECURSIVE bandwidth AS ( SELECT - {rule} AS bw{comma} - {groups_str} + {bw_expr} AS bw,{groups_select} + MIN({value}) AS x_min, + MAX({value}) AS x_max FROM ({from}) AS __ggsql_qt__ - WHERE {value} IS NOT NULL - {group_by} + WHERE {value} IS NOT NULL{group_by} )", - rule = silverman_rule(adjust, value, from, groups, dialect), + bw_expr = bw_expr, + groups_select = groups_select, value = value, - group_by = group_by, - groups_str = groups_str, - comma = comma, from = from, + group_by = group_by ) } @@ -303,7 +251,10 @@ fn silverman_rule( format!("{adjust} * {min_expr} * POW(COUNT(*), -0.2)") } -fn choose_kde_kernel(parameters: &HashMap) -> Result { +fn choose_kde_kernel( + parameters: &HashMap, + smooth: Option, +) -> Result { let kernel = match parameters.get("kernel") { Some(ParameterValue::String(krnl)) => krnl.as_str(), _ => { @@ -352,34 +303,63 @@ fn choose_kde_kernel(parameters: &HashMap) -> Result, from: &str, group_by: &[String]) -> String { +fn build_data_cte( + value: &str, + smooth: Option<&str>, + weight: Option<&str>, + from: &str, + group_by: &[String], +) -> String { // Include weight column if provided, otherwise default to 1.0 let weight_col = if let Some(w) = weight { format!(", {} AS weight", w) } else { ", 1.0 AS weight".to_string() }; + let smooth_col = if let Some(s) = smooth { + format!(", {}", s) + } else { + "".to_string() + }; // Only filter out nulls in value column, keep NULLs in group columns - let filter_valid = format!("{} IS NOT NULL", value); + let mut filter_valid = format!("{} IS NOT NULL", value); + if let Some(s) = smooth { + filter_valid = format!( + "{filter} AND {smth} IS NOT NULL", + filter = filter_valid, + smth = s + ); + } format!( "data AS ( - SELECT {groups}{value} AS val{weight_col} + SELECT {groups}{value} AS val{weight_col}{smooth_col} FROM ({from}) WHERE {filter_valid} )", groups = with_trailing_comma(&group_by.join(", ")), value = value, weight_col = weight_col, + smooth_col = smooth_col, from = from, filter_valid = filter_valid ) @@ -387,53 +367,115 @@ fn build_data_cte(value: &str, weight: Option<&str>, from: &str, group_by: &[Str fn build_grid_cte( groups: &[String], - from: &str, - min: f64, - max: f64, n_points: usize, + tails: Option, dialect: &dyn SqlDialect, ) -> String { let has_groups = !groups.is_empty(); - let n_points = n_points - 1; // GENERATE_SERIES gives on point for free - let diff = (max - min).abs(); - - // Expand range 10% - let expand = 0.1; - let min = min - (expand * diff * 0.5); - let max = max + (expand * diff * 0.5); - let diff = (max - min).abs(); - - let seq = dialect.sql_generate_series(n_points + 1); - - if !has_groups { - return format!( - "{seq}, grid AS ( - SELECT {min} + (__ggsql_seq__.n * {diff} / {n_points}) AS x - FROM __ggsql_seq__ - )", - seq = seq, - min = min, - diff = diff, - n_points = n_points - ); - } + let n_points_minus_1 = n_points - 1; // For formula: n-1 divisions between n points - let groups = groups.join(", "); - format!( - "{seq}, grid AS ( + // Generate sequence CTE using dialect-specific SQL + let seq_cte = dialect.sql_generate_series(n_points); + + // Shared: global_range CTE (computes range dynamically from bandwidth table) + let global_range_cte = "global_range AS ( + SELECT + MIN(x_min) AS min, + MAX(x_max) AS max, + 3 * MAX(bw) AS expansion + FROM bandwidth + )"; + + // Shared: x-coordinate formula + let x_formula = format!( + "(global.min - global.expansion) + (seq.n * ((global.max - global.min) + 2 * global.expansion) / {n_points})", + n_points = n_points_minus_1 + ); + + // Build base grid CTE + let base_grid_cte = if !has_groups { + // Simple grid without groups + format!( + "grid AS ( + SELECT {x_formula} AS x + FROM global_range AS global + CROSS JOIN __ggsql_seq__ AS seq + )", + x_formula = x_formula + ) + } else { + let groups_str = groups.join(", "); + // When tails is specified, create full_grid; otherwise create grid directly + let cte_name = if tails.is_some() { "full_grid" } else { "grid" }; + format!( + "{cte_name} AS ( SELECT {groups}, - {min} + (__ggsql_seq__.n * {diff} / {n_points}) AS x - FROM __ggsql_seq__ - CROSS JOIN (SELECT DISTINCT {groups} FROM ({from})) AS groups + {x_formula} AS x + FROM global_range AS global + CROSS JOIN __ggsql_seq__ AS seq + CROSS JOIN (SELECT DISTINCT {groups} FROM bandwidth) AS groups )", - seq = seq, - groups = groups, - diff = diff, - min = min, - n_points = n_points, - from = from - ) + cte_name = cte_name, + groups = groups_str, + x_formula = x_formula + ) + }; + + // If tails is specified with groups, add the trimmed grid CTE + if let Some(extent) = tails { + if has_groups { + let bandwidth_join_conds: Vec = groups + .iter() + .map(|g| { + format!( + "full_grid.{col} IS NOT DISTINCT FROM bandwidth.{col}", + col = g + ) + }) + .collect(); + let grid_groups_select: Vec = + groups.iter().map(|g| format!("full_grid.{}", g)).collect(); + + format!( + "{seq_cte}, + {global_range_cte}, + {base_grid_cte}, + grid AS ( + SELECT {grid_groups}, full_grid.x + FROM full_grid + INNER JOIN bandwidth ON {bandwidth_join_conds} + WHERE full_grid.x >= bandwidth.x_min - {extent} * bandwidth.bw + AND full_grid.x <= bandwidth.x_max + {extent} * bandwidth.bw + )", + seq_cte = seq_cte, + global_range_cte = global_range_cte, + base_grid_cte = base_grid_cte, + grid_groups = grid_groups_select.join(", "), + bandwidth_join_conds = bandwidth_join_conds.join(" AND "), + extent = extent + ) + } else { + // No groups but tail_extent specified - not meaningful, treat as no tail_extent + format!( + "{seq_cte}, + {global_range_cte}, + {base_grid_cte}", + seq_cte = seq_cte, + global_range_cte = global_range_cte, + base_grid_cte = base_grid_cte + ) + } + } else { + format!( + "{seq_cte}, + {global_range_cte}, + {base_grid_cte}", + seq_cte = seq_cte, + global_range_cte = global_range_cte, + base_grid_cte = base_grid_cte + ) + } } fn compute_density( @@ -471,7 +513,7 @@ fn compute_density( INNER JOIN bandwidth ON {bandwidth_conditions} CROSS JOIN grid {matching_groups}", bandwidth_conditions = bandwidth_conditions, - matching_groups = matching_groups, + matching_groups = matching_groups ); // Build group-related SQL fragments @@ -538,21 +580,34 @@ mod tests { ); let bw_cte = density_sql_bandwidth(query, &groups, "x", ¶meters, &AnsiDialect); - let data_cte = build_data_cte("x", None, query, &groups); - let grid_cte = build_grid_cte(&groups, query, 0.0, 10.0, 512, &AnsiDialect); - let kernel = choose_kde_kernel(¶meters).expect("kernel should be valid"); + let data_cte = build_data_cte("x", None, None, query, &groups); + let grid_cte = build_grid_cte(&groups, 512, None, &AnsiDialect); + let kernel = choose_kde_kernel(¶meters, None).expect("kernel should be valid"); let sql = compute_density("x", &groups, kernel, &bw_cte, &data_cte, &grid_cte); - let expected = "WITH RECURSIVE bandwidth AS (SELECT 0.5 AS bw), + let expected = "WITH RECURSIVE + bandwidth AS ( + SELECT + 0.5 AS bw, + MIN(x) AS x_min, + MAX(x) AS x_max + FROM (SELECT x FROM (VALUES (1.0), (2.0), (3.0)) AS t(x)) AS __ggsql_qt__ + WHERE x IS NOT NULL + ), data AS ( SELECT x AS val, 1.0 AS weight FROM (SELECT x FROM (VALUES (1.0), (2.0), (3.0)) AS t(x)) WHERE x IS NOT NULL ), __ggsql_base__(n) AS (SELECT 0 UNION ALL SELECT n + 1 FROM __ggsql_base__ WHERE n < 7),__ggsql_seq__(n) AS (SELECT CAST(a.n * 64 + b.n * 8 + c.n AS REAL) AS n FROM __ggsql_base__ a, __ggsql_base__ b, __ggsql_base__ c WHERE a.n * 64 + b.n * 8 + c.n < 512), + global_range AS ( + SELECT MIN(x_min) AS min, MAX(x_max) AS max, 3 * MAX(bw) AS expansion + FROM bandwidth + ), grid AS ( - SELECT -0.5 + (__ggsql_seq__.n * 11 / 511) AS x - FROM __ggsql_seq__ + SELECT (global.min - global.expansion) + (seq.n * ((global.max - global.min) + 2 * global.expansion) / 511) AS x + FROM global_range AS global + CROSS JOIN __ggsql_seq__ AS seq ) SELECT __ggsql_stat_x, @@ -601,24 +656,39 @@ mod tests { ); let bw_cte = density_sql_bandwidth(query, &groups, "x", ¶meters, &AnsiDialect); - let data_cte = build_data_cte("x", None, query, &groups); - let grid_cte = build_grid_cte(&groups, query, -10.0, 10.0, 512, &AnsiDialect); - let kernel = choose_kde_kernel(¶meters).expect("kernel should be valid"); + let data_cte = build_data_cte("x", None, None, query, &groups); + let grid_cte = build_grid_cte(&groups, 512, None, &AnsiDialect); + let kernel = choose_kde_kernel(¶meters, None).expect("kernel should be valid"); let sql = compute_density("x", &groups, kernel, &bw_cte, &data_cte, &grid_cte); - let expected = "WITH RECURSIVE bandwidth AS (SELECT 0.5 AS bw, region, category FROM (SELECT x, region, category FROM (VALUES (1.0, 'A', 'X'), (2.0, 'B', 'Y')) AS t(x, region, category)) GROUP BY region, category), + let expected = "WITH RECURSIVE + bandwidth AS ( + SELECT + 0.5 AS bw, + region, category, + MIN(x) AS x_min, + MAX(x) AS x_max + FROM (SELECT x, region, category FROM (VALUES (1.0, 'A', 'X'), (2.0, 'B', 'Y')) AS t(x, region, category)) AS __ggsql_qt__ + WHERE x IS NOT NULL + GROUP BY region, category + ), data AS ( SELECT region, category, x AS val, 1.0 AS weight FROM (SELECT x, region, category FROM (VALUES (1.0, 'A', 'X'), (2.0, 'B', 'Y')) AS t(x, region, category)) WHERE x IS NOT NULL ), __ggsql_base__(n) AS (SELECT 0 UNION ALL SELECT n + 1 FROM __ggsql_base__ WHERE n < 7),__ggsql_seq__(n) AS (SELECT CAST(a.n * 64 + b.n * 8 + c.n AS REAL) AS n FROM __ggsql_base__ a, __ggsql_base__ b, __ggsql_base__ c WHERE a.n * 64 + b.n * 8 + c.n < 512), + global_range AS ( + SELECT MIN(x_min) AS min, MAX(x_max) AS max, 3 * MAX(bw) AS expansion + FROM bandwidth + ), grid AS ( SELECT region, category, - -11 + (__ggsql_seq__.n * 22 / 511) AS x - FROM __ggsql_seq__ - CROSS JOIN (SELECT DISTINCT region, category FROM (SELECT x, region, category FROM (VALUES (1.0, 'A', 'X'), (2.0, 'B', 'Y')) AS t(x, region, category))) AS groups + (global.min - global.expansion) + (seq.n * ((global.max - global.min) + 2 * global.expansion) / 511) AS x + FROM global_range AS global + CROSS JOIN __ggsql_seq__ AS seq + CROSS JOIN (SELECT DISTINCT region, category FROM bandwidth) AS groups ) SELECT __ggsql_stat_x, @@ -660,8 +730,20 @@ mod tests { assert_eq!(df.height(), 1024); // 512 grid points × 2 groups // Verify density integrates to ~2 (one per group) - // Grid spacing: (max - min) / (n - 1) = 22 / 511 ≈ 0.0430 - let dx = 22.0 / 511.0; + // Compute grid spacing dynamically from actual data + let x_col = df.column("__ggsql_stat_x").expect("x exists"); + // Cast to f64 if needed (AnsiDialect generates f32 from REAL) + let x_col = x_col + .cast(&polars::prelude::DataType::Float64) + .expect("can cast to f64"); + let x_vals = x_col.f64().expect("x is f64"); + let x_min = x_vals.into_iter().flatten().fold(f64::INFINITY, f64::min); + let x_max = x_vals + .into_iter() + .flatten() + .fold(f64::NEG_INFINITY, f64::max); + let dx = (x_max - x_min) / 511.0; // (n - 1) for 512 points + let density_col = df .column("__ggsql_stat_density") .expect("density column exists"); @@ -673,10 +755,9 @@ mod tests { .sum(); let integral = total * dx; - // With wide range (-10 to 10), we capture essentially all density mass - // Tolerance of 1e-6 - error is dominated by floating point precision + // Should integrate to ~2 (one per group) assert!( - (integral - 2.0).abs() < 1e-6, + (integral - 2.0).abs() < 0.01, "Density should integrate to ~2 (one per group), got {}", integral ); @@ -702,10 +783,13 @@ mod tests { // Verify bandwidth computation executes let reader = DuckDBReader::from_connection_string("duckdb://memory").unwrap(); let df = reader - .execute_sql(&format!("{}\nSELECT bw FROM bandwidth", bw_cte)) + .execute_sql(&format!( + "{}\nSELECT bw, x_min, x_max FROM bandwidth", + bw_cte + )) .expect("Bandwidth SQL should execute"); - assert_eq!(df.get_column_names(), vec!["bw"]); + assert_eq!(df.get_column_names(), vec!["bw", "x_min", "x_max"]); assert_eq!(df.height(), 1); // Test 2: With groups @@ -723,10 +807,16 @@ mod tests { // Verify grouped bandwidth computation executes let df = reader - .execute_sql(&format!("{}\nSELECT bw, region FROM bandwidth", bw_cte)) + .execute_sql(&format!( + "{}\nSELECT bw, region, x_min, x_max FROM bandwidth", + bw_cte + )) .expect("Grouped bandwidth SQL should execute"); - assert_eq!(df.get_column_names(), vec!["bw", "region"]); + assert_eq!( + df.get_column_names(), + vec!["bw", "region", "x_min", "x_max"] + ); assert_eq!(df.height(), 2); // Two groups: A and B } @@ -742,10 +832,10 @@ mod tests { ); let bw_cte = density_sql_bandwidth(query, &groups, "x", ¶meters, &AnsiDialect); - let data_cte = build_data_cte("x", None, query, &groups); + let data_cte = build_data_cte("x", None, None, query, &groups); // Use wide range to capture essentially all density mass - let grid_cte = build_grid_cte(&groups, query, -5.0, 15.0, 512, &AnsiDialect); - let kernel = choose_kde_kernel(¶meters).expect("kernel should be valid"); + let grid_cte = build_grid_cte(&groups, 512, None, &AnsiDialect); + let kernel = choose_kde_kernel(¶meters, None).expect("kernel should be valid"); let sql = compute_density("x", &groups, kernel, &bw_cte, &data_cte, &grid_cte); // Execute query @@ -763,8 +853,20 @@ mod tests { assert_eq!(df.height(), 512); // Compute integral using trapezoidal rule - // Grid spacing: (max - min) / (n - 1) - let dx = 22.0 / 511.0; // (15 - (-5) expanded by 10%) / (512 - 1) + // Get actual grid spacing from the data (dynamically computed range) + let x_col = df.column("__ggsql_stat_x").expect("x exists"); + // Cast to f64 if needed (AnsiDialect generates f32 from REAL) + let x_col = x_col + .cast(&polars::prelude::DataType::Float64) + .expect("can cast to f64"); + let x_vals = x_col.f64().expect("x is f64"); + let x_min = x_vals.into_iter().flatten().fold(f64::INFINITY, f64::min); + let x_max = x_vals + .into_iter() + .flatten() + .fold(f64::NEG_INFINITY, f64::max); + let dx = (x_max - x_min) / (df.height() as f64 - 1.0); + let density_col = df.column("__ggsql_stat_density").expect("density exists"); let total: f64 = density_col .f64() @@ -825,7 +927,7 @@ mod tests { ParameterValue::String("invalid_kernel".to_string()), ); - let result = choose_kde_kernel(¶meters); + let result = choose_kde_kernel(¶meters, None); assert!(result.is_err()); match result { @@ -850,11 +952,11 @@ mod tests { ); let bw_cte = density_sql_bandwidth(query, &groups, "x", ¶meters, &AnsiDialect); - let grid_cte = build_grid_cte(&groups, query, 0.0, 4.0, 100, &AnsiDialect); - let kernel = choose_kde_kernel(¶meters).expect("kernel should be valid"); + let grid_cte = build_grid_cte(&groups, 100, None, &AnsiDialect); + let kernel = choose_kde_kernel(¶meters, None).expect("kernel should be valid"); // Unweighted (default weights of 1.0) - let data_cte_unweighted = build_data_cte("x", None, query, &groups); + let data_cte_unweighted = build_data_cte("x", None, None, query, &groups); let sql_unweighted = compute_density( "x", &groups, @@ -871,7 +973,7 @@ mod tests { // With explicit uniform weights (should be equivalent) let query_weighted = "SELECT x, 1.0 AS weight FROM (VALUES (1.0), (2.0), (3.0)) AS t(x)"; - let data_cte_weighted = build_data_cte("x", Some("weight"), query_weighted, &groups); + let data_cte_weighted = build_data_cte("x", None, Some("weight"), query_weighted, &groups); let sql_weighted = compute_density("x", &groups, kernel, &bw_cte, &data_cte_weighted, &grid_cte); let df_weighted = reader @@ -1012,9 +1114,9 @@ mod tests { ); let bw_cte = density_sql_bandwidth(query, &groups, "x", ¶meters, &AnsiDialect); - let data_cte = build_data_cte("x", None, query, &groups); - let grid_cte = build_grid_cte(&groups, query, 0.0, 100.0, 512, &AnsiDialect); - let kernel = choose_kde_kernel(¶meters).expect("kernel should be valid"); + let data_cte = build_data_cte("x", None, None, query, &groups); + let grid_cte = build_grid_cte(&groups, 512, None, &AnsiDialect); + let kernel = choose_kde_kernel(¶meters, None).expect("kernel should be valid"); let sql = compute_density("x", &groups, kernel, &bw_cte, &data_cte, &grid_cte); // Warm-up run diff --git a/src/plot/layer/geom/smooth.rs b/src/plot/layer/geom/smooth.rs index e8d55854..456f293e 100644 --- a/src/plot/layer/geom/smooth.rs +++ b/src/plot/layer/geom/smooth.rs @@ -1,8 +1,11 @@ //! Smooth geom implementation use super::{DefaultAesthetics, DefaultParam, DefaultParamValue, GeomTrait, GeomType}; +use crate::plot::geom::types::get_column_name; use crate::plot::types::DefaultAestheticValue; -use crate::Mappings; +use crate::plot::{ParameterValue, StatResult}; +use crate::reader::SqlDialect; +use crate::{naming, GgsqlError, Mappings, Result}; /// Smooth geom - smoothed conditional means (regression, LOESS, etc.) #[derive(Debug, Clone, Copy)] @@ -18,8 +21,9 @@ impl GeomTrait for Smooth { defaults: &[ ("pos1", DefaultAestheticValue::Required), ("pos2", DefaultAestheticValue::Required), + ("weight", DefaultAestheticValue::Null), ("stroke", DefaultAestheticValue::String("#3366FF")), - ("linewidth", DefaultAestheticValue::Number(1.0)), + ("linewidth", DefaultAestheticValue::Number(2.0)), ("opacity", DefaultAestheticValue::Number(1.0)), ("linetype", DefaultAestheticValue::String("solid")), ], @@ -27,17 +31,80 @@ impl GeomTrait for Smooth { } fn default_params(&self) -> &'static [DefaultParam] { - &[DefaultParam { - name: "position", - default: DefaultParamValue::String("identity"), - }] + &[ + DefaultParam { + name: "position", + default: DefaultParamValue::String("identity"), + }, + DefaultParam { + name: "method", + default: DefaultParamValue::String("nw"), + }, + DefaultParam { + name: "bandwidth", + default: DefaultParamValue::Null, + }, + DefaultParam { + name: "adjust", + default: DefaultParamValue::Number(1.0), + }, + DefaultParam { + name: "kernel", + default: DefaultParamValue::String("gaussian"), + }, + ] } fn needs_stat_transform(&self, _aesthetics: &Mappings) -> bool { true } - // Note: stat_smooth not yet implemented - will return Identity for now + fn default_remappings(&self) -> &'static [(&'static str, DefaultAestheticValue)] { + &[ + ("pos1", DefaultAestheticValue::Column("pos1")), + ("pos2", DefaultAestheticValue::Column("intensity")), + ] + } + + fn apply_stat_transform( + &self, + query: &str, + _schema: &crate::plot::Schema, + aesthetics: &Mappings, + group_by: &[String], + parameters: &std::collections::HashMap, + _execute_query: &dyn Fn(&str) -> crate::Result, + dialect: &dyn SqlDialect, + ) -> crate::Result { + let Some(ParameterValue::String(method)) = parameters.get("method") else { + return Err(GgsqlError::ValidationError( + "The `method` setting must be a string.".to_string(), + )); + }; + + match method.as_str() { + "nw" | "nadaraya-watson" => { + // Smooth geom: hardcode tails=0.0 (trim exactly to data range, no extrapolation) + let mut params = parameters.clone(); + params.insert("tails".to_string(), ParameterValue::Number(0.0)); + + super::density::stat_density( + query, + aesthetics, + "pos1", + Some("pos2"), + group_by, + ¶ms, + dialect, + ) + } + "ols" => stat_ols(query, aesthetics, group_by), + "tls" => stat_tls(query, aesthetics, group_by), + _ => Err(GgsqlError::ValidationError( + "The `method` setting must be 'nw', 'ols', or 'tls'.".to_string(), + )), + } + } } impl std::fmt::Display for Smooth { @@ -45,3 +112,344 @@ impl std::fmt::Display for Smooth { write!(f, "smooth") } } + +fn stat_ols(query: &str, aesthetics: &Mappings, group_by: &[String]) -> Result { + let x_col = get_column_name(aesthetics, "pos1").ok_or_else(|| { + GgsqlError::ValidationError("Smooth requires 'pos1' aesthetic".to_string()) + })?; + let y_col = get_column_name(aesthetics, "pos2").ok_or_else(|| { + GgsqlError::ValidationError("Smooth requires 'pos2' aesthetic".to_string()) + })?; + + // Build group-related SQL fragments + let (groups_str, group_by_clause) = if group_by.is_empty() { + (String::new(), String::new()) + } else { + ( + format!("{}, ", group_by.join(", ")), + format!("GROUP BY {}", group_by.join(", ")), + ) + }; + + // Compute regression coefficients and predict at min and max x values + // We use UNION ALL to get two rows per group (one for x_min, one for x_max) + // Slope: (E[XY] - E[X]E[Y]) / (E[X²] - E[X]²) + // Fitted: E[Y] + slope * (x - E[X]) + let final_query = format!( + "WITH + coefficients AS ( + SELECT + {groups}AVG({x}) AS x_mean, + AVG({y}) AS y_mean, + AVG({x} * {y}) AS xy_mean, + AVG({x} * {x}) AS xx_mean, + MIN({x}) AS x_min, + MAX({x}) AS x_max + FROM ({data}) + WHERE {x} IS NOT NULL AND {y} IS NOT NULL + {group_by} + ) + SELECT + {groups}x_min AS {x_out}, + (y_mean + ((xy_mean - x_mean * y_mean) / (xx_mean - x_mean * x_mean)) * (x_min - x_mean)) AS {y_out} + FROM coefficients + UNION ALL + SELECT + {groups}x_max AS {x_out}, + (y_mean + ((xy_mean - x_mean * y_mean) / (xx_mean - x_mean * x_mean)) * (x_max - x_mean)) AS {y_out} + FROM coefficients", + groups = groups_str, + x = x_col, + y = y_col, + data = query, + x_out = naming::stat_column("pos1"), + y_out = naming::stat_column("intensity"), // We name this 'intensity' to be consistent with the nadaraya-watson kernel + group_by = group_by_clause + ); + + Ok(StatResult::Transformed { + query: final_query, + stat_columns: vec!["pos1".to_string(), "intensity".to_string()], + dummy_columns: vec![], + consumed_aesthetics: vec!["pos1".to_string(), "pos2".to_string()], + }) +} + +fn stat_tls(query: &str, aesthetics: &Mappings, group_by: &[String]) -> Result { + let x_col = get_column_name(aesthetics, "pos1").ok_or_else(|| { + GgsqlError::ValidationError("Smooth requires 'pos1' aesthetic".to_string()) + })?; + let y_col = get_column_name(aesthetics, "pos2").ok_or_else(|| { + GgsqlError::ValidationError("Smooth requires 'pos2' aesthetic".to_string()) + })?; + + // Build group-related SQL fragments + let (groups_str, group_by_clause) = if group_by.is_empty() { + (String::new(), String::new()) + } else { + ( + format!("{}, ", group_by.join(", ")), + format!("GROUP BY {}", group_by.join(", ")), + ) + }; + + // Compute Total Least Squares (orthogonal regression) + // TLS minimizes perpendicular distances, not vertical distances + // Slope: β = (Var(y) - Var(x) + sqrt((Var(y) - Var(x))² + 4*Cov(x,y)²)) / (2*Cov(x,y)) + // Where: Var(x) = E[x²] - E[x]², Var(y) = E[y²] - E[y]², Cov(x,y) = E[xy] - E[x]E[y] + let final_query = format!( + "WITH + coefficients AS ( + SELECT + {groups}AVG({x}) AS x_mean, + AVG({y}) AS y_mean, + AVG({x} * {y}) AS xy_mean, + AVG({x} * {x}) AS xx_mean, + AVG({y} * {y}) AS yy_mean, + MIN({x}) AS x_min, + MAX({x}) AS x_max + FROM ({data}) + WHERE {x} IS NOT NULL AND {y} IS NOT NULL + {group_by} + ), + tls_coefficients AS ( + SELECT + {groups}x_mean, + y_mean, + (yy_mean - y_mean * y_mean) - (xx_mean - x_mean * x_mean) AS var_diff, + (xy_mean - x_mean * y_mean) AS covariance, + x_min, + x_max + FROM coefficients + ) + SELECT + {groups}x_min AS {x_out}, + (y_mean + ((var_diff + SQRT(var_diff * var_diff + 4 * covariance * covariance)) / (2 * covariance)) * (x_min - x_mean)) AS {y_out} + FROM tls_coefficients + UNION ALL + SELECT + {groups}x_max AS {x_out}, + (y_mean + ((var_diff + SQRT(var_diff * var_diff + 4 * covariance * covariance)) / (2 * covariance)) * (x_max - x_mean)) AS {y_out} + FROM tls_coefficients", + groups = groups_str, + x = x_col, + y = y_col, + data = query, + x_out = naming::stat_column("pos1"), + y_out = naming::stat_column("intensity"), + group_by = group_by_clause + ); + + Ok(StatResult::Transformed { + query: final_query, + stat_columns: vec!["pos1".to_string(), "intensity".to_string()], + dummy_columns: vec![], + consumed_aesthetics: vec!["pos1".to_string(), "pos2".to_string()], + }) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::plot::AestheticValue; + use crate::reader::duckdb::DuckDBReader; + use crate::reader::Reader; + + #[test] + fn test_stat_ols_ungrouped() { + let reader = DuckDBReader::from_connection_string("duckdb://memory").unwrap(); + + let query = "SELECT x, y FROM (VALUES (1.0, 2.0), (2.0, 4.0), (3.0, 6.0)) AS t(x, y)"; + let groups: Vec = vec![]; + + let mut mapping = crate::Mappings::new(); + mapping.aesthetics.insert( + "pos1".to_string(), + AestheticValue::Column { + name: "x".to_string(), + original_name: None, + is_dummy: false, + }, + ); + mapping.aesthetics.insert( + "pos2".to_string(), + AestheticValue::Column { + name: "y".to_string(), + original_name: None, + is_dummy: false, + }, + ); + + let result = stat_ols(query, &mapping, &groups).expect("stat_ols should succeed"); + + if let StatResult::Transformed { + query: sql, + stat_columns, + .. + } = result + { + assert_eq!(stat_columns, vec!["pos1", "intensity"]); + + let df = reader.execute_sql(&sql).expect("SQL should execute"); + + // Should have 2 rows (min and max x) + assert_eq!(df.height(), 2); + assert_eq!( + df.get_column_names(), + vec!["__ggsql_stat_pos1", "__ggsql_stat_intensity"] + ); + } else { + panic!("Expected Transformed result"); + } + } + + #[test] + fn test_stat_ols_grouped() { + let reader = DuckDBReader::from_connection_string("duckdb://memory").unwrap(); + + let query = "SELECT x, y, category FROM (VALUES + (1.0, 2.0, 'A'), (2.0, 4.0, 'A'), (3.0, 6.0, 'A'), + (1.0, 3.0, 'B'), (2.0, 5.0, 'B'), (3.0, 7.0, 'B') + ) AS t(x, y, category)"; + let groups = vec!["category".to_string()]; + + let mut mapping = crate::Mappings::new(); + mapping.aesthetics.insert( + "pos1".to_string(), + AestheticValue::Column { + name: "x".to_string(), + original_name: None, + is_dummy: false, + }, + ); + mapping.aesthetics.insert( + "pos2".to_string(), + AestheticValue::Column { + name: "y".to_string(), + original_name: None, + is_dummy: false, + }, + ); + + let result = stat_ols(query, &mapping, &groups).expect("stat_ols should succeed"); + + if let StatResult::Transformed { + query: sql, + stat_columns, + .. + } = result + { + assert_eq!(stat_columns, vec!["pos1", "intensity"]); + + let df = reader.execute_sql(&sql).expect("SQL should execute"); + + // Should have 4 rows (2 points × 2 groups) + assert_eq!(df.height(), 4); + assert_eq!( + df.get_column_names(), + vec!["category", "__ggsql_stat_pos1", "__ggsql_stat_intensity"] + ); + } else { + panic!("Expected Transformed result"); + } + } + + #[test] + fn test_stat_tls_ungrouped() { + let reader = DuckDBReader::from_connection_string("duckdb://memory").unwrap(); + + let query = "SELECT x, y FROM (VALUES (1.0, 2.0), (2.0, 4.0), (3.0, 6.0)) AS t(x, y)"; + let groups: Vec = vec![]; + + let mut mapping = crate::Mappings::new(); + mapping.aesthetics.insert( + "pos1".to_string(), + AestheticValue::Column { + name: "x".to_string(), + original_name: None, + is_dummy: false, + }, + ); + mapping.aesthetics.insert( + "pos2".to_string(), + AestheticValue::Column { + name: "y".to_string(), + original_name: None, + is_dummy: false, + }, + ); + + let result = stat_tls(query, &mapping, &groups).expect("stat_tls should succeed"); + + if let StatResult::Transformed { + query: sql, + stat_columns, + .. + } = result + { + assert_eq!(stat_columns, vec!["pos1", "intensity"]); + + let df = reader.execute_sql(&sql).expect("SQL should execute"); + + // Should have 2 rows (min and max x) + assert_eq!(df.height(), 2); + assert_eq!( + df.get_column_names(), + vec!["__ggsql_stat_pos1", "__ggsql_stat_intensity"] + ); + } else { + panic!("Expected Transformed result"); + } + } + + #[test] + fn test_stat_tls_grouped() { + let reader = DuckDBReader::from_connection_string("duckdb://memory").unwrap(); + + let query = "SELECT x, y, category FROM (VALUES + (1.0, 2.0, 'A'), (2.0, 4.0, 'A'), (3.0, 6.0, 'A'), + (1.0, 3.0, 'B'), (2.0, 5.0, 'B'), (3.0, 7.0, 'B') + ) AS t(x, y, category)"; + let groups = vec!["category".to_string()]; + + let mut mapping = crate::Mappings::new(); + mapping.aesthetics.insert( + "pos1".to_string(), + AestheticValue::Column { + name: "x".to_string(), + original_name: None, + is_dummy: false, + }, + ); + mapping.aesthetics.insert( + "pos2".to_string(), + AestheticValue::Column { + name: "y".to_string(), + original_name: None, + is_dummy: false, + }, + ); + + let result = stat_tls(query, &mapping, &groups).expect("stat_tls should succeed"); + + if let StatResult::Transformed { + query: sql, + stat_columns, + .. + } = result + { + assert_eq!(stat_columns, vec!["pos1", "intensity"]); + + let df = reader.execute_sql(&sql).expect("SQL should execute"); + + // Should have 4 rows (2 points × 2 groups) + assert_eq!(df.height(), 4); + assert_eq!( + df.get_column_names(), + vec!["category", "__ggsql_stat_pos1", "__ggsql_stat_intensity"] + ); + } else { + panic!("Expected Transformed result"); + } + } +} diff --git a/src/plot/layer/geom/violin.rs b/src/plot/layer/geom/violin.rs index 60192d98..b213ebb7 100644 --- a/src/plot/layer/geom/violin.rs +++ b/src/plot/layer/geom/violin.rs @@ -63,6 +63,10 @@ impl GeomTrait for Violin { name: "width", default: DefaultParamValue::Number(0.9), }, + DefaultParam { + name: "tails", + default: DefaultParamValue::Number(3.0), + }, ] } @@ -88,17 +92,10 @@ impl GeomTrait for Violin { aesthetics: &Mappings, group_by: &[String], parameters: &HashMap, - execute_query: &dyn Fn(&str) -> crate::Result, + _execute_query: &dyn Fn(&str) -> crate::Result, dialect: &dyn crate::reader::SqlDialect, ) -> Result { - stat_violin( - query, - aesthetics, - group_by, - parameters, - execute_query, - dialect, - ) + stat_violin(query, aesthetics, group_by, parameters, dialect) } /// Post-process the violin DataFrame to scale offset to [0, 0.5 * width]. @@ -172,7 +169,6 @@ fn stat_violin( aesthetics: &Mappings, group_by: &[String], parameters: &HashMap, - execute: &dyn Fn(&str) -> crate::Result, dialect: &dyn crate::reader::SqlDialect, ) -> Result { // Verify y exists @@ -194,13 +190,14 @@ fn stat_violin( )); } + // Violin uses tails parameter from user (default 3.0 set in default_params) super::density::stat_density( query, aesthetics, "pos2", + None, group_by.as_slice(), parameters, - execute, dialect, ) } @@ -264,15 +261,8 @@ mod tests { let execute = |sql: &str| reader.execute_sql(sql); - let result = stat_violin( - query, - &aesthetics, - &groups, - ¶meters, - &execute, - &AnsiDialect, - ) - .expect("stat_violin should succeed"); + let result = stat_violin(query, &aesthetics, &groups, ¶meters, &AnsiDialect) + .expect("stat_violin should succeed"); // Verify the result is a transformed stat result match result { @@ -336,15 +326,8 @@ mod tests { let execute = |sql: &str| reader.execute_sql(sql); - let result = stat_violin( - query, - &aesthetics, - &groups, - ¶meters, - &execute, - &AnsiDialect, - ) - .expect("stat_violin should succeed"); + let result = stat_violin(query, &aesthetics, &groups, ¶meters, &AnsiDialect) + .expect("stat_violin should succeed"); // Verify the result is a transformed stat result match result { @@ -414,6 +397,77 @@ mod tests { } } + #[test] + fn test_violin_tails_parameter() { + // Verify that the violin geom has a tails parameter with default 3.0 + let violin = Violin; + let params = violin.default_params(); + + let tails_param = params.iter().find(|p| p.name == "tails"); + assert!( + tails_param.is_some(), + "Violin should have a 'tails' parameter" + ); + + if let Some(param) = tails_param { + match param.default { + DefaultParamValue::Number(n) => { + assert!( + (n - 3.0).abs() < 1e-6, + "Default tails should be 3.0, got {}", + n + ); + } + _ => panic!("Tails parameter should have a numeric default"), + } + } + + // Test with custom tails value + let query = "SELECT species, flipper_length FROM penguins"; + let aesthetics = create_basic_aesthetics(); + let groups: Vec = vec![]; + let mut parameters = HashMap::new(); + parameters.insert("bandwidth".to_string(), ParameterValue::Number(5.0)); + parameters.insert( + "kernel".to_string(), + ParameterValue::String("gaussian".to_string()), + ); + parameters.insert("tails".to_string(), ParameterValue::Number(1.5)); // Custom tails + + let reader = DuckDBReader::from_connection_string("duckdb://memory").unwrap(); + + // Create test data + let setup_sql = "CREATE TABLE penguins AS SELECT * FROM (VALUES + ('Adelie', 181.0), ('Adelie', 186.0), ('Adelie', 195.0), + ('Gentoo', 217.0), ('Gentoo', 221.0), ('Gentoo', 230.0) + ) AS t(species, flipper_length)"; + reader.execute_sql(setup_sql).unwrap(); + + let execute = |sql: &str| reader.execute_sql(sql); + + let result = stat_violin(query, &aesthetics, &groups, ¶meters, &AnsiDialect) + .expect("stat_violin with custom tails should succeed"); + + // Verify the SQL includes the tails constraint + match result { + StatResult::Transformed { + query: stat_query, .. + } => { + // The generated SQL should include the tails filtering + // We verify this by checking the SQL contains the bandwidth filtering + assert!( + stat_query.contains("1.5"), + "SQL should contain the custom tails value 1.5" + ); + + // Execute and verify it produces results + let df = execute(&stat_query).expect("Generated SQL should execute"); + assert!(df.height() > 0, "Should produce density data"); + } + _ => panic!("Expected Transformed result"), + } + } + // ==================== Post-Process Tests ==================== #[test] diff --git a/src/writer/vegalite/layer.rs b/src/writer/vegalite/layer.rs index 92cdc261..3a82cf19 100644 --- a/src/writer/vegalite/layer.rs +++ b/src/writer/vegalite/layer.rs @@ -41,6 +41,7 @@ pub fn geom_to_mark(geom: &Geom) -> Value { GeomType::Boxplot => "boxplot", GeomType::Text => "text", GeomType::Segment => "rule", + GeomType::Smooth => "line", GeomType::Rule => "rule", GeomType::Linear => "rule", GeomType::ErrorBar => "rule", @@ -1333,11 +1334,6 @@ impl GeomRenderer for ViolinRenderer { continuous_col, continuous_col ); - // Filter threshold to trim very low density regions (removes thin tails) - // The offset is pre-scaled to [0, 0.5 * width] by geom post_process, - // but this filter still catches extremely low values. - let filter_expr = format!("datum.{} > 0.001", offset_col); - // Preserve existing transforms (e.g., source filter) and extend with violin-specific transforms let existing_transforms = layer_spec .get("transform") @@ -1350,10 +1346,6 @@ impl GeomRenderer for ViolinRenderer { let mut transforms = existing_transforms; transforms.extend(vec![ - json!({ - // Remove points with very low density to clean up thin tails - "filter": filter_expr - }), json!({ // Mirror offset on both sides (offset is pre-scaled to [0, 0.5 * width]) "calculate": violin_offset,