Skip to content

Feeding the Beast - Data Loading Secrets for Hungry Neural Networks

Published: at 02:53 AM in 11 min readSuggest Changes

Table of Contents

Open Table of Contents

Introduction

Data loading is a critical component in the machine learning pipeline, especially when dealing with large datasets or distributed computing environments. Efficient data loading can significantly impact the training time and overall performance of your neural networks. In this post, we’ll dive deep into the world of data loading, exploring various techniques and best practices to keep your hungry neural networks well-fed and optimized.

TL;DR

torch.utils.data torch.utils.data.Dataset have getitem and len methods implemented. torch.utils.data.DataLoader

The Foundation: Datasets

PyTorch Datasets

At the heart of any modern deep learning pipeline lies the concept of datasets. In PyTorch, datasets are represented by the torch.utils.data.Dataset class, which provides an interface for accessing individual samples and their corresponding labels. PyTorch comes with several built-in datasets, such as torchvision.datasets.CIFAR10 and torchvision.datasets.MNIST, making it easy to get started with common datasets.

Datasets can be categorized into two main types:

Map-Style Datasets

Map-style datasets are the most common type of datasets in PyTorch. They implement the __getitem__ and __len__ methods, allowing for random access to individual samples. This makes them ideal for tasks where you need to shuffle data or access specific items by index.

Let’s create a simple map-style dataset for a collection of images and their labels:

from torch.utils.data import Dataset
from PIL import Image
import os

class ImageDataset(Dataset):
    def __init__(self, image_dir: str | Path, transform: Callable | None = None):
        self.image_dir = image_dir
        self.transform = transform
        self.images = [f for f in os.listdir(image_dir) if f.endswith('.jpg')]
        self.labels = [int(f.split('_')[0]) for f in self.images]  # Assuming filenames are "label_name.jpg"

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.images[idx])
        image = Image.open(img_path).convert('RGB')
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label

# Usage example
from torchvision import transforms

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

dataset = ImageDataset('path/to/image/directory', transform=transform)
print(f"Dataset size: {len(dataset)}")
print(f"First item: {dataset[0]}")

In this example, ImageDataset is a map-style dataset that loads images from a directory. The __len__ method returns the total number of images, while __getitem__ loads and returns a specific image and its label based on the given index.

Key features of map-style datasets:

  1. Random access: You can access any item using its index.
  2. Known length: The __len__ method allows you to determine the size of the dataset easily.
  3. Shuffling: Efficient for creating random batches during training.
  4. Compatibility: Works well with PyTorch’s DataLoader for batch processing and multi-threaded data loading.

Map-style datasets are particularly useful when your entire dataset can fit into memory or when you can quickly load individual samples on-demand. They provide flexibility in data access patterns, making them suitable for most deep learning tasks.

Iterable-Style Datasets

Iterable-style datasets are an alternative to map-style datasets, providing an iterator over the samples instead of random access. This can be useful when particularly suitable for cases where random reads are expensive or even improbable, and where the batch size depends on the fetched data.

At the heart of efficient data loading lies the concept of resumable iterators. These powerful tools allow us to pause and resume iteration, which is crucial for checkpointing, fault tolerance, and efficient data loading in deep learning pipelines.

Let’s start by implementing a basic ResumableIterator class:

class ResumableIterator:
    def __init__(self, contents: list):
        self.contents = contents
        self.state = 0

    def __iter__(self):
        return self

    def __next__(self):
        if self.state >= len(self.contents):
            raise StopIteration
        item = self.contents[self.state]
        self.state += 1
        return item

    def get_state(self):
        return self.state

    def set_state(self, state):
        if 0 <= state < len(self.contents):
            self.state = state
        else:
            raise ValueError("Invalid state")

This ResumableIterator class provides the basic functionality we need:

Specialized Iterators for Different Data Sources

In real-world scenarios, we often need to work with various data sources. Let’s explore how we can create specialized iterators for different types of data.

List Iterator

First, let’s create a simple ListIterator that extends our ResumableIterator:

class ListIterator(ResumableIterator):
    def __init__(self, contents: list):
        super().__init__(contents)

This iterator is perfect for in-memory list data and serves as a simple example of how to extend our base ResumableIterator class.

JSONL File Iterator

When dealing with large datasets, it’s common to store data in JSONL (JSON Lines) format, where each line is a valid JSON object. Let’s create an iterator for this format:

