Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/main/python/systemds/scuro/drsearch/operator_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ def get_not_self_contained_representations(self, modality: ModalityType):
return reps

def get_context_operators(self, modality_type):
if modality_type not in self._context_operators.keys():
return []
return self._context_operators[modality_type]

def get_dimensionality_reduction_operators(self, modality_type):
Expand Down
107 changes: 63 additions & 44 deletions src/main/python/systemds/scuro/drsearch/unimodal_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,8 @@ def _build_modality_dag(
operator.__class__, [leaf_id], operator.get_current_parameters()
)
current_node_id = rep_node_id
dags.append(builder.build(current_node_id))
rep_dag = builder.build(current_node_id)
dags.append(rep_dag)

dimensionality_reduction_dags = self.add_dimensionality_reduction_operators(
builder, current_node_id
Expand Down Expand Up @@ -387,11 +388,6 @@ def _build_modality_dag(
[context_node_id],
operator.get_current_parameters(),
)
dimensionality_reduction_dags = self.add_dimensionality_reduction_operators(
builder, context_rep_node_id
) # TODO: check if this is correctly using the 3d approach of the dimensionality reduction operator
if dimensionality_reduction_dags is not None:
dags.extend(dimensionality_reduction_dags)

agg_operator = AggregatedRepresentation()
context_agg_node_id = builder.create_operation_node(
Expand All @@ -409,64 +405,88 @@ def _build_modality_dag(
not_self_contained_reps = [
rep for rep in not_self_contained_reps if rep != operator.__class__
]
rep_id = current_node_id

for combination in self._combination_operators:
current_node_id = rep_node_id
for other_rep in not_self_contained_reps:
other_rep_id = builder.create_operation_node(
other_rep, [leaf_id], other_rep().parameters
)

for rep in not_self_contained_reps:
other_rep_id = builder.create_operation_node(
rep, [leaf_id], rep().parameters
)
for combination in self._combination_operators:
combine_id = builder.create_operation_node(
combination.__class__,
[current_node_id, other_rep_id],
[rep_id, other_rep_id],
combination.get_current_parameters(),
)
dags.append(builder.build(combine_id))
current_node_id = combine_id
if modality.modality_type in [
ModalityType.EMBEDDING,
ModalityType.IMAGE,
ModalityType.AUDIO,
]:
dags.extend(
self.default_context_operators(
modality, builder, leaf_id, current_node_id
rep_dag = builder.build(combine_id)
dags.append(rep_dag)
if modality.modality_type in [
ModalityType.EMBEDDING,
ModalityType.IMAGE,
ModalityType.AUDIO,
]:
dags.extend(
self.default_context_operators(
modality, builder, leaf_id, rep_dag, False
)
)
)
elif modality.modality_type == ModalityType.TIMESERIES:
dags.extend(
self.temporal_context_operators(
modality, builder, leaf_id, current_node_id
elif modality.modality_type == ModalityType.TIMESERIES:
dags.extend(
self.temporal_context_operators(
modality,
builder,
leaf_id,
)
)
)
rep_id = combine_id

if rep_dag.nodes[-1].operation().output_modality_type in [
ModalityType.EMBEDDING
]:
dags.extend(
self.default_context_operators(
modality, builder, leaf_id, rep_dag, True
)
)

if (
modality.modality_type == ModalityType.TIMESERIES
or modality.modality_type == ModalityType.AUDIO
):
dags.extend(self.temporal_context_operators(modality, builder, leaf_id))
return dags

def default_context_operators(self, modality, builder, leaf_id, current_node_id):
def default_context_operators(
self, modality, builder, leaf_id, rep_dag, apply_context_to_leaf=False
):
dags = []
context_operators = self._get_context_operators(modality.modality_type)
for context_op in context_operators:
if apply_context_to_leaf:
if (
modality.modality_type != ModalityType.TEXT
and modality.modality_type != ModalityType.VIDEO
):
context_node_id = builder.create_operation_node(
context_op,
[leaf_id],
context_op().get_current_parameters(),
)
dags.append(builder.build(context_node_id))
context_operators = self._get_context_operators(modality.modality_type)
for context_op in context_operators:
context_node_id = builder.create_operation_node(
context_op,
[leaf_id],
context_op().get_current_parameters(),
)
dags.append(builder.build(context_node_id))

context_operators = self._get_context_operators(
rep_dag.nodes[-1].operation().output_modality_type
)
for context_op in context_operators:
context_node_id = builder.create_operation_node(
context_op,
[current_node_id],
[rep_dag.nodes[-1].node_id],
context_op().get_current_parameters(),
)
dags.append(builder.build(context_node_id))

return dags

def temporal_context_operators(self, modality, builder, leaf_id, current_node_id):
def temporal_context_operators(self, modality, builder, leaf_id):
aggregators = self.operator_registry.get_representations(modality.modality_type)
context_operators = self._get_context_operators(modality.modality_type)

Expand Down Expand Up @@ -561,12 +581,11 @@ def get_k_best_results(

results = results[: self.k]
sorted_indices = sorted_indices[: self.k]

task_cache = self.cache.get(modality.modality_id, {}).get(task.model.name, None)
if not task_cache:
cache = [
list(task_results[i].dag.execute([modality]).values())[-1]
for i in sorted_indices
list(results[i].dag.execute([modality]).values())[-1]
for i in range(len(results))
]
elif isinstance(task_cache, list):
cache = task_cache
Expand Down
9 changes: 8 additions & 1 deletion src/main/python/systemds/scuro/modality/transformed.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,12 @@
class TransformedModality(Modality):

def __init__(
self, modality, transformation, new_modality_type=None, self_contained=True
self,
modality,
transformation,
new_modality_type=None,
self_contained=True,
set_data=False,
):
"""
Parent class of the different Modalities (unimodal & multimodal)
Expand All @@ -49,6 +54,8 @@ def __init__(
modality.data_type,
modality.transform_time,
)
if set_data:
self.data = modality.data
self.transformation = None
self.self_contained = (
self_contained and transformation.self_contained
Expand Down
19 changes: 19 additions & 0 deletions src/main/python/systemds/scuro/modality/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,25 @@ class ModalityType(Flag):
def get_schema(self):
return ModalitySchemas.get(self.name)

def has_field(self, md, field):
for value in md.values():
if field in value:
return True
else:
return False
return False

def get_field_for_instances(self, md, field):
data = []
for items in md.values():
data.append(self.get_field(items, field))
return data

def get_field(self, md, field):
if field in md:
return md[field]
return None

def update_metadata(self, md, data):
return ModalitySchemas.update_metadata(self.name, md, data)

Expand Down
24 changes: 19 additions & 5 deletions src/main/python/systemds/scuro/modality/unimodal_modality.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,14 @@ def context(self, context_operator):
if not self.has_data():
self.extract_raw_data()

transformed_modality = TransformedModality(self, context_operator)

transformed_modality.data = context_operator.execute(self)
transformed_modality = TransformedModality(
self, context_operator, set_data=True
)
d = context_operator.execute(transformed_modality)
if d is not None:
transformed_modality.data = d
else:
transformed_modality.data = self.data
transformed_modality.transform_time += time.time() - start
return transformed_modality

Expand Down Expand Up @@ -212,14 +217,23 @@ def _apply_padding(self, modality, original_lengths, pad_dim_one):
mode="constant",
constant_values=0,
)
else:
elif len(embeddings.shape) == 2:
padded = np.pad(
embeddings,
((0, padding_needed), (0, 0)),
mode="constant",
constant_values=0,
)
padded_embeddings.append(padded)
elif len(embeddings.shape) == 3:
padded = np.pad(
embeddings,
((0, padding_needed), (0, 0), (0, 0)),
mode="constant",
constant_values=0,
)
padded_embeddings.append(padded)
else:
raise ValueError(f"Unsupported shape: {embeddings.shape}")
else:
padded_embeddings.append(embeddings)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def transform(self, modality):
aggregated_modality = TransformedModality(
modality, self, self_contained=modality.self_contained
)
aggregated_modality.data = self.aggregation.execute(modality)
end = time.perf_counter()
aggregated_modality.transform_time += end - start
aggregated_modality.data = self.aggregation.execute(modality)
return aggregated_modality
35 changes: 7 additions & 28 deletions src/main/python/systemds/scuro/representations/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,35 +28,12 @@
from systemds.scuro.drsearch.operator_registry import register_representation
from systemds.scuro.utils.static_variables import get_device
import os
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import DataLoader
from systemds.scuro.utils.torch_dataset import TextDataset, TextSpanDataset

os.environ["TOKENIZERS_PARALLELISM"] = "false"


class TextDataset(Dataset):
def __init__(self, texts):

self.texts = []
if isinstance(texts, list):
self.texts = texts
else:
for text in texts:
if text is None:
self.texts.append("")
elif isinstance(text, np.ndarray):
self.texts.append(str(text.item()) if text.size == 1 else str(text))
elif not isinstance(text, str):
self.texts.append(str(text))
else:
self.texts.append(text)

def __len__(self):
return len(self.texts)

def __getitem__(self, idx):
return self.texts[idx]


class BertFamily(UnimodalRepresentation):
def __init__(
self,
Expand Down Expand Up @@ -96,10 +73,12 @@ def hook(model, input, output):
layer.register_forward_hook(get_activation(name))
break

if isinstance(modality.data[0], list):
if ModalityType.TEXT.has_field(modality.metadata, "text_spans"):
dataset = TextSpanDataset(modality.data, modality.metadata)
embeddings = []
for d in modality.data:
embeddings.append(self.create_embeddings(d, self.model, tokenizer))
for text in dataset:
embedding = self.create_embeddings(text, self.model, tokenizer)
embeddings.append(embedding)
else:
embeddings = self.create_embeddings(modality.data, self.model, tokenizer)

Expand Down
26 changes: 1 addition & 25 deletions src/main/python/systemds/scuro/representations/elmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,34 +29,10 @@
from systemds.scuro.utils.static_variables import get_device
from flair.embeddings import ELMoEmbeddings
from flair.data import Sentence
from torch.utils.data import Dataset
from systemds.scuro.utils.torch_dataset import TextDataset
from torch.utils.data import DataLoader


class TextDataset(Dataset):
def __init__(self, texts):

self.texts = []
if isinstance(texts, list):
self.texts = texts
else:
for text in texts:
if text is None:
self.texts.append("")
elif isinstance(text, np.ndarray):
self.texts.append(str(text.item()) if text.size == 1 else str(text))
elif not isinstance(text, str):
self.texts.append(str(text))
else:
self.texts.append(text)

def __len__(self):
return len(self.texts)

def __getitem__(self, idx):
return self.texts[idx]


# @register_representation([ModalityType.TEXT])
class ELMoRepresentation(UnimodalRepresentation):
def __init__(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def _extract_text(instance: Any) -> str:
return text


@register_context_operator(ModalityType.TEXT)
# @register_context_operator(ModalityType.TEXT)
class SentenceBoundarySplit(Context):
"""
Splits text at sentence boundaries while respecting maximum word count.
Expand Down Expand Up @@ -154,7 +154,7 @@ def execute(self, modality):
return chunked_data


@register_context_operator(ModalityType.TEXT)
# @register_context_operator(ModalityType.TEXT)
class OverlappingSplit(Context):
"""
Splits text with overlapping chunks using a sliding window approach.
Expand Down
Loading
Loading