Back to Blog
Developer Tutorial
Hands-on Implementation

Building Your Own Music Separation Pipeline: A Developer's Guide

Transform theoretical knowledge into production-ready code. Build, evaluate, and deploy a complete music source separation system using modern tools and best practices.

JewelMusic Development Team
February 14, 2025
28 min read
Developer Guide Music Separation

From Theory to Practice

You've learned the mathematics, studied the architectures, and understood the trade-offs. Now it's time to build something real. This comprehensive guide will take you through creating a production-ready music source separation pipeline from scratch.

By the end of this tutorial, you'll have a complete system that can separate vocals from any song, with proper evaluation metrics, error handling, and deployment considerations. We'll use Python, PyTorch, and modern MLOps practices to create something you can actually ship.

🎯 What We'll Build

Core Components

  • • Data preprocessing pipeline
  • • Neural network training framework
  • • Evaluation and metrics system
  • • Real-time inference API

Production Features

  • • Docker containerization
  • • Model versioning and rollback
  • • Monitoring and logging
  • • Performance optimization

Project Setup and Architecture

Project Structure
music-separation-pipeline/
├── src/
│   ├── data/
│   │   ├── __init__.py
│   │   ├── dataset.py          # Dataset classes and loaders
│   │   ├── preprocessing.py    # Audio preprocessing utilities
│   │   └── augmentation.py     # Data augmentation techniques
│   ├── models/
│   │   ├── __init__.py
│   │   ├── unet.py            # U-Net architecture
│   │   ├── demucs.py          # Demucs implementation
│   │   └── base.py            # Base model class
│   ├── training/
│   │   ├── __init__.py
│   │   ├── trainer.py         # Training loop and utilities
│   │   └── losses.py          # Loss functions
│   ├── inference/
│   │   ├── __init__.py
│   │   ├── separator.py       # Inference engine
│   │   └── api.py             # REST API server
│   └── evaluation/
│       ├── __init__.py
│       ├── metrics.py         # SDR, SIR, SAR calculations
│       └── benchmark.py       # Evaluation scripts
├── configs/
│   ├── model_config.yaml      # Model hyperparameters
│   ├── training_config.yaml   # Training settings
│   └── api_config.yaml        # API configuration
├── docker/
│   ├── Dockerfile
│   └── docker-compose.yml
├── tests/
├── notebooks/                 # Jupyter notebooks for exploration
├── requirements.txt
├── setup.py
└── README.md

Design Principles

  • • Modular architecture for easy extension
  • • Configuration-driven development
  • • Clear separation of concerns
  • • Comprehensive testing coverage

Key Dependencies

  • • PyTorch: Deep learning framework
  • • torchaudio: Audio processing
  • • librosa: Audio analysis
  • • FastAPI: High-performance web API
Environment Setup
# Create virtual environment
python -m venv music-separation-env
source music-separation-env/bin/activate  # On Windows: music-separation-env\Scripts\activate

# Install dependencies
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
pip install librosa soundfile numpy scipy
pip install fastapi uvicorn pydantic
pip install wandb tensorboard  # For experiment tracking
pip install pytest black flake8  # Development tools

# Install project in development mode
pip install -e .

# Verify installation
python -c "import torch; print(f'PyTorch {torch.__version__}, CUDA available: {torch.cuda.is_available()}')"
python -c "import librosa; print(f'librosa {librosa.__version__}')"

Data Pipeline: Processing Audio for ML

Dataset Implementation
# src/data/dataset.py
import torch
from torch.utils.data import Dataset
import torchaudio
import librosa
import numpy as np
from pathlib import Path
from typing import Tuple, Optional