class JsonlFileIterator(ResumableIterator):
    def __init__(self, filename):
        self.filename = filename
        with open(self.filename, "r") as f:
            self.lines = [json.loads(line) for line in f]
        super().__init__(self.lines)

This iterator loads all the data into memory upon initialization. However, for very large files, this approach might not be memory-efficient. Let’s create a more efficient version:

class EfficientJsonlFileIterator(ResumableIterator):
    def __init__(self, filename: str):
        self.filename = filename
        self.state = 0
        self.file = open(self.filename, "r")

    def __next__(self):
        line = self.file.readline()
        if not line:
            self.file.close()
            raise StopIteration
        self.state += 1
        return json.loads(line)

    def set_state(self, state: int):
        if not isinstance(state, int) or state < 0:
            raise ValueError("Invalid state")
        if state == self.state:
            return
        if state < self.state:
            self.file.close()
            self.file = open(self.filename, "r")
            for _ in range(state):
                self.file.readline()
            self.state = state
        else:
            for _ in range(state - self.state):
                line = self.file.readline()
                if not line:
                    self.file.close()
                    raise ValueError("Invalid state")
                self.state += 1

This EfficientJsonlFileIterator reads the file line by line, reducing memory usage for large files.

Multi-File Iterator

In many cases, our dataset might be split across multiple files. Let’s create an iterator that can handle multiple JSONL files:

class MultiJsonlFileIterator:
    def __init__(self, filenames: list[str]):
        self.filenames = filenames
        self.num_files = len(filenames)

        self.state = 0
        self.file_ends = {}  # keep track of the end of each file
        self.total_items = 0

        self.curr_file_idx = 0
        self.curr_file_iter = JsonlFileIterator(filenames[0])

    def __iter__(self):
        return self

    def __next__(self):
        try:
            out = next(self.curr_file_iter)
            self.state += 1
            self.total_items = max(self.total_items, self.state)
            return out
        except StopIteration as e:
            # Check if we have more files to read
            if self.curr_file_idx + 1 == self.num_files:
                # No more files to read
                raise StopIteration from e

            # Track the current file size
            self.file_ends[self.curr_file_idx] = self.state
            self.curr_file_idx += 1
            self.curr_file_iter = JsonlFileIterator(self.filenames[self.curr_file_idx])
            return next(self)

    def get_state(self):
        return self.state

    def set_state(self, state: int):
        if self.state < 0 or self.state > self.total_items:
            raise ValueError("state must be between 0 and the total number of items")
        # Find the file that contains the state
        prev_end = 0
        for file_idx, end in self.file_ends.items():
            if state <= end:
                self.curr_file_idx = file_idx
                local_state = state - prev_end
                break
            prev_end = end
        self.curr_file_iter = JsonlFileIterator(self.filenames[self.curr_file_idx])
        self.curr_file_iter.set_state(local_state)

This iterator seamlessly handles multiple files, moving to the next file when it reaches the end of the current one.

Advanced Techniques for Data Loading

As we move into more complex scenarios, we need to consider techniques that can handle distributed training and improve performance. Let’s explore some advanced data loading techniques.

Sharded Data Iterator

In distributed training scenarios, we often need to shard our data across multiple workers. Here’s an implementation of a ShardedDataIterator:

Helpers for running distributed processes
def init_process(
    rank: int, world_size: int, fn: Callable[[int, int], None], backend="nccl"
):
    """Initialize the distributed environment."""
    os.environ["MASTER_ADDR"] = "127.0.0.1"
    os.environ["MASTER_PORT"] = "29500"
    dist.init_process_group(backend, rank=rank, world_size=world_size)
    fn(rank, world_size)


def run_dist_fn(fn: Callable[[int, int], None], num_ranks: int = 4):
    processes = []
    with contextlib.suppress(RuntimeError):
        mp.set_start_method("spawn")
    for rank in range(num_ranks):
        p = mp.Process(target=init_process, args=(rank, num_ranks, fn))
        p.start()
        processes.append(p)

    for p in processes:
        p.join()

def chunk(arr: list, n: int):
    """Yield successive n-sized chunks from arr."""
    for i in range(0, len(arr), n):
        yield arr[i : i + n]


