2 * Fast Artificial Neural Network Library (fann) Copyright (C) 2003
3 * Steffen Nissen (lukesky@diku.dk)
5 * This library is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU Lesser General Public License as published
7 * by the Free Software Foundation; either version 2.1 of the License, or
8 * (at your option) any later version.
10 * This library is distributed in the hope that it will be useful, but
11 * WITHOUT ANY WARRANTY; without even the implied warranty of
12 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
13 * Lesser General Public License for more details.
15 * You should have received a copy of the GNU Lesser General Public
16 * License along with this library; if not, write to the Free Software
17 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
29 * Reads training data from a file.
31 FANN_EXTERNAL struct fann_train_data *FANN_API fann_read_train_from_file(const char *configuration_file)
33 struct fann_train_data *data;
34 FILE *file = fopen(configuration_file, "r");
38 fann_error(NULL, FANN_E_CANT_OPEN_CONFIG_R, configuration_file);
42 data = fann_read_train_from_fd(file, configuration_file);
48 * Save training data to a file
50 FANN_EXTERNAL int FANN_API fann_save_train(struct fann_train_data *data, const char *filename)
52 return fann_save_train_internal(data, filename, 0, 0);
56 * Save training data to a file in fixed point algebra. (Good for testing
57 * a network in fixed point)
59 FANN_EXTERNAL int FANN_API fann_save_train_to_fixed(struct fann_train_data *data, const char *filename,
60 unsigned int decimal_point)
62 return fann_save_train_internal(data, filename, 1, decimal_point);
66 * deallocate the train data structure.
68 FANN_EXTERNAL void FANN_API fann_destroy_train(struct fann_train_data *data)
72 if(data->input != NULL)
73 fann_safe_free(data->input[0]);
74 if(data->output != NULL)
75 fann_safe_free(data->output[0]);
76 fann_safe_free(data->input);
77 fann_safe_free(data->output);
82 * Test a set of training data and calculate the MSE
84 FANN_EXTERNAL float FANN_API fann_test_data(struct fann *ann, struct fann_train_data *data)
90 for(i = 0; i != data->num_data; i++)
92 fann_test(ann, data->input[i], data->output[i]);
95 return fann_get_MSE(ann);
101 * Internal train function
103 float fann_train_epoch_quickprop(struct fann *ann, struct fann_train_data *data)
107 if(ann->prev_train_slopes == NULL)
109 fann_clear_train_arrays(ann);
114 for(i = 0; i < data->num_data; i++)
116 fann_run(ann, data->input[i]);
117 fann_compute_MSE(ann, data->output[i]);
118 fann_backpropagate_MSE(ann);
119 fann_update_slopes_batch(ann, ann->first_layer + 1, ann->last_layer - 1);
121 fann_update_weights_quickprop(ann, data->num_data, 0, ann->total_connections);
123 return fann_get_MSE(ann);
127 * Internal train function
129 float fann_train_epoch_irpropm(struct fann *ann, struct fann_train_data *data)
133 if(ann->prev_train_slopes == NULL)
135 fann_clear_train_arrays(ann);
140 for(i = 0; i < data->num_data; i++)
142 fann_run(ann, data->input[i]);
143 fann_compute_MSE(ann, data->output[i]);
144 fann_backpropagate_MSE(ann);
145 fann_update_slopes_batch(ann, ann->first_layer + 1, ann->last_layer - 1);
148 fann_update_weights_irpropm(ann, 0, ann->total_connections);
150 return fann_get_MSE(ann);
154 * Internal train function
156 float fann_train_epoch_batch(struct fann *ann, struct fann_train_data *data)
162 for(i = 0; i < data->num_data; i++)
164 fann_run(ann, data->input[i]);
165 fann_compute_MSE(ann, data->output[i]);
166 fann_backpropagate_MSE(ann);
167 fann_update_slopes_batch(ann, ann->first_layer + 1, ann->last_layer - 1);
170 fann_update_weights_batch(ann, data->num_data, 0, ann->total_connections);
172 return fann_get_MSE(ann);
176 * Internal train function
178 float fann_train_epoch_incremental(struct fann *ann, struct fann_train_data *data)
184 for(i = 0; i != data->num_data; i++)
186 fann_train(ann, data->input[i], data->output[i]);
189 return fann_get_MSE(ann);
193 * Train for one epoch with the selected training algorithm
195 FANN_EXTERNAL float FANN_API fann_train_epoch(struct fann *ann, struct fann_train_data *data)
197 switch (ann->training_algorithm)
199 case FANN_TRAIN_QUICKPROP:
200 return fann_train_epoch_quickprop(ann, data);
201 case FANN_TRAIN_RPROP:
202 return fann_train_epoch_irpropm(ann, data);
203 case FANN_TRAIN_BATCH:
204 return fann_train_epoch_batch(ann, data);
205 case FANN_TRAIN_INCREMENTAL:
206 return fann_train_epoch_incremental(ann, data);
211 FANN_EXTERNAL void FANN_API fann_train_on_data(struct fann *ann, struct fann_train_data *data,
212 unsigned int max_epochs,
213 unsigned int epochs_between_reports,
218 int desired_error_reached;
221 printf("Training with %s\n", FANN_TRAIN_NAMES[ann->training_algorithm]);
224 if(epochs_between_reports && ann->callback == NULL)
226 printf("Max epochs %8d. Desired error: %.10f.\n", max_epochs, desired_error);
229 for(i = 1; i <= max_epochs; i++)
234 error = fann_train_epoch(ann, data);
235 desired_error_reached = fann_desired_error_reached(ann, desired_error);
238 * print current output
240 if(epochs_between_reports &&
241 (i % epochs_between_reports == 0 || i == max_epochs || i == 1 ||
242 desired_error_reached == 0))
244 if(ann->callback == NULL)
246 printf("Epochs %8d. Current error: %.10f. Bit fail %d.\n", i, error,
249 else if(((*ann->callback)(ann, data, max_epochs, epochs_between_reports,
250 desired_error, i)) == -1)
253 * you can break the training by returning -1
259 if(desired_error_reached == 0)
264 FANN_EXTERNAL void FANN_API fann_train_on_file(struct fann *ann, const char *filename,
265 unsigned int max_epochs,
266 unsigned int epochs_between_reports,
269 struct fann_train_data *data = fann_read_train_from_file(filename);
275 fann_train_on_data(ann, data, max_epochs, epochs_between_reports, desired_error);
276 fann_destroy_train(data);
282 * shuffles training data, randomizing the order
284 FANN_EXTERNAL void FANN_API fann_shuffle_train_data(struct fann_train_data *train_data)
286 unsigned int dat = 0, elem, swap;
289 for(; dat < train_data->num_data; dat++)
291 swap = (unsigned int) (rand() % train_data->num_data);
294 for(elem = 0; elem < train_data->num_input; elem++)
296 temp = train_data->input[dat][elem];
297 train_data->input[dat][elem] = train_data->input[swap][elem];
298 train_data->input[swap][elem] = temp;
300 for(elem = 0; elem < train_data->num_output; elem++)
302 temp = train_data->output[dat][elem];
303 train_data->output[dat][elem] = train_data->output[swap][elem];
304 train_data->output[swap][elem] = temp;
311 * INTERNAL FUNCTION Scales data to a specific range
313 void fann_scale_data(fann_type ** data, unsigned int num_data, unsigned int num_elem,
314 fann_type new_min, fann_type new_max)
316 unsigned int dat, elem;
317 fann_type old_min, old_max, temp, old_span, new_span, factor;
319 old_min = old_max = data[0][0];
322 * first calculate min and max
324 for(dat = 0; dat < num_data; dat++)
326 for(elem = 0; elem < num_elem; elem++)
328 temp = data[dat][elem];
331 else if(temp > old_max)
336 old_span = old_max - old_min;
337 new_span = new_max - new_min;
338 factor = new_span / old_span;
340 for(dat = 0; dat < num_data; dat++)
342 for(elem = 0; elem < num_elem; elem++)
344 temp = (data[dat][elem] - old_min) * factor + new_min;
347 data[dat][elem] = new_min;
349 * printf("error %f < %f\n", temp, new_min);
352 else if(temp > new_max)
354 data[dat][elem] = new_max;
356 * printf("error %f > %f\n", temp, new_max);
361 data[dat][elem] = temp;
368 * Scales the inputs in the training data to the specified range
370 FANN_EXTERNAL void FANN_API fann_scale_input_train_data(struct fann_train_data *train_data,
371 fann_type new_min, fann_type new_max)
373 fann_scale_data(train_data->input, train_data->num_data, train_data->num_input, new_min,
378 * Scales the inputs in the training data to the specified range
380 FANN_EXTERNAL void FANN_API fann_scale_output_train_data(struct fann_train_data *train_data,
381 fann_type new_min, fann_type new_max)
383 fann_scale_data(train_data->output, train_data->num_data, train_data->num_output, new_min,
388 * Scales the inputs in the training data to the specified range
390 FANN_EXTERNAL void FANN_API fann_scale_train_data(struct fann_train_data *train_data,
391 fann_type new_min, fann_type new_max)
393 fann_scale_data(train_data->input, train_data->num_data, train_data->num_input, new_min,
395 fann_scale_data(train_data->output, train_data->num_data, train_data->num_output, new_min,
400 * merges training data into a single struct.
402 FANN_EXTERNAL struct fann_train_data *FANN_API fann_merge_train_data(struct fann_train_data *data1,
403 struct fann_train_data *data2)
406 fann_type *data_input, *data_output;
407 struct fann_train_data *dest =
408 (struct fann_train_data *) malloc(sizeof(struct fann_train_data));
412 fann_error((struct fann_error*)data1, FANN_E_CANT_ALLOCATE_MEM);
416 if((data1->num_input != data2->num_input) || (data1->num_output != data2->num_output))
418 fann_error((struct fann_error*)data1, FANN_E_TRAIN_DATA_MISMATCH);
422 fann_init_error_data((struct fann_error *) dest);
423 dest->error_log = data1->error_log;
425 dest->num_data = data1->num_data+data2->num_data;
426 dest->num_input = data1->num_input;
427 dest->num_output = data1->num_output;
428 dest->input = (fann_type **) calloc(dest->num_data, sizeof(fann_type *));
429 if(dest->input == NULL)
431 fann_error((struct fann_error*)data1, FANN_E_CANT_ALLOCATE_MEM);
432 fann_destroy_train(dest);
436 dest->output = (fann_type **) calloc(dest->num_data, sizeof(fann_type *));
437 if(dest->output == NULL)
439 fann_error((struct fann_error*)data1, FANN_E_CANT_ALLOCATE_MEM);
440 fann_destroy_train(dest);
444 data_input = (fann_type *) calloc(dest->num_input * dest->num_data, sizeof(fann_type));
445 if(data_input == NULL)
447 fann_error((struct fann_error*)data1, FANN_E_CANT_ALLOCATE_MEM);
448 fann_destroy_train(dest);
451 memcpy(data_input, data1->input[0], dest->num_input * data1->num_data * sizeof(fann_type));
452 memcpy(data_input + (dest->num_input*data1->num_data),
453 data2->input[0], dest->num_input * data2->num_data * sizeof(fann_type));
455 data_output = (fann_type *) calloc(dest->num_output * dest->num_data, sizeof(fann_type));
456 if(data_output == NULL)
458 fann_error((struct fann_error*)data1, FANN_E_CANT_ALLOCATE_MEM);
459 fann_destroy_train(dest);
462 memcpy(data_output, data1->output[0], dest->num_output * data1->num_data * sizeof(fann_type));
463 memcpy(data_output + (dest->num_output*data1->num_data),
464 data2->output[0], dest->num_output * data2->num_data * sizeof(fann_type));
466 for(i = 0; i != dest->num_data; i++)
468 dest->input[i] = data_input;
469 data_input += dest->num_input;
470 dest->output[i] = data_output;
471 data_output += dest->num_output;
477 * return a copy of a fann_train_data struct
479 FANN_EXTERNAL struct fann_train_data *FANN_API fann_duplicate_train_data(struct fann_train_data
483 fann_type *data_input, *data_output;
484 struct fann_train_data *dest =
485 (struct fann_train_data *) malloc(sizeof(struct fann_train_data));
489 fann_error((struct fann_error*)data, FANN_E_CANT_ALLOCATE_MEM);
493 fann_init_error_data((struct fann_error *) dest);
494 dest->error_log = data->error_log;
496 dest->num_data = data->num_data;
497 dest->num_input = data->num_input;
498 dest->num_output = data->num_output;
499 dest->input = (fann_type **) calloc(dest->num_data, sizeof(fann_type *));
500 if(dest->input == NULL)
502 fann_error((struct fann_error*)data, FANN_E_CANT_ALLOCATE_MEM);
503 fann_destroy_train(dest);
507 dest->output = (fann_type **) calloc(dest->num_data, sizeof(fann_type *));
508 if(dest->output == NULL)
510 fann_error((struct fann_error*)data, FANN_E_CANT_ALLOCATE_MEM);
511 fann_destroy_train(dest);
515 data_input = (fann_type *) calloc(dest->num_input * dest->num_data, sizeof(fann_type));
516 if(data_input == NULL)
518 fann_error((struct fann_error*)data, FANN_E_CANT_ALLOCATE_MEM);
519 fann_destroy_train(dest);
522 memcpy(data_input, data->input[0], dest->num_input * dest->num_data * sizeof(fann_type));
524 data_output = (fann_type *) calloc(dest->num_output * dest->num_data, sizeof(fann_type));
525 if(data_output == NULL)
527 fann_error((struct fann_error*)data, FANN_E_CANT_ALLOCATE_MEM);
528 fann_destroy_train(dest);
531 memcpy(data_output, data->output[0], dest->num_output * dest->num_data * sizeof(fann_type));
533 for(i = 0; i != dest->num_data; i++)
535 dest->input[i] = data_input;
536 data_input += dest->num_input;
537 dest->output[i] = data_output;
538 data_output += dest->num_output;
543 FANN_EXTERNAL struct fann_train_data *FANN_API fann_subset_train_data(struct fann_train_data
544 *data, unsigned int pos,
548 fann_type *data_input, *data_output;
549 struct fann_train_data *dest =
550 (struct fann_train_data *) malloc(sizeof(struct fann_train_data));
554 fann_error((struct fann_error*)data, FANN_E_CANT_ALLOCATE_MEM);
558 if(pos > data->num_data || pos+length > data->num_data)
560 fann_error((struct fann_error*)data, FANN_E_TRAIN_DATA_SUBSET, pos, length, data->num_data);
564 fann_init_error_data((struct fann_error *) dest);
565 dest->error_log = data->error_log;
567 dest->num_data = length;
568 dest->num_input = data->num_input;
569 dest->num_output = data->num_output;
570 dest->input = (fann_type **) calloc(dest->num_data, sizeof(fann_type *));
571 if(dest->input == NULL)
573 fann_error((struct fann_error*)data, FANN_E_CANT_ALLOCATE_MEM);
574 fann_destroy_train(dest);
578 dest->output = (fann_type **) calloc(dest->num_data, sizeof(fann_type *));
579 if(dest->output == NULL)
581 fann_error((struct fann_error*)data, FANN_E_CANT_ALLOCATE_MEM);
582 fann_destroy_train(dest);
586 data_input = (fann_type *) calloc(dest->num_input * dest->num_data, sizeof(fann_type));
587 if(data_input == NULL)
589 fann_error((struct fann_error*)data, FANN_E_CANT_ALLOCATE_MEM);
590 fann_destroy_train(dest);
593 memcpy(data_input, data->input[pos], dest->num_input * dest->num_data * sizeof(fann_type));
595 data_output = (fann_type *) calloc(dest->num_output * dest->num_data, sizeof(fann_type));
596 if(data_output == NULL)
598 fann_error((struct fann_error*)data, FANN_E_CANT_ALLOCATE_MEM);
599 fann_destroy_train(dest);
602 memcpy(data_output, data->output[pos], dest->num_output * dest->num_data * sizeof(fann_type));
604 for(i = 0; i != dest->num_data; i++)
606 dest->input[i] = data_input;
607 data_input += dest->num_input;
608 dest->output[i] = data_output;
609 data_output += dest->num_output;
614 FANN_EXTERNAL unsigned int FANN_API fann_length_train_data(struct fann_train_data *data)
616 return data->num_data;
619 FANN_EXTERNAL unsigned int FANN_API fann_num_input_train_data(struct fann_train_data *data)
621 return data->num_input;
624 FANN_EXTERNAL unsigned int FANN_API fann_num_output_train_data(struct fann_train_data *data)
626 return data->num_output;
630 Save the train data structure.
632 int fann_save_train_internal(struct fann_train_data *data, const char *filename,
633 unsigned int save_as_fixed, unsigned int decimal_point)
636 FILE *file = fopen(filename, "w");
640 fann_error((struct fann_error *) data, FANN_E_CANT_OPEN_TD_W, filename);
643 retval = fann_save_train_internal_fd(data, file, filename, save_as_fixed, decimal_point);
650 Save the train data structure.
652 int fann_save_train_internal_fd(struct fann_train_data *data, FILE * file, const char *filename,
653 unsigned int save_as_fixed, unsigned int decimal_point)
655 unsigned int num_data = data->num_data;
656 unsigned int num_input = data->num_input;
657 unsigned int num_output = data->num_output;
662 unsigned int multiplier = 1 << decimal_point;
665 fprintf(file, "%u %u %u\n", data->num_data, data->num_input, data->num_output);
667 for(i = 0; i < num_data; i++)
669 for(j = 0; j < num_input; j++)
674 fprintf(file, "%d ", (int) (data->input[i][j] * multiplier));
678 if(((int) floor(data->input[i][j] + 0.5) * 1000000) ==
679 ((int) floor(data->input[i][j] * 1000000.0 + 0.5)))
681 fprintf(file, "%d ", (int) data->input[i][j]);
685 fprintf(file, "%f ", data->input[i][j]);
689 fprintf(file, FANNPRINTF " ", data->input[i][j]);
694 for(j = 0; j < num_output; j++)
699 fprintf(file, "%d ", (int) (data->output[i][j] * multiplier));
703 if(((int) floor(data->output[i][j] + 0.5) * 1000000) ==
704 ((int) floor(data->output[i][j] * 1000000.0 + 0.5)))
706 fprintf(file, "%d ", (int) data->output[i][j]);
710 fprintf(file, "%f ", data->output[i][j]);
714 fprintf(file, FANNPRINTF " ", data->output[i][j]);
725 * INTERNAL FUNCTION Reads training data from a file descriptor.
727 struct fann_train_data *fann_read_train_from_fd(FILE * file, const char *filename)
729 unsigned int num_input, num_output, num_data, i, j;
730 unsigned int line = 1;
731 fann_type *data_input, *data_output;
732 struct fann_train_data *data =
733 (struct fann_train_data *) malloc(sizeof(struct fann_train_data));
737 fann_error(NULL, FANN_E_CANT_ALLOCATE_MEM);
741 if(fscanf(file, "%u %u %u\n", &num_data, &num_input, &num_output) != 3)
743 fann_error(NULL, FANN_E_CANT_READ_TD, filename, line);
744 fann_destroy_train(data);
749 fann_init_error_data((struct fann_error *) data);
751 data->num_data = num_data;
752 data->num_input = num_input;
753 data->num_output = num_output;
754 data->input = (fann_type **) calloc(num_data, sizeof(fann_type *));
755 if(data->input == NULL)
757 fann_error(NULL, FANN_E_CANT_ALLOCATE_MEM);
758 fann_destroy_train(data);
762 data->output = (fann_type **) calloc(num_data, sizeof(fann_type *));
763 if(data->output == NULL)
765 fann_error(NULL, FANN_E_CANT_ALLOCATE_MEM);
766 fann_destroy_train(data);
770 data_input = (fann_type *) calloc(num_input * num_data, sizeof(fann_type));
771 if(data_input == NULL)
773 fann_error(NULL, FANN_E_CANT_ALLOCATE_MEM);
774 fann_destroy_train(data);
778 data_output = (fann_type *) calloc(num_output * num_data, sizeof(fann_type));
779 if(data_output == NULL)
781 fann_error(NULL, FANN_E_CANT_ALLOCATE_MEM);
782 fann_destroy_train(data);
786 for(i = 0; i != num_data; i++)
788 data->input[i] = data_input;
789 data_input += num_input;
791 for(j = 0; j != num_input; j++)
793 if(fscanf(file, FANNSCANF " ", &data->input[i][j]) != 1)
795 fann_error(NULL, FANN_E_CANT_READ_TD, filename, line);
796 fann_destroy_train(data);
802 data->output[i] = data_output;
803 data_output += num_output;
805 for(j = 0; j != num_output; j++)
807 if(fscanf(file, FANNSCANF " ", &data->output[i][j]) != 1)
809 fann_error(NULL, FANN_E_CANT_READ_TD, filename, line);
810 fann_destroy_train(data);
820 * INTERNAL FUNCTION returns 0 if the desired error is reached and -1 if it is not reached
822 int fann_desired_error_reached(struct fann *ann, float desired_error)
824 switch (ann->train_stop_function)
826 case FANN_STOPFUNC_MSE:
827 if(fann_get_MSE(ann) <= desired_error)
830 case FANN_STOPFUNC_BIT:
831 if(ann->num_bit_fail <= (unsigned int)desired_error)