Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added SectionB/Flux_GroupB.pdf
Binary file not shown.
201 changes: 201 additions & 0 deletions SectionB/LearningRateandSoftmax.mlx
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
Code and Figures for Section B: How to develop a computational model
How should we choose?
clear; close all; clc
dir ='C:\Users\dream\OneDrive - University College London\Lab\FLUX\2021FluxCompModellingWorkshop-main'
cd(dir);
addpath('.\general_functions\') % add softmax function here



outcome = 1; V = .5; alpha = .6;
PE = outcome - V
V = V + alpha * PE
When reward probabilities are fixed, what is the best learning rate?
% 80:20 rewarding outcomes
outcomes = [zeros(20,1); ones(80,1)]; outcomes = outcomes(randperm(length(outcomes)));

p_reward =[];

p_reward(1:100)= 0.8;


% prior at 0
V = 0.5;
% learning rate
params = [];
params.alpha = 1; %play with this paramenter: try .3, .1
PE = [];
for t = 1:length(outcomes)
[V(t+1), PE(t)] = learn_RescWagn(V(t),outcomes(t),params);
end

figure()
plot(V,'DisplayName','value','LineWidth',5)
hold on

plot(outcomes,'o','DisplayName','outcomes','MarkerSize',48,'Marker','.',...
'LineStyle','none',...
'Color',[0.929411764705882 0.694117647058824 0.125490196078431])
hold on
plot(p_reward,'DisplayName','p(reward)','LineWidth',5,'Color', [0.6350, 0.0780, 0.1840])

xlabel('trial','FontSize', 24);
legend('FontSize', 22)

When reward probabilities are alternating, what is the best learning rate?

% alternating 80:20 rewarding outcomes
outcomes1 = [zeros(5,1); ones(20,1)]; outcomes1 = outcomes1(randperm(length(outcomes1)));
outcomes2 = [zeros(20,1); ones(5,1)]; outcomes2 = outcomes2(randperm(length(outcomes2)));

outcomes=[]
outcomes(1:25)=outcomes1
outcomes(26:50)=outcomes2
outcomes(51:75)=outcomes1
outcomes(76:100)=outcomes2


p_reward =[];

p_reward(1:25)= 0.8;
p_reward(26:50)= 0.2;
p_reward(51:75)= 0.8;
p_reward(76:100)= 0.2;


% learning rate
params = [];
params.alpha = .3; % try .1
PE = [];
for t = 1:length(outcomes)
[V(t+1), PE(t)] = learn_RescWagn(V(t),outcomes(t),params);
end

figure()
plot(V,'DisplayName','value','LineWidth',5)
hold on
plot(outcomes,'o','DisplayName','outcomes','MarkerSize',48,'Marker','.',...
'LineStyle','none',...
'Color',[0.929411764705882 0.694117647058824 0.125490196078431])
hold on
plot(p_reward,'DisplayName','p(reward)','LineWidth',5,'Color', [0.6350, 0.0780, 0.1840])

xlabel('trial','FontSize', 24);
legend('FontSize', 22)





How does the softmax work?
% Value of both purple (1) and orange (2) slot machines ranging from 0 and 100
values = [0:100; 100:-1:0];

Explotiation behaviour - where we choose the slot machine whenever its value is better than the other
Temperature is LOW, choices are LESS NOISY, and more DETERMINSITIC
taus = 0.0001;

figure()
for i = 1:length(taus)
params = []; params.tau = taus(i);
for v = 1:length(values)
policy(i,v,:) = softmax(values(:,v),params);
end
plot(values(1,:)-values(2,:),squeeze(policy(i,:,1)),'DisplayName',['tau = ' num2str(taus(i))],'LineWidth',2)
hold on
end

legend
ax = gca;
ax.FontSize = 16;
xlabel('Value (Purple - Orange)','FontSize',20)
ylabel('P(Choose Purple)','FontSize',20)
xlim([-100 100])
ylim([0 1])


Exploration behaviour - where we choose the slot machines randomly
Temperature is HIGH, choices are MORE NOISY, and more RANDOM
taus = [10000];

figure()
for i = 1:length(taus)
params = []; params.tau = taus(i);
for v = 1:length(values)
policy(i,v,:) = softmax(values(:,v),params);
end
plot(values(1,:)-values(2,:),squeeze(policy(i,:,1)),'DisplayName',['tau = ' num2str(taus(i))],'LineWidth',2)
hold on
end
legend
ax = gca;
ax.FontSize = 16;
xlabel('Value (Purple - Orange)','FontSize',20)
ylabel('P(Choose Purple)','FontSize',20)
xlim([-100 100])
ylim([0 1])



Varying temperature leads to different choice curves....
taus = [0.01; 15; 1000];

figure()
for i = 1:length(taus)
params = []; params.tau = taus(i);
for v = 1:length(values)
policy(i,v,:) = softmax(values(:,v),params);
end
plot(values(1,:)-values(2,:),squeeze(policy(i,:,1)),'DisplayName',['tau = ' num2str(taus(i))],'LineWidth',2)
hold on
end
legend
ax = gca;
ax.FontSize = 16;
xlabel('Value (Purple - Orange)','FontSize',20)
ylabel('P(Choose Purple)','FontSize',20)
xlim([-100 100])
ylim([0 1])



When temperature = 1,
taus = [1];

figure()
for i = 1:length(taus)
params = []; params.tau = taus(i);
for v = 1:length(values)
policy(i,v,:) = softmax(values(:,v),params);
end
plot(values(1,:)-values(2,:),squeeze(policy(i,:,1)),'DisplayName',['tau = ' num2str(taus(i))],'LineWidth',2, 'Color','b')
hold on
end
legend
ax = gca;
ax.FontSize = 16;
xlabel('Value (Purple - Orange)','FontSize',20)
ylabel('P(Choose Purple)','FontSize',20)
xlim([-100 100])
ylim([0 1])


When temperature = 15,
taus = [15];

figure()
for i = 1:length(taus)
params = []; params.tau = taus(i);
for v = 1:length(values)
policy(i,v,:) = softmax(values(:,v),params);
end
plot(values(1,:)-values(2,:),squeeze(policy(i,:,1)),'DisplayName',['tau = ' num2str(taus(i))],'LineWidth',2, 'Color','r')
hold on
end
legend
ax = gca;
ax.FontSize = 16;
xlabel('Value (Purple - Orange)','FontSize',20)
ylabel('P(Choose Purple)','FontSize',20)
xlim([-100 100])
ylim([0 1])