diff --git a/src/main/python/systemds/scuro/drsearch/operator_registry.py b/src/main/python/systemds/scuro/drsearch/operator_registry.py index bf9547ddbf6..e9c302ba901 100644 --- a/src/main/python/systemds/scuro/drsearch/operator_registry.py +++ b/src/main/python/systemds/scuro/drsearch/operator_registry.py @@ -97,6 +97,8 @@ def get_not_self_contained_representations(self, modality: ModalityType): return reps def get_context_operators(self, modality_type): + if modality_type not in self._context_operators.keys(): + return [] return self._context_operators[modality_type] def get_dimensionality_reduction_operators(self, modality_type): diff --git a/src/main/python/systemds/scuro/drsearch/unimodal_optimizer.py b/src/main/python/systemds/scuro/drsearch/unimodal_optimizer.py index c555c2b677d..5b03147ec11 100644 --- a/src/main/python/systemds/scuro/drsearch/unimodal_optimizer.py +++ b/src/main/python/systemds/scuro/drsearch/unimodal_optimizer.py @@ -356,7 +356,8 @@ def _build_modality_dag( operator.__class__, [leaf_id], operator.get_current_parameters() ) current_node_id = rep_node_id - dags.append(builder.build(current_node_id)) + rep_dag = builder.build(current_node_id) + dags.append(rep_dag) dimensionality_reduction_dags = self.add_dimensionality_reduction_operators( builder, current_node_id @@ -387,11 +388,6 @@ def _build_modality_dag( [context_node_id], operator.get_current_parameters(), ) - dimensionality_reduction_dags = self.add_dimensionality_reduction_operators( - builder, context_rep_node_id - ) # TODO: check if this is correctly using the 3d approach of the dimensionality reduction operator - if dimensionality_reduction_dags is not None: - dags.extend(dimensionality_reduction_dags) agg_operator = AggregatedRepresentation() context_agg_node_id = builder.create_operation_node( @@ -409,64 +405,88 @@ def _build_modality_dag( not_self_contained_reps = [ rep for rep in not_self_contained_reps if rep != operator.__class__ ] + rep_id = current_node_id - for combination in self._combination_operators: - current_node_id = rep_node_id - for other_rep in not_self_contained_reps: - other_rep_id = builder.create_operation_node( - other_rep, [leaf_id], other_rep().parameters - ) - + for rep in not_self_contained_reps: + other_rep_id = builder.create_operation_node( + rep, [leaf_id], rep().parameters + ) + for combination in self._combination_operators: combine_id = builder.create_operation_node( combination.__class__, - [current_node_id, other_rep_id], + [rep_id, other_rep_id], combination.get_current_parameters(), ) - dags.append(builder.build(combine_id)) - current_node_id = combine_id - if modality.modality_type in [ - ModalityType.EMBEDDING, - ModalityType.IMAGE, - ModalityType.AUDIO, - ]: - dags.extend( - self.default_context_operators( - modality, builder, leaf_id, current_node_id + rep_dag = builder.build(combine_id) + dags.append(rep_dag) + if modality.modality_type in [ + ModalityType.EMBEDDING, + ModalityType.IMAGE, + ModalityType.AUDIO, + ]: + dags.extend( + self.default_context_operators( + modality, builder, leaf_id, rep_dag, False + ) ) - ) - elif modality.modality_type == ModalityType.TIMESERIES: - dags.extend( - self.temporal_context_operators( - modality, builder, leaf_id, current_node_id + elif modality.modality_type == ModalityType.TIMESERIES: + dags.extend( + self.temporal_context_operators( + modality, + builder, + leaf_id, + ) ) - ) + rep_id = combine_id + + if rep_dag.nodes[-1].operation().output_modality_type in [ + ModalityType.EMBEDDING + ]: + dags.extend( + self.default_context_operators( + modality, builder, leaf_id, rep_dag, True + ) + ) + + if ( + modality.modality_type == ModalityType.TIMESERIES + or modality.modality_type == ModalityType.AUDIO + ): + dags.extend(self.temporal_context_operators(modality, builder, leaf_id)) return dags - def default_context_operators(self, modality, builder, leaf_id, current_node_id): + def default_context_operators( + self, modality, builder, leaf_id, rep_dag, apply_context_to_leaf=False + ): dags = [] - context_operators = self._get_context_operators(modality.modality_type) - for context_op in context_operators: + if apply_context_to_leaf: if ( modality.modality_type != ModalityType.TEXT and modality.modality_type != ModalityType.VIDEO ): - context_node_id = builder.create_operation_node( - context_op, - [leaf_id], - context_op().get_current_parameters(), - ) - dags.append(builder.build(context_node_id)) + context_operators = self._get_context_operators(modality.modality_type) + for context_op in context_operators: + context_node_id = builder.create_operation_node( + context_op, + [leaf_id], + context_op().get_current_parameters(), + ) + dags.append(builder.build(context_node_id)) + context_operators = self._get_context_operators( + rep_dag.nodes[-1].operation().output_modality_type + ) + for context_op in context_operators: context_node_id = builder.create_operation_node( context_op, - [current_node_id], + [rep_dag.nodes[-1].node_id], context_op().get_current_parameters(), ) dags.append(builder.build(context_node_id)) return dags - def temporal_context_operators(self, modality, builder, leaf_id, current_node_id): + def temporal_context_operators(self, modality, builder, leaf_id): aggregators = self.operator_registry.get_representations(modality.modality_type) context_operators = self._get_context_operators(modality.modality_type) @@ -561,12 +581,11 @@ def get_k_best_results( results = results[: self.k] sorted_indices = sorted_indices[: self.k] - task_cache = self.cache.get(modality.modality_id, {}).get(task.model.name, None) if not task_cache: cache = [ - list(task_results[i].dag.execute([modality]).values())[-1] - for i in sorted_indices + list(results[i].dag.execute([modality]).values())[-1] + for i in range(len(results)) ] elif isinstance(task_cache, list): cache = task_cache diff --git a/src/main/python/systemds/scuro/modality/transformed.py b/src/main/python/systemds/scuro/modality/transformed.py index 078b65f0bc3..a443f5a313a 100644 --- a/src/main/python/systemds/scuro/modality/transformed.py +++ b/src/main/python/systemds/scuro/modality/transformed.py @@ -31,7 +31,12 @@ class TransformedModality(Modality): def __init__( - self, modality, transformation, new_modality_type=None, self_contained=True + self, + modality, + transformation, + new_modality_type=None, + self_contained=True, + set_data=False, ): """ Parent class of the different Modalities (unimodal & multimodal) @@ -49,6 +54,8 @@ def __init__( modality.data_type, modality.transform_time, ) + if set_data: + self.data = modality.data self.transformation = None self.self_contained = ( self_contained and transformation.self_contained diff --git a/src/main/python/systemds/scuro/modality/type.py b/src/main/python/systemds/scuro/modality/type.py index 23d97e869b0..85f4d04e9ba 100644 --- a/src/main/python/systemds/scuro/modality/type.py +++ b/src/main/python/systemds/scuro/modality/type.py @@ -210,6 +210,25 @@ class ModalityType(Flag): def get_schema(self): return ModalitySchemas.get(self.name) + def has_field(self, md, field): + for value in md.values(): + if field in value: + return True + else: + return False + return False + + def get_field_for_instances(self, md, field): + data = [] + for items in md.values(): + data.append(self.get_field(items, field)) + return data + + def get_field(self, md, field): + if field in md: + return md[field] + return None + def update_metadata(self, md, data): return ModalitySchemas.update_metadata(self.name, md, data) diff --git a/src/main/python/systemds/scuro/modality/unimodal_modality.py b/src/main/python/systemds/scuro/modality/unimodal_modality.py index 4efaa7d7333..89d95810e01 100644 --- a/src/main/python/systemds/scuro/modality/unimodal_modality.py +++ b/src/main/python/systemds/scuro/modality/unimodal_modality.py @@ -91,9 +91,14 @@ def context(self, context_operator): if not self.has_data(): self.extract_raw_data() - transformed_modality = TransformedModality(self, context_operator) - - transformed_modality.data = context_operator.execute(self) + transformed_modality = TransformedModality( + self, context_operator, set_data=True + ) + d = context_operator.execute(transformed_modality) + if d is not None: + transformed_modality.data = d + else: + transformed_modality.data = self.data transformed_modality.transform_time += time.time() - start return transformed_modality @@ -212,14 +217,23 @@ def _apply_padding(self, modality, original_lengths, pad_dim_one): mode="constant", constant_values=0, ) - else: + elif len(embeddings.shape) == 2: padded = np.pad( embeddings, ((0, padding_needed), (0, 0)), mode="constant", constant_values=0, ) - padded_embeddings.append(padded) + elif len(embeddings.shape) == 3: + padded = np.pad( + embeddings, + ((0, padding_needed), (0, 0), (0, 0)), + mode="constant", + constant_values=0, + ) + padded_embeddings.append(padded) + else: + raise ValueError(f"Unsupported shape: {embeddings.shape}") else: padded_embeddings.append(embeddings) diff --git a/src/main/python/systemds/scuro/representations/aggregated_representation.py b/src/main/python/systemds/scuro/representations/aggregated_representation.py index bcc36f46210..cad1a4a448b 100644 --- a/src/main/python/systemds/scuro/representations/aggregated_representation.py +++ b/src/main/python/systemds/scuro/representations/aggregated_representation.py @@ -38,7 +38,7 @@ def transform(self, modality): aggregated_modality = TransformedModality( modality, self, self_contained=modality.self_contained ) + aggregated_modality.data = self.aggregation.execute(modality) end = time.perf_counter() aggregated_modality.transform_time += end - start - aggregated_modality.data = self.aggregation.execute(modality) return aggregated_modality diff --git a/src/main/python/systemds/scuro/representations/bert.py b/src/main/python/systemds/scuro/representations/bert.py index be579c0dd6c..6f4d3705a14 100644 --- a/src/main/python/systemds/scuro/representations/bert.py +++ b/src/main/python/systemds/scuro/representations/bert.py @@ -28,35 +28,12 @@ from systemds.scuro.drsearch.operator_registry import register_representation from systemds.scuro.utils.static_variables import get_device import os -from torch.utils.data import Dataset, DataLoader +from torch.utils.data import DataLoader +from systemds.scuro.utils.torch_dataset import TextDataset, TextSpanDataset os.environ["TOKENIZERS_PARALLELISM"] = "false" -class TextDataset(Dataset): - def __init__(self, texts): - - self.texts = [] - if isinstance(texts, list): - self.texts = texts - else: - for text in texts: - if text is None: - self.texts.append("") - elif isinstance(text, np.ndarray): - self.texts.append(str(text.item()) if text.size == 1 else str(text)) - elif not isinstance(text, str): - self.texts.append(str(text)) - else: - self.texts.append(text) - - def __len__(self): - return len(self.texts) - - def __getitem__(self, idx): - return self.texts[idx] - - class BertFamily(UnimodalRepresentation): def __init__( self, @@ -96,10 +73,12 @@ def hook(model, input, output): layer.register_forward_hook(get_activation(name)) break - if isinstance(modality.data[0], list): + if ModalityType.TEXT.has_field(modality.metadata, "text_spans"): + dataset = TextSpanDataset(modality.data, modality.metadata) embeddings = [] - for d in modality.data: - embeddings.append(self.create_embeddings(d, self.model, tokenizer)) + for text in dataset: + embedding = self.create_embeddings(text, self.model, tokenizer) + embeddings.append(embedding) else: embeddings = self.create_embeddings(modality.data, self.model, tokenizer) diff --git a/src/main/python/systemds/scuro/representations/elmo.py b/src/main/python/systemds/scuro/representations/elmo.py index ba2a99f8e1d..33e4f741414 100644 --- a/src/main/python/systemds/scuro/representations/elmo.py +++ b/src/main/python/systemds/scuro/representations/elmo.py @@ -29,34 +29,10 @@ from systemds.scuro.utils.static_variables import get_device from flair.embeddings import ELMoEmbeddings from flair.data import Sentence -from torch.utils.data import Dataset +from systemds.scuro.utils.torch_dataset import TextDataset from torch.utils.data import DataLoader -class TextDataset(Dataset): - def __init__(self, texts): - - self.texts = [] - if isinstance(texts, list): - self.texts = texts - else: - for text in texts: - if text is None: - self.texts.append("") - elif isinstance(text, np.ndarray): - self.texts.append(str(text.item()) if text.size == 1 else str(text)) - elif not isinstance(text, str): - self.texts.append(str(text)) - else: - self.texts.append(text) - - def __len__(self): - return len(self.texts) - - def __getitem__(self, idx): - return self.texts[idx] - - # @register_representation([ModalityType.TEXT]) class ELMoRepresentation(UnimodalRepresentation): def __init__( diff --git a/src/main/python/systemds/scuro/representations/text_context.py b/src/main/python/systemds/scuro/representations/text_context.py index b98b90e187f..b4f82bda19c 100644 --- a/src/main/python/systemds/scuro/representations/text_context.py +++ b/src/main/python/systemds/scuro/representations/text_context.py @@ -72,7 +72,7 @@ def _extract_text(instance: Any) -> str: return text -@register_context_operator(ModalityType.TEXT) +# @register_context_operator(ModalityType.TEXT) class SentenceBoundarySplit(Context): """ Splits text at sentence boundaries while respecting maximum word count. @@ -154,7 +154,7 @@ def execute(self, modality): return chunked_data -@register_context_operator(ModalityType.TEXT) +# @register_context_operator(ModalityType.TEXT) class OverlappingSplit(Context): """ Splits text with overlapping chunks using a sliding window approach. diff --git a/src/main/python/systemds/scuro/representations/text_context_with_indices.py b/src/main/python/systemds/scuro/representations/text_context_with_indices.py index 7daf93855f3..5a3c3b34e0d 100644 --- a/src/main/python/systemds/scuro/representations/text_context_with_indices.py +++ b/src/main/python/systemds/scuro/representations/text_context_with_indices.py @@ -134,7 +134,7 @@ def execute(self, modality): return chunked_data -# @register_context_operator(ModalityType.TEXT) +@register_context_operator(ModalityType.TEXT) class SentenceBoundarySplitIndices(Context): """ Splits text at sentence boundaries while respecting maximum word count. @@ -162,18 +162,17 @@ def execute(self, modality): Returns: List of lists, where each inner list contains text chunks (strings) """ - chunked_data = [] - for instance in modality.data: + for instance, metadata in zip(modality.data, modality.metadata.values()): text = _extract_text(instance) if not text: - chunked_data.append((0, 0)) + ModalityType.TEXT.add_field(metadata, "text_spans", [(0, 0)]) continue sentences = _split_into_sentences(text) if not sentences: - chunked_data.append((0, len(text))) + ModalityType.TEXT.add_field(metadata, "text_spans", [(0, len(text))]) continue chunks = [] @@ -225,12 +224,12 @@ def execute(self, modality): if not chunks: chunks = [(0, len(text))] - chunked_data.append(chunks) + ModalityType.TEXT.add_field(metadata, "text_spans", chunks) - return chunked_data + return None -# @register_context_operator(ModalityType.TEXT) +@register_context_operator(ModalityType.TEXT) class OverlappingSplitIndices(Context): """ Splits text with overlapping chunks using a sliding window approach. @@ -263,18 +262,17 @@ def execute(self, modality): Returns: List of tuples, where each tuple contains start and end index to the text chunks """ - chunked_data = [] - for instance in modality.data: + for instance, metadata in zip(modality.data, modality.metadata.values()): text = _extract_text(instance) if not text: - chunked_data.append((0, 0)) + ModalityType.TEXT.add_field(metadata, "text_spans", [(0, 0)]) continue words = _split_into_words(text) if len(words) <= self.max_words: - chunked_data.append((0, len(text))) + ModalityType.TEXT.add_field(metadata, "text_spans", [(0, len(text))]) continue chunks = [] @@ -295,6 +293,6 @@ def execute(self, modality): if not chunks: chunks = [(0, len(text))] - chunked_data.append(chunks) + ModalityType.TEXT.add_field(metadata, "text_spans", chunks) - return chunked_data + return None diff --git a/src/main/python/systemds/scuro/utils/torch_dataset.py b/src/main/python/systemds/scuro/utils/torch_dataset.py index 9c462e36753..ba3e24a3178 100644 --- a/src/main/python/systemds/scuro/utils/torch_dataset.py +++ b/src/main/python/systemds/scuro/utils/torch_dataset.py @@ -24,6 +24,8 @@ import torch import torchvision.transforms as transforms +from systemds.scuro.modality.type import ModalityType + class CustomDataset(torch.utils.data.Dataset): def __init__(self, data, data_type, device, size=None, tf=None): @@ -78,3 +80,43 @@ def __getitem__(self, index) -> Dict[str, object]: def __len__(self) -> int: return len(self.data) + + +class TextDataset(torch.utils.data.Dataset): + def __init__(self, texts): + + self.texts = [] + if isinstance(texts, list): + self.texts = texts + else: + for text in texts: + if text is None: + self.texts.append("") + elif isinstance(text, np.ndarray): + self.texts.append(str(text.item()) if text.size == 1 else str(text)) + elif not isinstance(text, str): + self.texts.append(str(text)) + else: + self.texts.append(text) + + def __len__(self): + return len(self.texts) + + def __getitem__(self, idx): + return self.texts[idx] + + +class TextSpanDataset(torch.utils.data.Dataset): + def __init__(self, full_texts, metadata): + self.full_texts = full_texts + self.spans_per_text = ModalityType.TEXT.get_field_for_instances( + metadata, "text_spans" + ) + + def __len__(self): + return len(self.full_texts) + + def __getitem__(self, idx): + text = self.full_texts[idx] + spans = self.spans_per_text[idx] + return [text[s:e] for (s, e) in spans] diff --git a/src/main/python/tests/scuro/test_operator_registry.py b/src/main/python/tests/scuro/test_operator_registry.py index 189e3e44d71..443cc039d6b 100644 --- a/src/main/python/tests/scuro/test_operator_registry.py +++ b/src/main/python/tests/scuro/test_operator_registry.py @@ -21,9 +21,9 @@ import unittest -from systemds.scuro.representations.text_context import ( - SentenceBoundarySplit, - OverlappingSplit, +from systemds.scuro.representations.text_context_with_indices import ( + SentenceBoundarySplitIndices, + OverlappingSplitIndices, ) from systemds.scuro.representations.covarep_audio_features import ( @@ -134,8 +134,8 @@ def test_context_operator_in_registry(self): DynamicWindow, ] assert registry.get_context_operators(ModalityType.TEXT) == [ - SentenceBoundarySplit, - OverlappingSplit, + SentenceBoundarySplitIndices, + OverlappingSplitIndices, ] # def test_fusion_operator_in_registry(self): diff --git a/src/main/python/tests/scuro/test_text_context_operators.py b/src/main/python/tests/scuro/test_text_context_operators.py index 1f041654076..ffa702b7c82 100644 --- a/src/main/python/tests/scuro/test_text_context_operators.py +++ b/src/main/python/tests/scuro/test_text_context_operators.py @@ -36,6 +36,7 @@ ) from systemds.scuro.modality.unimodal_modality import UnimodalModality from systemds.scuro.modality.type import ModalityType +from systemds.scuro.representations.bert import Bert class TestTextContextOperator(unittest.TestCase): @@ -80,33 +81,30 @@ def test_overlapping_split(self): def test_sentence_boundary_split_indices(self): sentence_boundary_split = SentenceBoundarySplitIndices(10, min_words=4) - chunks = sentence_boundary_split.execute(self.text_modality) - for i in range(0, len(chunks)): - for chunk in chunks[i]: - text = self.text_modality.data[i][chunk[0] : chunk[1]].split(" ") + sentence_boundary_split.execute(self.text_modality) + for instance, md in zip( + self.text_modality.data, self.text_modality.metadata.values() + ): + for chunk in md["text_spans"]: + text = instance[chunk[0] : chunk[1]].split(" ") assert len(text) <= 10 and ( text[-1][-1] == "." or text[-1][-1] == "!" or text[-1][-1] == "?" ) def test_overlapping_split_indices(self): overlapping_split = OverlappingSplitIndices(40, 0.1) - chunks = overlapping_split.execute(self.text_modality) - for i in range(len(chunks)): + overlapping_split.execute(self.text_modality) + for instance, md in zip( + self.text_modality.data, self.text_modality.metadata.values() + ): prev_chunk = (0, 0) - for j, chunk in enumerate(chunks[i]): + for j, chunk in enumerate(md["text_spans"]): if j > 0: - prev_words = self.text_modality.data[i][ - prev_chunk[0] : prev_chunk[1] - ].split(" ") - curr_words = self.text_modality.data[i][chunk[0] : chunk[1]].split( - " " - ) + prev_words = instance[prev_chunk[0] : prev_chunk[1]].split(" ") + curr_words = instance[chunk[0] : chunk[1]].split(" ") assert prev_words[-4:] == curr_words[:4] prev_chunk = chunk - assert ( - len(self.text_modality.data[i][chunk[0] : chunk[1]].split(" ")) - <= 40 - ) + assert len(instance[chunk[0] : chunk[1]].split(" ")) <= 40 if __name__ == "__main__": diff --git a/src/main/python/tests/scuro/test_unimodal_optimizer.py b/src/main/python/tests/scuro/test_unimodal_optimizer.py index 0d8ae901778..7fa606d835c 100644 --- a/src/main/python/tests/scuro/test_unimodal_optimizer.py +++ b/src/main/python/tests/scuro/test_unimodal_optimizer.py @@ -36,6 +36,7 @@ ) from systemds.scuro.representations.word2vec import W2V from systemds.scuro.representations.bow import BoW +from systemds.scuro.representations.bert import Bert from systemds.scuro.modality.unimodal_modality import UnimodalModality from systemds.scuro.representations.resnet import ResNet from tests.scuro.data_generator import ( @@ -124,7 +125,7 @@ def optimize_unimodal_representation_for_modality(self, modality): ): registry = Registry() - unimodal_optimizer = UnimodalOptimizer([modality], self.tasks, False) + unimodal_optimizer = UnimodalOptimizer([modality], self.tasks, False, k=1) unimodal_optimizer.optimize() assert ( @@ -133,7 +134,7 @@ def optimize_unimodal_representation_for_modality(self, modality): ) assert len(unimodal_optimizer.operator_performance.task_names) == 2 result, cached = unimodal_optimizer.operator_performance.get_k_best_results( - modality, 1, self.tasks[0], "accuracy" + modality, self.tasks[0], "accuracy" ) assert len(result) == 1 assert len(cached) == 1