登录后复制
% Getting Started examples for animateGraDes% % Example 1: Simplest%--------------------agd = animateGraDes(); % instantiateagd.funcStr='x^2+2*x*y+3*y^2+4*x+5*y+6'; % cost function is requiredagd.animate(); % start% Example 2: Alpha overshooting%-------------------------------agd = animateGraDes(); % instantiateagd.funcStr='x^2+2*x*y+3*y^2+4*x+5*y+6'; % cost function is required% other optional parametersagd.alpha = 0.2; % big alphaagd.drawContour = true; % contour plotagd.animate(); % Example 3: saddle point%------------------------agd = animateGraDes();agd.alpha=0.15;agd.funcStr='x^4-2*x^2+y^2'; % special function with saddle pointsagd.startPoint=[1.5 1.5]; % point not on the ridgeagd.drawContour=true; % draw contour. Set to false if want 3D insteadagd.xrange=-2:0.1:2; % xrange covers local min and start pointagd.yrange=-2:0.1:2; % yrange covers local min and start pointagd.animate();classdef animateGraDes < handle properties (Access=public) funcStr; % Function of x, y in String alpha; % alpha for gradient descent startPoint; % start point for gradient descent maxStepCount; % Stop after # steps even if min is smaller than threshold stopThreshold; % when distance is smaller than this, stop xrange; % x range for showing funcStr yrange; % y range for showing funStr showAnnotation; % show annotation or not drawContour; % drawContour instead of 3D surface stepsPerSecond; % advance # steps each second outfile; % output an animation GIF if set to a filename end % private properties for internal use only properties (Access=private) func; % function handler of funcStr xPartial; % x partial derivative of func yPartial; % y partial derivative of func end methods(Access=public) function agd = animateGraDes() % set default values. User can overwrite after instantiation agd.stepsPerSecond = 5; agd.alpha = 0.1; agd.startPoint = [5 5]; agd.xrange = -10:1:10; agd.yrange = -10:1:10; agd.maxStepCount = 100; agd.stopThreshold = 1E-10; agd.showAnnotation = true; agd.outfile = []; end function animate(obj) clf try obj.func = str2func(['@(x, y)' obj.funcStr]); symFunc = sym(obj.func); syms x y disp(ME); return; end pauseInSec = 1/obj.stepsPerSecond; [X, Y] = meshgrid(obj.xrange, obj.yrange); Z = obj.computeZ(X, Y); if obj.drawContour contour(X, Y, Z, 20); else surf(X,Y,Z); alpha 0.5 end hold on xStart = obj.startPoint(1); yStart = obj.startPoint(2); if ~isempty(obj.outfile) [img, map] = rgb2ind(frame2im( getframe(gcf)),256); imwrite(img,map,obj.outfile,'gif','DelayTime',0.5); end ann = []; if obj.showAnnotation dim = [0.05 0.81 0.38 0.13]; strDisplay = 'Running ...'; ann = annotation('textbox', dim, 'String', strDisplay,'BackgroundColor','white', 'FitBoxToText','on'); end for i=0:obj.maxStepCount zStart = obj.func(xStart, yStart); xEnd = double(xStart - obj.getXpartial(xStart, yStart)); yEnd = double(yStart - obj.getYpartial(xStart, yStart)); zEnd = double(obj.func(xEnd, yEnd)); if obj.drawContour plot([xStart xEnd], [yStart, yEnd], 'r-*'); else plot3([xStart xEnd], [yStart yEnd], [zStart zEnd],'r-*'); end xStart = xEnd; yStart = yEnd; if ~isempty(obj.outfile) [img, map] = rgb2ind(frame2im( getframe(gcf)),256); imwrite(img,map,obj.outfile,'gif','writemode', 'append','delaytime',pauseInSec); else pause(pauseInSec); end ann.String = ['Running ' num2str(i) '/' num2str(obj.maxStepCount)]; end if obj.showAnnotation if ~isempty(ann) strDisplay = {['\alpha: ' num2str(obj.alpha)], ... ['step count: ' num2str(i)], ... ['Min: (' num2str(xEnd) ', ' num2str(yEnd) ', ' num2str(zEnd) ')']}; ann.String = strDisplay; end if ~isempty(obj.outfile) [img, map] = rgb2ind(frame2im( getframe(gcf)),256); imwrite(img,map,obj.outfile,'gif','writemode', 'append','delaytime',pauseInSec); end end x = xIn; y = yIn; zValue = obj.alpha*subs(obj.xPartial); end function Z = computeZ(obj, X, Y) sz = size(X); Z = zeros(sz(1), sz(2)); for i=1:sz(1) for j=1:sz(2) Z(i, j) = obj.func(X(i, j), Y(i, j)); end end end function done = checkStop(obj, xStart, xEnd, yStart, yEnd, zStart, zEnd) done = false; dis = (xStart-xEnd)^2+(yStart-yEnd)^2+(zStart-zEnd)^2; if dis <obj.stopThreshold done = true; end end endend1.2.3.4.5.6.7.8.9.10.11.12.13.14.15.16.17.18.19.20.21.22.23.24.25.26.27.28.29.30.31.32.33.34.35.36.37.38.39.40.41.42.43.44.45.46.47.48.49.50.51.52.53.54.55.56.57.58.59.60.61.62.63.64.65.66.67.68.69.70.71.72.73.74.75.76.77.78.79.80.81.82.83.84.85.86.87.88.89.90.91.92.93.94.95.96.97.98.99.100.101.102.103.104.105.106.107.108.109.110.111.112.113.114.115.116.117.118.119.120.121.122.123.124.125.126.127.128.129.130.131.132.133.134.135.136.137.138.139.140.141.142.143.144.145.146.147.148.149.150.151.152.153.154.155.156.157.158.159.160.161.162.163.
1 matlab版本
2014a
免责声明:本文系网络转载或改编,未找到原创作者,版权归原作者所有。如涉及版权,请联系删