现在的位置: 首页 > 综合 > 正文

手写识别2

2018年02月20日 ⁄ 综合 ⁄ 共 2937字 ⁄ 字号 评论关闭

经过昨天晚上的努力,自己把最简单的单层神经网实现了,可是误差不收敛,会剧烈震荡,没搞定。今天早上修改了学习速率,结果就能收敛了,而且能识别部分手写字母。还有就是学习的部分收敛条件或者说是终止条件很纠结。

介于简单网络的识别效果,我觉得下一步应该改进训练样本的数量,现在每一类只有一个训练样本,出现分类错误的可能还是很高的。

下面是学习和识别部分的代码:

计算误差的函数

        public double CalcError(double[] w, int[] v, int index )
        {
            double tmp = 0.0f;
            for (int i = 0; i < NumInLayer + 1; i++)
            {
                tmp += w[i] * v[i];
            }
            tmp = index - tmp;
            return tmp;
        }

学习部分的函数

        public void Learn(String name)
        {
            String mgrec = System.AppDomain.CurrentDomain.SetupInformation.ApplicationBase + "mgrec.txt";
            StreamWriter sw = new StreamWriter(fileName, true);
            StreamWriter sw_mg = new StreamWriter(mgrec, true);
            String context = GetBMPContext(bmpDraw);
            sw.WriteLine(name + " " + context);
            sw_mg.WriteLine(name + " " + GetCurrentMouseGesture());
            sw.Close();
            sw_mg.Close();
            //training the NN
            int targetIndex = -1;
            switch (name)
            {
                case "a":
                    targetIndex = 1;
                    break;
                case "b":
                    targetIndex = 2;
                    break;
                case "c":
                    targetIndex = 3;
                    break;
                case "d":
                    targetIndex = 4;
                    break;
                case "e":
                    targetIndex = 5;
                    break;
                case "f":
                    targetIndex = 6;
                    break;
                case "g":
                    targetIndex = 7;
                    break;
                case "h":
                    targetIndex = 8;
                    break;
                case "i":
                    targetIndex = 9;
                    break;
                case "j":
                    targetIndex = 10;
                    break;
                case "k":
                    targetIndex = 11;
                    break;
                case "l":
                    targetIndex = 12;
                    break;
                case "m":
                    targetIndex = 13;
                    break;
                case "n":
                    targetIndex = 14;
                    break;
                case "o":
                    targetIndex = 15;
                    break;
                case "p":
                    targetIndex = 16;
                    break;
                case "q":
                    targetIndex = 17;
                    break;
                case "r":
                    targetIndex = 18;
                    break;
                case "s":
                    targetIndex = 19;
                    break;
                case "t":
                    targetIndex = 20;
                    break;
                case "u":
                    targetIndex = 21;
                    break;
                case "v":
                    targetIndex = 22;
                    break;
                case "w":
                    targetIndex = 23;
                    break;
                case "x":
                    targetIndex = 24;
                    break;
                case "y":
                    targetIndex = 25;
                    break;
                case "z":
                    targetIndex = 26;
                    break;
                default:
                    break;
            }

            int[] v =new int[NumInLayer + 1];
            double err = 10.0f;
            double niu = 0.005;
            double Epsilon = 1e-2;
            double[] w = new double[NumInLayer + 1];
            double[] dw = new double[NumInLayer + 1];
            int iSeed = 9;
            Random rm = new Random(iSeed);

            v[0] = 1;
            w[0] = rm.Next(-100, 100) / 200;
            for (int i = 1; i < NumInLayer + 1; i++ )
            {
                int temp = rm.Next(-100, 100);
                w[i] = temp / 200.0f;
                v[i] =context[i - 1] - '0';
            }

            while (Math.Abs(err) >= Epsilon)
            {
                for (int i = 0; i < NumInLayer + 1; i++)
                    dw[i] = 0.0f;
                err = CalcError(w, v, targetIndex);
                for (int i = 0; i < NumInLayer + 1; i++)
                    dw[i] += niu * err * v[i];
                for (int i = 0; i < NumInLayer + 1; i++)
                    w[i] += dw[i];
            }
            String ww = System.AppDomain.CurrentDomain.SetupInformation.ApplicationBase + "weight.txt";
            StreamWriter sww = new StreamWriter(ww, true);
            sww.WriteLine(name);
            for (int i = 0; i < NumInLayer + 1; i++)
            {
                sww.WriteLine(w[i]);
            }
            sww.Close();

识别部分

        public String Recognise()
        {
            String ww = System.AppDomain.CurrentDomain.SetupInformation.ApplicationBase + "weight.txt";
            String context = GetBMPContext(bmpDraw);
            String res = "";
            StreamReader rww = new StreamReader(ww);
            double[] w = new double[NumInLayer + 1];
            double max = 0.0f;
            int[] v = new int[NumInLayer + 1];
            while (!rww.EndOfStream)
            {
                String tmp = rww.ReadLine();
                double tmpres = 0.0f;
                if (tmp[0] >= 'a' && tmp[0] <= 'z')
                {
                    tmpres += w[0];
                    for (int i = 1; i < NumInLayer + 1; i++)
                    {
                        w[i] = Convert.ToDouble(rww.ReadLine());
                        v[i] = context[i - 1] - '0';
                        tmpres += w[i] * v[i]; 
                    }
                    if (tmpres >= max)
                    {
                        max = tmpres;
                        res = tmp;
                    }
                }
            }
            Console.WriteLine(max);
            if (res == "")
                res = "denied";
            return res;
        } 

   

下一步试试多层网络,多搞些训练样本,还有就是新的识别算法。目前的算法有一个突出的问题,当新实例与样本之间存在一个缩放与偏移,要想办法消除,或者采用别的算法。

工程已上传,有兴趣的可以自己试试。

http://download.csdn.net/detail/clhmw/3847713

抱歉!评论已关闭.