A Harder Learning Problem

alexis%yummy@gateway.mitre.org alexis%yummy at gateway.mitre.org
Thu Aug 4 11:23:00 EDT 1988


There are many problems with the current standard "benchmark"
tasks that are used with NNs, but one of them is that they're
just too simple.  It's hard to compare learning algorithms when
the task the network has to perform is excessively easy.

One of the tasks that we've been using at MITRE to test and compare our 
learning algorithms is to distinguish between two intertwined spirals.
This task uses a net with 2 inputs and 1 output.  The inputs correspond
to <x,y> points, and the net should output a 1 on one spiral and
a 0 on the other.  Each of the spirals contains 3 full revolutions.

This task has some nice features: it's very non-linear, it's relatively
difficult (our spiffed up learning algorithm requires ~15-20 million
presentations = ~150-200 thousand epochs = ~1-2 days of cpu on a (loaded)
Sun4/280 to learn, ... we've never succeeded at getting vanilla bp to
correctly converge), and because you have 2 in and 1 out you can *PLOT* 
the current transfer function of the entire network as it learns.

I'd be interested in seeing other people try this or a related problem.
Following this is a simple C program that we use to generate I/O data.

Alexis P. Wieland           wieland at mitre.arpa
MITRE Corporation               or
7525 Colshire Dr.           alexis%yummy at gateway.mitre.org
McLean, VA  22102


/=========================================================================/

#include <stdio.h>
#include <math.h>

/*************************************************************************
 **
 **  mkspiral.c
 **
 **  A program to generate input and output data for a neural network
 **  with 2 inputs and 1 output.
 **
 **  If the 2 inputs are taken to represent an x-y position and the
 **  output (which is either 0.0 or 1.0) is taken to represent which of
 **  two classes the input point is in, then the data forms two coiled
 **  spirals.  Each spiral forms 3 complete revolutions and contains
 **  97 points (32 pts per revolution plus end points).  Spiral 1 passes
 **  from (0, 6.5) -> (6, 0) -> (0, -5.5) -> (-5, 0) -> (0, 4.5) ->
 **  ... -> (0, 0.5).  Likewise, Spiral 0 passes from (0, -6.5) ->
 **  (-6, 0) -> (0, 5.5) -> (5, 0) -> (0, -4.5) -> ... -> (0, -0.5).
 **
 **  This program writes out data in ascii, one exemplar per line, in
 **  the form:  ((x-pt y-pt) (class)).
 **
 **  This data set was developed to test learning algorithms developed
 **  at the MITRE Corporation.  The intention was to create a data set
 **  which would be non-trivial to learn.  We at MITRE have never
 **  succeeded at learning this task with vanilla back-propagation.
 **
 **  Any questions or comment (reports of success or failure with this
 **  task are as interesting as anything to us) contact:
 **
 **        Alexis P. Wieland
 **        MITRE Corporation
 **        7525 Colshire Dr.
 **        McLean, VA  22102
 **        (703) 883-7476
 **        wieland at mitre.ARPA
 **
 *************************************************************************/

main()
{
  int i;
  double x, y, angle, radius;

  /* write spiral of data */
  for (i=0; i<=96; i++) {
    angle = i * M_PI / 16.0;
    radius = 6.5 * (104 - i) / 104.0;
    x = radius * sin(angle);
    y = radius * cos(angle);
    printf("((%8.5f  %8.5f)   (%3.1f))\n",  x,  y, 1.0);
    printf("((%8.5f  %8.5f)   (%3.1f))\n", -x, -y, 0.0);
  }
}


More information about the Connectionists mailing list