数据集获取方式和数据加强方式
取出數(shù)據(jù)集方法一,模型直取
mnist.load.data()
用mnist.load.data()讀取numpy數(shù)據(jù)直接送入model
喂入numpy
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
定義輸入維數(shù)
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
])
numpy降維輸入,張量轉(zhuǎn)化為numpy
predictions = model(x_train[:1])
predictions = model(x_train[:1]).numpy()
predictions
cifar10.load_data()
第一個例子是 mnist.load_data()讀出的numpy, 容易限定為一張圖片,next(iter(train_ds))[0]只能是一批batch圖片,張量可用序號
用train_ds.take(1)方式取出的是批量圖片張量,張量可用序號
輸入可以是張量或numpy
train_ds = tf.data.Dataset.from_tensor_slices(
(x_train, y_train)).shuffle(10000).batch(32)
從中取出圖片張量
for images, labels in train_ds:
train_step(images, labels)
GradientTape方法需要model輸入圖片張量
with tf.GradientTape() as tape:
# training=True is only needed if there are layers with different
# behavior during training versus inference (e.g. Dropout).
predictions = model(images, training=True)
loss = loss_object(labels, predictions)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
取出數(shù)據(jù)集方法二 image_data_from_directory
import pathlib
dataset_url = “https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz”
data_dir = tf.keras.utils.get_file(‘flower_photos’, origin=dataset_url, untar=True)
data_dir = pathlib.Path(data_dir)
train_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir, validation_split=0.2, subset=“training”, seed=123, image_size=(img_height, img_width), batch_size=batch_size)
class_names = train_ds.class_names
print(class_names)
取出圖片方法:
plt.figure(figsize=(10, 10))
for images, labels in train_ds.take(1):
for i in range(9):
ax = plt.subplot(3, 3, i + 1)
plt.imshow(images[i].numpy().astype(“uint8”))
plt.title(class_names[labels[i]])
plt.axis(“off”)
train_ds放緩存
AUTOTUNE = tf.data.experimental.AUTOTUNE
train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
model.fit直接用train_ds
history = model.fit( train_ds, validation_data=val_ds, epochs=epochs)
train validation已經(jīng)在不同目錄
_URL = ‘https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip’
path_to_zip = tf.keras.utils.get_file(‘cats_and_dogs.zip’, origin=_URL, extract=True)
PATH = os.path.join(os.path.dirname(path_to_zip), ‘cats_and_dogs_filtered’)
train_dir = os.path.join(PATH, ‘train’)
validation_dir = os.path.join(PATH, ‘validation’)
BATCH_SIZE = 32
IMG_SIZE = (160, 160)
train_dataset = image_dataset_from_directory(train_dir, shuffle=True, batch_size=BATCH_SIZE, image_size=IMG_SIZE)
分出test_dataset數(shù)據(jù)集
val_batches = tf.data.experimental.cardinality(validation_dataset)
test_dataset = validation_dataset.take(val_batches // 5)
validation_dataset = validation_dataset.skip(val_batches // 5)
取出數(shù)據(jù)集方法三使用昵稱直接tfds下載
此方法可下載大部分數(shù)據(jù)集
(train_ds, val_ds, test_ds), metadata = tfds.load(
‘tf_flowers’, split=[‘train[:80%]’, ‘train[80%:90%]’, ‘train[90%:]’], with_info=True, as_supervised=True,)
特有的獲得圖片名稱的方法
get_label_name = metadata.features[‘label’].int2str
image, label = next(iter(train_ds))
print(image.shape)
_ = plt.imshow(image)
_ = plt.title(get_label_name(label))
num_classes = metadata.features[‘label’].num_classes
數(shù)據(jù)加強層及擴維方式一加層
data_augmentation = tf.keras.Sequential([tf.keras.layers.experimental.preprocessing.RandomFlip(‘horizontal’),
tf.keras.layers.experimental.preprocessing.RandomRotation(0.2),])
for image, _ in train_dataset.take(1):
plt.figure(figsize=(10, 10))
first_image = image[0]
for i in range(9):
ax = plt.subplot(3, 3, i + 1)
augmented_image = data_augmentation(tf.expand_dims(first_image, 0))
plt.imshow(augmented_image[0] / 255)
plt.axis(‘off’)
數(shù)據(jù)加強層方法二數(shù)據(jù)集函數(shù)
batch_size = 32 AUTOTUNE = tf.data.experimental.AUTOTUNEdata_augmentation = tf.keras.Sequential([layers.experimental.preprocessing.RandomFlip("horizontal_and_vertical"),layers.experimental.preprocessing.RandomRotation(0.2), ])def prepare(ds, shuffle=False, augment=False):# Resize and rescale all datasetsds = ds.map(lambda x, y: (resize_and_rescale(x), y), num_parallel_calls=AUTOTUNE)if shuffle:ds = ds.shuffle(1000)# Batch all datasetsds = ds.batch(batch_size)# Use data augmentation only on the training setif augment:ds = ds.map(lambda x, y: (data_augmentation(x, training=True), y), num_parallel_calls=AUTOTUNE)# Use buffered prefecting on all datasetsreturn ds.prefetch(buffer_size=AUTOTUNE)train_ds = prepare(train_ds, shuffle=True, augment=True) val_ds = prepare(val_ds) test_ds = prepare(test_ds)數(shù)據(jù)加強方法三數(shù)據(jù)集方法tf.image
數(shù)據(jù)加強方法四加子類
class RandomInvert(layers.Layer):
def init(self, factor=0.5, **kwargs):
super().init(**kwargs)
self.factor = factor
def call(self, x):
return random_invert_img(x)
_ = plt.imshow(RandomInvert()(image)[0])
生成子類后,用以上所有方式調(diào)用類方法或直接map
aug_ds = train_ds.map(lambda x, y: (resize_and_rescale(x, training=True), y))
總結(jié)
以上是生活随笔為你收集整理的数据集获取方式和数据加强方式的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: python画菱形的代码_Python打
- 下一篇: es安装ik分词器