X-Git-Url: https://ruin.nu/git/?a=blobdiff_plain;f=src%2Fmodelidentifier.cpp;h=a7c67b82c77c5ed3e01335158755221693b2f09d;hb=HEAD;hp=561eea3741462986cf08bfac8119a4c643f90bd7;hpb=868e1b08fbbd49e489dc1349cb3657521c5c1dd9;p=germs.git diff --git a/src/modelidentifier.cpp b/src/modelidentifier.cpp index 561eea3..a7c67b8 100644 --- a/src/modelidentifier.cpp +++ b/src/modelidentifier.cpp @@ -22,6 +22,8 @@ #include "genealgorithms.h" #include +#include "models.h" +#include "model.h" using namespace std; @@ -36,7 +38,7 @@ ModelIdentifier::~ModelIdentifier(){ fann_destroy(_ann); } -std::map ModelIdentifier::identify(const GeneOrder& go){ +priority_queue > ModelIdentifier::identify(const GeneOrder& go){ int pos = 0; int neg = 0; for (GeneOrder::iterator g = go.begin(); g != go.end(); ++g){ @@ -45,8 +47,44 @@ std::map ModelIdentifier::identify(const GeneOrde else ++neg; } - map scores; - scores[X] = 1; - scores[Whirl] = -1; + double length = go.size(); + vector input(8); + input[0] = pos/length; + input[1] = neg/length; + + pair seqs = longestSequences(go); + + input[2] = seqs.first/length; + input[3] = seqs.second/length; + + double cycles = countCycles(go); + input[4] = cycles/length; + + vector comps = findComponents(go); + + pos = 0; + neg = 0; + int un = 0; + for (vector::iterator c = comps.begin(); c != comps.end(); ++c){ + if (c->sign > 0) + ++pos; + else if (c->sign < 0) + ++neg; + else + ++un; + } + + input[5] = un/cycles; + input[6] = pos/cycles; + input[7] = neg/cycles; + + double *output = fann_run(_ann,&input[0]); + + priority_queue > scores; + scores.push(pair(output[0],Model(new Models::X))); + scores.push(pair(output[1],Model(new Models::Zipper))); + scores.push(pair(output[2],Model(new Models::Whirl))); + scores.push(pair(output[3],Model(new Models::FatX))); + scores.push(pair(output[4],Model(new Models::Cloud))); return scores; }