function GradientDescent % % Script for gradient descent training of a 3-layer feed-forward artificial % neural network. % % Written for CoSMO 2012 by Gunnar Blohm % http://www.compneurosci.com % %%========================================================================= % training set N = 100; % length of training set input = 15*(randn(2,N)); output = input(1,:) - input(2,:); % network variables x = -50:1:50; % preferred directions of input units Ni = 2+length(x); Nh = 91; % number hidden layer units No = length(x); % number population output units eps = 0.1; % learning rate alp = 0; % momentum tau = 0; % for resilient back-prob tog = 1; if tau > 0, tog = 0; end % initialize weights w1 = 1.*(rand(Nh,Ni)-.5); % random w2 = 1.*(rand(No,Nh)-.5); % random w3 = 1.*(rand(1,No)-.5); % w3 = x./sqrt(2*pi*10^2); % fixed read-out weights %% training loop ok = 1; k = 1; figure; h = uicontrol('style','togglebutton','string','STOP','position',[0 0 100 30]); temp1 = 0; temp2 = 0; temp3 = 0; while ok, for i = 1:N, % encode input % il(:,i) = [exp(-(x-input(1,i)).^2./10.^2./2) exp(-(x-input(2,i)).^2./10.^2./2)]'; il(:,i) = [exp(-(x-input(1,i)).^2./10.^2./2) (input(2,i)+50)/100 (input(2,i)-50)/100]'; % compute layer activations hl(:,i) = trans(w1*il(:,i)); ol(:,i) = trans(w2*hl(:,i)); % decode output (read-out) out(i) = w3*ol(:,i); % back-propagation (gradient descent) err(i) = out(i) - output(i); de3(i) = -err(i); dw3(:,i) = eps.*de3(i)*ol(:,i); de2(:,i) = w3*de3(i); dw2(:,:,i) = eps.*de2(:,i)*hl(:,i)'; de1(:,i) = sum(w2.*repmat(de2(:,i),1,Nh)); dw1(:,:,i) = eps.*de1(:,i)*il(:,i)'; end % update weights w1 = w1 + tog*mean(dw1,3) + alp*temp1 + tau*sign(mean(dw1,3)); w2 = w2 + tog*mean(dw2,3) + alp*temp2 + tau*sign(mean(dw2,3)); w3 = w3;% + tog*mean(dw3,2)' + alp*temp3 + tau*sign(mean(dw3,2)'); temp1 = mean(dw1,3); temp2 = mean(dw2,3); temp3 = mean(dw3,2)'; % plot error changes ERR(k) = sqrt(sum(err.^2))/N; loglog(1:k,ERR); k = k+1; drawnow; ok0 = get(h,'Value'); if ok0 ~= 0, ok = 0; end end close all; figure; subplot(2,1,1);plot(output,out,'.'); axis equal;subplot(2,1,2);plot(output); hold on; plot(out,'r') % keyboard %% transfer function function y = trans(x) y = 1./(1+exp(-x)); return