]> ruin.nu Git - germs.git/blob - src/modelidentifier.cpp
ModelIdentifier::identify implemented and passes test
[germs.git] / src / modelidentifier.cpp
1 /***************************************************************************
2  *   Copyright (C) 2006 by Michael Andreen                                 *
3  *   andreen@student.chalmers.se                                           *
4  *                                                                         *
5  *   This program is free software; you can redistribute it and/or modify  *
6  *   it under the terms of the GNU General Public License as published by  *
7  *   the Free Software Foundation; either version 2 of the License, or     *
8  *   (at your option) any later version.                                   *
9  *                                                                         *
10  *   This program is distributed in the hope that it will be useful,       *
11  *   but WITHOUT ANY WARRANTY; without even the implied warranty of        *
12  *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the         *
13  *   GNU General Public License for more details.                          *
14  *                                                                         *
15  *   You should have received a copy of the GNU General Public License     *
16  *   along with this program; if not, write to the                         *
17  *   Free Software Foundation, Inc.,                                       *
18  *   51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA          *
19  ***************************************************************************/
20
21 #include "modelidentifier.h"
22 #include "genealgorithms.h"
23
24 #include <doublefann.h>
25
26 using namespace std;
27
28 ModelIdentifier::ModelIdentifier(std::string ann){
29         _ann = fann_create_from_file(ann.c_str());
30         if(!_ann){
31                 throw invalid_argument("Could not create network");
32         }
33 }
34
35 ModelIdentifier::~ModelIdentifier(){
36         fann_destroy(_ann);
37 }
38
39 std::map<ModelIdentifier::Model,double> ModelIdentifier::identify(const GeneOrder& go){
40         int pos = 0;
41         int neg = 0;
42         for (GeneOrder::iterator g = go.begin(); g != go.end(); ++g){
43                 if (*g >= 0)
44                         ++pos;
45                 else
46                         ++neg;
47         }
48         double length = go.size();
49         vector<double> input(8);
50         input[0] = pos/length;
51         input[1] = neg/length;
52
53         pair<int,int> seqs = longestSequences(go);
54
55         input[2] = seqs.first/length;
56         input[3] = seqs.second/length;
57
58         double cycles = countCycles(go);
59         input[4] = cycles/length;
60
61         vector<Component> comps = findComponents(go);
62
63         pos = 0;
64         neg = 0;
65         int un = 0;
66         for (vector<Component>::iterator c = comps.begin(); c != comps.end(); ++c){
67                 if (c->sign > 0)
68                         ++pos;
69                 else if (c->sign < 0)
70                         ++neg;
71                 else
72                         ++un;
73         }
74
75         input[5] = un/cycles;
76         input[6] = pos/cycles;
77         input[7] = neg/cycles;
78
79         double *output = fann_run(_ann,&input[0]);
80
81         map<Model,double> scores;
82         scores[X] = output[0];
83         scores[Zipper] = output[1];
84         scores[Whirl] = output[2];
85         scores[FatX] = output[3];
86         scores[Cloud] = output[4];
87         return scores;
88 }