Skip to content
Open
2 changes: 0 additions & 2 deletions malariagen_data/af1.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@
"funestus": TAXON_PALETTE[0],
}

XPEHH_GWSS_CACHE_NAME = "af1_xpehh_gwss_v1"
IHS_GWSS_CACHE_NAME = "af1_ihs_gwss_v1"
ROH_HMM_CACHE_NAME = "af1_roh_hmm_v1"


Expand Down
10 changes: 5 additions & 5 deletions malariagen_data/anoph/cnv_frq.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,14 +545,15 @@ def _gene_cnv_frequencies_advanced(
filter_unassigned=filter_unassigned,
)

debug("group samples to make cohorts")
group_samples_by_cohort = df_samples.groupby([taxon_by, "area", "period"])
# Group samples to make cohorts.
group_samples_by_cohort = df_samples.groupby(
["cohort_taxon", "cohort_area", "cohort_period"]
)

debug("build cohorts dataframe")
df_cohorts = _build_cohorts_from_sample_grouping(
group_samples_by_cohort=group_samples_by_cohort,
min_cohort_size=min_cohort_size,
taxon_by=taxon_by,
)

debug("figure out expected copy number")
Expand All @@ -577,8 +578,7 @@ def _gene_cnv_frequencies_advanced(
debug("build event count and nobs for each cohort")
for cohort_index, cohort in enumerate(df_cohorts.itertuples()):
# construct grouping key
cohort_taxon = getattr(cohort, taxon_by)
cohort_key = cohort_taxon, cohort.area, cohort.period
cohort_key = cohort.taxon, cohort.area, cohort.period

# obtain sample indices for cohort
sample_indices = group_samples_by_cohort.indices[cohort_key]
Expand Down
58 changes: 31 additions & 27 deletions malariagen_data/anoph/frq_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,11 @@ def _prep_samples_for_cohort_grouping(
f"Invalid values in {period_by!r} column. Must be either pandas.Period or null."
)

# Copy the specified period_by column to a new "period" column.
df_samples["period"] = df_samples[period_by]
# Copy the specified period_by column to a new "cohort_period" column.
df_samples["cohort_period"] = df_samples[period_by]
else:
# Use the vectorized period creation function.
df_samples["period"] = period_by_func_vectorized(df_samples)
df_samples["cohort_period"] = period_by_func_vectorized(df_samples)

# Validate area_by.
if area_by not in df_samples.columns:
Expand All @@ -90,15 +90,18 @@ def _prep_samples_for_cohort_grouping(
f"Must be the name of an existing column in the sample metadata."
)

# Copy the specified area_by column to a new "area" column.
df_samples["area"] = df_samples[area_by]
# Copy the specified area_by column to a new "cohort_area" column.
df_samples["cohort_area"] = df_samples[area_by]

# Copy the specified taxon_by column to a new "cohort_taxon" column,
# normalizing it like area_by and period_by.
# See: https://github.com/malariagen/malariagen-data-python/issues/808
df_samples["cohort_taxon"] = df_samples[taxon_by]

return df_samples


def _build_cohorts_from_sample_grouping(
*, group_samples_by_cohort, min_cohort_size, taxon_by
):
def _build_cohorts_from_sample_grouping(*, group_samples_by_cohort, min_cohort_size):
# Build cohorts dataframe.
df_cohorts = group_samples_by_cohort.agg(
size=("sample_id", len),
Expand All @@ -112,7 +115,21 @@ def _build_cohorts_from_sample_grouping(
# Reset index so that the index fields are included as columns.
df_cohorts = df_cohorts.reset_index()

# Rename cohort_ fields back to standard fields to maintain API compatibility
df_cohorts.rename(
columns={
"cohort_taxon": "taxon",
"cohort_area": "area",
"cohort_period": "period",
},
inplace=True,
)

# Add cohort helper variables.
cohort_period_start = df_cohorts["period"].apply(lambda v: v.start_time)
cohort_period_end = df_cohorts["period"].apply(lambda v: v.end_time)
df_cohorts["period_start"] = cohort_period_start
df_cohorts["period_end"] = cohort_period_end
# Vectorized extraction of period start/end times.
period = df_cohorts["period"]
if pd.api.types.is_period_dtype(period.dtype):
Expand All @@ -127,25 +144,12 @@ def _build_cohorts_from_sample_grouping(
lambda v: v.end_time if pd.notna(v) else pd.NaT
)

# Create a label that is similar to the cohort metadata,
# although this won't be perfect.
# Vectorized string operations
if taxon_by == frq_params.taxon_by_default:
# Default case: area_taxon_short_period
area_str = df_cohorts["area"].astype(str)
taxon_short = df_cohorts[taxon_by].astype(str).str.slice(0, 4)
period_str = df_cohorts["period"].astype(str)
df_cohorts["label"] = area_str + "_" + taxon_short + "_" + period_str
else:
# Non-default case: replace non-alphanumeric characters with underscores
area_str = df_cohorts["area"].astype(str)
taxon_clean = (
df_cohorts[taxon_by]
.astype(str)
.str.replace(r"[^A-Za-z0-9]+", "_", regex=True)
)
period_str = df_cohorts["period"].astype(str)
df_cohorts["label"] = area_str + "_" + taxon_clean + "_" + period_str
# Create a label using the normalized "taxon" column.
# Vectorized string operations for better performance
area_str = df_cohorts["area"].astype(str)
taxon_short = df_cohorts["taxon"].astype(str).str.slice(0, 4)
period_str = df_cohorts["period"].astype(str)
df_cohorts["label"] = area_str + "_" + taxon_short + "_" + period_str

# Apply minimum cohort size.
df_cohorts = df_cohorts.query(f"size >= {min_cohort_size}").reset_index(drop=True)
Expand Down
10 changes: 5 additions & 5 deletions malariagen_data/anoph/hap_frq.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,13 +177,14 @@ def haplotypes_frequencies_advanced(
)

# Group samples to make cohorts.
group_samples_by_cohort = df_samples.groupby([taxon_by, "area", "period"])
group_samples_by_cohort = df_samples.groupby(
["cohort_taxon", "cohort_area", "cohort_period"]
)

# Build cohorts dataframe.
df_cohorts = _build_cohorts_from_sample_grouping(
group_samples_by_cohort=group_samples_by_cohort,
min_cohort_size=min_cohort_size,
taxon_by=taxon_by,
)

# Access haplotypes.
Expand Down Expand Up @@ -220,9 +221,8 @@ def haplotypes_frequencies_advanced(
df_cohorts.itertuples(), desc="Compute allele frequencies"
)
for cohort in cohorts_iterator:
cohort_taxon = getattr(cohort, taxon_by)
cohort_key = cohort_taxon, cohort.area, cohort.period
cohort_key_str = cohort_taxon + "_" + cohort.area + "_" + str(cohort.period)
cohort_key = cohort.taxon, cohort.area, cohort.period
cohort_key_str = cohort.taxon + "_" + cohort.area + "_" + str(cohort.period)
# We reset all frequencies, counts to 0 for each cohort, nobs is set to the number of haplotypes
n_samples = cohort.size
hap_freq = {k: 0 for k in f_all.keys()}
Expand Down
14 changes: 5 additions & 9 deletions malariagen_data/anoph/snp_frq.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,13 +534,14 @@ def snp_allele_frequencies_advanced(
)

# Group samples to make cohorts.
group_samples_by_cohort = df_samples.groupby([taxon_by, "area", "period"])
group_samples_by_cohort = df_samples.groupby(
["cohort_taxon", "cohort_area", "cohort_period"]
)

# Build cohorts dataframe.
df_cohorts = _build_cohorts_from_sample_grouping(
group_samples_by_cohort=group_samples_by_cohort,
min_cohort_size=min_cohort_size,
taxon_by=taxon_by,
)

# Early check for no cohorts.
Expand Down Expand Up @@ -596,8 +597,7 @@ def snp_allele_frequencies_advanced(
desc="Compute SNP allele frequencies",
)
for cohort_index, cohort in cohorts_iterator:
cohort_taxon = getattr(cohort, taxon_by)
cohort_key = cohort_taxon, cohort.area, cohort.period
cohort_key = cohort.taxon, cohort.area, cohort.period
sample_indices = group_samples_by_cohort.indices[cohort_key]

cohort_ac, cohort_an = _cohort_alt_allele_counts_melt(
Expand Down Expand Up @@ -673,11 +673,7 @@ def snp_allele_frequencies_advanced(

# Cohort variables.
for coh_col in df_cohorts.columns:
if coh_col == taxon_by:
# Other functions expect cohort_taxon, e.g. plot_frequencies_interactive_map()
ds_out["cohort_taxon"] = "cohorts", df_cohorts[coh_col]
else:
ds_out[f"cohort_{coh_col}"] = "cohorts", df_cohorts[coh_col]
ds_out[f"cohort_{coh_col}"] = "cohorts", df_cohorts[coh_col]

# Variant variables.
for snp_col in df_variants.columns:
Expand Down
66 changes: 66 additions & 0 deletions tests/anoph/test_frq_base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Tests for _prep_samples_for_cohort_grouping filter_unassigned behavior.

See: https://github.com/malariagen/malariagen-data-python/issues/806
See: https://github.com/malariagen/malariagen-data-python/issues/808
"""

import pandas as pd
Expand All @@ -12,6 +13,7 @@ def _make_test_df(taxon_col="taxon"):
"""Create a test DataFrame with intermediate and unassigned taxon values."""
return pd.DataFrame(
{
"sample_id": ["S1", "S2", "S3", "S4"],
taxon_col: [
"gambiae",
"intermediate_gambcolu_arabiensis",
Expand All @@ -25,6 +27,52 @@ def _make_test_df(taxon_col="taxon"):
)


class TestPrepSamplesNormalizeTaxon:
"""Tests for taxon_by normalization to standard 'taxon' column. See #808."""

def test_default_taxon_column_unchanged(self):
"""When taxon_by='taxon', cohort_taxon is created from the taxon column."""
df = _make_test_df(taxon_col="taxon")
result = _prep_samples_for_cohort_grouping(
df_samples=df,
area_by="admin1_iso",
period_by="year",
taxon_by="taxon",
filter_unassigned=False,
)
assert "cohort_taxon" in result.columns
assert result["cohort_taxon"].iloc[0] == "gambiae"
# Original taxon column is preserved
assert "taxon" in result.columns

def test_custom_taxon_creates_cohort_column(self):
"""When taxon_by is custom, a 'cohort_taxon' column is created."""
df = _make_test_df(taxon_col="custom_taxon")
result = _prep_samples_for_cohort_grouping(
df_samples=df,
area_by="admin1_iso",
period_by="year",
taxon_by="custom_taxon",
filter_unassigned=False,
)
assert "cohort_taxon" in result.columns
assert result["cohort_taxon"].iloc[0] == "gambiae"
assert "custom_taxon" in result.columns

def test_area_column_created(self):
"""area_by is normalized to 'cohort_area' column."""
df = _make_test_df(taxon_col="taxon")
result = _prep_samples_for_cohort_grouping(
df_samples=df,
area_by="admin1_iso",
period_by="year",
taxon_by="taxon",
filter_unassigned=False,
)
assert "cohort_area" in result.columns
assert result["cohort_area"].iloc[0] == "KE-01"


class TestPrepSamplesFilterUnassigned:
"""Tests for the filter_unassigned parameter in _prep_samples_for_cohort_grouping."""

Expand All @@ -38,10 +86,16 @@ def test_default_taxon_column_filters(self):
period_by="year",
taxon_by="taxon",
)
# The original taxon column is filtered in-place
assert result["taxon"].iloc[0] == "gambiae"
assert result["taxon"].iloc[1] is None
assert result["taxon"].iloc[2] is None
assert result["taxon"].iloc[3] == "coluzzii"
# cohort_taxon is created from the filtered taxon column
assert result["cohort_taxon"].iloc[0] == "gambiae"
assert result["cohort_taxon"].iloc[1] is None
assert result["cohort_taxon"].iloc[2] is None
assert result["cohort_taxon"].iloc[3] == "coluzzii"

def test_custom_column_preserves(self):
"""When taxon_by is a custom column and filter_unassigned=None (default),
Expand All @@ -58,6 +112,12 @@ def test_custom_column_preserves(self):
assert result["custom_taxon"].iloc[2] == "unassigned"
assert result["custom_taxon"].iloc[3] == "coluzzii"

# Under PR 997, the custom_taxon is also copied to cohort_taxon
assert result["cohort_taxon"].iloc[0] == "gambiae"
assert result["cohort_taxon"].iloc[1] == "intermediate_gambcolu_arabiensis"
assert result["cohort_taxon"].iloc[2] == "unassigned"
assert result["cohort_taxon"].iloc[3] == "coluzzii"

def test_explicit_filter_true(self):
"""When filter_unassigned=True, always filter regardless of column name."""
df = _make_test_df(taxon_col="custom_taxon")
Expand All @@ -73,6 +133,9 @@ def test_explicit_filter_true(self):
assert result["custom_taxon"].iloc[2] is None
assert result["custom_taxon"].iloc[3] == "coluzzii"

# Under PR 997, it matches cohort_taxon
assert result["cohort_taxon"].iloc[1] is None

def test_explicit_filter_false(self):
"""When filter_unassigned=False, never filter even for default 'taxon' column."""
df = _make_test_df(taxon_col="taxon")
Expand All @@ -87,6 +150,9 @@ def test_explicit_filter_false(self):
assert result["taxon"].iloc[1] == "intermediate_gambcolu_arabiensis"
assert result["taxon"].iloc[2] == "unassigned"
assert result["taxon"].iloc[3] == "coluzzii"
# cohort_taxon follows the same (unfiltered) values
assert result["cohort_taxon"].iloc[0] == "gambiae"
assert result["cohort_taxon"].iloc[1] == "intermediate_gambcolu_arabiensis"

def test_does_not_modify_original(self):
"""Ensure the original DataFrame is not modified."""
Expand Down
Loading