Dataset

class ssm.CopyDataset(sequence_len, mem_tokens, vocab_size, selective=True, batch_size=64)

Bases: IterableDataset

A dataset composed of an infinite number of samples for copy/selective copy tasks. It generates samples on-the-fly.

generate_data()

Generate a dataset of input-output pairs. The input is a sequence of tokens, and the output is a copy of the input sequence with a specific token indicating the start of the copy.

Parameters:

N (int) – Number of samples to generate.

Returns:

Tuple of input and output tensors.

Return type:

tuple(torch.Tensor, torch.Tensor)

static make_selective(t)

Randomly permute the input tensor in order to create a selective copy dataset.

Parameters:

t (torch.Tensor) – Input tensor to be permuted.

Returns:

Permuted tensor.

Return type:

torch.Tensor