简单的神经网络练习,效果不好
import numpy as np
import torch
import torch.nn as nn
def fizz_buzz_decode(i, label):
"""这个写法有点秀, 输入第几个int -> 找到游戏规则 3, 5, 15"""
return [str(i), 'fizz', 'buzz', 'fizzbuzz'][label]
def fizz_buzz_encode(i):
if i % 15 == 0:
return 3
elif i % 5 == 0:
return 2
elif i % 3 == 0:
return 1
else:
return 0
def helper(i):
return fizz_buzz_decode(i, fizz_buzz_encode(i))
NUM_DIGITS = 10 # 10个digit来表达这个二进制
def binary_encode(i, NUM_DIGITS):
"""把一个十进制的数转为需要位数的 二进制数, 再翻转下
# 一个整数a, a & 1 这个表达式可以用来判断a的奇偶性。二进制的末位为0表示偶数,最末位为1表示奇数。使用a%2来判断奇偶性和a & 1是一样的作用,但是a & 1要快好多。主要是用来判断末尾数是不是
"""
return np.array([i >> d & 1 for d in range(NUM_DIGITS)][::-1])
# 训练数据
X_train = torch.Tensor([binary_encode(i, NUM_DIGITS) for i in range(101, 2 ** NUM_DIGITS)])
y_train = torch.LongTensor([fizz_buzz_encode(i) for i in range(101, 2 ** NUM_DIGITS)])
# model
NUM_HIDDEN = 100
model = nn.Sequential(
nn.Linear(NUM_DIGITS, NUM_HIDDEN, bias=False),
nn.ReLU(),
nn.Linear(NUM_HIDDEN, 4, bias=False) # 输出的种类就4中 数字,fizz,buzz,fizzbuzz
)
# loss -> 两个分布的相似度 -> crossentropy
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
# train
BATCH_SIZE = 128
for epoch in range(2000):
for start in range(0, len(X_train), BATCH_SIZE):
end = start + BATCH_SIZE
batchX = X_train[start:end]
batchY = y_train[start:end]
y_pred = model(batchX)
loss = loss_fn(y_pred, batchY)
print('epoch=', epoch, " loss=", loss.item())
optimizer.zero_grad()
loss.backward()
optimizer.step()
X_test = torch.Tensor([binary_encode(i, NUM_DIGITS) for i in range(1, 100)])
with torch.no_grad():
y_test = model(X_test)
print(y_test)
predicts = zip(range(0, 101), list(y_test.max(1)[1].data.tolist()))
print('\n----------------------------')
print([fizz_buzz_decode(i, x) for i, x in predicts])
结果:
['0', '1', 'fizz', '3', 'buzz', 'fizz', '6', '7', 'fizz', 'buzz', '10', 'fizz', '12', '13', 'fizzbuzz', '15', '16', 'fizz', '18', 'buzz', 'fizz', '21', '22', 'fizz', 'buzz', '25', 'fizz', '27', '28', 'fizzbuzz', '30', '31', 'fizz', '33', 'buzz', 'fizz', '36', '37', 'fizz', 'buzz', '40', 'fizz', '42', '43', 'fizzbuzz', '45', '46', 'fizz', '48', 'buzz', 'fizz', '51', '52', 'fizz', 'buzz', '55', 'fizz', '57', '58', 'fizzbuzz', '60', '61', 'fizz', '63', 'buzz', 'fizz', '66', '67', 'fizz', 'buzz', '70', 'fizz', '72', '73', 'fizzbuzz', '75', '76', 'fizz', '78', 'buzz', 'fizz', '81', '82', 'fizz', 'buzz', '85', 'fizz', '87', '88', 'fizzbuzz', '90', '91', 'fizz', '93', 'buzz', 'fizz', '96', '97', 'fizz']