ImageNet初体验
背景
ImageNet 数据集大约有1500万张图片,2.2万类,可以说你能想到,想象不到的图片都能在里面找到。
本文使用 ImageNet来识别图片中的内容。
初体验
# ImageNet 数据集大约有1500万张图片,2.2万类,可以说你能想到,想象不到的图片都能在里面找到。
import numpy as np
from PIL import Image
import matplotlib.pylab as plt
import tensorflow as tf
import tensorflow_hub as hub
from tensorflow.keras import layers, datasets
# url = "https://tfhub.dev/google/tf2-preview/mobilenet_v2/classification/4"
# 这里手动下载解压至 datasets/imagenet 目录中,离线加载。
url = "datasets/imagenet"
model = tf.keras.Sequential([
hub.KerasLayer(url, input_shape=(224, 224, 3))
])
# 使用 ImagetNet 来测试任意图片,这里测试一只猫
cat = tf.keras.utils.get_file(
'cat.png', 'http://www.ykjsxy.net/uploads/allimg/200318/1-20031Q011405K.jpg')
cat = Image.open(cat).resize((224, 224))
# cat
result = model.predict(np.array(cat).reshape(1, 224, 224, 3)/255.0)
ans = np.argmax(result[0], axis=-1)
print('result.shape:', result.shape, 'ans:', ans)
# result.shape: (1, 1001) ans: 284
# labels_url = 'https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt'
# labels_path = tf.keras.utils.get_file('ImageNetLabels.txt', labels_url)
# imagenet_labels = np.array(open(labels_path).read().splitlines())
# print(imagenet_labels[ans])
# 模型的输出有1001个分类,测试的结果是 284。
# 将下载 ImageNetLabels.txt ,就可以知道 284 代表的分类的名称,可以看到结果是 tiger cat,即老虎猫。