Skip to content

C++ API Reference

Tensor Class

Constructor

Tensor(const std::vector<size_t>& shape, DType dtype = DType::Float32, Device device = Device::CPU);

Factory Methods

zeros

static Tensor zeros(const std::vector<size_t>& shape, DType dtype = DType::Float32, Device device = Device::CPU);

ones

static Tensor ones(const std::vector<size_t>& shape, DType dtype = DType::Float32, Device device = Device::CPU);

rand

static Tensor rand(const std::vector<size_t>& shape, Device device = Device::CPU);

Operators

Addition

Tensor operator+(const Tensor& other) const;

Subtraction

Tensor operator-(const Tensor& other) const;

Multiplication

Tensor operator*(const Tensor& other) const;

Division

Tensor operator/(const Tensor& other) const;

Methods

sum

Tensor sum(const std::vector<int>& axes = {}, bool keepdims = false) const;

mean

Tensor mean(const std::vector<int>& axes = {}, bool keepdims = false) const;

matmul

Tensor matmul(const Tensor& other) const;

print

void print() const;