1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75
| import tensorflow as tf from tensorflow import keras from tensorflow.keras import datasets, optimizers, Sequential, layers
def data_pre_process(x, y): x = tf.cast(x, dtype=tf.float32) / 255. y = tf.cast(y, dtype=tf.int32)
return x, y
def dataLoader(): (x, y), (x_val, y_val) = datasets.fashion_mnist.load_data() db = tf.data.Dataset.from_tensor_slices((x, y)) db = db.map(data_pre_process).batch(128)
db_val = tf.data.Dataset.from_tensor_slices((x_val, y_val)) db_val = db_val.map(data_pre_process).batch(128)
return db, db_val
if __name__ == "__main__": db, db_val = dataLoader()
model = Sequential([ layers.Dense(256, activation=tf.nn.relu), layers.Dense(128, activation=tf.nn.relu), layers.Dense(64, activation=tf.nn.relu), layers.Dense(32, activation=tf.nn.relu), layers.Dense(10, activation=tf.nn.relu), ]) model.build(input_shape=[None, 28*28]) model.summary()
optimizers = optimizers.Adam(lr=1e-3) for epoch in range(30): for step, (x, y) in enumerate(db): x = tf.reshape(x, [-1, 28*28]) with tf.GradientTape() as tape: logits = model(x) y_onehot = tf.one_hot(y, depth=10) loss_mse = tf.reduce_mean(tf.losses.MSE(y_onehot, logits)) loss_ce = tf.reduce_mean(tf.losses.categorical_crossentropy(y_onehot, logits, from_logits=True))
grads = tape.gradient(loss_ce, model.trainable_variables) optimizers.apply_gradients(zip(grads, model.trainable_variables))
if step % 100 == 0: print(epoch, step, 'loss:', float(loss_ce), float(loss_mse)) total_correct = 0 total_num = 0 for x, y in db_val: x = tf.reshape(x, [-1, 28*28]) logits = model(x) prob = tf.nn.softmax(logits, axis=1)
pred = tf.argmax(prob, axis=1) pred = tf.cast(pred, dtype=tf.int32)
correct = tf.equal(pred, y) correct = tf.reduce_sum(tf.cast(correct, dtype=tf.int32)) total_correct += int(correct) total_num += x.shape[0]
acc = total_correct/total_num print(epoch, 'test acc:', float(acc))
|