当前位置: 首页 > 工具软件 > DLRM > 使用案例 >

tensorflow2 生产DLRM tfrecord数据

陶高峯
2023-12-01

sok dlrm数据

下载terebate数据:
https://ailab.criteo.com/ressources/

# Usage: dlrm_raw input_dir output_dir --train {days for training} --test {days for testing}
$ dlrm_raw ./ ./ \
--train 0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22 \
--test 23

处理binary 为tfrecord

import os
import queue
import concurrent

import numpy as np
import tensorflow as tf
import json


class BinaryDataset:

    def __init__(
        self,
        label_bin,
        dense_bin,
        category_bin,
        batch_size=1,
        drop_last=True,
        global_rank=0,
        global_size=1,
        prefetch=1,
        label_raw_type=np.int32,
        dense_raw_type=np.int32,
        category_raw_type=np.int32,
        log=True,
    ):
        """
        * batch_size : The batch size of local rank, which means the total batch size of all ranks should be (batch_size * global_size).
        * prefetch   : If prefetch > 1, it can only be read sequentially, otherwise it will return incorrect samples.
        """
        self._check_file(label_bin, dense_bin, category_bin)

        self._batch_size = batch_size
        self._drop_last = drop_last
        self._global_rank = global_rank
        self._global_size = global_size

        # actual number of samples in the binary file
        self._samples_in_all_ranks = os.path.getsize(label_bin) // 4

        # self._num_entries represents there are how many batches
        self._num_entries = self._samples_in_all_ranks // (batch_size * global_size)

        # number of samples in current rank
        self._num_samples = self._num_entries * batch_size

        if not self._drop_last:
            if (self._samples_in_all_ranks % (batch_size * global_size)) < global_size:
                print("The number of samples in last batch is less than global_size, so the drop_last=False will be ignored.")
                self._drop_last = True
            else:
                # assign the samples in the last batch to different local ranks
                samples_in_last_batch = [(self._samples_in_all_ranks % (batch_size * global_size)) // global_size] * global_size
                for i in range(global_size):
                    if i < (self._samples_in_all_ranks % (batch_size * global_size)) % global_size:
                        samples_in_last_batch[i] += 1
                assert(sum(samples_in_last_batch) == (self._samples_in_all_ranks % (batch_size * global_size)))
                self._samples_in_last_batch = samples_in_last_batch[global_rank]

                # the offset of last batch
                self._last_batch_offset = []
                offset = 0
                for i in range(global_size):
                    self._last_batch_offset.append(offset)
                    offset += samples_in_last_batch[i]

                # correct the values when drop_last=Fasle
                self._num_entries += 1
                self._num_samples = (self._num_entries - 1) * batch_size + self._samples_in_last_batch

        self._prefetch = min(prefetch, self._num_entries)
        if self._prefetch >1 :
          self._prefetch_queue = queue.Queue()
          self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)

        self._label_file = os.open(label_bin, os.O_RDONLY)
        self._dense_file = os.open(dense_bin, os.O_RDONLY)
        self._category_file = os.open(category_bin, os.O_RDONLY)

        self._label_raw_type = label_raw_type
        self._dense_raw_type = dense_raw_type
        self._category_raw_type = category_raw_type

        self._log = log

    def _check_file(self, label_bin, dense_bin, category_bin):
        # num_samples represents the actual number of samples in the dataset
        num_samples = os.path.getsize(label_bin) // 4
        if num_samples <= 0:
            raise RuntimeError("There must be at least one sample in %s"%label_bin)

        # check file size
        for file, bytes_per_sample in [[label_bin, 4], [dense_bin, 52], [category_bin, 104]]:
            file_size = os.path.getsize(file)
            if file_size % bytes_per_sample != 0:
                raise RuntimeError("The file size of %s should be an integer multiple of %d."%(file, bytes_per_sample))
            if (file_size // bytes_per_sample) != num_samples:
                raise RuntimeError("The number of samples in %s is not equeal to %s"%(dense_bin, label_bin))

    def __del__(self):
        for file in [self._label_file, self._dense_file, self._category_file]:
            if file is not None:
                os.close(file)

    def __len__(self):
        return self._num_entries

    def __getitem__(self, idx):
        if idx >= self._num_entries:
            raise IndexError()

        if self._prefetch <= 1:
            return self._get(idx)

        if idx == 0:
            for i in range(self._prefetch):
                self._prefetch_queue.put(self._executor.submit(self._get, (i)))

        if idx < (self._num_entries - self._prefetch):
            self._prefetch_queue.put(self._executor.submit(self._get, (idx + self._prefetch)))

        return self._prefetch_queue.get().result()

    def _get(self, idx):
        # calculate the offset & number of the samples to be read
        if not self._drop_last and idx == self._num_entries - 1:
            sample_offset = idx * (self._batch_size * self._global_size) + self._last_batch_offset[self._global_rank]
            batch = self._samples_in_last_batch
        else:
            sample_offset = idx * (self._batch_size * self._global_size) + (self._batch_size * self._global_rank)
            batch = self._batch_size

        # read the data from binary file
        label_raw_data = os.pread(self._label_file, 4 * batch, 4 * sample_offset)
        label = np.frombuffer(label_raw_data, dtype=self._label_raw_type).reshape([batch, 1])

        dense_raw_data = os.pread(self._dense_file, 52 * batch, 52 * sample_offset)
        dense = np.frombuffer(dense_raw_data, dtype=self._dense_raw_type).reshape([batch, 13])

        category_raw_data = os.pread(self._category_file, 104 * batch, 104 * sample_offset)
        category = np.frombuffer(category_raw_data, dtype=self._category_raw_type).reshape([batch, 26])

        # convert numpy data to tensorflow data
        if self._label_raw_type == self._dense_raw_type and self._label_raw_type == self._category_raw_type:
            data = np.concatenate([label, dense, category], axis=1)
            data = tf.convert_to_tensor(data)
            label = tf.cast(data[:, 0:1], dtype=tf.float32)
            dense = tf.cast(data[:, 1:14], dtype=tf.float32)
            category = tf.cast(data[:, 14:40], dtype=tf.int64)
        else:
            label = tf.convert_to_tensor(label, dtype=tf.float32)
            dense = tf.convert_to_tensor(dense, dtype=tf.float32)
            category = tf.convert_to_tensor(category, dtype=tf.int64)
        # preprocess
        if self._log:
            dense = tf.math.log(dense + 3.0)

        return (dense, category), label

data_dir = "/data2/wuyongyu02/data/terabete/sok"

with open(os.path.join(data_dir, 'train/metadata.json'), 'r') as f:
  metadata = json.load(f)
print(metadata)

dtype = {'int32': np.int32, 'float32': np.float32}

dataset = BinaryDataset(
            label_bin=os.path.join(data_dir, 'train/label.bin'),
            dense_bin=os.path.join(data_dir, 'train/dense.bin'),
            category_bin=os.path.join(data_dir, 'train/category.bin'),
            batch_size=1,
            drop_last=True,
            global_rank=0,
            global_size=1,
            prefetch=1,
            label_raw_type=dtype[metadata['label_raw_type']],
            dense_raw_type=dtype[metadata['dense_raw_type']],
            category_raw_type=dtype[metadata['category_raw_type']],
            log=metadata['dense_log'],
        )

def _bytes_feature(value):
  """Returns a bytes_list from a string / byte."""
  # if isinstance(value, type(tf.constant(0))):
  #   value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
  if isinstance(value, list):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _int64_feature(value):
  """Returns an int64_list from a bool / enum / int / uint."""
  if isinstance(value, list):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def _float_feature(value):
  if isinstance(value, list):
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))
  return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

