Transformations#
Grain Transforms interface denotes transformations which are applied to data. In the case of local transformations (such as map, random map, filter), the transforms receive an element on which custom changes are applied. For global transformations (such as batching), one must provide the batch size.
The Grain core transforms interface code is here.
Map Transform#
Map Transform is for 1:1 transformations of elements. Elements can be of any
type, it is the user’s responsibility to use the transformation such that the
inputs it receives correspond to the signature.
Example of transformation which implements Map Transform (for elements of type
int):
class PlusOne(grain.transforms.Map):
def map(self, x: int) -> int:
return x + 1
MapWithIndex Transform#
MapWithIndex Transform is similar to Map transform in being a 1:1
transformations of elements, but also takes in the index/position of the element
as the first argument. This is useful for pairing elements with an index key or
even keeping it as metadata alongside the actual data.
Example of transformation which implements MapWithIndex transform (for
elements of type int):
class PlusOneWithIndexKey(grain.transforms.MapWithIndex):
def map_with_index(self, i: int, x: int) -> tuple[int, int]:
return (x + 1, i)
RandomMap Transform#
RandomMap Transform is for 1:1 random transformations of elements. The
interface requires a np.random.Generator as parameter to the random_map
function.
Example of a RandomMap Transform:
class PlusRandom(grain.transforms.RandomMap):
def random_map(self, x: int, rng: np.random.Generator) -> int:
return x + rng.integers(100_000)
FlatMap Transform#
FlatMap Transform is for splitting operations of individual elements. The
max_fan_out is the maximum number of splits that an element can generate.
Please consult the code for detailed info.
Example of a FlatMap Transform:
class FlatMapTransformExample(grain.experimental.FlatMapTransform):
max_fan_out: int
def flat_map(self, element: int):
for _ in range(self.max_fan_out):
yield element
Filter Transform#
Filter Transform is for applying filtering to individual elements. Elements
for which the filter function returns False will be removed.
Example of a Filter Transform that removes all even elements:
class RemoveEvenElements(grain.transforms.Filter):
def filter(self, element: int) -> bool:
return element % 2
Batch#
To apply the Batch transform, pass
grain.transforms.Batch(batch_size=batch_size, drop_remainder=drop_remainder).
Note: The batch size used when passing Batch transform will be the global
batch size if it is done before sharding and the per host batch size if it is
after. Typically usage with IndexSampler is after sharding.