-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathprint_plot_eval.py
More file actions
43 lines (32 loc) · 1.22 KB
/
print_plot_eval.py
File metadata and controls
43 lines (32 loc) · 1.22 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch
import numpy as np
import os
import sys
if __name__ == "__main__":
args = sys.argv[1:]
files = []
for file in os.listdir(args[0]):
if file.endswith(".pt"):
print(os.path.join(args[0], file))
files.append(os.path.join(args[0], file))
steps = []
rewards = []
for each_file in files:
seed_data = torch.load(each_file)
steps.extend(seed_data['steps_taken'])
rewards.extend(seed_data['reward'])
print('Average performance over three seeds is :', np.mean(steps), ' with standard ERROR: ',
np.std(steps) / np.sqrt(len(steps)))
print('Average reward over three seeds is :', np.mean(rewards), ' with standard ERROR: ',
np.std(rewards) / np.sqrt(len(rewards)))
# save to a json file
import json
save_dict = {
'steps_mean': np.mean(steps),
'steps_se': np.std(steps) / np.sqrt(len(steps)),
'reward_mean': np.mean(rewards),
'reward_se': np.std(rewards) / np.sqrt(len(rewards))
}
with open(os.path.join(args[0], 'summary_eval.json'), 'w') as f:
json.dump(save_dict, f)
print('Saved summary to ', os.path.join(args[0], 'summary_eval.json'))