class DataPartitioner:
    def __init__(
        self,
        shard_pattern: str,
        world_size: int,
        seed: int = 42,
    ):
        self.shard_pattern = shard_pattern
        self.shards = sorted(glob.glob(shard_pattern))
        self.world_size = world_size
        self.seed = seed

        # Assume equal partitioning for now
        self.shards_per_partition = len(self.shards) // world_size
        self.partition_inds = list(
            chunk(list(range(len(self.shards))), self.shards_per_partition)
        )

    def use(self, partition: int):
        shard_indices = self.partition_inds[partition]
        shard_files = [self.shards[i] for i in shard_indices]
        return ParallelMultiJsonlFileIterator(shard_files)

def run_dataloader(rank: int, world_size: int):
    data_partitioner = DataPartitioner("./shards/*.jsonl", world_size)
    rank = dist.get_rank()
    partition = data_partitioner.use(rank)
    for item in partition:
        print(f"Rank {rank}: {item}")

    dist.destroy_process_group()


if __name__ == "__main__":
    world_size = 1
    run_dist_fn(run_dataloader, world_size)

This iterator ensures that each shard gets a unique subset of the data, which is crucial for distributed training to avoid data duplication across workers.

Parallel Data Loading

To further optimize data loading, especially when dealing with I/O-bound operations, we can implement parallel data loading. Here’s an example using Python’s multiprocessing module:

import multiprocessing
from concurrent.futures import ProcessPoolExecutor

class ParallelMultiJsonlFileIterator(ResumableIterator):
    def __init__(
        self,
        filenames: list[str],
        num_workers: int | None = None,
        chunk_size: int = 100,
    ):
        self.filenames = filenames
        self.chunk_size = chunk_size
        self.num_workers = num_workers or multiprocessing.cpu_count()
        self.pool = ProcessPoolExecutor(max_workers=self.num_workers)
        self.state = 0
        self.iterator = self._create_iterator()

    def _load_file(self, filename):
        with open(filename, "r") as f:
            for line in f:
                yield json.loads(line)

    def _create_iterator(self):
        file_iters = [self._load_file(f) for f in self.filenames]
        combined_iter = itertools.chain.from_iterable(file_iters)
        return self.pool.map(identity, combined_iter, chunksize=self.chunk_size)

    def __next__(self):
        try:
            item = next(self.iterator)
        except StopIteration as e:
            self.pool.shutdown(wait=False)
            raise StopIteration from e

        self.state += 1
        return item

    # ... (get_state and set_state methods)

This parallel iterator uses multiple processes to load data concurrently, which can significantly speed up data loading, especially when dealing with many small files or when preprocessing is required.

Caching for Repeated Access

In scenarios where you need to access the same data multiple times (e.g., multiple epochs in training), implementing a caching mechanism can significantly improve performance. Here’s a simple cached iterator:

from functools import lru_cache

class CachedIterator(ResumableIterator):
    def __init__(self, iterator: ResumableIterator, cache_size: int = 100):
        self.iterator = iterator
        self.cache_size = cache_size
        self.state = 0

    @lru_cache(maxsize=None)
    def _get_item(self, index):
        self.iterator.set_state(index)
        return next(self.iterator)

    def __next__(self):
        try:
            item = self._get_item(self.state)
            self.state += 1
            return item
        except Exception as e:
            raise StopIteration from e

    # ... (get_state and set_state methods)

This cached iterator wraps another iterator and caches its outputs, making subsequent accesses to the same data much faster.

Conclusion

Efficient data loading is a crucial aspect of training large-scale neural networks. By implementing resumable iterators, we can create flexible and efficient data loading pipelines that can handle various data sources and scenarios.

We’ve covered several key concepts and implementations:

  1. Basic resumable iterators
  2. Specialized iterators for different data sources (lists, JSONL files, multi-file datasets)
  3. Sharded data iterators for distributed training
  4. Parallel data loading for improved performance
  5. Caching mechanisms for repeated data access

These tools and techniques provide a solid foundation for building efficient data loading pipelines for your machine learning projects. Remember to always profile and benchmark your specific use case to determine which optimizations provide the most benefit.

As you continue to work with large-scale datasets and complex neural network architectures, keep exploring new ways to optimize your data loading pipeline. It’s an essential step in improving the overall efficiency and performance of your machine learning workflows.


Previous Post
Neural Scaling Laws (Then and Now)
Next Post
Large-Scale Neural Network Training