/* Autor: Burkhard Lenze */
/* Datum: 18.01.2009     */

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.Locale;
import java.util.Scanner;

public class backpro1 
{
 /*  globale Variablen  */
 static float[][] x = new float[10][100];/*  Eingabevektorfeld                */
 static float[][] y = new float[10][100];/*  Ausgabevektorfeld                */
 static float[] theta = new float[10];	 /*  Schwellwertvektorfeld            */
 static float[][] w = new float[10][10]; /*  Matrixfeld der Eingangsgewichte  */
 static float[][] g = new float[10][10]; /*  Matrixfeld der Ausgangsgewichte  */
 static float[] x_test = new float[10];	 /*  Aktueller Testeingang            */
 static float[] y_test = new float[10];	 /*  Aktueller Testausgang            */
 static int n;				 /*  Anzahl der Eingabe-Neuronen      */
 static int q;				 /*  Anzahl der verborgenen Neuronen  */
 static int m;		     	         /*  Anzahl der Ausgabe-Neuronen      */
 static int t;				 /*  Anzahl der Assoziationen         */
 static int s;				 /*  Anzahl der Trainingszyklen       */
 static int s_max;			 /*  Maximalzahl der Trainingszyklen  */
 static float epsilon;			 /*  Maximaler Ausgabefehler          */
 static float beta;			 /*  Stauchungsparameter              */
 static float lambda;			 /*  Lernrate                         */
 static char com;			 /*  Menuparameter                    */
	
 /*  Funktionen */
 static void lese_x_test()
 {
  int i;
  System.out.print("\n");
  for(i=1;i<=n;i++)
  {
   System.out.print("x[" + i + "] = ? ");
   Scanner sc = new Scanner(System.in).useLocale(Locale.ENGLISH);
   x_test[i] = sc.nextFloat();
  } 
 }
	
 static void lese_x()
 {
  int i;
  System.out.print("\n");
  for(i=1;i<=n;i++)
  {
   System.out.print("x[" + i + "][" + s +  "] = ? ");
   Scanner sc = new Scanner(System.in).useLocale(Locale.ENGLISH);
   x[i][s] = sc.nextFloat();
  }
 }
	
 static void lese_y()
 {
  int j;
  System.out.print("\n");
  for(j=1;j<=m;j++)
  {
   System.out.print("y[" + j + "][" + s + "] = ? ");
   Scanner sc = new Scanner(System.in).useLocale(Locale.ENGLISH);
   y[j][s] = sc.nextFloat();
  }
 }
	
 static float T(float A)
 {
  float Z;
  Z=(float)(1.0/(1.0 + Math.exp(-(beta*A))));
  return(Z);
 }
	
 static float TS(float A)
 {
  float Z;
  Z=(float)((beta*Math.exp(-(beta*A)))/((1.0f+Math.exp(-(beta*A)))*(1.0f+Math.exp(-(beta*A)))));
  return(Z);
 }
	
