Trainer
- class ssm.Trainer(model, dataset, steps, metric_tracker, device=None, test_steps=0, optimizer_class=<class 'torch.optim.adam.Adam'>, optimizer_params={'lr': 0.001}, scheduler_class=None, scheduler_params=None)
Bases:
object
- _count_parameters()
Count the number of trainable and non-trainable parameters in the model. :return: The number of trainable parameters. :rtype: int
- compute_metrics(output, target)
Compute the loss and accuracy metrics. :param torch.Tensor output: The model output. :param torch.Tensor target: The ground truth labels. :return: The loss and accuracy values. :rtype: tuple
- fit()
Train the model using gradient accumulation.
- model_summary()
Print a summary of the model, including the number of parameters and the architecture.
- move_to_device()
Move the model and loss function to the specified device.
- static set_device()
Determine the device to use for training (CPU or GPU). This method checks for the availability of CUDA and Metal Performance Shaders (MPS) on macOS. If neither is available, it defaults to CPU. :return: The device to use for training. :rtype: torch.device
- test()
Test the model