Custom Kubernetes Schedulers for GPU-Intensive ML Workloads

15 min read
Goh Ling Yong
Technology enthusiast and software architect specializing in AI-driven development tools and modern software engineering practices. Passionate about the intersection of artificial intelligence and human creativity in building tomorrow's digital solutions.

The Inadequacy of the Default Scheduler for ML Platforms

The kube-scheduler component in Kubernetes is responsible for one of the most critical control plane functions: assigning pending pods to worker nodes. It operates on a two-phase cycle: filtering and scoring. First, it filters out nodes that cannot satisfy a pod's hard requirements (e.g., CPU, memory, GPU requests, taints/tolerations). Then, it scores the remaining feasible nodes based on a set of priority functions (LeastAllocated, BalancedResourceAllocation, etc.) and selects the one with the highest score.

For stateless web applications, this model is highly effective. However, for large-scale distributed machine learning training, this general-purpose approach reveals critical limitations:

  • Topology Ignorance: A pod requesting 4 NVIDIA A100 GPUs might be scheduled on a node with 8 GPUs. The default scheduler is unaware if those 4 GPUs are interconnected via high-speed NVLink or are split across different PCIe root complexes or NUMA nodes. This placement decision can introduce a severe communication bottleneck, degrading training performance by orders of magnitude for models that rely on frequent gradient exchange.
  • Lack of Gang Scheduling: Distributed training frameworks like Horovod or PyTorch's DistributedDataParallel often require a fixed number of worker pods to start simultaneously. The default scheduler processes pods individually. It might successfully schedule 5 out of 8 required workers, with the remaining 3 pending due to resource fragmentation. The 5 scheduled pods consume valuable GPU resources while sitting idle, waiting for their peers. This leads to resource deadlock and significant waste, a condition unacceptable in costly GPU environments.
  • Inflexible Scoring: The default scoring plugins are designed for generic cluster health. They do not intrinsically value co-locating related pods or understand the specific performance characteristics of ML hardware. We need a mechanism to explicitly favor nodes that offer the best performance profile for a given ML job.
  • To overcome these limitations, we must move beyond simple nodeAffinity and podAntiAffinity rules and implement a custom scheduler tailored to the unique demands of ML workloads. This involves using the Kubernetes scheduler framework to build custom plugins that encode our domain-specific scheduling logic.

    Architecture of a Custom Scheduler

    A custom scheduler is not a complete replacement but rather an extension of the core scheduling logic. We leverage the k8s.io/kube-scheduler/pkg/scheduler/framework/plugins package, which provides a series of extension points in the scheduling cycle. Our custom logic, written in Go, will hook into these points.

    Key Extension Points for ML Workloads:

    * Filter: Similar to the filtering phase, it allows a plugin to disqualify a node for a given pod. We will use this to enforce strict GPU topology requirements.

    * Score: After filtering, this plugin assigns a score (integer) to each feasible node. The node with the highest cumulative score is chosen. We'll use this to rank nodes based on how well their GPU topology matches the pod's needs.

    * Permit: This is a crucial point for gang scheduling. It runs after a pod has been deemed schedulable for a node but before it is bound. A plugin at this stage can Wait or Deny a pod. This allows us to hold a pod in a waiting state until all members of its group are also ready to be scheduled.

    Our custom scheduler will be a standalone Go binary, packaged in a Docker container, and deployed within the Kubernetes cluster as a Deployment. It will watch for unscheduled pods with a specific schedulerName and execute our custom scheduling logic.

    Here is the boilerplate for a custom scheduler's main.go:

    go
    // main.go
    package main
    
    import (
    	"os"
    	"k8s.io/component-base/cli"
    	"k8s.io/kubernetes/cmd/kube-scheduler/app"
    
    	// Import your custom plugins here
    	_ "my.company.com/ml-scheduler/plugins/topologyscore"
    	_ "my.company.com/ml-scheduler/plugins/gangschedule"
    )
    
    func main() {
    	// The command boilerplate is provided by the Kubernetes scheduler app package.
    	// We create a new scheduler command, providing our custom plugins.
    	// The scheduler framework will automatically discover and register plugins
    	// that were imported via the blank import.
    	command := app.NewSchedulerCommand(
    		// Add custom plugins to the registry
    		// app.WithPlugin(topologyscore.Name, topologyscore.New),
    		// app.WithPlugin(gangschedule.Name, gangschedule.New),
    	)
    
    	code := cli.Run(command)
    	os.Exit(code)
    }

    Implementing a Topology-Aware GPU Scheduler

    Our first goal is to ensure pods requesting multiple GPUs are placed on nodes where those GPUs are interconnected with NVLink for maximum bandwidth.

    Prerequisites:

  • NVIDIA Device Plugin for Kubernetes: This must be running in the cluster to expose GPUs as a schedulable resource (nvidia.com/gpu).
  • Node Feature Discovery (NFD): We use NFD to inspect hardware on each node and apply labels. We'll configure it to detect NVLink connectivity between GPUs and apply a label like gpu-topology.my.company.com/nvlink-pairs: "4" to nodes that have 4 pairs of NVLink-connected GPUs.
  • Pod Specification:

    We'll use an annotation on the Pod to specify the topology requirement:

    yaml
    apiVersion: v1
    kind: Pod
    metadata:
      name: distributed-training-worker-0
      annotations:
        gpu-topology.my.company.com/min-nvlink-gpus: "4"
    spec:
      schedulerName: ml-scheduler
      containers:
      - name: cuda-container
        image: nvidia/cuda:11.4.0-base-ubuntu20.04
        command: ["sleep", "3600"]
        resources:
          limits:
            nvidia.com/gpu: 4

    Implementation of Filter and Score Plugins:

    We will create a single plugin that implements both the Filter and Score extension points.

    go
    // plugins/topologyscore/topology_score.go
    package topologyscore
    
    import (
    	"context"
    	"fmt"
    	"strconv"
    
    	"k8s.io/api/core/v1"
    	"k8s.io/apimachinery/pkg/runtime"
    	"k8s.io/klog/v2"
    	"k8s.io/kubernetes/pkg/scheduler/framework"
    )
    
    const (
    	Name = "TopologyScore"
    	AnnotationMinNVLinkGPUs = "gpu-topology.my.company.com/min-nvlink-gpus"
    	LabelNVLinkPairs      = "gpu-topology.my.company.com/nvlink-pairs"
    )
    
    type TopologyScore struct{}
    
    var _ framework.FilterPlugin = &TopologyScore{}
    var _ framework.ScorePlugin = &TopologyScore{}
    
    func (ts *TopologyScore) Name() string {
    	return Name
    }
    
    // New initializes a new plugin and returns it.
    func New(_ runtime.Object, _ framework.Handle) (framework.Plugin, error) {
    	return &TopologyScore{}, nil
    }
    
    // Filter plugin implementation
    func (ts *TopologyScore) Filter(ctx context.Context, state *framework.CycleState, pod *v1.Pod, nodeInfo *framework.NodeInfo) *framework.Status {
    	// Get the number of GPUs requested by the pod
    	requestedGPUs := getRequestedGPUs(pod)
    	if requestedGPUs == 0 {
    		return framework.NewStatus(framework.Success)
    	}
    
    	// Check if the pod has the topology annotation
    	minNVLinkGPUsStr, ok := pod.Annotations[AnnotationMinNVLinkGPUs]
    	if !ok {
    		// No topology requirement, so this plugin doesn't filter
    		return framework.NewStatus(framework.Success)
    	}
    
    	minNVLinkGPUs, err := strconv.Atoi(minNVLinkGPUsStr)
    	if err != nil {
    		klog.Errorf("Invalid value for annotation %s: %v", AnnotationMinNVLinkGPUs, err)
    		return framework.NewStatus(framework.Error, "invalid topology annotation")
    	}
    
    	// The pod must request at least the number of GPUs specified in the annotation
    	if requestedGPUs < int64(minNVLinkGPUs) {
    		return framework.NewStatus(framework.Unschedulable, fmt.Sprintf("pod requests %d GPUs but requires %d with NVLink", requestedGPUs, minNVLinkGPUs))
    	}
    
    	// Check the node's labels for NVLink topology
    	node := nodeInfo.Node()
    	if node.Labels == nil {
    		return framework.NewStatus(framework.Unschedulable, "node has no topology labels")
    	}
    
    	nvlinkPairsStr, ok := node.Labels[LabelNVLinkPairs]
    	if !ok {
    		return framework.NewStatus(framework.Unschedulable, "node missing nvlink-pairs label")
    	}
    
    	nvlinkPairs, err := strconv.Atoi(nvlinkPairsStr)
    	if err != nil {
    		return framework.NewStatus(framework.Error, "invalid nvlink-pairs label on node")
    	}
    
    	// A node with N pairs can support a pod needing up to 2*N GPUs on NVLink.
    	// This is a simplified model; a real implementation might need more granular info.
    	if int64(nvlinkPairs*2) < requestedGPUs {
    		return framework.NewStatus(framework.Unschedulable, fmt.Sprintf("node only has %d NVLink-connected GPUs, pod needs %d", nvlinkPairs*2, requestedGPUs))
    	}
    
    	return framework.NewStatus(framework.Success)
    }
    
    // Score plugin implementation
    func (ts *TopologyScore) Score(ctx context.Context, state *framework.CycleState, p *v1.Pod, nodeName string) (int64, *framework.Status) {
        // Note: We need the full NodeInfo to get labels. The handle can provide this.
        // In a real implementation, you'd get the nodeInfo from a snapshot provided by the framework handle.
        // For this example, we'll assume we can get it. A more robust way is to use the Filter plugin's results cached in CycleState.
    
    	// We want to give higher scores to nodes that better match the request.
    	// For simplicity, we just return a high score if it has the label. A more advanced
    	// scorer would rank nodes with 8 NVLink pairs higher than nodes with 4 pairs for a 4-GPU pod.
    	requestedGPUs := getRequestedGPUs(p)
    	if requestedGPUs == 0 {
    		return 0, framework.NewStatus(framework.Success)
    	}
    
    	// A simple scoring logic: give a high score to any node that passed the filter.
    	// A more complex logic could give a higher score to nodes with more NVLink pairs than the minimum required.
    	// This encourages scheduling on more powerful nodes, leaving less powerful ones for smaller jobs.
    	return 100, framework.NewStatus(framework.Success)
    }
    
    func (ts *TopologyScore) ScoreExtensions() framework.ScoreExtensions {
    	return nil // No normalization needed for this simple scorer
    }
    
    func getRequestedGPUs(pod *v1.Pod) int64 {
    	var count int64
    	for _, container := range pod.Spec.Containers {
    		if val, ok := container.Resources.Limits["nvidia.com/gpu"]; ok {
    			count += val.Value()
    		}
    	}
    	return count
    }
    
    // Registration boilerplate
    func init() {
        // This part is simplified. In a real scenario, you'd use a factory function
        // and register it in your main.go.
    }

    This plugin ensures that only nodes with sufficient NVLink-connected GPUs are considered, and it gives them a high score, prioritizing them over other nodes.

    Implementing Gang Scheduling with a `Permit` Plugin

    For distributed training, we need to ensure an entire group of pods is scheduled atomically.

    1. Define a PodGroup CRD:

    First, we define a Custom Resource Definition that represents a group of pods that must be co-scheduled.

    yaml
    # crd/podgroup.yaml
    apiVersion: apiextensions.k8s.io/v1
    kind: CustomResourceDefinition
    metadata:
      name: podgroups.my.company.com
    spec:
      group: my.company.com
      versions:
        - name: v1alpha1
          served: true
          storage: true
          schema:
            openAPIV3Schema:
              type: object
              properties:
                spec:
                  type: object
                  properties:
                    minMember:
                      type: integer
                      description: The minimal number of pods to be co-scheduled.
                    scheduleTimeoutSeconds:
                      type: integer
                      default: 180
                      description: Timeout for waiting for the minMember pods.
      scope: Namespaced
      names:
        plural: podgroups
        singular: podgroup
        kind: PodGroup
        shortNames:
        - pg

    2. Pod and PodGroup Specification:

    Users will create a PodGroup resource and then label their pods to associate them with it.

    yaml
    # podgroup-example.yaml
    apiVersion: my.company.com/v1alpha1
    kind: PodGroup
    metadata:
      name: my-distributed-job
      namespace: ml-jobs
    spec:
      minMember: 4
    ---
    # pod-example.yaml
    apiVersion: v1
    kind: Pod
    metadata:
      name: distributed-worker-0
      namespace: ml-jobs
      labels:
        pod-group.my.company.com/name: my-distributed-job
    spec:
      schedulerName: ml-scheduler
      # ... container spec

    3. Implementation of the Permit Plugin:

    The Permit plugin is the core of this logic. It holds pods in a Wait state until the minMember count is reached.

    go
    // plugins/gangschedule/gang_schedule.go
    package gangschedule
    
    import (
    	"context"
    	"fmt"
    	"sync"
    	"time"
    
    	"k8s.io/api/core/v1"
    	"k8s.io/apimachinery/pkg/runtime"
    	"k8s.io/klog/v2"
    	"k8s.io/kubernetes/pkg/scheduler/framework"
    
        // You would need to import your generated clientset for the PodGroup CRD
        // podgroupv1alpha1 "my.company.com/ml-scheduler/pkg/apis/podgroup/v1alpha1"
        // clientset "my.company.com/ml-scheduler/pkg/client/clientset/versioned"
    )
    
    const (
    	Name = "GangSchedule"
    	PodGroupNameLabel = "pod-group.my.company.com/name"
    )
    
    // This is a simplified in-memory store for pod groups.
    // A production implementation should be more robust and potentially
    // use the PodGroup CRD status field to coordinate.
    type PodGroupInfo struct {
    	name string
    	namespace string
    	minMember int
    	pods []*v1.Pod
    	creationTime time.Time
    	mutex sync.Mutex
    }
    
    type GangSchedule struct {
    	handle framework.Handle
    	podGroupInfos sync.Map // Map[string]*PodGroupInfo
        // podGroupClient clientset.Interface
    }
    
    var _ framework.PermitPlugin = &GangSchedule{}
    
    func (gs *GangSchedule) Name() string {
    	return Name
    }
    
    func New(_ runtime.Object, h framework.Handle) (framework.Plugin, error) {
        // In a real implementation, initialize the CRD client here.
    	return &GangSchedule{
    		handle: h,
    	}, nil
    }
    
    // Permit plugin implementation
    func (gs *GangSchedule) Permit(ctx context.Context, state *framework.CycleState, pod *v1.Pod, nodeName string) (*framework.Status, time.Duration) {
    	podGroupName, ok := pod.Labels[PodGroupNameLabel]
    	if !ok {
    		// Not part of a gang, allow immediately
    		return framework.NewStatus(framework.Success), 0
    	}
    
    	groupKey := fmt.Sprintf("%s/%s", pod.Namespace, podGroupName)
    
    	// Get or create the PodGroupInfo for this group
    	val, _ := gs.podGroupInfos.LoadOrStore(groupKey, &PodGroupInfo{
    		name: podGroupName,
    		namespace: pod.Namespace,
    		creationTime: time.Now(),
    		// In a real implementation, fetch minMember from the PodGroup CRD
    		minMember: 4, // Hardcoded for example
    	})
    	pgInfo := val.(*PodGroupInfo)
    
    	pgInfo.mutex.Lock()
    	defer pgInfo.mutex.Unlock()
    
    	// Add the current pod to the group
    	pgInfo.pods = append(pgInfo.pods, pod)
    
    	klog.Infof("Pod %s/%s added to group %s. Current count: %d, required: %d", pod.Namespace, pod.Name, groupKey, len(pgInfo.pods), pgInfo.minMember)
    
    	if len(pgInfo.pods) < pgInfo.minMember {
    		klog.Infof("Group %s not ready, pod %s/%s will wait.", groupKey, pod.Namespace, pod.Name)
    		// Group not ready, tell the framework to wait. Timeout is crucial.
    		// The timeout should be fetched from the PodGroup CRD.
    		return framework.NewStatus(framework.Wait), 180 * time.Second
    	}
    
    	// The group is ready! Allow all waiting pods in this group to proceed.
    	klog.Infof("Group %s is now ready with %d pods. Allowing all to proceed.", groupKey, len(pgInfo.pods))
    	for _, waitingPod := range pgInfo.pods {
    		// The framework handle allows us to signal waiting pods.
    		gs.handle.IterateOverWaitingPods(func(wp framework.WaitingPod) {
    			if wp.GetPod().UID == waitingPod.UID {
    				klog.Infof("Allowing waiting pod %s/%s", wp.GetPod().Namespace, wp.GetPod().Name)
    				wp.Allow(gs.Name())
    			}
    		})
    	}
    
    	// Clean up the group info
    	gs.podGroupInfos.Delete(groupKey)
    
    	// Allow the current pod to proceed
    	return framework.NewStatus(framework.Success), 0
    }

    This Permit plugin intercepts any pod with the pod-group.my.company.com/name label. It adds the pod to an in-memory group and checks if the group has reached its minMember size. If not, it returns a Wait status, holding the pod. Once the threshold is met, it iterates through all waiting pods in the framework's queue and explicitly allows every member of the now-complete group to proceed to the binding phase.

    Production Deployment and Configuration

    1. Scheduler Configuration:

    We need a KubeSchedulerConfiguration file to enable our custom plugins and disable default plugins that might conflict.

    yaml
    # scheduler-config.yaml
    apiVersion: kubescheduler.config.k8s.io/v1
    kind: KubeSchedulerConfiguration
    leaderElection:
      leaderElect: true
    clientConnection:
      kubeconfig: "/etc/kubernetes/scheduler.conf"
    profiles:
      - schedulerName: ml-scheduler
        plugins:
          # Our custom plugins are enabled here
          filter:
            enabled:
              - name: "TopologyScore"
          score:
            enabled:
              - name: "TopologyScore"
          permit:
            enabled:
              - name: "GangSchedule"
          # We keep some default plugins that are still useful
          queueSort:
            enabled:
              - name: "PrioritySort"
          preFilter:
            enabled:
              - name: "NodeResourcesFit"
              - name: "NodePorts"
          bind:
            enabled:
              - name: "DefaultBinder"

    2. Deployment Manifest:

    The custom scheduler runs as a Deployment in the kube-system namespace.

    yaml
    # custom-scheduler-deployment.yaml
    apiVersion: v1
    kind: ServiceAccount
    metadata:
      name: ml-scheduler
      namespace: kube-system
    ---
    apiVersion: rbac.authorization.k8s.io/v1
    kind: ClusterRoleBinding
    metadata:
      name: ml-scheduler-as-kube-scheduler
    subjects:
    - kind: ServiceAccount
      name: ml-scheduler
      namespace: kube-system
    roleRef:
      kind: ClusterRole
      name: system:kube-scheduler
      apiGroup: rbac.authorization.k8s.io
    ---
    # Additional role to access our PodGroup CRDs
    apiVersion: rbac.authorization.k8s.io/v1
    kind: ClusterRole
    metadata:
      name: podgroup-reader
    rules:
    - apiGroups: ["my.company.com"]
      resources: ["podgroups"]
      verbs: ["get", "list", "watch"]
    ---
    apiVersion: rbac.authorization.k8s.io/v1
    kind: ClusterRoleBinding
    metadata:
      name: ml-scheduler-podgroup-reader
    subjects:
    - kind: ServiceAccount
      name: ml-scheduler
      namespace: kube-system
    roleRef:
      kind: ClusterRole
      name: podgroup-reader
      apiGroup: rbac.authorization.k8s.io
    ---
    apiVersion: apps/v1
    kind: Deployment
    metadata:
      name: ml-scheduler
      namespace: kube-system
      labels:
        app: ml-scheduler
    spec:
      replicas: 2 # For High Availability
      selector:
        matchLabels:
          app: ml-scheduler
      template:
        metadata:
          labels:
            app: ml-scheduler
        spec:
          serviceAccountName: ml-scheduler
          containers:
          - name: ml-scheduler
            image: my.company.com/ml-scheduler:v1.0.0
            args:
            - /bin/kube-scheduler
            - --config=/etc/kubernetes/scheduler-config.yaml
            - -v=4 # Verbose logging
            volumeMounts:
            - name: scheduler-config
              mountPath: /etc/kubernetes
          volumes:
          - name: scheduler-config
            configMap:
              name: ml-scheduler-config

    Advanced Considerations and Edge Cases

    * High Availability: The deployment runs with two replicas. The --leader-elect=true flag ensures only one instance is active at any time, preventing multiple schedulers from trying to bind the same pod.

    * State Management in Permit Plugin: The in-memory map in our GangSchedule plugin is a single point of failure. If the active scheduler pod dies, all waiting pod group information is lost. A production-grade implementation must persist this state. A common pattern is to update the status subresource of the PodGroup CRD with the list of admitted pod UIDs. When a new scheduler takes over, it can reconcile the state from the CRDs.

    * Interaction with Cluster Autoscaler: The Cluster Autoscaler simulates scheduling to decide if a new node is needed. It must be configured to use the logic of our custom scheduler. This is typically done by providing the autoscaler with the same scheduler configuration profile, ensuring it understands that 8 pending pods might only be schedulable if a new 8-GPU node is added, rather than trying to fit them onto existing, smaller nodes.

    * Preemption: Our gang scheduling logic must interact correctly with preemption. If a high-priority pod needs resources, it might preempt one of the pods from our waiting group. The Permit plugin needs logic to handle this: it should reject the entire group if one of its members is preempted, preventing a partial launch.

    * Performance Benchmarking: To validate the scheduler, we must benchmark. A typical test involves:

    1. Baseline: Deploy 20 distributed jobs (4 pods each) on a cluster using the default scheduler. Measure the time-to-train (from first pod creation to job completion) and GPU utilization over time. Observe high idle GPU time due to partial scheduling.

    2. Test: Deploy the same 20 jobs using the ml-scheduler. Measure the same metrics. The expected outcome is a lower average time-to-train and consistently higher GPU utilization, as jobs are scheduled in atomic blocks and placed on topology-aware nodes.

    By moving beyond the default scheduler, we transform Kubernetes from a generic container orchestrator into a high-performance, specialized platform for machine learning, directly addressing the performance and efficiency bottlenecks that plague large-scale ML infrastructure.

    Found this article helpful?

    Share it with others who might benefit from it.

    More Articles