用户注册



邮箱:

密码:

用户登录


邮箱:

密码:
记住登录一个月忘记密码?

发表随想


还能输入:200字
云代码 - python代码库

识别手写数字

2020-04-23 作者: 云代码会员举报

[python]代码库

import numpy
import scipy.special
import matplotlib.pyplot   
%matplotlib inline
class neuralNetwork:
    def __init__(self,inputnodes,hiddennodes,outputnodes,learningrate):    
        self.inodes=inputnodes      
        self.hnodes=hiddennodes
        self.onodes=outputnodes
        self.wih=numpy.random.normal(0.0,pow(self.hnodes,-0.5),(self.hnodes,self.inodes))
        self.who=numpy.random.normal(0.0,pow(self.onodes,-0.5), (self.onodes,self.hnodes))
        self.lr=learningrate
        self.activation_function=lambda x:scipy.special.expit(x)
        pass
    def train(self,inputs_list,targets_list):
        inputs=numpy.array(inputs_list,ndmin=2).T            
        targets=numpy.array(targets_list,ndmin=2).T         
        hidden_inputs=numpy.dot(self.wih,inputs)
        hidden_outputs=self.activation_function(hidden_inputs)
        final_inputs=numpy.dot(self.who,hidden_outputs)
        final_outputs=self.activation_function(final_inputs)
        output_errors=targets-final_outputs                #输出误差=目标-最终输出
        hidden_errors=numpy.dot(self.who.T,output_errors)
        self.who+=self.lr*numpy.dot((output_errors*final_outputs*(1.0-final_outputs)), numpy.transpose(hidden_outputs))
        self.wih+=self.lr*numpy.dot((hidden_errors*hidden_outputs*(1.0-hidden_outputs)), numpy.transpose(inputs))
        pass
    def query(self,inputs_list):
        inputs=numpy.array(inputs_list,ndmin=2).T
        hidden_inputs=numpy.dot(self.wih,inputs)
        hidden_outputs=self.activation_function(hidden_inputs)
        final_inputs=numpy.dot(self.who,hidden_outputs)
        final_outputs=self.activation_function(final_inputs)
        return final_outputs
input_nodes=784        
hidden_nodes=200          
output_nodes=10
learning_rate=0.1
n=neuralNetwork(input_nodes,hidden_nodes,output_nodes,learning_rate)
training_data_file=open("mnist_dataset/mnist_train_100.csv",'r')
training_data_list=training_data_file.readlines()
training_data_file.close()
epochs=5                  
for e in range(epochs):                      
    for record in training_data_list:
        all_values=record.split(',')
        inputs=(numpy.asfarray(all_values[1:])/255.0*0.99)+0.01
        targets=numpy.zeros(output_nodes)+0.01
        targets[int(all_values[0])]=0.99
        n.train(inputs,targets)
        pass
pass
test_data_file=open("mnist_dataset/mnist_test_10.csv",'r')
test_data_list=test_data_file.readlines()
test_data_file.close()
scorecard=[]
for record in test_data_list:
    all_values=record.split(',')
    correct_label=int(all_values[0])
    print(correct_label,"correct label")
    inputs=(numpy.asfarray(all_values[1:])/255.0*0.99)+0.01  
    outputs=n.query(inputs)
    label=numpy.argmax(outputs)     
    print(label,"network's answer")
    if(label==correct_label):
        scorecard.append(1)
    else:
        scorecard.append(0)
        pass
pass
scorecard_array=numpy.asarray(scorecard)
print("performance=",scorecard_array.sum()/scorecard_array.size)



网友评论    (发表评论)

共1 条评论 1/1页

发表评论:

评论须知:

  • 1、评论每次加2分,每天上限为30;
  • 2、请文明用语,共同创建干净的技术交流环境;
  • 3、若被发现提交非法信息,评论将会被删除,并且给予扣分处理,严重者给予封号处理;
  • 4、请勿发布广告信息或其他无关评论,否则将会删除评论并扣分,严重者给予封号处理。


扫码下载

加载中,请稍后...

输入口令后可复制整站源码

加载中,请稍后...