如何解决这个问题呢,tensorflow下无法运行

图片描述

import tensorflow as tf
import numpy as np
import os
import matplotlib.pyplot as plt

os.environ["CUDA_VISIBLE_DEVICES"] = "1"

# In[3]:

train_dir = "G:/苹果测试数据/"

def get_files(file_dir):
    good = []
    label_good = []
    bad = []
    label_bad = []
    medium = []
    label_medium = []
    for file in os.listdir(file_dir):
        name = file.split(sep = ".")
        if name[0] == "good":
            good.append(file_dir + file)
            label_good.append(0)
        elif name[0] == "medium":
            medium.append(file_dir + file)
            label_medium.append(1)
        else:
            bad.append(file_dir + file)
            label_bad.append(2)
    print("There are %d good apples\nThere are %d medium apples\nThere are %d bad apples" %(len(good),len(medium),len(bad)))
    
    
    image_list = np.hstack((good,medium,bad))
    label_list = np.hstack((label_good,label_medium,label_bad))
    
    temp = np.array([image_list,label_list])
    temp = temp.transpose()
    np.random.shuffle(temp)
    
    image_list = list (temp[:,0])
    label_list = list (temp[:,1])
    label_list = [float(i) for i in label_list]
    
    return image_list,label_list



def get_batch(image,label,image_W,image_H,batch_size,capacity):
    image = tf.cast(image,tf.string)
    label = tf.cast(label,tf.int32)
    
    input_queue = tf.train.slice_input_producer([image,label])
    
    label = input_queue[1]
    image_contents = tf.read_file(input_queue[0])
    image = tf.image.decode_jpeg(image_contents,channels = 3)
#数据增强可以加    
    image = tf.image.resize_image_with_crop_or_pad(image,image_W,image_H)
    image = tf.image.per_image_standardization(image)
    
    image_batch,label_batch = tf.train.batch([image,label],
                                            batch_size = batch_size,
                                            num_threads = 64,
                                            capacity = capacity)
    
    label_batch = tf.reshape(label_batch,[batch_size])
    
    return image_batch,label_batch

BATCH_SIZE = 2
CAPACITY = 256
IMG_W = 208
IMG_H = 208

train_dir = "G:/苹果测试数据/"

image_list,label_list = get_files(train_dir)
image_batch,label_batch = get_batch(image_list,label_list,IMG_W,IMG_H,BATCH_SIZE,CAPACITY)

with tf.Session() as sess:
    i=0
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)
    
    try:
        while not coord.should_stop() and i<2:
            img,label = sess.run([image_batch,label_batch])
            
            for j in np.arange(BATCH_SIZE):
                print ("label:%d" %label[j])
                plt.imshow(img[j,:,:,:])
                plt.show()
            i+=1
    except tf.errors.OutOfRangeError:
        print ("done!")
    finally:
        coord.request_stop()
    coord.join(threads)

图片描述

阅读 2.5k
撰写回答
你尚未登录,登录后可以
  • 和开发者交流问题的细节
  • 关注并接收问题和回答的更新提醒
  • 参与内容的编辑和改进,让解决方法与时俱进
推荐问题