class MusicSeparationDataset(Dataset):
    """
    Dataset for music source separation training.
    Expects directory structure:
    dataset/
      ├── train/
      │   ├── mixture/
      │   │   ├── song1.wav
      │   │   └── song2.wav
      │   ├── vocals/
      │   │   ├── song1.wav
      │   │   └── song2.wav
      │   └── accompaniment/
      │       ├── song1.wav
      │       └── song2.wav
    """
    
    def __init__(
        self,
        data_dir: str,
        split: str = "train",
        sample_rate: int = 44100,
        segment_length: float = 6.0,  # seconds
        normalize: bool = True,
        transforms=None
    ):
        self.data_dir = Path(data_dir) / split
        self.sample_rate = sample_rate
        self.segment_samples = int(segment_length * sample_rate)
        self.normalize = normalize
        self.transforms = transforms
        
        # Find all mixture files
        self.mixture_files = list((self.data_dir / "mixture").glob("*.wav"))
        self.mixture_files.sort()
        
        # Verify corresponding files exist
        self._verify_files()
    
    def _verify_files(self):
        """Ensure all stems exist for each mixture"""
        verified_files = []
        for mix_file in self.mixture_files:
            stem_name = mix_file.stem
            vocals_path = self.data_dir / "vocals" / f"{stem_name}.wav"
            accomp_path = self.data_dir / "accompaniment" / f"{stem_name}.wav"
            
            if vocals_path.exists() and accomp_path.exists():
                verified_files.append(mix_file)
            else:
                print(f"Warning: Missing stems for {stem_name}")
        
        self.mixture_files = verified_files
        print(f"Dataset loaded: {len(self.mixture_files)} songs")
    
    def __len__(self) -> int:
        return len(self.mixture_files)
    
    def __getitem__(self, idx: int) -> dict:
        mix_file = self.mixture_files[idx]
        stem_name = mix_file.stem
        
        # Load audio files
        mixture, sr = torchaudio.load(mix_file)
        vocals, _ = torchaudio.load(
            self.data_dir / "vocals" / f"{stem_name}.wav"
        )
        accompaniment, _ = torchaudio.load(
            self.data_dir / "accompaniment" / f"{stem_name}.wav"
        )
        
        # Resample if necessary
        if sr != self.sample_rate:
            resampler = torchaudio.transforms.Resample(sr, self.sample_rate)
            mixture = resampler(mixture)
            vocals = resampler(vocals)
            accompaniment = resampler(accompaniment)
        
        # Convert to mono if stereo
        if mixture.shape[0] > 1:
            mixture = torch.mean(mixture, dim=0, keepdim=True)
            vocals = torch.mean(vocals, dim=0, keepdim=True)
            accompaniment = torch.mean(accompaniment, dim=0, keepdim=True)
        
        # Extract random segment
        audio_length = mixture.shape[1]
        if audio_length > self.segment_samples:
            start = torch.randint(0, audio_length - self.segment_samples, (1,))
            mixture = mixture[:, start:start + self.segment_samples]
            vocals = vocals[:, start:start + self.segment_samples]
            accompaniment = accompaniment[:, start:start + self.segment_samples]
        else:
            # Pad if too short
            pad_amount = self.segment_samples - audio_length
            mixture = torch.nn.functional.pad(mixture, (0, pad_amount))
            vocals = torch.nn.functional.pad(vocals, (0, pad_amount))
            accompaniment = torch.nn.functional.pad(accompaniment, (0, pad_amount))
        
        # Normalize
        if self.normalize:
            max_val = torch.max(torch.abs(mixture))
            if max_val > 0:
                mixture = mixture / max_val
                vocals = vocals / max_val
                accompaniment = accompaniment / max_val
        
        # Apply transforms
        if self.transforms:
            mixture = self.transforms(mixture)
            vocals = self.transforms(vocals)
            accompaniment = self.transforms(accompaniment)
        
        return {
            "mixture": mixture.squeeze(0),  # [samples]
            "vocals": vocals.squeeze(0),
            "accompaniment": accompaniment.squeeze(0),
            "filename": stem_name
        }
Audio Preprocessing Utilities
# src/data/preprocessing.py
import torch
import torchaudio
import librosa
import numpy as np
from typing import Tuple

