当前位置: 首页 > 知识库问答 >
问题:

当尝试在Android studio中实现tensorflow lite模型时,我得到一个不兼容的数据类型错误

仇经武
2023-03-14
java.lang.IllegalArgumentException: Cannot convert between a TensorFlowLite tensor with type INT64 and a Java object of type [[F (which is compatible with the TensorFlowLite type FLOAT32).
import io
df = pd.read_csv(io.BytesIO(data_to_load['species_by_location_v4.csv']))

# CREATE X ARRAY
# This is the array containing the explanatory variables (in this case pentad and month)

loc_array = df.iloc[:, 1:3]
print(loc_array)

# create y array (classes to be predicted)

label_array = df.iloc[:, 4]

# get number of distinct classes and convert y array to consecutive integers from 0 to 170 (y_true)

raw_y_true = label_array
mapping_to_numbers = {}
y_true = np.zeros((len(raw_y_true)))
for i, raw_label in enumerate(raw_y_true):
    if raw_label not in mapping_to_numbers:
        mapping_to_numbers[raw_label] = len(mapping_to_numbers)
    y_true[i] = mapping_to_numbers[raw_label]
print(y_true)
# [0. 1. 2. 3. 1. 2.]
print(mapping_to_numbers)

# get number of distinct classes

num_classes = len(mapping_to_numbers)
print(num_classes)


# create simple model

model = tf.keras.Sequential([
  tf.keras.layers.Dense(num_classes, activation='softmax')
])

# compile model

model.compile(
    optimizer = 'adam',
    loss='sparse_categorical_crossentropy', 
    metrics=['accuracy']
)

# train model

history = model.fit(loc_array, y_true, epochs=500, verbose=False)
print('finished')

# create labels file

labels = '\n'.join(mapping_to_numbers.keys())

with open('labels_locmth.txt', 'w') as f:
  f.write(labels)

!cat labels.txt

# convert to tflite

saved_model_dir = 'save/fine_tuning'
tf.saved_model.save(model, saved_model_dir)

converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
tflite_model = converter.convert()

with open('model_locmth.tflite', 'wb') as f:
  f.write(tflite_model)


temp = {'pentad_unique_key': [63669], 'mth': [2] }
test_x = pd.DataFrame(temp, columns = ['pentad_unique_key', 'mth'])
result = model.predict(test_x, batch_size=None, verbose=0, steps=None)
float[][] inputVal = new float[1][2];
inputVal[0][0] = 63669;
inputVal[0][1] = 2;

float[][] outputs = new float[1][171];

tflite.run(inputVal, outputs);

try{
   tflite = new Interpreter(loadModelFile());
   labelList = loadLabelList();
} catch (Exception ex) {
   ex.printStackTrace();
}

private MappedByteBuffer loadModelFile() throws IOException {

    // Open the model using an input stream and memory map it to load
    AssetFileDescriptor fileDescriptor = this.getAssets().openFd("model_locmth.tflite");
    FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
    FileChannel fileChannel = inputStream.getChannel();
    long startOffset = fileDescriptor.getStartOffset();
    long declaredLength = fileDescriptor.getDeclaredLength();
    return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);

}

共有1个答案

邢璞
2023-03-14

如前所述,我认为问题在于y_true的输入dtype。当您将y_true创建为

y_true = np.zeros((len(raw_y_true)))

它会将dtype显示为np.int64,因为它是默认的dtype。
tflite模型可以使用错误中提到的float32(不是int64或float64),所以您需要将y_true定义为

y_true = np.zeros((len(raw_y_true)),dtype=np.float32)

因此,我只在代码的以下部分更改了上面的行

label_array = df.iloc[:, 4]

raw_y_true = label_array
mapping_to_numbers = {}
y_true = np.zeros((len(raw_y_true)),dtype=np.float32)
for i, raw_label in enumerate(raw_y_true):
    if raw_label not in mapping_to_numbers:
        mapping_to_numbers[raw_label] = len(mapping_to_numbers)
    y_true[i] = mapping_to_numbers[raw_label]
print(y_true)
# [0. 1. 2. 3. 1. 2.]
print(mapping_to_numbers)

print(np.argmax(result)) # 122
 类似资料:
  • 我在尝试从主函数中的点获取最小值/最大值时收到错误/警告。如何计算最大/最小点?有没有更简单的方法?我应该使用结构吗?

  • 我有一个类似Sagemaker的培训脚本, 其中大部分都是我从MNIST的例子中偷来的。 当我训练的时候,一切都很顺利,但是当我试着像, 我得到(完整输出) 在我的S3桶中,我实际上没有看到模型。 我错过了什么?

  • 我正在为我的大学项目制作一个颤动应用程序,我正在添加一个登录和注册页面并通过Firebase进行身份验证,当我点击登录时,调试控制台显示“错误类型'AuthResult'不是类型转换中类型'FirebaseUser'的子类型”,当我在此错误后重新加载应用程序时,它成功登录。 在此次更新后,firebase_auth包更新到0.12.0之前一切都运行良好,方法“signInSusEmailAndPa

  • 我有一个json格式的地图列表,我正试图在列表上呈现“title”。 我通过一个api(http.get)读取数据,然后解析它。 我想在列表中显示标题。 这是我的代码... 获取数据 转换为json 欢迎参加模范班 我得到一个错误说"类型'列表'不是类型'地图的子类型

  • 问题内容: 我在写一些代码,遇到编译错误。这就是我所拥有的: 我以为使用继承时没有正确声明泛型,所以我检查了Oracle的教程,他们在其中编写 上面的声明使用相同的泛型类型,这是我在示例中要完成的工作。似乎假设from 与from 有所不同。否则,它应该能够 将 两者 合并并看到相同的类型。如何正确实现我要实现的目标? ( 即,某物的常量是某物的值,这是同一物的表达 ) 问题答案: 您的变量定义为

  • 我一直试图编译这个简单的警报对话框,以便在用户单击提交按钮时显示。编译代码时会弹出一条错误消息: 错误:(33,74)错误:不兼容的类型: 这个类叫做Login_Activity,它扩展了BaseActivity,它扩展了Activity。