diff --git a/.gitignore b/.gitignore index 8dbd223f..44bf87ff 100644 --- a/.gitignore +++ b/.gitignore @@ -107,6 +107,7 @@ venv.bak/ downloaded_datasets/ .data/ .vscode/ +sst/ .idea/ *.pkl diff --git a/podium/datasets/dataset.py b/podium/datasets/dataset.py index e13919d4..975ba593 100644 --- a/podium/datasets/dataset.py +++ b/podium/datasets/dataset.py @@ -19,10 +19,17 @@ Optional, Sequence, Tuple, + TypeVar, Union, overload, ) + +try: + from typing import Protocol +except ImportError: + from typing_extensions import Protocol + import numpy as np from podium.field import Field, unpack_fields @@ -33,6 +40,39 @@ FieldType = Optional[Union[Field, List[Field]]] +class Comparable(Protocol): + """ + Protocol for annotating comparable types. + """ + + @abstractmethod + def __lt__(self: "CT", other: "CT") -> bool: + pass + + +CT = TypeVar("CT", bound=Comparable) + + +def _get_permutation(size, seed, generator): + if seed is not None and generator is not None: + raise ValueError( + "Both `seed` and `generator` were provided. Please specify just one of them." + ) + + if generator is None or isinstance(generator, np.random.Generator): + raise ValueError( + "The provided generator must be an instance of numpy.random.Generator" + ) + + if generator is None: + if seed is None: + seed = np.random.get_state()[1][0] + np.random.random() + generator = np.random.default_rng(seed) + + return generator.permutation(size) + + class DatasetBase(ABC): """ Abstract base class for all datasets in Podium. @@ -44,7 +84,7 @@ def __init__(self, fields: Union[Dict[str, FieldType], List[FieldType]]): @property def fields(self) -> Tuple[Field]: """ - List containing all fields of this dataset. + Tuple containing all fields of this dataset. """ return self._fields @@ -82,7 +122,7 @@ def __getattr__(self, field: Union[str, Field]) -> Iterator[Tuple[Any, Any]]: Parameters ---------- - field_name : str + field : str or podium.Field The name of the field whose values are to be returned. Returns @@ -104,7 +144,6 @@ def attr_generator(_dataset, _field_name): yield x[field_name] return attr_generator(self, field_name) - else: raise AttributeError(f"Dataset has no field {field_name}.") @@ -132,7 +171,7 @@ def finalize_fields(self, *datasets: "DatasetBase") -> None: data_sources.append(self) # for each example in each dataset, - # update _all_ non-eager fields + # update all non-eager fields for dataset in data_sources: for example in dataset: for field in fields_to_build: @@ -158,7 +197,7 @@ def batch(self) -> Tuple[NamedTuple, NamedTuple]: return next(iter(SingleBatchIterator(self, shuffle=False))) - def sorted(self, key: Callable[[Example], Any], reverse=False) -> "DatasetBase": + def sort(self, key: Callable[[Example], CT], reverse=False) -> "DatasetBase": """ Creates a new DatasetBase instance in which all Examples are sorted according to the value returned by `key`. @@ -186,7 +225,7 @@ def index_key(i): indices.sort(key=index_key, reverse=reverse) return self[indices] - def filtered(self, predicate: Callable[[Example], bool]) -> "DatasetBase": + def filter(self, predicate: Callable[[Example], bool]) -> "DatasetBase": """ Filters examples with given predicate and returns a new DatasetBase instance containing those examples. @@ -206,19 +245,29 @@ def filtered(self, predicate: Callable[[Example], bool]) -> "DatasetBase": indices = [i for i, example in enumerate(self) if predicate(example)] return self[indices] - def shuffled(self) -> "DatasetBase": + def shuffle( + self, seed: Optional[int] = None, generator: Optional[np.random.Generator] = None + ) -> "DatasetBase": """ Creates a new DatasetBase instance containing all Examples, but in shuffled order. + Parameters + ---------- + seed : int, optional + A seed used to initialized the default NumPy random Generator. + Default: None. + generator: np.random.Generator, optional + NumPy random Generator to use to compute the permutation. + Default: None. + Returns ------- DatasetBase A new DatasetBase instance containing all Examples, but in shuffled order. """ - shuffled_indices = np.random.permutation(len(self)) - return self[shuffled_indices] + return self[_get_permutation(len(self), seed, generator)] def __repr__(self): # Distribute field prints across lines for readability @@ -249,11 +298,13 @@ def __getitem__(self, i: int) -> Example: ... @overload - def __getitem__(self, i: Iterable[int]) -> "DatasetBase": + def __getitem__(self, i: Union[slice, Iterable[int]]) -> "DatasetBase": ... @abstractmethod - def __getitem__(self, i: slice) -> "DatasetBase": + def __getitem__( + self, i: Union[int, slice, Iterable[int]] + ) -> Union[Example, "DatasetBase"]: """ Returns an example or a new dataset containing the indexed examples. @@ -327,153 +378,149 @@ def __init__(self, examples, fields, sort_key=None): together examples with similar lengths to minimize padding. """ self._examples = examples - self.sort_key = sort_key + self._sort_key = sort_key super().__init__(fields) def __getitem__( self, i: Union[int, Iterable[int], slice] - ) -> Union["DatasetBase", Example]: - """ - Returns an example or a new dataset containing the indexed examples. - - If indexed with an int, only the example at that position will be returned. - If Indexed with a slice or iterable, all examples indexed by the object - will be collected and a new dataset containing only those examples will be - returned. The new dataset will contain copies of the old dataset's fields and - will be identical to the original dataset, with the exception of the example - number and ordering. See wiki for detailed examples. - - Examples in the returned Dataset are the same ones present in the - original dataset. If a complete deep-copy of the dataset, or its slice, - is needed please refer to the `get` method. + ) -> Union["Dataset", Example]: + if isinstance(i, int): + return self._examples[i] + + examples = ( + self.examples[i] + if isinstance(i, slice) + else [self._examples[idx] for idx in i] + ) + return Dataset(examples, self._fields) - Usage example: + def __len__(self) -> int: + return len(self._examples) - example = dataset[1] # Indexing by single integer returns a single example + def _get_examples(self) -> List[Example]: + return self._examples - new_dataset = dataset[1:10] # Multi-indexing returns a new dataset containing - # the indexed examples. + def sort( + self, key: Callable[[Example], CT], reverse=False, inplace: bool = False + ) -> "Dataset": + """ + Creates a new DatasetBase instance in which all Examples are sorted + according to the value returned by `key`. Parameters ---------- - i : int or slice or iterable - Index used to index examples. + key: callable + specifies a function of one argument that is used to extract a comparison key + from each Example. + reverse: bool + If set to True, then the list elements are sorted as if each comparison were + reversed. + inplace : bool + If True, the dataset is sorted in-place and returned. Returns ------- - single example or Dataset - If i is an int, a single example will be returned. - If i is a slice or iterable, a copy of this dataset containing - only the indexed examples will be returned. - """ - - return self.get(i) - - def get(self, i, deep_copy=False): + Dataset + A new Dataset instance with sorted Examples. """ - Returns an example or a new dataset containing the indexed examples. - If indexed with an int, only the example at that position - will be returned. - If Indexed with a slice or iterable, all examples indexed by the object - will be collected and a new dataset containing only those examples will be - returned. The new dataset will contain copies of the old dataset's fields - and will be identical to the original dataset, with the exception of the - example number and ordering. See wiki for detailed examples. - - Example:: + def index_key(i): + return key(self[i]) - # Indexing by a single integers returns a single example - example = dataset.get(1) + indices = list(range(len(self))) + indices.sort(key=index_key, reverse=reverse) - # Same as the first example, but returns a deep_copy of the example - example_copy = dataset.get(1, deep_copy=True) + if inplace: + self._examples = [self._examples[idx] for idx in indices] + return self - # Multi-indexing returns a new dataset containing the indexed examples - s = slice(1, 10) - new_dataset = dataset.get(s) + return super().sort(key, reverse) - new_dataset_copy = dataset.get(s, deep_copy=True) + def filter( + self, predicate: Callable[[Example], bool], inplace: bool = False + ) -> "Dataset": + """ + Filters examples with given predicate and returns a new Dataset instance + containing those examples. If inplace is True, the dataset is modified + in-place and returned. Parameters ---------- - i : int or slice or iterable - Index used to index examples. - - deep_copy: bool - If true, the returned dataset will contain deep-copies of this - dataset's examples and fields. - If false, existing examples and fields will be reused. + predicate : callable + Predicate should be a callable that accepts example as input and returns + true if the example shouldn't be filtered, otherwise returns false + inplace : bool + If True, the dataset is filtered in-place and returned. Returns ------- - single example or Dataset - If i is an int, a single example will be returned. - If i is a slice or iterable, a dataset containing - only the indexed examples will be returned. + Dataset + A new or the original Dataset instance + containing only the Examples for which `predicate` returned True. """ + if inplace: + self._examples = [example for example in self if predicate(example)] + return self - if isinstance(i, slice): - return self._dataset_copy_with_examples(self.examples[i], deep_copy=deep_copy) - - elif isinstance(i, int): - example = self.examples[i] - return copy.deepcopy(example) if deep_copy else example + return super().filter(predicate) - else: - # Numpy style multi-indexing - indexed_examples = [self.examples[index] for index in i] - return self._dataset_copy_with_examples(indexed_examples, deep_copy=deep_copy) - - def __len__(self) -> int: + def shuffle( + self, + seed: Optional[int] = None, + generator: Optional[np.random.Generator] = None, + inplace: bool = False, + ) -> "Dataset": """ - Returns the number of examples in the dataset. + Creates a new Dataset instance containing all Examples, but in shuffled + order. If inplace is True, the dataset is modified in-place and + returned. + + Parameters + ---------- + seed : int, optional + A seed used to initialized the default NumPy random Generator. + Default: None. + generator: np.random.Generator, optional + NumPy random Generator to use to compute the permutation. + Default: None. + inplace : bool + If True, the dataset is shuffled in-place and returned. Returns ------- - int - The number of examples in the dataset. - """ - return len(self._examples) - - def _get_examples(self) -> List[Example]: - return self._examples - - def __iter__(self): + Dataset + A new or the original Dataset instance containing all Examples, but in shuffled + order. """ - Iterates over all examples in the dataset in order. + if inplace: + self._examples = [ + self._examples[idx] + for idx in _get_permutation(len(self), seed, generator) + ] + return self - Yields - ------ - example - Yields examples in the dataset. - """ - yield from self._examples + return super().shuffle(seed, generator) - def filter(self, predicate, inplace=False): + def copy(self, copy_fields: bool = False): """ - Method filters examples with given predicate. + Returns a Dataset instance with the copied examples. If `copy_fields` is + true, the dataset fields are copied as well. Parameters ---------- - predicate : callable - predicate should be a callable that accepts example as input and returns - true if the example shouldn't be filtered, otherwise returns false - inplace : bool, default False - if True, do operation inplace and return None - """ - filtered_examples = [ex for ex in self.examples if predicate(ex)] + copy_fields : bool + If True, the dataset fields are copied as well. - if inplace: - self._examples = filtered_examples - return - else: - return Dataset( - examples=filtered_examples, fields=self.fields, sort_key=self.sort_key - ) - - def filtered(self, predicate: Callable[[Example], bool]) -> "DatasetBase": - return self.filter(predicate, inplace=False) + Returns + ------- + Dataset + A copied Dataset. + """ + return Dataset( + copy.deepcopy(self._examples), + copy.deepcopy(self._fields) if copy_fields else self._fields, + sort_key=self._sort_key, + ) def split( self, @@ -570,7 +617,7 @@ def split( ) splits = tuple( - Dataset(example_list, self.fields, sort_key=self.sort_key) + Dataset(example_list, self.fields, sort_key=self._sort_key) for example_list in (train_data, val_data, test_data) if example_list ) @@ -625,49 +672,6 @@ def __setstate__(self, state): """ self.__dict__ = state - def _dataset_copy_with_examples( - self, examples: list, deep_copy: bool = False - ) -> "Dataset": - """ - Creates a new dataset with the same fields and sort_key. The new dataset - contains only the fields passed to this function.Fields are deep-copied - into the new dataset, but examples are used as-is. - - Parameters - ---------- - examples - examples to be kept in the copy of the dataset. - - deep_copy - Whether to deep-copy the examples nad fields of this dataset. - if False, existing fields and examples will be reused. - - Returns - ------- - Dataset - a copy of this dataset containing only the passed examples. - """ - # Deep-copy if needed - examples = copy.deepcopy(examples) if deep_copy else examples - fields = copy.deepcopy(self.fields) if deep_copy else self.fields - - return Dataset(examples, fields, self.sort_key) - - def shuffle_examples(self, random_state=None): - """ - Shuffles the examples in this dataset. - - Parameters - ---------- - random_state : int - The random seed used for shuffling. - """ - - if random_state is not None: - random.seed(random_state) - - random.shuffle(self.examples) - @staticmethod def from_dataset(dataset: DatasetBase) -> "Dataset": """ diff --git a/podium/datasets/iterator.py b/podium/datasets/iterator.py index d34446aa..eaf225f1 100644 --- a/podium/datasets/iterator.py +++ b/podium/datasets/iterator.py @@ -252,7 +252,7 @@ def __iter__(self) -> PythonIterator[Tuple[NamedTuple, NamedTuple]]: data = self._dataset[indices] if self._sort_key is not None: - data = data.sorted(key=self._sort_key) + data = data.sort(key=self._sort_key) for i in range(0, len(data), self._batch_size): batch_dataset = data[i : i + self._batch_size] @@ -426,7 +426,7 @@ def __init__(self, dataset: DatasetBase = None, shuffle=True): If sort_key is not None, this flag being True may not have any effect since the dataset will always be sorted after being shuffled (the only difference shuffling can make is in the - order of elements with the same value of sort_key).. + order of elements with the same value of sort_key). Default is False. """ super().__init__(dataset=dataset, batch_size=len(dataset), shuffle=shuffle) @@ -504,12 +504,12 @@ def __iter__(self) -> PythonIterator[Tuple[NamedTuple, NamedTuple]]: step = self._batch_size * self.look_ahead_multiplier dataset = self._dataset if self._sort_key is not None: - dataset = dataset.sorted(key=self._sort_key) + dataset = dataset.sort(key=self._sort_key) for i in range(0, len(dataset), step): bucket = dataset[i : i + step] if self.bucket_sort_key is not None: - bucket = bucket.sorted(key=self.bucket_sort_key) + bucket = bucket.sort(key=self.bucket_sort_key) for j in range(0, len(bucket), self._batch_size): batch_dataset = bucket[j : j + self._batch_size] diff --git a/tests/datasets/test_pyarrow_tabular_dataset.py b/tests/datasets/test_arrow_tabular_dataset.py similarity index 96% rename from tests/datasets/test_pyarrow_tabular_dataset.py rename to tests/datasets/test_arrow_tabular_dataset.py index 50f7efa6..13e7eb68 100644 --- a/tests/datasets/test_pyarrow_tabular_dataset.py +++ b/tests/datasets/test_arrow_tabular_dataset.py @@ -169,11 +169,11 @@ def test_finalize_fields(data, fields, mocker): dataset.delete_cache() -def test_filtered(data, pyarrow_dataset): +def test_filter(data, pyarrow_dataset): def filter_even(ex): return ex["number"][0] % 2 == 0 - filtered_dataset = pyarrow_dataset.filtered(filter_even) + filtered_dataset = pyarrow_dataset.filter(filter_even) filtered_data = [d[0] for d in data if d[0] % 2 == 0] for (raw, _), d in zip(filtered_dataset.number, filtered_data): @@ -276,19 +276,19 @@ def test_delete_cache(data, fields): assert not os.path.exists(cache_dir) -def test_sorted(data, pyarrow_dataset): +def test_sort(data, pyarrow_dataset): indices = [1, 5, 2, 7, 3] data_slice = [data[i] for i in indices] dataset_slice = pyarrow_dataset[indices] sorted_data = sorted(data_slice, key=lambda x: x[0], reverse=False) - sorted_dataset = dataset_slice.sorted(key=lambda ex: ex["number"][0], reverse=False) + sorted_dataset = dataset_slice.sort(key=lambda ex: ex["number"][0], reverse=False) for d, ex in zip(sorted_data, sorted_dataset): assert d[0] == ex["number"][0] reverse_sorted_data = sorted(data_slice, key=lambda x: x[0], reverse=True) - reverse_sorted_dataset = dataset_slice.sorted( + reverse_sorted_dataset = dataset_slice.sort( key=lambda ex: ex["number"][0], reverse=True ) for d, ex in zip(reverse_sorted_data, reverse_sorted_dataset): diff --git a/tests/datasets/test_dataset.py b/tests/datasets/test_dataset.py index 9e2cb384..0c3fd435 100644 --- a/tests/datasets/test_dataset.py +++ b/tests/datasets/test_dataset.py @@ -418,14 +418,14 @@ def test_tabular_dataset_preserve_sort_key( dataset = create_tabular_dataset( tabular_dataset_fields, file_format, file_path, use_dict ) - dataset.sort_key = sort_key_str + dataset._sort_key = sort_key_str dataset.finalize_fields() d_train, d_test = dataset.split(split_ratio=0.5, shuffle=False) # the sort key should be passed from the original dataset - assert d_train.sort_key == sort_key_str - assert d_test.sort_key == sort_key_str + assert d_train._sort_key == sort_key_str + assert d_test._sort_key == sort_key_str @pytest.mark.parametrize("file_format, use_dict", FORMAT_USE_DICT_COMBINATIONS) @@ -445,7 +445,7 @@ def test_tabular_dataset_pickle_sort_key( with open(dataset_file, "rb") as fdata: loaded_dataset = dill.load(fdata) - assert loaded_dataset.sort_key == sort_key_str + assert loaded_dataset._sort_key == sort_key_str @pytest.mark.parametrize("file_format, use_dict", FORMAT_USE_DICT_COMBINATIONS) @@ -615,37 +615,20 @@ def test_indexing(indexes): test_indexing(list(range(1, 10, 3))) -def test_dataset_deep_copy(data, field_list): - original_dataset = create_dataset(data, field_list) - original_examples = original_dataset.examples - - dataset_no_deep_copy = original_dataset.get(slice(0, 5), deep_copy=False) - for original, copy in zip(original_dataset.fields, dataset_no_deep_copy.fields): - assert copy is original - for original, copy in zip(original_examples, dataset_no_deep_copy.examples): - assert copy is original - - dataset_deep_copy = original_dataset.get(slice(0, 5), deep_copy=True) +def test_dataset_copy(data, field_list): + dataset = create_dataset(data, field_list) + dataset_copy = dataset.copy() - assert original_dataset.fields is not dataset_deep_copy.fields - for original, copy in zip(original_dataset.fields, dataset_deep_copy.fields): - assert copy is not original + for ex, ex_copy in zip(dataset.examples, dataset_copy.examples): + assert ex is not ex_copy - for original, copy in zip(original_examples, dataset_deep_copy.examples): - assert copy is not original - assert copy["text"] == original["text"] - assert copy["label"] == original["label"] + for field, field_copy in zip(dataset.fields, dataset_copy.fields): + assert field is field_copy - original_example = original_examples[2] - no_copy_example = original_dataset.get(2, deep_copy=False) - indexed_example = original_dataset[2] - deep_copied_example = original_dataset.get(2, deep_copy=True) + dataset_copy_with_fields = dataset.copy(copy_fields=True) - assert no_copy_example is original_example - assert indexed_example is original_example - assert deep_copied_example is not original_example - assert deep_copied_example["text"] == original_example["text"] - assert deep_copied_example["label"] == original_example["label"] + for field, field_copy in zip(dataset.fields, dataset_copy_with_fields.fields): + assert field is not field_copy def test_dataset_multiindexing_pickling(data, field_list):