- Published on
- |Views: 121|24 min read
Apple Silicon Metal vs NVIDIA CUDA
- Authors
- Name
- Shashank Shekhar
- @sshkhr16
My predicament of being GPU Poor in terms of owning NVIDIA GPUs, but having an Apple M1 Max MacBook led me to explore the Apple Silicon GPUs and the Metal framework for GPU programming. However, I found that the Metal documentation and examples are highly geared towards graphics programming, and do not capture the ability of Apple Silicon GPUs to perform GPGPU (General Purpose GPU) computations very well. Since my goal is to make GPUs Go Brrr i.e. implement computational workloads relevant to deep learning, I decided to document the key concepts in the Apple Silicon GPU hardware and software ecosystem for myself, as well as for any other folks who are not well versed in graphics programming but are looking to use their Apple GPUs for machine learning or HPC workloads.
This is a relatively high level overview of the Apple GPU architecture and Metal programming API, and how it compares to the NVIDIA GPU architecture and CUDA API. It is aimed mostly at folks in machine learning and scientific computing, perhaps someone who is somewhat familiar with CUDA and looking to run computational workloads on the Apple Silicon GPUs.
It doesn't cover several low-level details, including but not limited to: NVIDIA's tensor cores, Apple's Neural Engines, interfacing CPU and GPU programming, resource optimization, or almost anything that is relevant for graphics workloads.
Comparing the Device Architectures
Since I started learning about CUDA and parallel programming from the Programming Massively Parallel Processors book (abbreviated as PMPP), I thought it would be a nice parallel to compare the GPU architecture and memory model of Apple Silicon family with that of NVIDIA based on the illustrations shown in PMPP.

The architecture of a CUDA enabled GPU is illustrated as shown above (reproduced from Fig 4.1 of PMPP). The SM
in the figure is the Streaming Multiprocessor, which is the basic compute unit of a CUDA GPU. Each SM contains multiple ALUs (Arithmetic Logic Units), schedulers, and caches (grouped together control in the image). The SMs are organized into a grid of thread blocks
(the white rectangles containing the eight green squares), which are further organized into a grid of threads
(the little green squares in the image). The threads in a block can communicate with each other through shared memory
, while the threads in different blocks cannot. The GPU also comes with a global memory
space, which is accessible by all threads, but has a higher latency than shared memory.

