From 3879c32dbc95aacb7bb9236734e3efaf5b6004b7 Mon Sep 17 00:00:00 2001 From: esoteric-ephemera Date: Fri, 3 Apr 2026 09:07:15 -0700 Subject: [PATCH] basic pagination + sorting --- mp_api/client/core/client.py | 7 +++++ mp_api/client/routes/materials/electrodes.py | 6 +++- mp_api/client/routes/materials/summary.py | 25 ++++++++++----- tests/client/materials/test_electrodes.py | 30 ++++++++++++++++++ tests/client/materials/test_summary.py | 32 +++++++++++++++++++- 5 files changed, 90 insertions(+), 10 deletions(-) diff --git a/mp_api/client/core/client.py b/mp_api/client/core/client.py index e2127b22..2dace920 100644 --- a/mp_api/client/core/client.py +++ b/mp_api/client/core/client.py @@ -965,6 +965,10 @@ def _submit_requests( # noqa # No splitting needed - get first page total_data = {"data": []} initial_criteria = copy(criteria) + if isinstance( + initial_criteria.get("_page"), int + ) and not initial_criteria.get("_per_page"): + initial_criteria["_per_page"] = initial_criteria.get("_limit") data, total_num_docs = self._submit_request_and_process( url=url, verify=True, @@ -1438,6 +1442,9 @@ def _search( # This method should be customized for each end point to give more user friendly, # documented kwargs. + # If user specifies page, ensure only one chunk is returned + if isinstance(kwargs.get("_page"), int) and num_chunks is None: + num_chunks = 1 return self._get_all_documents( kwargs, all_fields=all_fields, diff --git a/mp_api/client/routes/materials/electrodes.py b/mp_api/client/routes/materials/electrodes.py index 71a469d6..583582c1 100644 --- a/mp_api/client/routes/materials/electrodes.py +++ b/mp_api/client/routes/materials/electrodes.py @@ -15,7 +15,7 @@ class BaseElectrodeRester(BaseRester): primary_key = "battery_id" _exclude_search_fields: list[str] | None = None - def search( # pragma: ignore + def search( self, battery_ids: str | list[str] | None = None, average_voltage: tuple[float, float] | None = None, @@ -39,6 +39,8 @@ def search( # pragma: ignore chunk_size: int = 1000, all_fields: bool = True, fields: list[str] | None = None, + _page: int | None = None, + _sort_fields: str | None = None, ) -> list[InsertionElectrodeDoc | ConversionElectrodeDoc] | list[dict]: """Query using a variety of search criteria. @@ -77,6 +79,8 @@ def search( # pragma: ignore all_fields (bool): Whether to return all fields in the document. Defaults to True. fields (List[str]): List of fields in InsertionElectrodeDoc or ConversionElectrodeDoc to return data for. Default is battery_id and last_updated if all_fields is False. + _page (int or None) : Page of the results to skip to. + _sort_fields (str or None) : Field to sort on. Including a leading "-" sign will reverse sort order. Returns: ([InsertionElectrodeDoc or ConversionElectrodeDoc], [dict]) List of insertion/conversion electrode documents or dictionaries. diff --git a/mp_api/client/routes/materials/summary.py b/mp_api/client/routes/materials/summary.py index 1769781e..30a186b6 100644 --- a/mp_api/client/routes/materials/summary.py +++ b/mp_api/client/routes/materials/summary.py @@ -73,6 +73,8 @@ def search( # noqa: D417 chunk_size: int = 1000, all_fields: bool = True, fields: list[str] | None = None, + _page: int | None = None, + _sort_fields: str | None = None, **kwargs, ) -> list[SummaryDoc] | list[dict]: """Query core data using a variety of search criteria. @@ -150,6 +152,8 @@ def search( # noqa: D417 all_fields (bool): Whether to return all fields in the document. Defaults to True. fields (List[str]): List of fields in SummaryDoc to return data for. Default is material_id if all_fields is False. + _page (int or None) : Page of the results to skip to. + _sort_fields (str or None) : Field to sort on. Including a leading "-" sign will reverse sort order. Returns: ([SummaryDoc], [dict]) List of SummaryDoc documents or dictionaries. @@ -181,6 +185,8 @@ def search( # noqa: D417 "weighted_surface_energy", "weighted_work_function", "shape_factor", + "_page", + "_sort_fields", ] min_max_name_dict = { @@ -284,14 +290,17 @@ def _csrc(x): ) for param, value in user_settings.items(): - if isinstance(value, (int, float)): - value = (value, value) - query_params.update( - { - f"{min_max_name_dict[param]}_min": value[0], - f"{min_max_name_dict[param]}_max": value[1], - } - ) + if param.startswith("_"): + query_params[param] = value + else: + if isinstance(value, (int, float)): + value = (value, value) + query_params.update( + { + f"{min_max_name_dict[param]}_min": value[0], + f"{min_max_name_dict[param]}_max": value[1], + } + ) if material_ids: if isinstance(material_ids, str): diff --git a/tests/client/materials/test_electrodes.py b/tests/client/materials/test_electrodes.py index bc6e00ef..4953d36d 100644 --- a/tests/client/materials/test_electrodes.py +++ b/tests/client/materials/test_electrodes.py @@ -30,6 +30,8 @@ def conversion_rester(): "num_chunks", "all_fields", "fields", + "_page", + "_sort_fields", ] sub_doc_fields: list = [] @@ -80,3 +82,31 @@ def test_conversion_client(conversion_rester): }, sub_doc_fields=sub_doc_fields, ) + + +@pytest.mark.xfail(reason="sort requires API redeployment", strict=False) +@requires_api_key +def test_pagination_sort(): + num_docs = 5 + with ElectrodeRester() as rester: + results_page_1 = rester.search(_page=1, chunk_size=num_docs) + results_page_2 = rester.search(_page=2, chunk_size=num_docs) + assert all( + len(results) == num_docs for results in (results_page_1, results_page_2) + ) + assert {doc.battery_id for doc in results_page_1}.intersection( + {doc.battery_id for doc in results_page_2} + ) == set() + + ascending_e_hull = rester.search(_page=1, _sort_fields="average_voltage") + descending_e_hull = rester.search(_page=1, _sort_fields="-average_voltage") + + assert sorted( + range(num_docs), key=lambda idx: ascending_e_hull[idx].average_voltage + ) == list(range(num_docs)) + + assert sorted( + range(num_docs), + key=lambda idx: descending_e_hull[idx].average_voltage, + reverse=True, + ) == list(range(num_docs)) diff --git a/tests/client/materials/test_summary.py b/tests/client/materials/test_summary.py index 9d5b6398..e0766adf 100644 --- a/tests/client/materials/test_summary.py +++ b/tests/client/materials/test_summary.py @@ -1,10 +1,11 @@ import os from ..conftest import client_search_testing, requires_api_key -import pytest from emmet.core.summary import HasProps from emmet.core.symmetry import CrystalSystem +import numpy as np from pymatgen.analysis.magnetism import Ordering +import pytest from mp_api.client.routes.materials.summary import SummaryRester from mp_api.client.core.exceptions import MPRestWarning, MPRestError @@ -16,6 +17,8 @@ "num_chunks", "all_fields", "fields", + "_page", + "_sort_fields", ] alt_name_dict: dict = { @@ -134,3 +137,30 @@ def test_warning_messages(): with pytest.raises(MPRestError, match="not a valid property"): _ = search_method(num_elements=10, has_props=["apples"]) + + +@requires_api_key +def test_pagination_sort(): + num_docs = 5 + with SummaryRester() as rester: + results_page_1 = rester.search(_page=1, chunk_size=num_docs) + results_page_2 = rester.search(_page=2, chunk_size=num_docs) + assert all( + len(results) == num_docs for results in (results_page_1, results_page_2) + ) + assert {doc.material_id for doc in results_page_1}.intersection( + {doc.material_id for doc in results_page_2} + ) == set() + + ascending_e_hull = rester.search(_page=1, _sort_fields="energy_above_hull") + descending_e_hull = rester.search(_page=1, _sort_fields="-energy_above_hull") + + assert sorted( + range(num_docs), key=lambda idx: ascending_e_hull[idx].energy_above_hull + ) == list(range(num_docs)) + + assert sorted( + range(num_docs), + key=lambda idx: descending_e_hull[idx].energy_above_hull, + reverse=True, + ) == list(range(num_docs))