C++ API Reference
Tensor Class
Constructor
Tensor(const std::vector<size_t>& shape, DType dtype = DType::Float32, Device device = Device::CPU);
Factory Methods
zeros
static Tensor zeros(const std::vector<size_t>& shape, DType dtype = DType::Float32, Device device = Device::CPU);
ones
static Tensor ones(const std::vector<size_t>& shape, DType dtype = DType::Float32, Device device = Device::CPU);
rand
static Tensor rand(const std::vector<size_t>& shape, Device device = Device::CPU);
Operators
Addition
Tensor operator+(const Tensor& other) const;
Subtraction
Tensor operator-(const Tensor& other) const;
Multiplication
Tensor operator*(const Tensor& other) const;
Division
Tensor operator/(const Tensor& other) const;
Methods
sum
Tensor sum(const std::vector<int>& axes = {}, bool keepdims = false) const;
mean
Tensor mean(const std::vector<int>& axes = {}, bool keepdims = false) const;
matmul
Tensor matmul(const Tensor& other) const;
print