Based on my understanding of an Apple Silicon GPU, the architecture is similar to that of a CUDA GPU, but with one key difference. The Apple Silicon GPU has a core
which is largely equivalent to the SM in CUDA. Each core contains multiple ALUs and caches, and the cores are organized into a grid of threadgroups
(instead of blocks), which are further organized into a grid of threads
. The threads in a threadgroup can communicate with each other through threadgroup memory
, which is similar to shared memory in CUDA. The key difference between the two architectures arises in the global memory: unlike CUDA GPUs which have a global device memory space, Apple utilizes a unified memory
architecture where the CPU and GPU share the same memory.
While the unified memory architecture (UMA) allows for more efficient data sharing between the CPU and GPU, it also means that the memory access patterns need to be carefully managed to avoid contention. Apple's close control over the software and hardware stack in their computers allows them to utilize this unified memory architecture effectively for various applications.
In CUDA terminology, device
refers to the GPU, while host
refers to the CPU. Wherever you read the term device and host in this article, mentally replace it with GPU and CPU respectively.
For illustration purposes, the grid, blocks (threadgroups) and threads are shown as 1D/2D arrays in the figures above, but they can also be 3D arrays in both CUDA and Metal.
A dictionary for translating Apple Silicon GPU terms to NVIDIA
Based on our brief discussion of the two device architectures, I'm sure you must have noticed that several of the hardware and software abstractions in both the Apple and NVIDIA GPU ecosystems are similar. Just from the terms we have seen until now, we can already start to map the terms used in Metal to those more familiar to CUDA practitioners. The table below summarizes the Apple GPU terms we have encountered thus far and their equivalents in NVIDIA GPUs:
Apple Silicon Term | NVIDIA Equivalent | Description |
---|---|---|
GPU Core | Streaming Multiprocessor | Basic compute unit containing multiple ALUs, schedulers, and caches |
Grid | Grid | Overall structure of work to be processed by the GPU |
Threadgroup | Thread Block | Group of threads that can synchronize and share memory |
Thread | Thread | Individual execution unit that processes a single element of work |
SIMD-group | Warp | Group of 32 threads executed in lockstep |
Threadgroup Memory | Shared Memory | Fast memory accessible by all threads in a threadgroup/block |
Device Memory | Global Memory | Main GPU memory accessible by all threads |
Unified Memory | CUDA Unified Memory | Shared memory space accessible by both CPU and GPU |
The Programming Models: CUDA vs Metal
Since it is so fundamental to the difference between the two GPUs, we will revisit the differences in memory hierarchy of Apple Silicon and NVIDIA GPUs in a bit. But first, let's take a brief look at the programming models of both architectures. NVIDIA developed its own superset of C (and later C++) called CUDA to program its GPUs, starting all the way back in 2007. Apple came out with the Metal language and API in 2014, initially in Swfit and Objective-C but later on in C++ via the metal-cpp project.
I have used metal-cpp in this article, but frankly the documentation on Apple's developer website for it is very sparse. However, I chose to still use it for drawing a parallel to CUDA, and also cause I'm unfamiliar with Objective-C or Swift. If you are writing hardware accelerated code for Apple Silicon, metal-cpp might end being limiting for you, since the Metal Performance Shaders framework which has "highly optimized compute and graphics shaders" is not available (or at least not documented) in metal-cpp. CUDA has equivalent libraries with optimized kernels like CUTLASS which are both open source and quite well documented.
The Metal language
is a C++-like language that allows developers to write GPU code in a familiar syntax. Metal kernels (also referred to as "shaders" since Metal was originally written for graphics programming) are written as .metal
files which follow a syntax very similar to that of C.
Let us visualize the kernel code with a pratical example of a simple kernel that adds two arrays of floats. This example is borrowed from one out of four ML workloads documented in the Metal documentation (seriously, Apple 😠). The example below shows the same kernel written in both Metal and CUDA.
Differences in host code and kernel dispatch
While the kernel code looks suspiciously similar, the way we dispatch the kernel is quite different in Metal and CUDA. In CUDA, we can use the handy <<<...>>>
syntax to specify the grid and block dimensions and CUDA largely takes care of the kernel dispatch. Metal, on the other hand, is a somewhat lower-level library, and we need to explicitly set up and take care of a few more steps. I have written out a minimal host code setup for both Metal and CUDA where we utilize our kernels written above.
As you can see in CUDA, the key steps after writing our kernel that we need to do are allocating our host and device memory, copying the data from host to device, launching the kernel and copying the results back to host. While Metal avoids us having to copy data back and forth due to the unified memory architecture, we have to do a lot more work to get things set up before we call our kernel.
Calling a Metal kernel (after writing the GPU function in Metal Shading Language) involves creating a compute pipeline from the compiled function, allocating device-accessible buffers for input and output data, constructing a command buffer to hold instructions, encoding the kernel dispatch call into that command buffer, and finally submitting the command buffer to the GPU for execution.
As noted in another blog, I am unsure why some of this boilerplate could not be wrapped into an easy kernel dispatcher function, similar to how CUDA does it. I suspect it is because Apple wants to keep the Metal API as low-level as possible, and give developers more control over the GPU resources. As a beginner to GPU programming, I had to spend a little bit longer trying to learn about command queues, command buffers, data buffers, and pipeline. But they let you have really close control over your CPU and GPU compute cycles and memory, and for someone interested in further understanding how that can be useful in writing performant code can check out this video from Apple: Optimize Metal performance for Apple Silicon Macs.
Parallels in the programming APIs
Similar to how we can map a lot of the concepts from NVIDIA GPU system architecture to Apple Silicon, a lot of the programming constructs in CUDA and Metal have parallels. While I have tried my best to be correct in writing the translation from Metal to CUDA programming contructs, I did use Claude's helo to write the table, so please do let me know if you find any errors.
Metal API/Concept | CUDA Equivalent | Description |
---|---|---|
Memory Spaces | ||
device | (global) | Main GPU memory (default for parameters in CUDA) |
threadgroup | __shared__ | Memory shared within a threadgroup/block |
constant | __constant__ | Read-only memory space for constants |
thread | (local variables) | Thread-local memory (automatic variables) |
Kernel Functions | ||
kernel void functionName() | __global__ void functionName() | Defines a function that runs on the GPU |
[[thread_position_in_grid]] | blockIdx.x * blockDim.x + threadIdx.x | Getting the global thread index |
[[threadgroup_position_in_grid]] | blockIdx | Getting the threadgroup/block position |
[[thread_position_in_threadgroup]] | threadIdx | Getting the thread position within a threadgroup/block |
[[threads_per_threadgroup]] | blockDim | Getting the dimensions of a threadgroup/block |
[[threads_per_grid]] | gridDim * blockDim | Getting the total thread dimensions |
Synchronization | ||
threadgroup_barrier(mem_flags::mem_none) | __syncthreads() | Synchronize threads within a threadgroup/block |
threadgroup_barrier(mem_flags::mem_threadgroup) | __syncthreads() | Sync with memory visibility for threadgroup memory |
simdgroup_barrier(mem_flags::mem_none) | __syncwarp() | Synchronize threads within a SIMD-group/warp |
Memory Management | ||
MTL::Buffer* buffer = device->newBuffer() | cudaMalloc(&ptr, size) | Allocate memory on the GPU |
buffer->contents() | cudaMemcpy(dst, src, size, direction) | Access/transfer GPU memory |
MTL::ResourceStorageModeShared | cudaMallocManaged() | Unified memory accessible by CPU and GPU |
Execution Configuration | ||
MTL::Size gridSize(x, y, z) | dim3 gridSize(x, y, z) | Specifying grid dimensions |
MTL::Size threadgroupSize(x, y, z) | dim3 blockSize(x, y, z) | Specifying threadgroup/block dimensions |
computeEncoder->dispatchThreads(gridSize, threadgroupSize) | kernel<<<gridSize, blockSize>>() | Launching kernel with dimensions |
Resource Management | ||
buffer->release() | cudaFree(ptr) | Free GPU memory |
GPU Selection | ||
MTL::CreateSystemDefaultDevice() | cudaSetDevice(0) | Select the default GPU |
MTL::CopyAllDevices() | cudaGetDeviceCount() + iterate | Get all available GPUs |
Revisiting the memory models
Before we wrap up this comparison of Apple Silicon and NVIDIA GPUs, it is worthwhile to take a closer second look at the memory models of both architectures since it is one of the biggest differences between the two GPUs. The NVIDIA memory model is illustrated in the figure below (reproduced from Fig 5.2 of PMPP), and shows the different types of memory available on an NVIDIA GPU.

