当前位置: 首页 > 工具软件 > Buzz > 使用案例 >

fizz_buzz

佴博实
2023-12-01

简单的神经网络练习,效果不好

import numpy as np
import torch
import torch.nn as nn


def fizz_buzz_decode(i, label):
    """这个写法有点秀, 输入第几个int -> 找到游戏规则 3, 5, 15"""
    return [str(i), 'fizz', 'buzz', 'fizzbuzz'][label]


def fizz_buzz_encode(i):
    if i % 15 == 0:
        return 3
    elif i % 5 == 0:
        return 2
    elif i % 3 == 0:
        return 1
    else:
        return 0

def helper(i):
    return fizz_buzz_decode(i, fizz_buzz_encode(i))

NUM_DIGITS = 10  # 10个digit来表达这个二进制

def binary_encode(i, NUM_DIGITS):
    """把一个十进制的数转为需要位数的 二进制数, 再翻转下
    # 一个整数a, a & 1 这个表达式可以用来判断a的奇偶性。二进制的末位为0表示偶数,最末位为1表示奇数。使用a%2来判断奇偶性和a & 1是一样的作用,但是a & 1要快好多。主要是用来判断末尾数是不是
    """
    return np.array([i >> d & 1 for d in range(NUM_DIGITS)][::-1])


# 训练数据
X_train = torch.Tensor([binary_encode(i, NUM_DIGITS) for i in range(101, 2 ** NUM_DIGITS)])
y_train = torch.LongTensor([fizz_buzz_encode(i) for i in range(101, 2 ** NUM_DIGITS)])

# model
NUM_HIDDEN = 100
model = nn.Sequential(
    nn.Linear(NUM_DIGITS, NUM_HIDDEN, bias=False),
    nn.ReLU(),
    nn.Linear(NUM_HIDDEN, 4, bias=False)  # 输出的种类就4中 数字,fizz,buzz,fizzbuzz
)

# loss -> 两个分布的相似度 -> crossentropy
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)


# train
BATCH_SIZE = 128
for epoch in range(2000):
    for start in range(0, len(X_train), BATCH_SIZE):
        end = start + BATCH_SIZE
        batchX = X_train[start:end]
        batchY = y_train[start:end]

        y_pred = model(batchX)
        loss = loss_fn(y_pred, batchY)
        print('epoch=', epoch, " loss=", loss.item())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()


X_test = torch.Tensor([binary_encode(i, NUM_DIGITS) for i in range(1, 100)])

with torch.no_grad():
    y_test = model(X_test)
    print(y_test)
    predicts = zip(range(0, 101), list(y_test.max(1)[1].data.tolist()))
    print('\n----------------------------')
    print([fizz_buzz_decode(i, x) for i, x in predicts])



结果:

['0', '1', 'fizz', '3', 'buzz', 'fizz', '6', '7', 'fizz', 'buzz', '10', 'fizz', '12', '13', 'fizzbuzz', '15', '16', 'fizz', '18', 'buzz', 'fizz', '21', '22', 'fizz', 'buzz', '25', 'fizz', '27', '28', 'fizzbuzz', '30', '31', 'fizz', '33', 'buzz', 'fizz', '36', '37', 'fizz', 'buzz', '40', 'fizz', '42', '43', 'fizzbuzz', '45', '46', 'fizz', '48', 'buzz', 'fizz', '51', '52', 'fizz', 'buzz', '55', 'fizz', '57', '58', 'fizzbuzz', '60', '61', 'fizz', '63', 'buzz', 'fizz', '66', '67', 'fizz', 'buzz', '70', 'fizz', '72', '73', 'fizzbuzz', '75', '76', 'fizz', '78', 'buzz', 'fizz', '81', '82', 'fizz', 'buzz', '85', 'fizz', '87', '88', 'fizzbuzz', '90', '91', 'fizz', '93', 'buzz', 'fizz', '96', '97', 'fizz']

 

 类似资料:

相关阅读

相关文章

相关问答