5 cool PyTorch Functions that are beginner-friendly
PyTorch is a Python library with a wide variety of functions and operations, mostly used for deep learning. Data is stored in tensors. A tensor is a vector, number, matrix, or an n-dimensional array. Let’s see a few simple functions which can come in handy for you.
- torch.randn
- torch.cat
- torch.reshape
- torch.unbind
- torch.take
torch.randn
This generates a tensor filled with random numbers from a normal distribution with mean 0 and variance 1. The shape of the resulting tensor is determined by argument size.
size can be a sequence of integers defining the shape of the tensor or it can be a collection like a tuple
requires_grad specifies if autograd should record operations on the tensor
you cannot give two tuples as argument as that doesn’t define a valid shape
torch.cat
This function is used to concatenate given sequence of tensors. To do thus, all tensors should either have the same shape or be empty
This works since all tensors are of same shape
concatenates according to the given dimension
breaks since all tensors are not of the same shape
torch.reshape
This function is useful when you want to change the shape of a given tensor if it is possible
returns a tensor of shape (2,6) from a tensor of shape (3,4)
This fails because we cannot obtain a tensor of shape (4,4) from a tensor of shape (3,4)
torch.unbind
This is used to remove a dimension of a tensor
By default, the dimension is set to 0
torch.take
This function is useful to extract the elements at specific indices from a tensor
simple use case of the function
You cannot access elements out of range
Conclusion
This wraps up the 5 useful torch.Tensor functions that are simple to use.