-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
348 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,342 @@ | ||
const std = @import("std"); | ||
const builtin = @import("builtin"); | ||
const assert = std.debug.assert; | ||
const expectEqual = std.testing.expectEqual; | ||
const expectEqualDeep = std.testing.expectEqualDeep; | ||
|
||
/// Creates a Matrix with elements of type T and size rows times cols. | ||
pub fn Matrix(comptime T: type, comptime rows: usize, comptime cols: usize) type { | ||
assert(@typeInfo(T) == .Float); | ||
return struct { | ||
const Self = @This(); | ||
comptime rows: usize = rows, | ||
comptime cols: usize = cols, | ||
items: [rows][cols]T = undefined, | ||
|
||
/// Sets all elements to value. | ||
pub fn setAll(value: T) Self { | ||
var self = Self{}; | ||
for (&self.items) |*row| { | ||
for (row) |*col| { | ||
col.* = value; | ||
} | ||
} | ||
return self; | ||
} | ||
|
||
/// Returns an identity matrix of the matrix size. | ||
pub fn identity() Self { | ||
var self = Self{}; | ||
for (0..self.rows) |r| { | ||
for (0..self.cols) |c| { | ||
if (r == c) { | ||
self.items[r][c] = 1; | ||
} else { | ||
self.items[r][c] = 0; | ||
} | ||
} | ||
} | ||
return self; | ||
} | ||
|
||
/// Returns a matrix filled with random numbers. | ||
pub fn random(seed: ?u64) Self { | ||
const s: u64 = blk: { | ||
if (seed) |value| { | ||
break :blk value; | ||
} else { | ||
break :blk @truncate(@as(u128, @bitCast(std.time.nanoTimestamp()))); | ||
} | ||
}; | ||
var prng = std.rand.DefaultPrng.init(s); | ||
var rand = prng.random(); | ||
var self = Self{}; | ||
for (0..self.rows) |r| { | ||
for (0..self.cols) |c| { | ||
self.items[r][c] = rand.float(T); | ||
} | ||
} | ||
return self; | ||
} | ||
|
||
/// Returns the rows and columns as a struct. | ||
pub fn shape(self: Self) struct { usize, usize } { | ||
return .{ | ||
self.rows, | ||
self.cols, | ||
}; | ||
} | ||
|
||
/// Reshapes the matrix to a new shape. | ||
pub fn reshape(self: Self, comptime new_rows: usize, comptime new_cols: usize) Matrix(T, new_rows, new_cols) { | ||
comptime assert(rows * cols == new_rows * new_cols); | ||
var matrix = Matrix(T, new_rows, new_cols){}; | ||
for (0..new_rows) |r| { | ||
for (0..new_cols) |c| { | ||
const idx = r * new_cols + c; | ||
matrix.items[r][c] = self.at(idx / cols, @mod(idx, cols)); | ||
} | ||
} | ||
return matrix; | ||
} | ||
|
||
/// Returns a string representation of the matrix, for printing. | ||
pub fn toString(self: Self) [rows * cols * @bitSizeOf(T)]u8 { | ||
var print_buffer: [rows * cols * @bitSizeOf(T)]u8 = undefined; | ||
var printed: usize = 0; | ||
var written: []u8 = undefined; | ||
for (self.items) |row| { | ||
for (row) |val| { | ||
written = std.fmt.bufPrint(print_buffer[printed..], " {}", .{val}) catch unreachable; | ||
printed += written.len; | ||
} | ||
written = std.fmt.bufPrint(print_buffer[printed..], "\n", .{}) catch unreachable; | ||
printed += written.len; | ||
} | ||
return print_buffer; | ||
} | ||
|
||
/// Retrieves the element at position row, col in the matrix. | ||
pub fn at(self: Self, row: usize, col: usize) T { | ||
assert(row < self.rows); | ||
assert(col < self.cols); | ||
return self.items[row][col]; | ||
} | ||
|
||
/// Sets the element at row, col to val. | ||
pub fn set(self: *Self, row: usize, col: usize, val: T) void { | ||
assert(row < self.rows); | ||
assert(col < self.cols); | ||
self.items[row][col] = val; | ||
} | ||
|
||
/// Computes the trace (i.e. sum of the diagonal elements). | ||
pub fn trace(self: Self) T { | ||
assert(self.cols == self.rows); | ||
var val: T = 0; | ||
for (0..self.cols) |i| { | ||
val += self.items[i][i]; | ||
} | ||
return val; | ||
} | ||
|
||
/// Adds an offset to all matrix values. | ||
pub fn offset(self: Self, value: T) Self { | ||
var matrix: Self = undefined; | ||
for (0..rows) |r| { | ||
for (0..cols) |c| { | ||
matrix.items[r][c] = value + self.items[r][c]; | ||
} | ||
} | ||
return matrix; | ||
} | ||
|
||
/// Scales all matrix values. | ||
pub fn scale(self: Self, value: T) Self { | ||
var matrix: Self = undefined; | ||
for (0..rows) |r| { | ||
for (0..cols) |c| { | ||
matrix.items[r][c] = value * self.items[r][c]; | ||
} | ||
} | ||
return matrix; | ||
} | ||
|
||
/// Applies a unary function to all matrix values. | ||
pub fn apply(self: Self, comptime unaryFn: fn (arg: T) T) Self { | ||
var matrix: Self = undefined; | ||
for (0..rows) |r| { | ||
for (0..cols) |c| { | ||
matrix.items[r][c] = unaryFn(self.items[r][c]); | ||
} | ||
} | ||
return matrix; | ||
} | ||
|
||
/// Sets the sub-matrix at positon row, col to sub_matrix. | ||
pub fn setSubMatrix(self: *Self, row: usize, col: usize, matrix: anytype) void { | ||
assert(matrix.rows + row <= self.rows); | ||
assert(matrix.cols + col <= self.cols); | ||
for (0..matrix.rows) |r| { | ||
for (0..matrix.cols) |c| { | ||
self.items[row + r][col + c] = matrix.items[r][c]; | ||
} | ||
} | ||
} | ||
|
||
/// Sets the elements in the row. | ||
pub fn setRow(self: *Self, row: usize, values: [cols]T) void { | ||
assert(row < self.rows); | ||
for (0..self.cols) |c| { | ||
self.items[row][c] = values[c]; | ||
} | ||
} | ||
|
||
/// Sets the elements in the column. | ||
pub fn setCol(self: *Self, col: usize, values: [rows]T) void { | ||
assert(col < self.cols); | ||
for (0..self.rows) |r| { | ||
self.items[r][col] = values[r]; | ||
} | ||
} | ||
|
||
/// Returns the elements in the row as a row Matrix. | ||
pub fn getRow(self: Self, row: usize) Matrix(T, 1, cols) { | ||
assert(row < self.rows); | ||
var matrix = Matrix(T, 1, cols){}; | ||
for (0..self.cols) |c| { | ||
matrix.items[0][c] = self.items[row][c]; | ||
} | ||
return matrix; | ||
} | ||
|
||
/// Returns the elements in the column as a column Matrix. | ||
pub fn getCol(self: Self, col: usize) Matrix(T, rows, 1) { | ||
assert(col < self.cols); | ||
var matrix = Matrix(T, rows, 1){}; | ||
for (0..self.rows) |r| { | ||
matrix.items[r][0] = self.items[r][col]; | ||
} | ||
return matrix; | ||
} | ||
|
||
/// Transposes the matrix. | ||
pub fn transpose(self: Self) Matrix(T, cols, rows) { | ||
var m = Matrix(T, cols, rows){}; | ||
for (0..self.rows) |r| { | ||
for (0..self.cols) |c| { | ||
m.items[c][r] = self.items[r][c]; | ||
} | ||
} | ||
return m; | ||
} | ||
|
||
/// Adds a matrix. | ||
pub fn add(self: Self, other: Self) Self { | ||
var result: @TypeOf(self) = undefined; | ||
for (0..self.rows) |r| { | ||
for (0..self.cols) |c| { | ||
result.items[r][c] = self.items[r][c] + other.items[r][c]; | ||
} | ||
} | ||
return result; | ||
} | ||
|
||
/// Performs pointwise multiplication | ||
pub fn times(self: Self, other: Self) Self { | ||
var result: @TypeOf(self) = undefined; | ||
for (0..self.rows) |r| { | ||
for (0..self.cols) |c| { | ||
result.items[r][c] = self.items[r][c] * other.items[r][c]; | ||
} | ||
} | ||
return result; | ||
} | ||
|
||
/// Performs the dot (or internal product) of two matrices. | ||
pub fn dot(self: Self, other: anytype) Matrix(T, self.rows, other.cols) { | ||
comptime assert(self.cols == other.rows); | ||
var result = Matrix(T, self.rows, other.cols).setAll(0); | ||
for (0..self.rows) |r| { | ||
for (0..other.cols) |c| { | ||
for (0..self.cols) |k| { | ||
result.items[r][c] += self.items[r][k] * other.items[k][c]; | ||
} | ||
} | ||
} | ||
return result; | ||
} | ||
|
||
/// If the matrix only contains one element, it returns it, otherwise it fails to compile. | ||
pub fn item(self: Self) T { | ||
comptime assert(self.rows == 1 and self.cols == 1); | ||
return self.items[0][0]; | ||
} | ||
|
||
/// Sums all the elements in a matrix | ||
pub fn sum(self: Self) T { | ||
var accum: T = 0; | ||
for (self.items) |row| { | ||
for (row) |col| { | ||
accum += col; | ||
} | ||
} | ||
return accum; | ||
} | ||
|
||
/// Computes the norm of the matrix as the square root of the sum of its squared values. | ||
pub fn norm(self: Self) T { | ||
var sum_sq: T = 0; | ||
for (self.items) |row| { | ||
for (row) |col| { | ||
sum_sq += col * col; | ||
} | ||
} | ||
return @sqrt(sum_sq); | ||
} | ||
}; | ||
} | ||
|
||
test "identity" { | ||
const eye = Matrix(f32, 3, 3).identity(); | ||
try expectEqual(eye.sum(), 3); | ||
for (0..eye.rows) |r| { | ||
for (0..eye.cols) |c| { | ||
if (r == c) { | ||
try expectEqual(eye.at(r, c), 1); | ||
} else { | ||
try expectEqual(eye.at(r, c), 0); | ||
} | ||
} | ||
} | ||
} | ||
|
||
test "setAll" { | ||
const zeros = Matrix(f32, 3, 3).setAll(0); | ||
try expectEqual(zeros.sum(), 0); | ||
const ones = Matrix(f32, 3, 3).setAll(1); | ||
const shape = ones.shape(); | ||
try expectEqual(ones.sum(), @as(f32, @floatFromInt(shape[0] * shape[1]))); | ||
} | ||
|
||
test "shape" { | ||
const matrix = Matrix(f32, 4, 5){}; | ||
const shape = matrix.shape(); | ||
try expectEqual(shape[0], 4); | ||
try expectEqual(shape[1], 5); | ||
} | ||
|
||
test "scale" { | ||
const seed: u64 = @truncate(@as(u128, @bitCast(std.time.nanoTimestamp()))); | ||
const a = Matrix(f32, 4, 3).random(seed); | ||
const b = Matrix(f32, 4, 3).random(seed).scale(std.math.pi); | ||
try expectEqualDeep(a.shape(), b.shape()); | ||
for (0..a.rows) |r| { | ||
for (0..a.cols) |c| { | ||
try expectEqual(std.math.pi * a.at(r, c), b.at(r, c)); | ||
} | ||
} | ||
} | ||
|
||
test "apply" { | ||
var a = Matrix(f32, 3, 4).random(null); | ||
|
||
const f = struct { | ||
fn f(x: f32) f32 { | ||
return @sin(x); | ||
} | ||
}.f; | ||
|
||
var b = a.apply(f); | ||
try expectEqualDeep(a.shape(), b.shape()); | ||
for (0..a.rows) |r| { | ||
for (0..a.cols) |c| { | ||
try expectEqual(@sin(a.at(r, c)), b.at(r, c)); | ||
} | ||
} | ||
} | ||
|
||
test "norm" { | ||
var matrix = Matrix(f32, 3, 4).random(null); | ||
try expectEqual(matrix.norm(), @sqrt(matrix.times(matrix).sum())); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,7 @@ | ||
test "color unit tests" { | ||
_ = @import("color.zig"); | ||
} | ||
|
||
test "matrix unit tests" { | ||
_ = @import("matrix.zig"); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters