목차

  1. Feature 목록 얻기
  2. 레코드 순회
  3. 파일로 저장

저번 주말동안 Machine Learning Challenge Korea 2017에 참가했는데, 데이터가 모두 tfrecords 형태로 배포되었다. 익숙치 않기도 하고 다른 프레임워크(e.g. Caffe)를 이용하기 위해선 데이터를 꺼낼 필요가 있었다. 꺼낼 때 Tensorflow를 이용해야 한다는 점이 조금 불편하긴 하지만 코드 자체는 비교적 간단한 편이어서 짧게 적어보았다.

1. Feature 목록 얻기

import tensorflow as tf

def get_tfrecords_feature_list(tfrecords_filename):
    ptr = 0
    record_iterator = tf.python_io.tf_record_iterator(path=tfrecords_filename)

    for string_record in record_iterator:
        example = tf.train.Example()
        example.ParseFromString(string_record)
        return example.features.feature.keys()

    return []

2. 레코드 순회

import tensorflow as tf

def read_tfrecords(tfrecords_filename, is_train_val=False):
    ptr = 0
    record_iterator = tf.python_io.tf_record_iterator(path=tfrecords_filename)

    for string_record in record_iterator:
        example = tf.train.Example()
        example.ParseFromString(string_record)

        img = (example.features.feature['image']
                                      .bytes_list
                                      .value[0])

        if is_train_val:
            label = (example.features.feature['label']
                                            .int64_list
                                            .value[0])
            yield ptr, img, label
        else:
            yield ptr, img

        ptr += 1

3. 파일로 저장

with open('train/labels.txt', 'w') as f_labels:
    for idx, img, label in read_tfrecords('./train.tfrecords', is_train_val=True):
        fn = 'train/{}.png'.format(idx)
        with open(fn, 'wb') as f_img:
            f_img.write(img)
    print >>f_labels, "{} {}".format(fn, label)