Skip to content

Commit 704af93

Browse files
committed
update test for cli change
1 parent 330a8c2 commit 704af93

File tree

3 files changed

+17
-1
lines changed

3 files changed

+17
-1
lines changed
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
label_1
2+
label_2
3+
label_3
4+
label_4
5+
label_5
6+
label_6
7+
label_7
8+
label_8
9+
label_9
10+
label_10

tests/unit/cli/mock_dm.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import os
2+
13
import torch
24
from lightning.pytorch.core.datamodule import LightningDataModule
35
from torch.utils.data import DataLoader
@@ -29,6 +31,10 @@ def num_of_labels(self):
2931
def feature_vector_size(self):
3032
return self._feature_vector_size
3133

34+
@property
35+
def classes_txt_file_path(self) -> str:
36+
return os.path.join("tests", "unit", "cli", "classification_labels.txt")
37+
3238
def train_dataloader(self):
3339
assert self.feature_vector_size is not None, "feature_vector_size must be set"
3440
# Dummy dataset for example purposes

tests/unit/cli/testCLI.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ def setUp(self):
99
"fit",
1010
"--trainer=configs/training/default_trainer.yml",
1111
"--model=configs/model/ffn.yml",
12-
"--model.init_args.hidden_layers=[10]",
12+
"--model.init_args.hidden_layers=[1]",
1313
"--model.train_metrics=configs/metrics/micro-macro-f1.yml",
1414
"--data=tests/unit/cli/mock_dm_config.yml",
1515
"--model.pass_loss_kwargs=false",

0 commit comments

Comments
 (0)