-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
96 lines (79 loc) · 3.04 KB
/
main.py
File metadata and controls
96 lines (79 loc) · 3.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
from cartographer_callbacks import Cartographer
import transformers as tr
import datasets as ds
import numpy as np
from scipy.special import expit
import matplotlib.pyplot as plt
import seaborn as sns
def plot_map(cartographer: Cartographer):
# Plot
_, ax = plt.subplots(figsize=(9, 7))
sns.scatterplot(x=cartographer.variability, y=cartographer.confidence, hue=cartographer.correctness,
ax=ax)
sns.kdeplot(x=cartographer.variability, y=cartographer.confidence,
levels=8, color=sns.color_palette("Paired")[7], linewidths=1, ax=ax)
ax.set(title='Data map for QNLI train set\nbased on a DistilBERT classifier',
xlabel='Variability', ylabel='Confidence')
# Annotations
box_style = {'boxstyle': 'round', 'facecolor': 'white', 'ec': 'black'}
ax.text(0.14, 0.84,
'easy-to-learn',
transform=ax.transAxes,
verticalalignment='top',
bbox=box_style)
ax.text(0.75, 0.5,
'ambiguous',
transform=ax.transAxes,
verticalalignment='top',
bbox=box_style)
ax.text(0.14, 0.14,
'hard-to-learn',
transform=ax.transAxes,
verticalalignment='top',
bbox=box_style)
ax.legend(title='Correctness')
plt.show()
# Press the green button in the gutter to run the script.
if __name__ == '__main__':
# Load the dataset
dataset = ds.load_dataset("imdb")
# Load the tokenizer
tokenizer = tr.AlbertTokenizerFast.from_pretrained("albert-base-v2")
# Define the model
model = tr.AlbertForSequenceClassification.from_pretrained("albert-base-v2", num_labels=2)
# Tokenize the dataset
def tokenize_function(examples):
return tokenizer(examples['text'], padding='max_length', truncation=True)
tokenized_dataset = dataset.map(tokenize_function, batched=True)
for key in tokenized_dataset:
# TODO Remove sharding when not just testing anymore
tokenized_dataset[key] = tokenized_dataset[key].shard(4000,1)
# Define the training arguments
training_args = tr.TrainingArguments(
output_dir='./results',
num_train_epochs=3,
per_device_train_batch_size=4,
per_device_eval_batch_size=8,
weight_decay=0.01,
logging_dir='./logs',
logging_steps=500,
)
# Define the trainer
trainer = tr.Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset['train'],
eval_dataset=tokenized_dataset['test'],
)
def calc_probs(predictions):
sigmoid_scores = expit(predictions)
return sigmoid_scores / np.sum(sigmoid_scores, axis=1, keepdims=True)
cartographer = Cartographer(tokenized_dataset['train'],
sparse_labels=True,
trainer=trainer,
outputs_to_probabilities=calc_probs)
trainer.add_callback(cartographer)
# Fine-tune the model
print("start training")
trainer.train()
plot_map(cartographer)