As shown in the diagram, the model consists of a device grid containing multiple blocks, each holding multiple threads. Each thread has access to its own registers
for fast temporary storage. Threads within the same block can communicate through shared memory
, which provides higher bandwidth and lower latency than global memory
but is limited to intra-block access. All threads across the entire grid can access global memory, albeit with higher latency. Additionally, constant memory
provides read-only storage that's optimized for broadcast access when all threads read the same value simultaneously.
This memory hierarchy is fundamental to CUDA programming as it guides developers in optimizing data locality – placing frequently accessed data in faster memory spaces (registers and shared memory) while using global memory for data that needs to be accessed across blocks. The host (CPU) can transfer data to and from the device's global and constant memories but cannot directly access registers or shared memory, creating a clear separation between host and device memory spaces.

Similar to a an NVIDIA GPU device, the Apple Silicon GPUs also contain registers
and threadgroup local shared memory
. However, unlike the NVIDIA GPUs, there is no separate device global memory or constant memory which the GPU and CPU go back and forth on writing data to and reading data from. Instead, the unified memory serves as the global memory for the GPU, and is shared between the CPU and GPU. The constant buffers
are essentially cache optimized regions on the unified memory itself. Memory coherency is managed by both the Apple hardware and Metal API (as we saw earlier in the allocation of device buffers).
In the Apple developer article Choosing a Resource Storage Mode for Apple GPUs, it is illustrated how Metal can be used to control how buffers and textures are allocated in system memory, and how the CPU and GPU access them:
Storage Mode | Definition / Memory Location | CPU Access | GPU Access | Common Use Cases |
---|---|---|---|---|
MTLStorageModeShared | System (unified) memory accessible by both CPU and GPU. | Read/Write: CPU can directly read/write. Data is CPU–GPU coherent. | Read/Write: GPU can directly access the same shared memory. | Default for buffers/textures on Apple GPUs. Frequent CPU updates to GPU data. Ideal when CPU and GPU both need to read/write the resource. |
MTLStorageModePrivate | System (unified) memory allocated for GPU‐only access. | No direct CPU access: must use blit/encode operations to copy or fill. | Read/Write: GPU can fully access, read, and modify the data. | Render targets and intermediate resources. Large textures that don't require CPU readback. Optimized for GPU‐only usage to improve performance. |
MTLStorageModeMemoryless | Tile memory (on‐chip GPU memory), allocated per pass. | No CPU access: ephemeral resource. | Read/Write by GPU within a single render or compute pass only. | Depth, stencil, or color buffers that are only needed temporarily within a pass. Very fast access with reduced power usage, freed at end of pass. |
On-device memory
The physical memory available on the NVIDIA CUDA streaming multiprocessors and Apple Silicon Metal cores is often referred to as on-device memory. This includes the registers, shared memory, and caches (if available) on the GPU. Let us take a closer look at the NVIDIA on-device memory hierarchy.

