]> ruin.nu Git - germs.git/blob - fann/src/fann_train_data.c
Make it possible to build statically against the included fann library.
[germs.git] / fann / src / fann_train_data.c
1 /*
2  * Fast Artificial Neural Network Library (fann) Copyright (C) 2003
3  * Steffen Nissen (lukesky@diku.dk)
4  * 
5  * This library is free software; you can redistribute it and/or modify it 
6  * under the terms of the GNU Lesser General Public License as published
7  * by the Free Software Foundation; either version 2.1 of the License, or
8  * (at your option) any later version.
9  * 
10  * This library is distributed in the hope that it will be useful, but
11  * WITHOUT ANY WARRANTY; without even the implied warranty of
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
13  * Lesser General Public License for more details.
14  * 
15  * You should have received a copy of the GNU Lesser General Public
16  * License along with this library; if not, write to the Free Software
17  * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA 
18  */
19
20 #include <stdio.h>
21 #include <stdlib.h>
22 #include <stdarg.h>
23 #include <string.h>
24
25 #include "config.h"
26 #include "fann.h"
27
28 /*
29  * Reads training data from a file. 
30  */
31 FANN_EXTERNAL struct fann_train_data *FANN_API fann_read_train_from_file(const char *configuration_file)
32 {
33         struct fann_train_data *data;
34         FILE *file = fopen(configuration_file, "r");
35
36         if(!file)
37         {
38                 fann_error(NULL, FANN_E_CANT_OPEN_CONFIG_R, configuration_file);
39                 return NULL;
40         }
41
42         data = fann_read_train_from_fd(file, configuration_file);
43         fclose(file);
44         return data;
45 }
46
47 /*
48  * Save training data to a file 
49  */
50 FANN_EXTERNAL int FANN_API fann_save_train(struct fann_train_data *data, const char *filename)
51 {
52         return fann_save_train_internal(data, filename, 0, 0);
53 }
54
55 /*
56  * Save training data to a file in fixed point algebra. (Good for testing
57  * a network in fixed point) 
58  */
59 FANN_EXTERNAL int FANN_API fann_save_train_to_fixed(struct fann_train_data *data, const char *filename,
60                                                                                                          unsigned int decimal_point)
61 {
62         return fann_save_train_internal(data, filename, 1, decimal_point);
63 }
64
65 /*
66  * deallocate the train data structure. 
67  */
68 FANN_EXTERNAL void FANN_API fann_destroy_train(struct fann_train_data *data)
69 {
70         if(data == NULL)
71                 return;
72         if(data->input != NULL)
73                 fann_safe_free(data->input[0]);
74         if(data->output != NULL)
75                 fann_safe_free(data->output[0]);
76         fann_safe_free(data->input);
77         fann_safe_free(data->output);
78         fann_safe_free(data);
79 }
80
81 /*
82  * Test a set of training data and calculate the MSE 
83  */
84 FANN_EXTERNAL float FANN_API fann_test_data(struct fann *ann, struct fann_train_data *data)
85 {
86         unsigned int i;
87
88         fann_reset_MSE(ann);
89
90         for(i = 0; i != data->num_data; i++)
91         {
92                 fann_test(ann, data->input[i], data->output[i]);
93         }
94
95         return fann_get_MSE(ann);
96 }
97
98 #ifndef FIXEDFANN
99
100 /*
101  * Internal train function 
102  */
103 float fann_train_epoch_quickprop(struct fann *ann, struct fann_train_data *data)
104 {
105         unsigned int i;
106
107         if(ann->prev_train_slopes == NULL)
108         {
109                 fann_clear_train_arrays(ann);
110         }
111
112         fann_reset_MSE(ann);
113
114         for(i = 0; i < data->num_data; i++)
115         {
116                 fann_run(ann, data->input[i]);
117                 fann_compute_MSE(ann, data->output[i]);
118                 fann_backpropagate_MSE(ann);
119                 fann_update_slopes_batch(ann, ann->first_layer + 1, ann->last_layer - 1);
120         }
121         fann_update_weights_quickprop(ann, data->num_data, 0, ann->total_connections);
122
123         return fann_get_MSE(ann);
124 }
125
126 /*
127  * Internal train function 
128  */
129 float fann_train_epoch_irpropm(struct fann *ann, struct fann_train_data *data)
130 {
131         unsigned int i;
132
133         if(ann->prev_train_slopes == NULL)
134         {
135                 fann_clear_train_arrays(ann);
136         }
137
138         fann_reset_MSE(ann);
139
140         for(i = 0; i < data->num_data; i++)
141         {
142                 fann_run(ann, data->input[i]);
143                 fann_compute_MSE(ann, data->output[i]);
144                 fann_backpropagate_MSE(ann);
145                 fann_update_slopes_batch(ann, ann->first_layer + 1, ann->last_layer - 1);
146         }
147
148         fann_update_weights_irpropm(ann, 0, ann->total_connections);
149
150         return fann_get_MSE(ann);
151 }
152
153 /*
154  * Internal train function 
155  */
156 float fann_train_epoch_batch(struct fann *ann, struct fann_train_data *data)
157 {
158         unsigned int i;
159
160         fann_reset_MSE(ann);
161
162         for(i = 0; i < data->num_data; i++)
163         {
164                 fann_run(ann, data->input[i]);
165                 fann_compute_MSE(ann, data->output[i]);
166                 fann_backpropagate_MSE(ann);
167                 fann_update_slopes_batch(ann, ann->first_layer + 1, ann->last_layer - 1);
168         }
169
170         fann_update_weights_batch(ann, data->num_data, 0, ann->total_connections);
171
172         return fann_get_MSE(ann);
173 }
174
175 /*
176  * Internal train function 
177  */
178 float fann_train_epoch_incremental(struct fann *ann, struct fann_train_data *data)
179 {
180         unsigned int i;
181
182         fann_reset_MSE(ann);
183
184         for(i = 0; i != data->num_data; i++)
185         {
186                 fann_train(ann, data->input[i], data->output[i]);
187         }
188
189         return fann_get_MSE(ann);
190 }
191
192 /*
193  * Train for one epoch with the selected training algorithm 
194  */
195 FANN_EXTERNAL float FANN_API fann_train_epoch(struct fann *ann, struct fann_train_data *data)
196 {
197         switch (ann->training_algorithm)
198         {
199         case FANN_TRAIN_QUICKPROP:
200                 return fann_train_epoch_quickprop(ann, data);
201         case FANN_TRAIN_RPROP:
202                 return fann_train_epoch_irpropm(ann, data);
203         case FANN_TRAIN_BATCH:
204                 return fann_train_epoch_batch(ann, data);
205         case FANN_TRAIN_INCREMENTAL:
206                 return fann_train_epoch_incremental(ann, data);
207         }
208         return 0;
209 }
210
211 FANN_EXTERNAL void FANN_API fann_train_on_data(struct fann *ann, struct fann_train_data *data,
212                                                                                            unsigned int max_epochs,
213                                                                                            unsigned int epochs_between_reports,
214                                                                                            float desired_error)
215 {
216         float error;
217         unsigned int i;
218         int desired_error_reached;
219
220 #ifdef DEBUG
221         printf("Training with %s\n", FANN_TRAIN_NAMES[ann->training_algorithm]);
222 #endif
223
224         if(epochs_between_reports && ann->callback == NULL)
225         {
226                 printf("Max epochs %8d. Desired error: %.10f.\n", max_epochs, desired_error);
227         }
228
229         for(i = 1; i <= max_epochs; i++)
230         {
231                 /*
232                  * train 
233                  */
234                 error = fann_train_epoch(ann, data);
235                 desired_error_reached = fann_desired_error_reached(ann, desired_error);
236
237                 /*
238                  * print current output 
239                  */
240                 if(epochs_between_reports &&
241                    (i % epochs_between_reports == 0 || i == max_epochs || i == 1 ||
242                         desired_error_reached == 0))
243                 {
244                         if(ann->callback == NULL)
245                         {
246                                 printf("Epochs     %8d. Current error: %.10f. Bit fail %d.\n", i, error,
247                                            ann->num_bit_fail);
248                         }
249                         else if(((*ann->callback)(ann, data, max_epochs, epochs_between_reports, 
250                                                                           desired_error, i)) == -1)
251                         {
252                                 /*
253                                  * you can break the training by returning -1 
254                                  */
255                                 break;
256                         }
257                 }
258
259                 if(desired_error_reached == 0)
260                         break;
261         }
262 }
263
264 FANN_EXTERNAL void FANN_API fann_train_on_file(struct fann *ann, const char *filename,
265                                                                                            unsigned int max_epochs,
266                                                                                            unsigned int epochs_between_reports,
267                                                                                            float desired_error)
268 {
269         struct fann_train_data *data = fann_read_train_from_file(filename);
270
271         if(data == NULL)
272         {
273                 return;
274         }
275         fann_train_on_data(ann, data, max_epochs, epochs_between_reports, desired_error);
276         fann_destroy_train(data);
277 }
278
279 #endif
280
281 /*
282  * shuffles training data, randomizing the order 
283  */
284 FANN_EXTERNAL void FANN_API fann_shuffle_train_data(struct fann_train_data *train_data)
285 {
286         unsigned int dat = 0, elem, swap;
287         fann_type temp;
288
289         for(; dat < train_data->num_data; dat++)
290         {
291                 swap = (unsigned int) (rand() % train_data->num_data);
292                 if(swap != dat)
293                 {
294                         for(elem = 0; elem < train_data->num_input; elem++)
295                         {
296                                 temp = train_data->input[dat][elem];
297                                 train_data->input[dat][elem] = train_data->input[swap][elem];
298                                 train_data->input[swap][elem] = temp;
299                         }
300                         for(elem = 0; elem < train_data->num_output; elem++)
301                         {
302                                 temp = train_data->output[dat][elem];
303                                 train_data->output[dat][elem] = train_data->output[swap][elem];
304                                 train_data->output[swap][elem] = temp;
305                         }
306                 }
307         }
308 }
309
310 /*
311  * INTERNAL FUNCTION Scales data to a specific range 
312  */
313 void fann_scale_data(fann_type ** data, unsigned int num_data, unsigned int num_elem,
314                                          fann_type new_min, fann_type new_max)
315 {
316         unsigned int dat, elem;
317         fann_type old_min, old_max, temp, old_span, new_span, factor;
318
319         old_min = old_max = data[0][0];
320
321         /*
322          * first calculate min and max 
323          */
324         for(dat = 0; dat < num_data; dat++)
325         {
326                 for(elem = 0; elem < num_elem; elem++)
327                 {
328                         temp = data[dat][elem];
329                         if(temp < old_min)
330                                 old_min = temp;
331                         else if(temp > old_max)
332                                 old_max = temp;
333                 }
334         }
335
336         old_span = old_max - old_min;
337         new_span = new_max - new_min;
338         factor = new_span / old_span;
339
340         for(dat = 0; dat < num_data; dat++)
341         {
342                 for(elem = 0; elem < num_elem; elem++)
343                 {
344                         temp = (data[dat][elem] - old_min) * factor + new_min;
345                         if(temp < new_min)
346                         {
347                                 data[dat][elem] = new_min;
348                                 /*
349                                  * printf("error %f < %f\n", temp, new_min); 
350                                  */
351                         }
352                         else if(temp > new_max)
353                         {
354                                 data[dat][elem] = new_max;
355                                 /*
356                                  * printf("error %f > %f\n", temp, new_max); 
357                                  */
358                         }
359                         else
360                         {
361                                 data[dat][elem] = temp;
362                         }
363                 }
364         }
365 }
366
367 /*
368  * Scales the inputs in the training data to the specified range 
369  */
370 FANN_EXTERNAL void FANN_API fann_scale_input_train_data(struct fann_train_data *train_data,
371                                                                                                                 fann_type new_min, fann_type new_max)
372 {
373         fann_scale_data(train_data->input, train_data->num_data, train_data->num_input, new_min,
374                                         new_max);
375 }
376
377 /*
378  * Scales the inputs in the training data to the specified range 
379  */
380 FANN_EXTERNAL void FANN_API fann_scale_output_train_data(struct fann_train_data *train_data,
381                                                                                                                  fann_type new_min, fann_type new_max)
382 {
383         fann_scale_data(train_data->output, train_data->num_data, train_data->num_output, new_min,
384                                         new_max);
385 }
386
387 /*
388  * Scales the inputs in the training data to the specified range 
389  */
390 FANN_EXTERNAL void FANN_API fann_scale_train_data(struct fann_train_data *train_data,
391                                                                                                   fann_type new_min, fann_type new_max)
392 {
393         fann_scale_data(train_data->input, train_data->num_data, train_data->num_input, new_min,
394                                         new_max);
395         fann_scale_data(train_data->output, train_data->num_data, train_data->num_output, new_min,
396                                         new_max);
397 }
398
399 /*
400  * merges training data into a single struct. 
401  */
402 FANN_EXTERNAL struct fann_train_data *FANN_API fann_merge_train_data(struct fann_train_data *data1,
403                                                                                                                                          struct fann_train_data *data2)
404 {
405         unsigned int i;
406         fann_type *data_input, *data_output;
407         struct fann_train_data *dest =
408                 (struct fann_train_data *) malloc(sizeof(struct fann_train_data));
409
410         if(dest == NULL)
411         {
412                 fann_error((struct fann_error*)data1, FANN_E_CANT_ALLOCATE_MEM);
413                 return NULL;
414         }
415
416         if((data1->num_input != data2->num_input) || (data1->num_output != data2->num_output))
417         {
418                 fann_error((struct fann_error*)data1, FANN_E_TRAIN_DATA_MISMATCH);
419                 return NULL;
420         }
421
422         fann_init_error_data((struct fann_error *) dest);
423         dest->error_log = data1->error_log;
424
425         dest->num_data = data1->num_data+data2->num_data;
426         dest->num_input = data1->num_input;
427         dest->num_output = data1->num_output;
428         dest->input = (fann_type **) calloc(dest->num_data, sizeof(fann_type *));
429         if(dest->input == NULL)
430         {
431                 fann_error((struct fann_error*)data1, FANN_E_CANT_ALLOCATE_MEM);
432                 fann_destroy_train(dest);
433                 return NULL;
434         }
435
436         dest->output = (fann_type **) calloc(dest->num_data, sizeof(fann_type *));
437         if(dest->output == NULL)
438         {
439                 fann_error((struct fann_error*)data1, FANN_E_CANT_ALLOCATE_MEM);
440                 fann_destroy_train(dest);
441                 return NULL;
442         }
443
444         data_input = (fann_type *) calloc(dest->num_input * dest->num_data, sizeof(fann_type));
445         if(data_input == NULL)
446         {
447                 fann_error((struct fann_error*)data1, FANN_E_CANT_ALLOCATE_MEM);
448                 fann_destroy_train(dest);
449                 return NULL;
450         }
451         memcpy(data_input, data1->input[0], dest->num_input * data1->num_data * sizeof(fann_type));
452         memcpy(data_input + (dest->num_input*data1->num_data), 
453                 data2->input[0], dest->num_input * data2->num_data * sizeof(fann_type));
454
455         data_output = (fann_type *) calloc(dest->num_output * dest->num_data, sizeof(fann_type));
456         if(data_output == NULL)
457         {
458                 fann_error((struct fann_error*)data1, FANN_E_CANT_ALLOCATE_MEM);
459                 fann_destroy_train(dest);
460                 return NULL;
461         }
462         memcpy(data_output, data1->output[0], dest->num_output * data1->num_data * sizeof(fann_type));
463         memcpy(data_output + (dest->num_output*data1->num_data), 
464                 data2->output[0], dest->num_output * data2->num_data * sizeof(fann_type));
465
466         for(i = 0; i != dest->num_data; i++)
467         {
468                 dest->input[i] = data_input;
469                 data_input += dest->num_input;
470                 dest->output[i] = data_output;
471                 data_output += dest->num_output;
472         }
473         return dest;
474 }
475
476 /*
477  * return a copy of a fann_train_data struct 
478  */
479 FANN_EXTERNAL struct fann_train_data *FANN_API fann_duplicate_train_data(struct fann_train_data
480                                                                                                                                                  *data)
481 {
482         unsigned int i;
483         fann_type *data_input, *data_output;
484         struct fann_train_data *dest =
485                 (struct fann_train_data *) malloc(sizeof(struct fann_train_data));
486
487         if(dest == NULL)
488         {
489                 fann_error((struct fann_error*)data, FANN_E_CANT_ALLOCATE_MEM);
490                 return NULL;
491         }
492
493         fann_init_error_data((struct fann_error *) dest);
494         dest->error_log = data->error_log;
495
496         dest->num_data = data->num_data;
497         dest->num_input = data->num_input;
498         dest->num_output = data->num_output;
499         dest->input = (fann_type **) calloc(dest->num_data, sizeof(fann_type *));
500         if(dest->input == NULL)
501         {
502                 fann_error((struct fann_error*)data, FANN_E_CANT_ALLOCATE_MEM);
503                 fann_destroy_train(dest);
504                 return NULL;
505         }
506
507         dest->output = (fann_type **) calloc(dest->num_data, sizeof(fann_type *));
508         if(dest->output == NULL)
509         {
510                 fann_error((struct fann_error*)data, FANN_E_CANT_ALLOCATE_MEM);
511                 fann_destroy_train(dest);
512                 return NULL;
513         }
514
515         data_input = (fann_type *) calloc(dest->num_input * dest->num_data, sizeof(fann_type));
516         if(data_input == NULL)
517         {
518                 fann_error((struct fann_error*)data, FANN_E_CANT_ALLOCATE_MEM);
519                 fann_destroy_train(dest);
520                 return NULL;
521         }
522         memcpy(data_input, data->input[0], dest->num_input * dest->num_data * sizeof(fann_type));
523
524         data_output = (fann_type *) calloc(dest->num_output * dest->num_data, sizeof(fann_type));
525         if(data_output == NULL)
526         {
527                 fann_error((struct fann_error*)data, FANN_E_CANT_ALLOCATE_MEM);
528                 fann_destroy_train(dest);
529                 return NULL;
530         }
531         memcpy(data_output, data->output[0], dest->num_output * dest->num_data * sizeof(fann_type));
532
533         for(i = 0; i != dest->num_data; i++)
534         {
535                 dest->input[i] = data_input;
536                 data_input += dest->num_input;
537                 dest->output[i] = data_output;
538                 data_output += dest->num_output;
539         }
540         return dest;
541 }
542
543 FANN_EXTERNAL struct fann_train_data *FANN_API fann_subset_train_data(struct fann_train_data
544                                                                                                                                                  *data, unsigned int pos,
545                                                                                                                                                  unsigned int length)
546 {
547         unsigned int i;
548         fann_type *data_input, *data_output;
549         struct fann_train_data *dest =
550                 (struct fann_train_data *) malloc(sizeof(struct fann_train_data));
551
552         if(dest == NULL)
553         {
554                 fann_error((struct fann_error*)data, FANN_E_CANT_ALLOCATE_MEM);
555                 return NULL;
556         }
557         
558         if(pos > data->num_data || pos+length > data->num_data)
559         {
560                 fann_error((struct fann_error*)data, FANN_E_TRAIN_DATA_SUBSET, pos, length, data->num_data);
561                 return NULL;
562         }
563
564         fann_init_error_data((struct fann_error *) dest);
565         dest->error_log = data->error_log;
566
567         dest->num_data = length;
568         dest->num_input = data->num_input;
569         dest->num_output = data->num_output;
570         dest->input = (fann_type **) calloc(dest->num_data, sizeof(fann_type *));
571         if(dest->input == NULL)
572         {
573                 fann_error((struct fann_error*)data, FANN_E_CANT_ALLOCATE_MEM);
574                 fann_destroy_train(dest);
575                 return NULL;
576         }
577
578         dest->output = (fann_type **) calloc(dest->num_data, sizeof(fann_type *));
579         if(dest->output == NULL)
580         {
581                 fann_error((struct fann_error*)data, FANN_E_CANT_ALLOCATE_MEM);
582                 fann_destroy_train(dest);
583                 return NULL;
584         }
585
586         data_input = (fann_type *) calloc(dest->num_input * dest->num_data, sizeof(fann_type));
587         if(data_input == NULL)
588         {
589                 fann_error((struct fann_error*)data, FANN_E_CANT_ALLOCATE_MEM);
590                 fann_destroy_train(dest);
591                 return NULL;
592         }
593         memcpy(data_input, data->input[pos], dest->num_input * dest->num_data * sizeof(fann_type));
594
595         data_output = (fann_type *) calloc(dest->num_output * dest->num_data, sizeof(fann_type));
596         if(data_output == NULL)
597         {
598                 fann_error((struct fann_error*)data, FANN_E_CANT_ALLOCATE_MEM);
599                 fann_destroy_train(dest);
600                 return NULL;
601         }
602         memcpy(data_output, data->output[pos], dest->num_output * dest->num_data * sizeof(fann_type));
603
604         for(i = 0; i != dest->num_data; i++)
605         {
606                 dest->input[i] = data_input;
607                 data_input += dest->num_input;
608                 dest->output[i] = data_output;
609                 data_output += dest->num_output;
610         }
611         return dest;
612 }
613
614 FANN_EXTERNAL unsigned int FANN_API fann_length_train_data(struct fann_train_data *data)
615 {
616         return data->num_data;
617 }
618
619 FANN_EXTERNAL unsigned int FANN_API fann_num_input_train_data(struct fann_train_data *data)
620 {
621         return data->num_input;
622 }
623
624 FANN_EXTERNAL unsigned int FANN_API fann_num_output_train_data(struct fann_train_data *data)
625 {
626         return data->num_output;
627 }
628
629 /* INTERNAL FUNCTION
630    Save the train data structure.
631  */
632 int fann_save_train_internal(struct fann_train_data *data, const char *filename,
633                                                           unsigned int save_as_fixed, unsigned int decimal_point)
634 {
635         int retval = 0;
636         FILE *file = fopen(filename, "w");
637
638         if(!file)
639         {
640                 fann_error((struct fann_error *) data, FANN_E_CANT_OPEN_TD_W, filename);
641                 return -1;
642         }
643         retval = fann_save_train_internal_fd(data, file, filename, save_as_fixed, decimal_point);
644         fclose(file);
645         
646         return retval;
647 }
648
649 /* INTERNAL FUNCTION
650    Save the train data structure.
651  */
652 int fann_save_train_internal_fd(struct fann_train_data *data, FILE * file, const char *filename,
653                                                                  unsigned int save_as_fixed, unsigned int decimal_point)
654 {
655         unsigned int num_data = data->num_data;
656         unsigned int num_input = data->num_input;
657         unsigned int num_output = data->num_output;
658         unsigned int i, j;
659         int retval = 0;
660
661 #ifndef FIXEDFANN
662         unsigned int multiplier = 1 << decimal_point;
663 #endif
664
665         fprintf(file, "%u %u %u\n", data->num_data, data->num_input, data->num_output);
666
667         for(i = 0; i < num_data; i++)
668         {
669                 for(j = 0; j < num_input; j++)
670                 {
671 #ifndef FIXEDFANN
672                         if(save_as_fixed)
673                         {
674                                 fprintf(file, "%d ", (int) (data->input[i][j] * multiplier));
675                         }
676                         else
677                         {
678                                 if(((int) floor(data->input[i][j] + 0.5) * 1000000) ==
679                                    ((int) floor(data->input[i][j] * 1000000.0 + 0.5)))
680                                 {
681                                         fprintf(file, "%d ", (int) data->input[i][j]);
682                                 }
683                                 else
684                                 {
685                                         fprintf(file, "%f ", data->input[i][j]);
686                                 }
687                         }
688 #else
689                         fprintf(file, FANNPRINTF " ", data->input[i][j]);
690 #endif
691                 }
692                 fprintf(file, "\n");
693
694                 for(j = 0; j < num_output; j++)
695                 {
696 #ifndef FIXEDFANN
697                         if(save_as_fixed)
698                         {
699                                 fprintf(file, "%d ", (int) (data->output[i][j] * multiplier));
700                         }
701                         else
702                         {
703                                 if(((int) floor(data->output[i][j] + 0.5) * 1000000) ==
704                                    ((int) floor(data->output[i][j] * 1000000.0 + 0.5)))
705                                 {
706                                         fprintf(file, "%d ", (int) data->output[i][j]);
707                                 }
708                                 else
709                                 {
710                                         fprintf(file, "%f ", data->output[i][j]);
711                                 }
712                         }
713 #else
714                         fprintf(file, FANNPRINTF " ", data->output[i][j]);
715 #endif
716                 }
717                 fprintf(file, "\n");
718         }
719         
720         return retval;
721 }
722
723
724 /*
725  * INTERNAL FUNCTION Reads training data from a file descriptor. 
726  */
727 struct fann_train_data *fann_read_train_from_fd(FILE * file, const char *filename)
728 {
729         unsigned int num_input, num_output, num_data, i, j;
730         unsigned int line = 1;
731         fann_type *data_input, *data_output;
732         struct fann_train_data *data =
733                 (struct fann_train_data *) malloc(sizeof(struct fann_train_data));
734
735         if(data == NULL)
736         {
737                 fann_error(NULL, FANN_E_CANT_ALLOCATE_MEM);
738                 return NULL;
739         }
740
741         if(fscanf(file, "%u %u %u\n", &num_data, &num_input, &num_output) != 3)
742         {
743                 fann_error(NULL, FANN_E_CANT_READ_TD, filename, line);
744                 fann_destroy_train(data);
745                 return NULL;
746         }
747         line++;
748
749         fann_init_error_data((struct fann_error *) data);
750
751         data->num_data = num_data;
752         data->num_input = num_input;
753         data->num_output = num_output;
754         data->input = (fann_type **) calloc(num_data, sizeof(fann_type *));
755         if(data->input == NULL)
756         {
757                 fann_error(NULL, FANN_E_CANT_ALLOCATE_MEM);
758                 fann_destroy_train(data);
759                 return NULL;
760         }
761
762         data->output = (fann_type **) calloc(num_data, sizeof(fann_type *));
763         if(data->output == NULL)
764         {
765                 fann_error(NULL, FANN_E_CANT_ALLOCATE_MEM);
766                 fann_destroy_train(data);
767                 return NULL;
768         }
769
770         data_input = (fann_type *) calloc(num_input * num_data, sizeof(fann_type));
771         if(data_input == NULL)
772         {
773                 fann_error(NULL, FANN_E_CANT_ALLOCATE_MEM);
774                 fann_destroy_train(data);
775                 return NULL;
776         }
777
778         data_output = (fann_type *) calloc(num_output * num_data, sizeof(fann_type));
779         if(data_output == NULL)
780         {
781                 fann_error(NULL, FANN_E_CANT_ALLOCATE_MEM);
782                 fann_destroy_train(data);
783                 return NULL;
784         }
785
786         for(i = 0; i != num_data; i++)
787         {
788                 data->input[i] = data_input;
789                 data_input += num_input;
790
791                 for(j = 0; j != num_input; j++)
792                 {
793                         if(fscanf(file, FANNSCANF " ", &data->input[i][j]) != 1)
794                         {
795                                 fann_error(NULL, FANN_E_CANT_READ_TD, filename, line);
796                                 fann_destroy_train(data);
797                                 return NULL;
798                         }
799                 }
800                 line++;
801
802                 data->output[i] = data_output;
803                 data_output += num_output;
804
805                 for(j = 0; j != num_output; j++)
806                 {
807                         if(fscanf(file, FANNSCANF " ", &data->output[i][j]) != 1)
808                         {
809                                 fann_error(NULL, FANN_E_CANT_READ_TD, filename, line);
810                                 fann_destroy_train(data);
811                                 return NULL;
812                         }
813                 }
814                 line++;
815         }
816         return data;
817 }
818
819 /*
820  * INTERNAL FUNCTION returns 0 if the desired error is reached and -1 if it is not reached
821  */
822 int fann_desired_error_reached(struct fann *ann, float desired_error)
823 {
824         switch (ann->train_stop_function)
825         {
826         case FANN_STOPFUNC_MSE:
827                 if(fann_get_MSE(ann) <= desired_error)
828                         return 0;
829                 break;
830         case FANN_STOPFUNC_BIT:
831                 if(ann->num_bit_fail <= (unsigned int)desired_error)
832                         return 0;
833                 break;
834         }
835         return -1;
836 }