Dataset
- class ssm.CopyDataset(sequence_len, mem_tokens, vocab_size, selective=True, batch_size=64)
Bases:
IterableDataset
A dataset composed of an infinite number of samples for copy/selective copy tasks. It generates samples on-the-fly.
- generate_data()
Generate a dataset of input-output pairs. The input is a sequence of tokens, and the output is a copy of the input sequence with a specific token indicating the start of the copy.
- Parameters:
N (int) – Number of samples to generate.
- Returns:
Tuple of input and output tensors.
- Return type:
tuple(torch.Tensor, torch.Tensor)
- static make_selective(t)
Randomly permute the input tensor in order to create a selective copy dataset.
- Parameters:
t (torch.Tensor) – Input tensor to be permuted.
- Returns:
Permuted tensor.
- Return type:
torch.Tensor