COLUMNS = ['spa_fea', 'den_fea']
FEATURES = COLUMNS + ['label']
FEATURES_SIZE = {
    "spa_fea": 26,
    "den_fea": 13,
    "label": 1
}

def serialize_example(samples, labels):
  feature = {}
  feature["den_fea"] = _float_feature(samples[0].numpy().tolist()[0])
  feature["spa_fea"] = _int64_feature(samples[1].numpy().tolist()[0])    
  feature["label"] = _float_feature(labels.numpy().tolist()[0])
  example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
  return example_proto.SerializeToString()

count =0
def write_tf_record():
  write_data = "/data2/wuyongyu02/data/terabete/sok_records/records-part"
  with tf.io.TFRecordWriter(write_data) as writer:
    for i, (samples, labels) in enumerate(dataset):
      writer.write(serialize_example(samples, labels))
      count + = 1
      if count % 1000000 == 0:
        print("cur:", count)
    
write_tf_record()

# 测试读取
def _parse_tf_record(value):
  example_schema = {}
  example_schema["den_fea"] = tf.io.VarLenFeature(dtype=tf.float32)
  example_schema["label"] =  tf.io.VarLenFeature(dtype=tf.float32)
  example_schema["spa_fea"] = tf.io.VarLenFeature(dtype=tf.int64)
  example_schema = dict(sorted(example_schema.items(), key=lambda k: k[0]))
  features_tensor = tf.io.parse_example(value, example_schema)
  for k in FEATURES:
      features_tensor[k] = tf.sparse.to_dense(features_tensor[k])
  label_tensor = features_tensor.pop('label')
  return features_tensor, label_tensor

filenames = ["/data2/wuyongyu02/data/terabete/sok_records/records-part"]
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.repeat(1)
dataset = dataset.batch(1)
dataset = dataset.map(_parse_tf_record, num_parallel_calls=4)
iterator = iter(dataset)

a = iterator.get_next()

print("a:", a)
 类似资料: