grain.DataLoaderIterator#
- class grain.DataLoaderIterator(data_loader, state, validate_state=True)#
DataLoader iterator providing get/set state functionality.
This is the only iterator we expose to users. It wraps underlying MultipleProcessIterator. In order to set state, it recreates the underlying iterator fresh with a new state.
Checkpointing for DataLoaderIterator: DataLoaderIterator uses GrainPool, which distributes RecordMetadata from produced records among worker processes in a round robin fashion. Generally, some workers can process more elements than others at a given training step. Checkpointing logic goes as follows: 1) With each output batch produced, GrainPool emits the worker_index of The
worker that processed the batch.
DataLoaderIterator keeps track of the last_seen_index at each worker.
When restoring from a state, DataLoaderIterator checks what is the minimum last_seen_index (among the last seen indices for all workers.) and which worker processed that index. GrainPool is instructed to start distributing indices to the next worker.
- Parameters:
data_loader (DataLoader)
state (_IteratorState | None)
validate_state (bool)
- __init__(data_loader, state, validate_state=True)#
- Parameters:
data_loader (DataLoader)
state (dict[str, Any] | None)
validate_state (bool)
Methods
__init__(data_loader, state[, validate_state])get_state()load(directory)Loads the iterator state from a directory.
save(directory)Saves the iterator state to a directory.
set_state(state)Sets the state for the underlying iterator.