 static void lernen()
 {
  int i, j, p, st, z;
  float[] A = new float[10];
  float[] B = new float[10];
  float[] C = new float[10];
  float[] y_ = new float[10];
  float H, err;
  for(s=1;s<=t;s++)
  {
   lese_x();
   lese_y();
  }
  s=0;
  do
  {
   s=s+1;
   for(st=1;st<=t;st++)
   {
	for(p=1;p<=q;p++)
	{
	 H=0.0f;
	 for(i=1;i<=n;i++)
	 {
	  H=H+(w[i][p]*x[i][st]);
	 }
	 A[p]=T(H-theta[p]);
	 B[p]=TS(H-theta[p]);
	}
	for(j=1;j<=m;j++)
	{
	 y_[j]=0.0f;
	 for(p=1;p<=q;p++)
	 {
	  y_[j]=y_[j]+g[p][j]*A[p];
	 }
	}
	for(p=1;p<=q;p++)
	{
	 C[p]=0.0f;
	 for(j=1;j<=m;j++)
	 {
	  C[p]=C[p]+(y[j][st]-y_[j])*g[p][j];
	 }
	}
	for(j=1;j<=m;j++)
	{
	 for(p=1;p<=q;p++)
	 {
	  g[p][j]=g[p][j]+lambda*2.0f*(y[j][st]-y_[j])*A[p];
	 }
	}
	for(p=1;p<=q;p++)
	{
	 for(i=1;i<=n;i++)
	 {
	  w[i][p]=w[i][p]+lambda*2.0f*C[p]*B[p]*x[i][st];
	 }
	 theta[p]=theta[p]-lambda*2.0f*C[p]*B[p];
	}
   }
   err=0.0f;
   for(st=1;st<=t;st++)
   {
    for(j=1;j<=m;j++)
	{
	 y_[j]=0.0f;
	 for(p=1;p<=q;p++)
	 {
	  H=0.0f;
	  for(i=1;i<=n;i++)
	  {
	   H=H+(w[i][p]*x[i][st]);
	  }
	  y_[j]=y_[j]+g[p][j]*T(H-theta[p]);
	 }
	 err=err+(y[j][st]-y_[j])*(y[j][st]-y_[j]);
	}
   }
   z=1;
   if((s_max/100)>=1)
	 z=s_max/100;
   if(s%z==0)
   {
	System.out.print("\nEs wurden " + s + " Lernzyklen durchlaufen! ");
	System.out.print("\nDer summierte quadrierte Fehler lautet: " + err);
   }
   if(err<epsilon)
   {
	System.out.print("\n\nEs wurden " + s +  " Lernzyklen durchlaufen! ");
	System.out.print("\nDer summierte quadrierte Fehler lautet: " + err);
	System.out.print("\n\nFehler kleiner als " + epsilon);
	System.out.print("\n\nAbbruch des Lernalgorithmus!!!\n");
   }
   if(s==s_max)
   {
	System.out.print("\n\nEs wurden " + s +  " Lernzyklen durchlaufen! ");
	System.out.print("\nDer summierte quadrierte Fehler lautet: " + err);
	System.out.print("\n\nAbbruch des Lernalgorithmus!!!\n");
   }
  }
  while((err>=epsilon)&&(s!=s_max));
 }
	
 static void ausfuehren()
 {
  int i, j, p;
  float H;
  lese_x_test();
  for(j=1;j<=m;j++)
  {
   y_test[j]=0.0f;
   for(p=1;p<=q;p++)
   {
	H=0.0f;
	for(i=1;i<=n;i++)
	{
	 H=H+(w[i][p]*x_test[i]);
	}
	y_test[j]=y_test[j]+g[p][j]*T(H-theta[p]);
   }
  }
  System.out.print("\n");
  for(j=1;j<=m;j++)
  {
   System.out.println("y[" + j + "] = " + y_test[j]);
  }
 }
	
 static void veraendern()
 {
  int i, j, p;
  s=0;
  Scanner sc = new Scanner(System.in).useLocale(Locale.ENGLISH);
  System.out.print("\n");
  System.out.print("Anzahl n der Neuronen in der Eingabeschicht? ");
  n = sc.nextInt();
  System.out.print("Anzahl q der Neuronen in der verborgenen Schicht? ");
  q = sc.nextInt();
  System.out.print("Anzahl m der Neuronen in der Ausgabeschicht? ");
  m = sc.nextInt();
  System.out.print("Anzahl t der zu lernenden Assoziationen? ");
  t = sc.nextInt();
  System.out.print("Stauchungsparameter beta der Transferfunktion? ");
  beta = sc.nextFloat();
  System.out.print("Lernrate lambda mit der iteriert werden soll? ");
  lambda = sc.nextFloat();
  System.out.print("Anzahl s_max der maximal zu durchlaufenden Lernzyklen? ");
  s_max = sc.nextInt();
  System.out.print("Zu unterschreitender maximaler Fehler epsilon? ");
  epsilon = sc.nextFloat();
  for(i=1;i<=n;i++){x[i][0] = 0.0f;}
  for(j=1;j<=m;j++){y[j][0] = 0.0f;}
  for(p=1;p<=q;p++)
  {
   for(i=1;i<=n;i++){w[i][p] = 0.0f;}
   theta[p]=0.0f;
  }
  for(j=1;j<=m;j++)
  {
   for(p=1;p<=q;p++)
   {
	g[p][j]=0.0f;
   }
  }
 }
	
