This is a CNTK fully connected DNN classifier for the Iris data set (download). The data set is converted from the original format of:
5.1,3.5,1.4,0.2,Iris-setosa
into the CNTK text format:
|features 5.1 3.5 1.4 0.2 |label 0:1
where a line contains two columns ("features" and "label"):
features
is a dense vector (1x4)label
is a one hot encoded (1x3) vector in sparse form (index:value)It can be read by using CTFDeserializer
in the code below.
from cntk import *
from cntk.models import *
from cntk.layers import *
import numpy as np
nFeatures = 4
nClasses = 3
miniBatchSize = 150
miniBatchesToTrain = 100 # repeatedly train on the whole dataset, will overfit.
training_progress_output_freq = 10
# helper function to print training status along the way
def print_training_progress(trainer, mb, frequency):
if mb % frequency == 0:
training_loss = get_train_loss(trainer)
eval_crit = get_train_eval_criterion(trainer)
print("Minibatch: {}, Train Loss: {}, Train Evaluation Criterion: {}".format(
mb, training_loss, eval_crit))
Model definition is declarative. You define the input (feature vector, label), the model network. An SGD learner, and finally the trainer to put everything together.
fv = input_variable(nFeatures)
label = input_variable(nClasses) # on hot
model = Sequential([
Dense(50, activation=relu),
Dense(50, activation=relu),
Dense(50, activation=relu),
Dense(nClasses)
])(fv)
loss = cross_entropy_with_softmax(model, label)
err = classification_error(model, label)
learner = sgd(model.parameters, lr=0.0005)
trainer = Trainer(model, loss, err, [learner])
Data can either be read in using vanilla python code and be fed to trainer with a map of {streamInfo: data}. Or by converting the data into CNTK text format, and read using CTFDeserializer
.
# epoch_size can be:
# - INFINITELY_REPEAT: the source dataset will be infinately looped
# - FULL_DATA_SWEEP: the input will be fully sweeped once
reader = MinibatchSource(
CTFDeserializer(
"iris.data",
StreamDefs(
features = StreamDef(field='features', shape=nFeatures, is_sparse=False),
label = StreamDef(field='label', shape=nClasses, is_sparse=True))),
randomize=False,
epoch_size=INFINITELY_REPEAT)
Training is simply getting a minibatch (of form: {input_variable: data,...}
), and feed it into the trainer.
for mbCount in range(miniBatchesToTrain):
mb = reader.next_minibatch(miniBatchSize, input_map={fv: reader.streams.features, label: reader.streams.label})
# when epoch_size is FULL_DATA_SWEEP, this will be empty when the data is exhausted
if not mb: break
trainer.train_minibatch(mb)
print_training_progress(trainer, mbCount, training_progress_output_freq)
mbCount += 1
# predict:
classifier = softmax(model)
np.argmax(classifier.eval([7.2, 3.6, 6.1, 2.5])) # it should give you "2"
|features 5.1 3.5 1.4 0.2 |label 0:1
|features 4.9 3.0 1.4 0.2 |label 0:1
|features 4.7 3.2 1.3 0.2 |label 0:1
|features 4.6 3.1 1.5 0.2 |label 0:1
|features 5.0 3.6 1.4 0.2 |label 0:1
|features 5.4 3.9 1.7 0.4 |label 0:1
|features 4.6 3.4 1.4 0.3 |label 0:1
|features 5.0 3.4 1.5 0.2 |label 0:1
|features 4.4 2.9 1.4 0.2 |label 0:1
|features 4.9 3.1 1.5 0.1 |label 0:1
|features 5.4 3.7 1.5 0.2 |label 0:1
|features 4.8 3.4 1.6 0.2 |label 0:1
|features 4.8 3.0 1.4 0.1 |label 0:1
|features 4.3 3.0 1.1 0.1 |label 0:1
|features 5.8 4.0 1.2 0.2 |label 0:1
|features 5.7 4.4 1.5 0.4 |label 0:1
|features 5.4 3.9 1.3 0.4 |label 0:1
|features 5.1 3.5 1.4 0.3 |label 0:1
|features 5.7 3.8 1.7 0.3 |label 0:1
|features 5.1 3.8 1.5 0.3 |label 0:1
|features 5.4 3.4 1.7 0.2 |label 0:1
|features 5.1 3.7 1.5 0.4 |label 0:1
|features 4.6 3.6 1.0 0.2 |label 0:1
|features 5.1 3.3 1.7 0.5 |label 0:1
|features 4.8 3.4 1.9 0.2 |label 0:1
|features 5.0 3.0 1.6 0.2 |label 0:1
|features 5.0 3.4 1.6 0.4 |label 0:1
|features 5.2 3.5 1.5 0.2 |label 0:1
|features 5.2 3.4 1.4 0.2 |label 0:1
|features 4.7 3.2 1.6 0.2 |label 0:1
|features 4.8 3.1 1.6 0.2 |label 0:1
|features 5.4 3.4 1.5 0.4 |label 0:1
|features 5.2 4.1 1.5 0.1 |label 0:1
|features 5.5 4.2 1.4 0.2 |label 0:1
|features 4.9 3.1 1.5 0.1 |label 0:1
|features 5.0 3.2 1.2 0.2 |label 0:1
|features 5.5 3.5 1.3 0.2 |label 0:1
|features 4.9 3.1 1.5 0.1 |label 0:1
|features 4.4 3.0 1.3 0.2 |label 0:1
|features 5.1 3.4 1.5 0.2 |label 0:1
|features 5.0 3.5 1.3 0.3 |label 0:1
|features 4.5 2.3 1.3 0.3 |label 0:1
|features 4.4 3.2 1.3 0.2 |label 0:1
|features 5.0 3.5 1.6 0.6 |label 0:1
|features 5.1 3.8 1.9 0.4 |label 0:1
|features 4.8 3.0 1.4 0.3 |label 0:1
|features 5.1 3.8 1.6 0.2 |label 0:1
|features 4.6 3.2 1.4 0.2 |label 0:1
|features 5.3 3.7 1.5 0.2 |label 0:1
|features 5.0 3.3 1.4 0.2 |label 0:1
|features 7.0 3.2 4.7 1.4 |label 1:1
|features 6.4 3.2 4.5 1.5 |label 1:1
|features 6.9 3.1 4.9 1.5 |label 1:1
|features 5.5 2.3 4.0 1.3 |label 1:1
|features 6.5 2.8 4.6 1.5 |label 1:1
|features 5.7 2.8 4.5 1.3 |label 1:1
|features 6.3 3.3 4.7 1.6 |label 1:1
|features 4.9 2.4 3.3 1.0 |label 1:1
|features 6.6 2.9 4.6 1.3 |label 1:1
|features 5.2 2.7 3.9 1.4 |label 1:1
|features 5.0 2.0 3.5 1.0 |label 1:1
|features 5.9 3.0 4.2 1.5 |label 1:1
|features 6.0 2.2 4.0 1.0 |label 1:1
|features 6.1 2.9 4.7 1.4 |label 1:1
|features 5.6 2.9 3.6 1.3 |label 1:1
|features 6.7 3.1 4.4 1.4 |label 1:1
|features 5.6 3.0 4.5 1.5 |label 1:1
|features 5.8 2.7 4.1 1.0 |label 1:1
|features 6.2 2.2 4.5 1.5 |label 1:1
|features 5.6 2.5 3.9 1.1 |label 1:1
|features 5.9 3.2 4.8 1.8 |label 1:1
|features 6.1 2.8 4.0 1.3 |label 1:1
|features 6.3 2.5 4.9 1.5 |label 1:1
|features 6.1 2.8 4.7 1.2 |label 1:1
|features 6.4 2.9 4.3 1.3 |label 1:1
|features 6.6 3.0 4.4 1.4 |label 1:1
|features 6.8 2.8 4.8 1.4 |label 1:1
|features 6.7 3.0 5.0 1.7 |label 1:1
|features 6.0 2.9 4.5 1.5 |label 1:1
|features 5.7 2.6 3.5 1.0 |label 1:1
|features 5.5 2.4 3.8 1.1 |label 1:1
|features 5.5 2.4 3.7 1.0 |label 1:1
|features 5.8 2.7 3.9 1.2 |label 1:1
|features 6.0 2.7 5.1 1.6 |label 1:1
|features 5.4 3.0 4.5 1.5 |label 1:1
|features 6.0 3.4 4.5 1.6 |label 1:1
|features 6.7 3.1 4.7 1.5 |label 1:1
|features 6.3 2.3 4.4 1.3 |label 1:1
|features 5.6 3.0 4.1 1.3 |label 1:1
|features 5.5 2.5 4.0 1.3 |label 1:1
|features 5.5 2.6 4.4 1.2 |label 1:1
|features 6.1 3.0 4.6 1.4 |label 1:1
|features 5.8 2.6 4.0 1.2 |label 1:1
|features 5.0 2.3 3.3 1.0 |label 1:1
|features 5.6 2.7 4.2 1.3 |label 1:1
|features 5.7 3.0 4.2 1.2 |label 1:1
|features 5.7 2.9 4.2 1.3 |label 1:1
|features 6.2 2.9 4.3 1.3 |label 1:1
|features 5.1 2.5 3.0 1.1 |label 1:1
|features 5.7 2.8 4.1 1.3 |label 1:1
|features 6.3 3.3 6.0 2.5 |label 2:1
|features 5.8 2.7 5.1 1.9 |label 2:1
|features 7.1 3.0 5.9 2.1 |label 2:1
|features 6.3 2.9 5.6 1.8 |label 2:1
|features 6.5 3.0 5.8 2.2 |label 2:1
|features 7.6 3.0 6.6 2.1 |label 2:1
|features 4.9 2.5 4.5 1.7 |label 2:1
|features 7.3 2.9 6.3 1.8 |label 2:1
|features 6.7 2.5 5.8 1.8 |label 2:1
|features 7.2 3.6 6.1 2.5 |label 2:1
|features 6.5 3.2 5.1 2.0 |label 2:1
|features 6.4 2.7 5.3 1.9 |label 2:1
|features 6.8 3.0 5.5 2.1 |label 2:1
|features 5.7 2.5 5.0 2.0 |label 2:1
|features 5.8 2.8 5.1 2.4 |label 2:1
|features 6.4 3.2 5.3 2.3 |label 2:1
|features 6.5 3.0 5.5 1.8 |label 2:1
|features 7.7 3.8 6.7 2.2 |label 2:1
|features 7.7 2.6 6.9 2.3 |label 2:1
|features 6.0 2.2 5.0 1.5 |label 2:1
|features 6.9 3.2 5.7 2.3 |label 2:1
|features 5.6 2.8 4.9 2.0 |label 2:1
|features 7.7 2.8 6.7 2.0 |label 2:1
|features 6.3 2.7 4.9 1.8 |label 2:1
|features 6.7 3.3 5.7 2.1 |label 2:1
|features 7.2 3.2 6.0 1.8 |label 2:1
|features 6.2 2.8 4.8 1.8 |label 2:1
|features 6.1 3.0 4.9 1.8 |label 2:1
|features 6.4 2.8 5.6 2.1 |label 2:1
|features 7.2 3.0 5.8 1.6 |label 2:1
|features 7.4 2.8 6.1 1.9 |label 2:1
|features 7.9 3.8 6.4 2.0 |label 2:1
|features 6.4 2.8 5.6 2.2 |label 2:1
|features 6.3 2.8 5.1 1.5 |label 2:1
|features 6.1 2.6 5.6 1.4 |label 2:1
|features 7.7 3.0 6.1 2.3 |label 2:1
|features 6.3 3.4 5.6 2.4 |label 2:1
|features 6.4 3.1 5.5 1.8 |label 2:1
|features 6.0 3.0 4.8 1.8 |label 2:1
|features 6.9 3.1 5.4 2.1 |label 2:1
|features 6.7 3.1 5.6 2.4 |label 2:1
|features 6.9 3.1 5.1 2.3 |label 2:1
|features 5.8 2.7 5.1 1.9 |label 2:1
|features 6.8 3.2 5.9 2.3 |label 2:1
|features 6.7 3.3 5.7 2.5 |label 2:1
|features 6.7 3.0 5.2 2.3 |label 2:1
|features 6.3 2.5 5.0 1.9 |label 2:1
|features 6.5 3.0 5.2 2.0 |label 2:1
|features 6.2 3.4 5.4 2.3 |label 2:1
|features 5.9 3.0 5.1 1.8 |label 2:1