Apple的MLX是一个用于机器学习的推理框架,其最大的特点是能是用Apple的统一内存进行LLM推理。想到明年可能做AI编译器的相关工作,而手边有一台学校发的M3 Macbook Air,就想着尝鲜试下

环境配置

MLX手册给的是使用Python安装,这很明显不是我想要的,看了下HomeBrew有收入,那就直接HomeBrew安装

brew install mlx

但这个环境只能运行Python和C++版本的MLX,想要运行C版本的MLX,需要拉取Git,手动配置CMake手动安装(也许后面可以试试以后把MLX-C打包上传HomeBrew?)

git clone https://github.com/ml-explore/mlx-c

配置CMake和Ninja

mkdir build && cd build
cmake .. -G Ninja -DCMAKE_BUILD_TYPE=Release
sudo ninja install

但我在编译的时候出现报错,告诉我找不到Metal API,网上找资料被告知需要去AppStore安装完整版XCode,完成第一次运行后才能使用

sudo xcode-select -s /Applications/Xcode.app/Contents/Developer
sudo xcodebuild -license accept
sudo xcodebuild -runFirstLaunch

在安装完XCode以后,还需要再安装MetalToolchain才行

xcodebuild -downloadComponent MetalToolchain

CMake项目环境配置

CMAKE_PREFIX_PATH里把编译好的MLX-C给加上

cmake .. -G Ninja -DCMAKE_PREFIX_PATH="$(brew --prefix mlx);/usr/local"

相对应的CMakeLists.txt加上对应的find_package

find_package(MLX CONFIG REQUIRED)
find_package(MLXC CONFIG REQUIRED)

测试程序

#include <stdio.h>
#include <mlx/c/mlx.h>
void print_array(const char* msg, mlx_array arr) {
mlx_string str = mlx_string_new();
mlx_array_tostring(&str, arr);
printf("%s\n%s\n", msg, mlx_string_data(str));
mlx_string_free(str);
}
int main() {
// 打印 MLX 版本
mlx_string version = mlx_string_new();
mlx_version(&version);
printf("MLX C API 版本: %s\n\n", mlx_string_data(version));
// 获取默认 GPU stream
mlx_stream stream = mlx_default_gpu_stream_new();
// 创建浮点数数组
float data_a[] = {1.0f, 2.0f, 3.0f, 4.0f};
float data_b[] = {5.0f, 6.0f, 7.0f, 8.0f};
int shape[] = {4};
// 创建 MLX 数组
mlx_array a = mlx_array_new_data(data_a, shape, 1, MLX_FLOAT32);
mlx_array b = mlx_array_new_data(data_b, shape, 1, MLX_FLOAT32);
mlx_array c = mlx_array_new();
// 执行加法运算
mlx_add(&c, a, b, stream);
print_array("a + b =", c);
// 矩阵运算示例
int shape_2d[] = {2, 2};
mlx_array x = mlx_array_new();
mlx_array y = mlx_array_new();
mlx_array z = mlx_array_new();
mlx_reshape(&x, a, shape_2d, 2, stream);
mlx_reshape(&y, b, shape_2d, 2, stream);
mlx_matmul(&z, x, y, stream);
print_array("矩阵乘法结果:", z);
// 释放资源
mlx_array_free(a);
mlx_array_free(b);
mlx_array_free(c);
mlx_array_free(x);
mlx_array_free(y);
mlx_array_free(z);
mlx_stream_free(stream);
mlx_string_free(version);
return 0;
}