class AudioPreprocessor:
    """Handles STFT and spectrogram operations"""
    
    def __init__(
        self,
        n_fft: int = 2048,
        hop_length: int = 512,
        win_length: Optional[int] = None,
        window: str = "hann",
        sample_rate: int = 44100
    ):
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.win_length = win_length or n_fft
        self.window = window
        self.sample_rate = sample_rate
        
        # Create STFT transform
        self.stft = torchaudio.transforms.Spectrogram(
            n_fft=n_fft,
            hop_length=hop_length,
            win_length=self.win_length,
            window_fn=torch.hann_window,
            power=None  # Return complex spectrogram
        )
        
        # Create inverse STFT
        self.istft = torchaudio.transforms.InverseSpectrogram(
            n_fft=n_fft,
            hop_length=hop_length,
            win_length=self.win_length,
            window_fn=torch.hann_window
        )
    
    def waveform_to_spectrogram(self, waveform: torch.Tensor) -> torch.Tensor:
        """Convert waveform to complex spectrogram"""
        # waveform: [channels, samples] or [samples]
        if waveform.dim() == 1:
            waveform = waveform.unsqueeze(0)
        
        # Apply STFT
        spec = self.stft(waveform)  # [channels, freq, time, 2] (real, imag)
        return spec
    
    def spectrogram_to_waveform(self, spectrogram: torch.Tensor) -> torch.Tensor:
        """Convert complex spectrogram back to waveform"""
        # spectrogram: [channels, freq, time, 2]
        waveform = self.istft(spectrogram)
        return waveform
    
    def magnitude_phase_split(self, complex_spec: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Split complex spectrogram into magnitude and phase"""
        # complex_spec: [channels, freq, time, 2]
        real = complex_spec[..., 0]
        imag = complex_spec[..., 1]
        
        magnitude = torch.sqrt(real**2 + imag**2)
        phase = torch.atan2(imag, real)
        
        return magnitude, phase
    
    def magnitude_phase_combine(self, magnitude: torch.Tensor, phase: torch.Tensor) -> torch.Tensor:
        """Combine magnitude and phase into complex spectrogram"""
        real = magnitude * torch.cos(phase)
        imag = magnitude * torch.sin(phase)
        
        complex_spec = torch.stack([real, imag], dim=-1)
        return complex_spec
    
    def create_masks(self, mixture_mag: torch.Tensor, source_mags: list) -> list:
        """Create soft masks for source separation"""
        # Add small epsilon to avoid division by zero
        eps = 1e-8
        total_magnitude = sum(source_mags) + eps
        
        masks = []
        for source_mag in source_mags:
            mask = source_mag / total_magnitude
            masks.append(mask)
        
        return masks

# Data augmentation transforms
class AudioAugmentation:
    """Audio data augmentation techniques"""
    
    @staticmethod
    def add_noise(waveform: torch.Tensor, noise_level: float = 0.01) -> torch.Tensor:
        """Add Gaussian noise to waveform"""
        noise = torch.randn_like(waveform) * noise_level
        return waveform + noise
    
    @staticmethod
    def time_stretch(waveform: torch.Tensor, rate: float, sample_rate: int) -> torch.Tensor:
        """Time stretching using librosa"""
        waveform_np = waveform.numpy()
        stretched = librosa.effects.time_stretch(waveform_np, rate=rate)
        return torch.from_numpy(stretched)
    
    @staticmethod
    def pitch_shift(waveform: torch.Tensor, n_steps: int, sample_rate: int) -> torch.Tensor:
        """Pitch shifting using librosa"""
        waveform_np = waveform.numpy()
        shifted = librosa.effects.pitch_shift(waveform_np, sr=sample_rate, n_steps=n_steps)
        return torch.from_numpy(shifted)

Model Implementation: U-Net for Source Separation

U-Net Architecture
# src/models/unet.py
import torch
import torch.nn as nn
import torch.nn.functional as F

class ConvBlock(nn.Module):
    """Basic convolutional block with normalization and activation"""
    
    def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        return self.conv(x)

class DownBlock(nn.Module):
    """Encoder block with max pooling"""
    
    def __init__(self, in_channels: int, out_channels: int):
        super().__init__()
        self.conv = ConvBlock(in_channels, out_channels)
        self.pool = nn.MaxPool2d(2)
    
    def forward(self, x):
        skip = self.conv(x)
        x = self.pool(skip)
        return x, skip

class UpBlock(nn.Module):
    """Decoder block with skip connections"""
    
    def __init__(self, in_channels: int, out_channels: int):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, 2, stride=2)
        self.conv = ConvBlock(in_channels, out_channels)
    
    def forward(self, x, skip):
        x = self.up(x)
        # Handle size mismatch
        diffY = skip.size()[2] - x.size()[2]
        diffX = skip.size()[3] - x.size()[3]
        x = F.pad(x, [diffX // 2, diffX - diffX // 2,
                     diffY // 2, diffY - diffY // 2])
        
        x = torch.cat([skip, x], dim=1)
        return self.conv(x)

class UNet(nn.Module):
    """U-Net for music source separation"""
    
    def __init__(
        self,
        in_channels: int = 1,
        n_sources: int = 2,
        base_channels: int = 64,
        n_layers: int = 4
    ):
        super().__init__()
        self.n_sources = n_sources
        
        # Encoder path
        self.encoder = nn.ModuleList()
        channels = [in_channels] + [base_channels * (2**i) for i in range(n_layers)]
        
        for i in range(n_layers):
            self.encoder.append(DownBlock(channels[i], channels[i+1]))
        
        # Bottleneck
        self.bottleneck = ConvBlock(channels[-1], channels[-1] * 2)
        
        # Decoder path
        self.decoder = nn.ModuleList()
        channels.reverse()
        channels[0] = channels[0] * 2  # Account for bottleneck
        
        for i in range(n_layers):
            self.decoder.append(UpBlock(channels[i], channels[i+1]))
        
        # Final output layers - separate head for each source
        self.output_layers = nn.ModuleList([
            nn.Conv2d(channels[-1], 1, 1) for _ in range(n_sources)
        ])
        
        # Activation for masks
        self.activation = nn.Sigmoid()
    
    def forward(self, x):
        # x: [batch, 1, freq, time] - magnitude spectrogram
        
        # Encoder path
        skips = []
        for encoder in self.encoder:
            x, skip = encoder(x)
            skips.append(skip)
        
        # Bottleneck
        x = self.bottleneck(x)
        
        # Decoder path
        skips.reverse()
        for decoder, skip in zip(self.decoder, skips):
            x = decoder(x, skip)
        
        # Generate masks for each source
        masks = []
        for output_layer in self.output_layers:
            mask = self.activation(output_layer(x))
            masks.append(mask)
        
        # Stack masks: [batch, n_sources, freq, time]
        masks = torch.cat(masks, dim=1)
        
        # Normalize masks to sum to 1
        masks = masks / (torch.sum(masks, dim=1, keepdim=True) + 1e-8)
        
        return masks

class MagnitudeUNet(nn.Module):
    """Complete model for magnitude spectrogram separation"""
    
    def __init__(
        self,
        n_fft: int = 2048,
        hop_length: int = 512,
        n_sources: int = 2,
        **unet_kwargs
    ):
        super().__init__()
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.n_sources = n_sources
        
        # STFT parameters
        self.stft = torchaudio.transforms.Spectrogram(
            n_fft=n_fft,
            hop_length=hop_length,
            power=None  # Complex spectrogram
        )
        
        self.istft = torchaudio.transforms.InverseSpectrogram(
            n_fft=n_fft,
            hop_length=hop_length
        )
        
        # U-Net for mask prediction
        self.unet = UNet(in_channels=1, n_sources=n_sources, **unet_kwargs)
    
    def forward(self, waveform):
        # waveform: [batch, samples]
        batch_size = waveform.size(0)
        
        # Convert to spectrogram
        complex_spec = self.stft(waveform)  # [batch, freq, time, 2]
        magnitude = torch.sqrt(complex_spec[..., 0]**2 + complex_spec[..., 1]**2)
        phase = torch.atan2(complex_spec[..., 1], complex_spec[..., 0])
        
        # Predict masks: [batch, n_sources, freq, time]
        masks = self.unet(magnitude.unsqueeze(1))
        
        # Apply masks to magnitude
        separated_mags = masks * magnitude.unsqueeze(1)  # [batch, n_sources, freq, time]
        
        # Reconstruct with original phase
        separated_specs = []
        for i in range(self.n_sources):
            mag = separated_mags[:, i]  # [batch, freq, time]
            real = mag * torch.cos(phase)
            imag = mag * torch.sin(phase)
            complex_spec = torch.stack([real, imag], dim=-1)
            separated_specs.append(complex_spec)
        
        # Convert back to waveform
        separated_waveforms = []
        for spec in separated_specs:
            waveform = self.istft(spec)
            separated_waveforms.append(waveform)
        
        return torch.stack(separated_waveforms, dim=1)  # [batch, n_sources, samples]

Training Pipeline: From Data to Model

Training Implementation
# src/training/trainer.py
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from pathlib import Path
import wandb
from tqdm import tqdm
import json

class SourceSeparationTrainer:
    """Training pipeline for source separation models"""
    
    def __init__(
        self,
        model: nn.Module,
        train_loader: DataLoader,
        val_loader: DataLoader,
        config: dict,
        device: str = "cuda"
    ):
        self.model = model.to(device)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.config = config
        self.device = device
        
        # Training setup
        self.optimizer = optim.Adam(
            model.parameters(),
            lr=config["learning_rate"],
            weight_decay=config.get("weight_decay", 1e-5)
        )
        
        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer,
            mode="min",
            patience=config.get("lr_patience", 5),
            factor=0.5
        )
        
        self.criterion = nn.L1Loss()  # Mean Absolute Error
        
        # Tracking
        self.best_val_loss = float("inf")
        self.train_losses = []
        self.val_losses = []
        
        # Experiment tracking
        if config.get("use_wandb", False):
            wandb.init(project="music-separation", config=config)
    
    def train_epoch(self, epoch: int) -> float:
        """Train for one epoch"""
        self.model.train()
        total_loss = 0.0
        
        pbar = tqdm(self.train_loader, desc=f"Epoch {epoch}")
        for batch_idx, batch in enumerate(pbar):
            # Move data to device
            mixture = batch["mixture"].to(self.device)
            vocals = batch["vocals"].to(self.device)
            accompaniment = batch["accompaniment"].to(self.device)
            
            # Target: [batch, n_sources, samples]
            target = torch.stack([vocals, accompaniment], dim=1)
            
            # Forward pass
            self.optimizer.zero_grad()
            output = self.model(mixture)  # [batch, n_sources, samples]
            
            # Compute loss
            loss = self.criterion(output, target)
            
            # Backward pass
            loss.backward()
            
            # Gradient clipping
            if self.config.get("grad_clip", False):
                torch.nn.utils.clip_grad_norm_(
                    self.model.parameters(), 
                    self.config["grad_clip"]
                )
            
            self.optimizer.step()
            
            # Update metrics
            total_loss += loss.item()
            
            # Update progress bar
            pbar.set_postfix({
                "loss": f"{loss.item():.4f}",
                "avg_loss": f"{total_loss/(batch_idx+1):.4f}"
            })
            
            # Log to wandb
            if self.config.get("use_wandb", False):
                wandb.log({
                    "train/batch_loss": loss.item(),
                    "train/learning_rate": self.optimizer.param_groups[0]["lr"]
                })
        
        avg_loss = total_loss / len(self.train_loader)
        return avg_loss
    
    def validate(self, epoch: int) -> float:
        """Validate the model"""
        self.model.eval()
        total_loss = 0.0
        
        with torch.no_grad():
            for batch in tqdm(self.val_loader, desc="Validation"):
                # Move data to device
                mixture = batch["mixture"].to(self.device)
                vocals = batch["vocals"].to(self.device)
                accompaniment = batch["accompaniment"].to(self.device)
                
                target = torch.stack([vocals, accompaniment], dim=1)
                
                # Forward pass
                output = self.model(mixture)
                loss = self.criterion(output, target)
                
                total_loss += loss.item()
        
        avg_loss = total_loss / len(self.val_loader)
        
        # Log validation metrics
        if self.config.get("use_wandb", False):
            wandb.log({
                "val/loss": avg_loss,
                "epoch": epoch
            })
        
        return avg_loss
    
    def save_checkpoint(self, epoch: int, val_loss: float, is_best: bool = False):
        """Save model checkpoint"""
        checkpoint = {
            "epoch": epoch,
            "model_state_dict": self.model.state_dict(),
            "optimizer_state_dict": self.optimizer.state_dict(),
            "scheduler_state_dict": self.scheduler.state_dict(),
            "val_loss": val_loss,
            "config": self.config,
            "train_losses": self.train_losses,
            "val_losses": self.val_losses
        }
        
        # Save latest checkpoint
        checkpoint_dir = Path(self.config["checkpoint_dir"])
        checkpoint_dir.mkdir(parents=True, exist_ok=True)
        
        torch.save(checkpoint, checkpoint_dir / "latest_checkpoint.pt")
        
        # Save best model
        if is_best:
            torch.save(checkpoint, checkpoint_dir / "best_model.pt")
            print(f"Saved new best model with val_loss: {val_loss:.4f}")
    
    def train(self, num_epochs: int):
        """Complete training loop"""
        print(f"Starting training for {num_epochs} epochs...")
        
        for epoch in range(1, num_epochs + 1):
            # Training phase
            train_loss = self.train_epoch(epoch)
            self.train_losses.append(train_loss)
            
            # Validation phase
            val_loss = self.validate(epoch)
            self.val_losses.append(val_loss)
            
            # Learning rate scheduling
            self.scheduler.step(val_loss)
            
            # Check for best model
            is_best = val_loss < self.best_val_loss
            if is_best:
                self.best_val_loss = val_loss
            
            # Save checkpoint
            if epoch % self.config.get("save_every", 5) == 0:
                self.save_checkpoint(epoch, val_loss, is_best)
            
            print(f"Epoch {epoch}: Train Loss: {train_loss:.4f}, "
                  f"Val Loss: {val_loss:.4f}")
        
        print("Training completed!")
        return self.train_losses, self.val_losses

# Loss functions
class MultiScaleLoss(nn.Module):
    """Multi-scale L1 loss for better detail preservation"""
    
    def __init__(self, scales: list = [2048, 1024, 512]):
        super().__init__()
        self.scales = scales
        self.criterion = nn.L1Loss()
    
    def forward(self, prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        total_loss = 0.0
        
        # Original scale loss
        total_loss += self.criterion(prediction, target)
        
        # Multi-scale losses
        for scale in self.scales:
            # Average pooling to create different scales
            pred_scaled = F.avg_pool1d(prediction, scale, stride=scale//2)
            target_scaled = F.avg_pool1d(target, scale, stride=scale//2)
            
            scale_loss = self.criterion(pred_scaled, target_scaled)
            total_loss += 0.3 * scale_loss  # Weight for multi-scale terms
        
        return total_loss
Training Script Example
# train.py
import yaml
import torch
from torch.utils.data import DataLoader
from src.data.dataset import MusicSeparationDataset
from src.models.unet import MagnitudeUNet
from src.training.trainer import SourceSeparationTrainer

def main():
    # Load configuration
    with open("configs/training_config.yaml", "r") as f:
        config = yaml.safe_load(f)
    
    # Setup device
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")
    
    # Create datasets
    train_dataset = MusicSeparationDataset(
        data_dir=config["data_dir"],
        split="train",
        **config["dataset"]
    )
    
    val_dataset = MusicSeparationDataset(
        data_dir=config["data_dir"],
        split="val",
        **config["dataset"]
    )
    
    # Create data loaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=config["batch_size"],
        shuffle=True,
        num_workers=config["num_workers"],
        pin_memory=True
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=config["batch_size"],
        shuffle=False,
        num_workers=config["num_workers"],
        pin_memory=True
    )
    
    # Create model
    model = MagnitudeUNet(**config["model"])
    print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
    
    # Create trainer
    trainer = SourceSeparationTrainer(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        config=config,
        device=device
    )
    
    # Start training
    trainer.train(config["num_epochs"])

if __name__ == "__main__":
    main()

Evaluation and Metrics

Evaluation Metrics Implementation
# src/evaluation/metrics.py
import torch
import numpy as np
from typing import Dict, List
import musdb
import museval

class SeparationMetrics:
    """Comprehensive evaluation metrics for source separation"""
    
    @staticmethod
    def sdr_loss(prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        """
        Signal-to-Distortion Ratio loss (differentiable)
        Higher values = better separation
        """
        # Ensure same length
        min_len = min(prediction.size(-1), target.size(-1))
        prediction = prediction[..., :min_len]
        target = target[..., :min_len]
        
        # Compute SDR
        target_energy = torch.sum(target**2, dim=-1) + 1e-8
        noise_energy = torch.sum((prediction - target)**2, dim=-1) + 1e-8
        
        sdr = 10 * torch.log10(target_energy / noise_energy)
        return -sdr.mean()  # Negative for loss (we want to minimize)
    
    @staticmethod
    def compute_sdr_sir_sar(
        prediction: np.ndarray,
        target: np.ndarray,
        mixture: np.ndarray
    ) -> Dict[str, float]:
        """
        Compute SDR, SIR, SAR using museval library
        
        Args:
            prediction: Estimated source [samples]
            target: True source [samples]
            mixture: Original mixture [samples]
        
        Returns:
            Dictionary with SDR, SIR, SAR values
        """
        # Ensure same length
        min_len = min(len(prediction), len(target), len(mixture))
        prediction = prediction[:min_len]
        target = target[:min_len]
        mixture = mixture[:min_len]
        
        # Stack for museval (expects [sources, samples])
        estimates = np.array([prediction])
        references = np.array([target])
        
        # Compute metrics
        scores = museval.evaluate(references, estimates)
        
        return {
            "SDR": float(scores[0][0]["SDR"]),
            "SIR": float(scores[0][0]["SIR"]),
            "SAR": float(scores[0][0]["SAR"])
        }
    
    @staticmethod
    def batch_evaluate(
        model: torch.nn.Module,
        dataloader,
        device: str = "cuda"
    ) -> Dict[str, List[float]]:
        """Evaluate model on entire dataset"""
        
        model.eval()
        all_metrics = {"SDR": [], "SIR": [], "SAR": []}
        
        with torch.no_grad():
            for batch in tqdm(dataloader, desc="Evaluating"):
                mixture = batch["mixture"].to(device)
                vocals = batch["vocals"].cpu().numpy()
                accompaniment = batch["accompaniment"].cpu().numpy()
                
                # Get predictions
                output = model(mixture)  # [batch, 2, samples]
                pred_vocals = output[:, 0].cpu().numpy()
                pred_accompaniment = output[:, 1].cpu().numpy()
                
                # Compute metrics for each sample in batch
                for i in range(mixture.size(0)):
                    # Vocals metrics
                    vocals_metrics = SeparationMetrics.compute_sdr_sir_sar(
                        pred_vocals[i], vocals[i], mixture[i].cpu().numpy()
                    )
                    
                    # Accompaniment metrics  
                    accomp_metrics = SeparationMetrics.compute_sdr_sir_sar(
                        pred_accompaniment[i], accompaniment[i], mixture[i].cpu().numpy()
                    )
                    
                    # Average metrics across sources
                    for metric in ["SDR", "SIR", "SAR"]:
                        avg_metric = (vocals_metrics[metric] + accomp_metrics[metric]) / 2
                        all_metrics[metric].append(avg_metric)
        
        # Compute summary statistics
        summary = {}
        for metric, values in all_metrics.items():
            summary[f"{metric}_mean"] = np.mean(values)
            summary[f"{metric}_std"] = np.std(values)
            summary[f"{metric}_median"] = np.median(values)
        
        return summary, all_metrics

class PerceptualEvaluator:
    """Tools for perceptual quality assessment"""
    
    @staticmethod
    def spectral_convergence(prediction: torch.Tensor, target: torch.Tensor) -> float:
        """Measure spectral similarity"""
        # Compute spectrograms
        pred_spec = torch.stft(prediction, n_fft=2048, return_complex=True)
        target_spec = torch.stft(target, n_fft=2048, return_complex=True)
        
        # Magnitude spectrograms
        pred_mag = torch.abs(pred_spec)
        target_mag = torch.abs(target_spec)
        
        # Spectral convergence
        numerator = torch.norm(target_mag - pred_mag, p="fro")
        denominator = torch.norm(target_mag, p="fro")
        
        return (numerator / (denominator + 1e-8)).item()
    
    @staticmethod
    def magnitude_loss(prediction: torch.Tensor, target: torch.Tensor) -> float:
        """Log magnitude loss"""
        pred_spec = torch.stft(prediction, n_fft=2048, return_complex=True)
        target_spec = torch.stft(target, n_fft=2048, return_complex=True)
        
        pred_mag = torch.abs(pred_spec)
        target_mag = torch.abs(target_spec)
        
        # Log magnitude
        pred_log = torch.log(pred_mag + 1e-8)
        target_log = torch.log(target_mag + 1e-8)
        
        return torch.nn.functional.l1_loss(pred_log, target_log).item()
Benchmark Script
# benchmark.py - Evaluate trained model
import torch
import yaml
import json
from pathlib import Path
from torch.utils.data import DataLoader
from src.data.dataset import MusicSeparationDataset
from src.models.unet import MagnitudeUNet
from src.evaluation.metrics import SeparationMetrics

def benchmark_model(model_path: str, config_path: str, data_dir: str):
    """Run comprehensive benchmark on trained model"""
    
    # Load config and model
    with open(config_path, "r") as f:
        config = yaml.safe_load(f)
    
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    # Load trained model
    model = MagnitudeUNet(**config["model"])
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint["model_state_dict"])
    model = model.to(device)
    
    # Create test dataset
    test_dataset = MusicSeparationDataset(
        data_dir=data_dir,
        split="test",
        **config["dataset"]
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=1,  # Process one song at a time for detailed analysis
        shuffle=False,
        num_workers=4
    )
    
    # Run evaluation
    print("Running evaluation...")
    summary, detailed_metrics = SeparationMetrics.batch_evaluate(
        model, test_loader, device
    )
    
    # Print results
    print("\n" + "="*50)
    print("BENCHMARK RESULTS")
    print("="*50)
    
    for metric, value in summary.items():
        print(f"{metric}: {value:.3f}")
    
    # Save detailed results
    results = {
        "summary": summary,
        "detailed_metrics": {k: [float(x) for x in v] for k, v in detailed_metrics.items()},
        "model_path": model_path,
        "config": config
    }
    
    results_path = Path(model_path).parent / "benchmark_results.json"
    with open(results_path, "w") as f:
        json.dump(results, f, indent=2)
    
    print(f"\nDetailed results saved to: {results_path}")
    return summary, detailed_metrics

if __name__ == "__main__":
    import argparse
    
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", required=True, help="Path to trained model")
    parser.add_argument("--config", required=True, help="Path to config file")
    parser.add_argument("--data", required=True, help="Path to test data")
    
    args = parser.parse_args()
    
    benchmark_model(args.model, args.config, args.data)

Deployment: Production-Ready API

FastAPI Implementation
# src/inference/api.py
import torch
import torchaudio
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import StreamingResponse
import io
import tempfile
import asyncio
from pathlib import Path
import logging
from typing import Optional
from src.models.unet import MagnitudeUNet
from src.inference.separator import AudioSeparator

# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Global model instance
separator: Optional[AudioSeparator] = None

app = FastAPI(
    title="Music Source Separation API",
    description="AI-powered vocal separation from music",
    version="1.0.0"
)

@app.on_event("startup")
async def startup_event():
    """Load model on startup"""
    global separator
    
    try:
        model_path = "models/best_model.pt"
        separator = AudioSeparator(model_path)
        logger.info("Model loaded successfully")
    except Exception as e:
        logger.error(f"Failed to load model: {e}")
        raise

@app.get("/health")
async def health_check():
    """Health check endpoint"""
    return {
        "status": "healthy", 
        "model_loaded": separator is not None,
        "gpu_available": torch.cuda.is_available()
    }

@app.post("/separate")
async def separate_audio(
    audio_file: UploadFile = File(...),
    output_format: str = "wav"
):
    """
    Separate vocals from music
    
    Args:
        audio_file: Input audio file (wav, mp3, etc.)
        output_format: Output format (wav, mp3)
    
    Returns:
        ZIP file containing separated stems
    """
    if separator is None:
        raise HTTPException(status_code=500, detail="Model not loaded")
    
    # Validate file format
    allowed_formats = {".wav", ".mp3", ".flac", ".m4a"}
    file_suffix = Path(audio_file.filename).suffix.lower()
    
    if file_suffix not in allowed_formats:
        raise HTTPException(
            status_code=400, 
            detail=f"Unsupported format: {file_suffix}"
        )
    
    try:
        # Save uploaded file temporarily
        with tempfile.NamedTemporaryFile(suffix=file_suffix, delete=False) as tmp_file:
            content = await audio_file.read()
            tmp_file.write(content)
            tmp_path = tmp_file.name
        
        # Perform separation
        vocals, accompaniment = await asyncio.get_event_loop().run_in_executor(
            None, separator.separate_file, tmp_path
        )
        
        # Create output streams
        vocals_buffer = io.BytesIO()
        accomp_buffer = io.BytesIO()
        
        # Save separated audio
        torchaudio.save(vocals_buffer, vocals, separator.sample_rate, format=output_format)
        torchaudio.save(accomp_buffer, accompaniment, separator.sample_rate, format=output_format)
        
        # Create ZIP response
        import zipfile
        zip_buffer = io.BytesIO()
        
        with zipfile.ZipFile(zip_buffer, 'w') as zip_file:
            vocals_buffer.seek(0)
            accomp_buffer.seek(0)
            
            zip_file.writestr(f"vocals.{output_format}", vocals_buffer.read())
            zip_file.writestr(f"accompaniment.{output_format}", accomp_buffer.read())
        
        zip_buffer.seek(0)
        
        # Cleanup
        Path(tmp_path).unlink()
        
        return StreamingResponse(
            io.BytesIO(zip_buffer.read()),
            media_type="application/zip",
            headers={"Content-Disposition": "attachment; filename=separated_audio.zip"}
        )
    
    except Exception as e:
        logger.error(f"Separation failed: {e}")
        raise HTTPException(status_code=500, detail=f"Processing failed: {str(e)}")

@app.post("/separate/preview")
async def separate_preview(
    audio_file: UploadFile = File(...),
    duration: float = 30.0
):
    """
    Quick preview separation (first N seconds)
    """
    if separator is None:
        raise HTTPException(status_code=500, detail="Model not loaded")
    
    try:
        # Process only first N seconds for preview
        with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
            content = await audio_file.read()
            tmp_file.write(content)
            tmp_path = tmp_file.name
        
        # Load and trim audio
        waveform, sr = torchaudio.load(tmp_path)
        max_samples = int(duration * sr)
        if waveform.size(1) > max_samples:
            waveform = waveform[:, :max_samples]
        
        # Perform separation
        vocals, accompaniment = await asyncio.get_event_loop().run_in_executor(
            None, separator.separate_tensor, waveform
        )
        
        # Return JSON with basic info
        return {
            "duration_processed": waveform.size(1) / sr,
            "sample_rate": sr,
            "vocals_energy": float(torch.mean(torch.abs(vocals))),
            "accompaniment_energy": float(torch.mean(torch.abs(accompaniment))),
            "separation_quality": "estimated_good"  # Could implement quality metric
        }
    
    except Exception as e:
        logger.error(f"Preview failed: {e}")
        raise HTTPException(status_code=500, detail=f"Preview failed: {str(e)}")

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)
Docker Deployment
# Dockerfile
FROM python:3.9-slim

# Install system dependencies
RUN apt-get update && apt-get install -y \
    gcc \
    g++ \
    libsndfile1 \
    ffmpeg \
    && rm -rf /var/lib/apt/lists/*

# Set working directory
WORKDIR /app

# Copy requirements and install Python dependencies
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt

# Copy source code
COPY src/ src/
COPY configs/ configs/
COPY models/ models/

# Set environment variables
ENV PYTHONPATH=/app
ENV MODEL_PATH=/app/models/best_model.pt

# Expose port
EXPOSE 8000

# Health check
HEALTHCHECK --interval=30s --timeout=30s --start-period=5s --retries=3 \
  CMD curl -f http://localhost:8000/health || exit 1

# Run API server
CMD ["uvicorn", "src.inference.api:app", "--host", "0.0.0.0", "--port", "8000"]
# docker-compose.yml
version: '3.8'

services:
  music-separation-api:
    build: .
    ports:
      - "8000:8000"
    volumes:
      - ./models:/app/models
      - ./logs:/app/logs
    environment:
      - MODEL_PATH=/app/models/best_model.pt
      - LOG_LEVEL=INFO
    restart: unless-stopped
    deploy:
      resources:
        reservations:
          devices:
            - driver: nvidia
              count: 1
              capabilities: [gpu]
  
  nginx:
    image: nginx:alpine
    ports:
      - "80:80"
    volumes:
      - ./nginx.conf:/etc/nginx/nginx.conf
    depends_on:
      - music-separation-api
    restart: unless-stopped

# Build and run:
# docker-compose up --build -d

Production Best Practices

Model Versioning and Rollback

Implement proper model versioning with Blue-Green deployments. Always keep the previous working model available for instant rollback. Use semantic versioning for model releases.

Monitoring and Observability

Monitor inference latency, GPU utilization, and model accuracy drift. Use tools like Prometheus + Grafana for metrics, and implement proper logging with structured data for debugging.

Resource Optimization

Use model quantization and ONNX Runtime for production inference. Implement request batching and GPU memory management. Consider using TensorRT for NVIDIA GPUs to achieve 2-5x speedup.

Security and Rate Limiting

Implement proper authentication, input validation, and rate limiting. Use HTTPS in production and validate all audio inputs for malicious content. Consider implementing API keys for access control.

Testing Strategy

Implement unit tests for data processing, integration tests for the API, and regression tests for model accuracy. Use golden datasets for consistent evaluation across model versions.

Scalability Considerations

Design for horizontal scaling with load balancers and multiple GPU instances. Implement proper queuing for long-running separation tasks. Consider using cloud services like AWS Batch or Kubernetes for orchestration.

Next Steps and Extensions

🔮 Advanced Features to Implement

Model Improvements

  • • Implement Demucs waveform model
  • • Add 4-stem separation (drums, bass, vocals, other)
  • • Experiment with Transformer architectures
  • • Try diffusion models for generation

Production Features

  • • Real-time streaming separation
  • • Web interface for file uploads
  • • Batch processing capabilities
  • • Integration with cloud storage
Learning Resources

Papers to Read

  • • "Music Source Separation in the Waveform Domain"
  • • "Conv-TasNet: Surpassing Ideal Time-Frequency Magnitude Masking"
  • • "Hybrid Transformers for Music Source Separation"

Code Repositories

  • • facebook/demucs - Official Demucs implementation
  • • sigsep/open-unmix-pytorch - Open-Unmix baseline
  • • asteroid-team/asteroid - Toolkit for audio separation

Development Resources

Continue Reading

Related Article
Demucs vs Spleeter: The Great Audio Separation Showdown
Now that you've built your own system, dive deep into comparing the two most influential models in source separation history.