Trainer

class ssm.Trainer(model, dataset, steps, metric_tracker, device=None, test_steps=0, optimizer_class=<class 'torch.optim.adam.Adam'>, optimizer_params={'lr': 0.001}, scheduler_class=None, scheduler_params=None)

Bases: object

_count_parameters()

Count the number of trainable and non-trainable parameters in the model. :return: The number of trainable parameters. :rtype: int

compute_metrics(output, target)

Compute the loss and accuracy metrics. :param torch.Tensor output: The model output. :param torch.Tensor target: The ground truth labels. :return: The loss and accuracy values. :rtype: tuple

fit()

Train the model using gradient accumulation.

model_summary()

Print a summary of the model, including the number of parameters and the architecture.

move_to_device()

Move the model and loss function to the specified device.

static set_device()

Determine the device to use for training (CPU or GPU). This method checks for the availability of CUDA and Metal Performance Shaders (MPS) on macOS. If neither is available, it defaults to CPU. :return: The device to use for training. :rtype: torch.device

test()

Test the model