 static void zeigen()
 {
  int i, j, p;
  System.out.print("\nDie Gewichte w[i][p] nach " + s + " Lernzyklen:\n");
  for(p=1;p<=q;p++)
  {
   for(i=1;i<=n;i++)
   {
	System.out.print("w[" + i + "][" + p + "] = " + w[i][p] + "  ");
   }
   System.out.print("\n");
  }
  System.out.print("\nDie Schwellwerte theta[p] nach " + s + " Lernzyklen:\n");
  for(p=1;p<=q;p++)
  {
   System.out.print("theta[" + p + "] = " + theta[p] + "\n");
  }
  System.out.print("\nDie Gewichte g[p][j] nach " + s + " Lernzyklen:\n");
  for(j=1;j<=m;j++)
  {
   for(p=1;p<=q;p++)
   {
	System.out.print("g[" + p + "][" + j + "] = " + g[p][j] + "\n");
   }
   System.out.print("\n");
  }
 }
	
 static void speichern()
 {
  int i, j, p;
  String dat;
  BufferedReader f = new BufferedReader(new InputStreamReader(System.in));
  System.out.println("\nWie lautet die Datei für Gewichte und Schwellwerte? ");
  try 
  {
   dat = f.readLine();
   BufferedWriter in = new BufferedWriter(new FileWriter(dat));
   in.write(n + "\n");
   in.write(q + "\n");
   in.write(m + "\n");
   for(p=1;p<=q;p++)
   {
	for(i=1;i<=n;i++)
	{
	 in.write(w[i][p] + "\n");
	}
   }
   for(p=1;p<=q;p++)
   {
    in.write(theta[p] + "\n");
   }
   for(j=1;j<=m;j++)
   {
	for(p=1;p<=q;p++)
	{
	 in.write(g[p][j] + "\n");
	}
   }
   in.close();
  } 
  catch (IOException e) 
  {
   System.out.println("Fehler beim Schreiben der Datei!");
  }
 }
	
 static void einlesen()
 {
  int i, j, p, r;
  String dat;
  BufferedReader f = new BufferedReader(new InputStreamReader(System.in));
  System.out.println("\nWie lautet die Datei für Gewichte und Schwellwerte? ");
  try 
  {
   dat = f.readLine();
   BufferedReader out = new BufferedReader(new FileReader(dat));
   r = Integer.parseInt(out.readLine());
   if (r==n)
   {
	r = Integer.parseInt(out.readLine());
	if(r==q)
	{
	 r = Integer.parseInt(out.readLine());
	 if(r==m)
	 {
	  for(p=1;p<=q;p++)
	  {
	   for(i=1;i<=n;i++)
	   {
		w[i][p] = Float.parseFloat(out.readLine());
	   }
	  }
	  for(p=1;p<=q;p++)
	  {
	   theta[p] = Float.parseFloat(out.readLine());
	  }
	  for(j=1;j<=m;j++)
	  {
	   for(p=1;p<=q;p++)
	   {
		g[p][j] = Float.parseFloat(out.readLine());
	   }
	  }
	 }
	 else
	 {
	  System.out.println("\nFehler! m ist in Datei " + r + " statt " + m);
	 }		
	}
	else
	{
	 System.out.println("\nFehler! q ist in Datei " + r + " statt " + q);
	}
   }
   else
   {
	System.out.println("\nFehler! n ist in Datei " + r + " statt " + n);
   }		
   out.close();
  } 
  catch (FileNotFoundException e) 
  {
   System.out.println("Fehler! Datei konnte nicht gefunden werden!");
  } 
  catch (IOException e) 
  {
   System.out.println("Fehler beim Lesen der Datei!");
  }
 }
	
 public static void main(String[] args) 
 {
  veraendern();
  do
  {
   System.out.println("\nlernen ausfuehren zeigen veraendern speichern einlesen beenden:\n");	
   Scanner sc = new Scanner(System.in).useLocale(Locale.ENGLISH);
   com = sc.next().charAt(0);
   switch(com)
   {
	case 'l': lernen(); break;
	case 'a': ausfuehren(); break;
	case 'z': zeigen(); break;
	case 'v': veraendern(); break;
	case 's': speichern(); break;
	case 'e': einlesen(); break;
	default:  break;
   }
  }
  while(com!='b');
 }
}