This memory hierarchy illustration is similar to the one in Fig 7.10 from PMPP, but actually reproduced from NVIDIA's developer blog CUDA Refresher: The CUDA Programming Model. In the NVIDIA CUDA memory hierarchy, each thread has access to private registers that are invisible to other threads and managed by the compiler. At the SM level, a fast on-chip scratchpad memory serves dual purposes as both L1 cache and shared memory (SMEM), enabling all threads in a CUDA block to share data, with physical resources distributed among blocks running on the same SM. Each SM also contains read-only memory components including instruction cache, constant memory, texture memory, and RO cache that kernels can access but not modify. The L2 cache represents a higher level in the hierarchy, shared across all SMs and accessible by every thread in every CUDA block. The global memory represents the DRAM sitting in the GPU.
There is trade-off between memory access speed and its storage capacity, due to power and size constraints introduced by larger/faster memory devices. The higher we go in the memory hierarchy, the larger the memory size but the slower the access time. The thread registers are the fastest on-chip memory, but each thread has a very limited amount of register memory available to it. On the other hand, the global device memory is the largest memory on-device memory available to each SM but has extremely low bandwidth and high latency to read/write from.
As an aside, the constant memory (not depicted here for brevity), is usually implemented as a dedicated area on the device global memory and a fast read-only cache usually small in size that is shared among all threads in a block. It is optimized for broadcast access, meaning that if all threads in a block read the same value from constant memory, it will be very fast. For a much more thorough explanation of the memory hierarchy, I would suggest reading the technical report Dissecting the NVIDIA volta GPU architecture via microbenchmarking.

I have tried to recreate an equivalent memory hierarchy system diagram for Apple Silicon GPUs above, as it helps illustrate several key differences from NVIDIA's approach. In the Apple unified system memory architecture, each GPU core has separate and dedicated memory components rather than the configurable shared resources found in NVIDIA SMs. The memory hierarchy starts with fast registers at the top, providing thread-local storage. Below this sits dedicated shared memory, which allows threads within the same threadgroup to communicate efficiently. Unlike NVIDIA's combined L1/shared memory approach, Apple implements a separate but small L1 cache for data alongside a dedicated instruction cache (not depicted). The L2 cache (much smaller than NVIDIA's) is shared across all cores, but uniquely Apple adds a substantial System Level Cache (SLC) between L2 and main memory — a layer absent in NVIDIA GPUs. The SLC, like the unified memory below it in the memory hierarchy, is shared between the GPU and CPU.
An apples 🤭 to oranges comparison
Since the memory hierarchy and its conseqeuent access patterns are very important for writing performant GPU code, it would help to illustrate the relevant numbers for Apple vs NVIDIA GPUs. Given below is a high‐level comparison table of key GPU memory‐hierarchy and architectural details for the Apple M1 Max GPU versus the NVIDIA GeForce RTX 3090. Both models are somewhat dated now (released in 2020-21), and are roughly comparable in terms of pricing (here in Canada the RTX 3090 sells for anywhere between CAD $2K-$4K) while an M1 Max capable MacBook/Mac Studio/Mac Mini sells for a similar CAD $2K-$4K range (sourced from various sources since the Apple online store doesn't seem to sell new ones anymore). While it is hard to compare pricing since the Apple M1 Max is not available as a discrete GPU, I think the table below is still useful.
Memory Feature | Apple M1 Max GPU | NVIDIA GeForce RTX 3090 (GA102) |
---|---|---|
Memory Type | Unified LPDDR5 (shared across CPU, GPU, Neural Engine, etc.) | GDDR6X (dedicated GPU memory) |
Memory Capacity | 32 GB or 64 GB | 24 GB |
Memory Bus Width | 512‐bit interface to LPDDR5 | 384‐bit interface to GDDR6X |
Peak Memory Bandwidth | Up to ~400 GB/s | ~936 GB/s |
L1 Data Cache | 8 KB per core | 128 KB combined L1/Shared Memory per SM (configurable) |
Shared Memory | ~60 KB per core | 128 KB combined L1/Shared Memory per SM (configurable) |
L2 Cache (Total) | 512 KB | 6 MB |
L3 Cache (SLC) | 48 MB | N/A |
Register File Size | ~208 KB per core | ~256 KB per SM (65,536 × 32‐bit registers) |
A Brief Note on Constraints
Since my primary concern is with writing kernels to perform parallel computation for machine learning workloads on the GPUs, I thought it would be useful to summarize some of the key performance and programming constraints for both Apple and NVIDIA GPUs. From my (limited) experience writing performant CUDA kernels, these bottlenecks tend to become very important when thinking about optimizing parallel algorithms on GPUs - with different bottlenecks becoming important at different scales. For a more detailed discussion of these bottlenecks, I would recommend reading Chapter 6 - Performance Considerations in PMPP.
Feature | Apple M1 Max GPU | NVIDIA GeForce RTX 3090 (GA102) |
---|---|---|
Architecture / Generation | Apple 7 (M1 Max) | Ampere |
Warp/SIMD Size | 32 threads per SIMD-group | 32 threads per warp |
Execution Model | SIMD – 32 lanes share instruction stream | SIMT – each thread has its own context |
Compute Units | 32 GPU cores | 82 Streaming Multiprocessors (SMs) |
ALUs per Compute Unit | 128 ALUs per core | 128 FP32 ALUs per SM + 4 Tensor Cores |
Total ALUs | ~4,096 ALUs | ~10,496 FP32 ALUs |
Clock Frequency | 1.296 GHz | 1.7 GHz (boost up to 1.8 GHz) |
Theoretical FP32 Performance | 10.617 TFLOPS | 35.6 TFLOPS |
Low-Precision Math | FP16: 10.617 TFLOPS | FP16: 71.2 TFLOPS (via Tensor Cores) |
Max Threads Per Block/Group | 1024 threads | 1024 threads |
Thread Dimensions | 1024 × 1024 × 1024 | 1024 × 1024 × 64 |
Grid/Dispatch Dimensions | 2³²-1 × 65535 × 65535 | 2³¹-1 × 65535 × 65535 |
Registers Per Block/Group | 65536 | 65536 |
Max Blocks/Threadgroups Per CU | 24 threadgroups per core | 16 blocks per SM |
Registers Per Thread | Up to 128 | Up to 255 |
Max Threads Per Core/SM | 384-3072 | 256-1536 |
Instruction Cache | 12 KB | 32 KB |
What's Next
Hopefully I have managed to convey my understanding of the Apple GPU architecture, memory hierarchy as well as the programming model in Metal to you. My discussion on the bottlenecks and performance constraints might come across a little out of place in this blog. But it comes in handy when writing a more practical programming guide for writing Metal kernels for GPU operations. I am currently trying to implement and systematically improve upon a square matrix fp32 matmul kernel in Metal, something similar to siboehm's great article How to Optimize a CUDA Matmul Kernel for cuBLAS-like Performance: a Worklog. By writing my kernels in Metal, I have already been able to much achieve better throughput than WeGPU kernels running on Apple Silicon for matrix multiplication workloads. However, I haven't been able to match or get close enough to the optimized MPSMatrixMultiplication kernel from the Metal Performance Shaders library. I will keep you posted about my experiments in my next blog!
References
For those looking to learn more about the Apple GPU architecture and Metal, here are some great references for further reading on the Apple Silicon GPUs and Metal that I came across while writing this article:
- Alyssa Rosenweig's 4-part series Dissecting the Apple M1 GPU
- Dougall Johnson's work on reverse engineering the Apple G13 GPU architecture (used by M1)
- Philip Turner's Apple GPU microarchitecture metal-benchmarks
- Hübner et al's 2025 paper Apple vs. Oranges: Evaluating the Apple Silicon M-Series SoCs for HPC Performance and Efficiency
- Biran Vogel's notes on benchmarkding scientific computing (matmul) using Metal Performance Testing on M1 CPU and GPU
Besides these references, here is a handy guide to all the documents relevant to Apple Metal and NVIDIA CUDA that were referred to earlier (numbers indicate the order they appeared):
- Metal
- Metal documentation
- Metal sample code library
- GPUs Go Brrr - Hazy Research Group Stanford Blog
- Programming Massively Parallel Processors
- CUDA
- metal-cpp
- Metal Performance Shaders
- CUTLASS
- Performing calculations on a GPU
- Andrew Chan: Thoughts on Metal (vs. CUDA)
- Optimize Metal performance for Apple Silicon Macs
- Choosing a Resource Storage Mode for Apple GPUs
- CUDA Refresher: The CUDA Programming Model
- Dissecting the NVIDIA volta GPU architecture via microbenchmarking
- Simon Boehm: How to Optimize a CUDA Matmul Kernel for cuBLAS-like Performance: a Worklog
- Zach Nussbaum: Optimizing a WebGPU Matmul Kernel
- MPSMatrixMultiplication