Building Your Own Music Separation Pipeline: A Developer's Guide
Transform theoretical knowledge into production-ready code. Build, evaluate, and deploy a complete music source separation system using modern tools and best practices.

From Theory to Practice
You've learned the mathematics, studied the architectures, and understood the trade-offs. Now it's time to build something real. This comprehensive guide will take you through creating a production-ready music source separation pipeline from scratch.
By the end of this tutorial, you'll have a complete system that can separate vocals from any song, with proper evaluation metrics, error handling, and deployment considerations. We'll use Python, PyTorch, and modern MLOps practices to create something you can actually ship.
🎯 What We'll Build
Core Components
- • Data preprocessing pipeline
- • Neural network training framework
- • Evaluation and metrics system
- • Real-time inference API
Production Features
- • Docker containerization
- • Model versioning and rollback
- • Monitoring and logging
- • Performance optimization
Project Setup and Architecture
music-separation-pipeline/ ├── src/ │ ├── data/ │ │ ├── __init__.py │ │ ├── dataset.py # Dataset classes and loaders │ │ ├── preprocessing.py # Audio preprocessing utilities │ │ └── augmentation.py # Data augmentation techniques │ ├── models/ │ │ ├── __init__.py │ │ ├── unet.py # U-Net architecture │ │ ├── demucs.py # Demucs implementation │ │ └── base.py # Base model class │ ├── training/ │ │ ├── __init__.py │ │ ├── trainer.py # Training loop and utilities │ │ └── losses.py # Loss functions │ ├── inference/ │ │ ├── __init__.py │ │ ├── separator.py # Inference engine │ │ └── api.py # REST API server │ └── evaluation/ │ ├── __init__.py │ ├── metrics.py # SDR, SIR, SAR calculations │ └── benchmark.py # Evaluation scripts ├── configs/ │ ├── model_config.yaml # Model hyperparameters │ ├── training_config.yaml # Training settings │ └── api_config.yaml # API configuration ├── docker/ │ ├── Dockerfile │ └── docker-compose.yml ├── tests/ ├── notebooks/ # Jupyter notebooks for exploration ├── requirements.txt ├── setup.py └── README.md
Design Principles
- • Modular architecture for easy extension
- • Configuration-driven development
- • Clear separation of concerns
- • Comprehensive testing coverage
Key Dependencies
- • PyTorch: Deep learning framework
- • torchaudio: Audio processing
- • librosa: Audio analysis
- • FastAPI: High-performance web API
# Create virtual environment python -m venv music-separation-env source music-separation-env/bin/activate # On Windows: music-separation-env\Scripts\activate # Install dependencies pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 pip install librosa soundfile numpy scipy pip install fastapi uvicorn pydantic pip install wandb tensorboard # For experiment tracking pip install pytest black flake8 # Development tools # Install project in development mode pip install -e . # Verify installation python -c "import torch; print(f'PyTorch {torch.__version__}, CUDA available: {torch.cuda.is_available()}')" python -c "import librosa; print(f'librosa {librosa.__version__}')"
Data Pipeline: Processing Audio for ML
# src/data/dataset.py import torch from torch.utils.data import Dataset import torchaudio import librosa import numpy as np from pathlib import Path from typing import Tuple, Optional class MusicSeparationDataset(Dataset): """ Dataset for music source separation training. Expects directory structure: dataset/ ├── train/ │ ├── mixture/ │ │ ├── song1.wav │ │ └── song2.wav │ ├── vocals/ │ │ ├── song1.wav │ │ └── song2.wav │ └── accompaniment/ │ ├── song1.wav │ └── song2.wav """ def __init__( self, data_dir: str, split: str = "train", sample_rate: int = 44100, segment_length: float = 6.0, # seconds normalize: bool = True, transforms=None ): self.data_dir = Path(data_dir) / split self.sample_rate = sample_rate self.segment_samples = int(segment_length * sample_rate) self.normalize = normalize self.transforms = transforms # Find all mixture files self.mixture_files = list((self.data_dir / "mixture").glob("*.wav")) self.mixture_files.sort() # Verify corresponding files exist self._verify_files() def _verify_files(self): """Ensure all stems exist for each mixture""" verified_files = [] for mix_file in self.mixture_files: stem_name = mix_file.stem vocals_path = self.data_dir / "vocals" / f"{stem_name}.wav" accomp_path = self.data_dir / "accompaniment" / f"{stem_name}.wav" if vocals_path.exists() and accomp_path.exists(): verified_files.append(mix_file) else: print(f"Warning: Missing stems for {stem_name}") self.mixture_files = verified_files print(f"Dataset loaded: {len(self.mixture_files)} songs") def __len__(self) -> int: return len(self.mixture_files) def __getitem__(self, idx: int) -> dict: mix_file = self.mixture_files[idx] stem_name = mix_file.stem # Load audio files mixture, sr = torchaudio.load(mix_file) vocals, _ = torchaudio.load( self.data_dir / "vocals" / f"{stem_name}.wav" ) accompaniment, _ = torchaudio.load( self.data_dir / "accompaniment" / f"{stem_name}.wav" ) # Resample if necessary if sr != self.sample_rate: resampler = torchaudio.transforms.Resample(sr, self.sample_rate) mixture = resampler(mixture) vocals = resampler(vocals) accompaniment = resampler(accompaniment) # Convert to mono if stereo if mixture.shape[0] > 1: mixture = torch.mean(mixture, dim=0, keepdim=True) vocals = torch.mean(vocals, dim=0, keepdim=True) accompaniment = torch.mean(accompaniment, dim=0, keepdim=True) # Extract random segment audio_length = mixture.shape[1] if audio_length > self.segment_samples: start = torch.randint(0, audio_length - self.segment_samples, (1,)) mixture = mixture[:, start:start + self.segment_samples] vocals = vocals[:, start:start + self.segment_samples] accompaniment = accompaniment[:, start:start + self.segment_samples] else: # Pad if too short pad_amount = self.segment_samples - audio_length mixture = torch.nn.functional.pad(mixture, (0, pad_amount)) vocals = torch.nn.functional.pad(vocals, (0, pad_amount)) accompaniment = torch.nn.functional.pad(accompaniment, (0, pad_amount)) # Normalize if self.normalize: max_val = torch.max(torch.abs(mixture)) if max_val > 0: mixture = mixture / max_val vocals = vocals / max_val accompaniment = accompaniment / max_val # Apply transforms if self.transforms: mixture = self.transforms(mixture) vocals = self.transforms(vocals) accompaniment = self.transforms(accompaniment) return { "mixture": mixture.squeeze(0), # [samples] "vocals": vocals.squeeze(0), "accompaniment": accompaniment.squeeze(0), "filename": stem_name }
# src/data/preprocessing.py import torch import torchaudio import librosa import numpy as np from typing import Tuple class AudioPreprocessor: """Handles STFT and spectrogram operations""" def __init__( self, n_fft: int = 2048, hop_length: int = 512, win_length: Optional[int] = None, window: str = "hann", sample_rate: int = 44100 ): self.n_fft = n_fft self.hop_length = hop_length self.win_length = win_length or n_fft self.window = window self.sample_rate = sample_rate # Create STFT transform self.stft = torchaudio.transforms.Spectrogram( n_fft=n_fft, hop_length=hop_length, win_length=self.win_length, window_fn=torch.hann_window, power=None # Return complex spectrogram ) # Create inverse STFT self.istft = torchaudio.transforms.InverseSpectrogram( n_fft=n_fft, hop_length=hop_length, win_length=self.win_length, window_fn=torch.hann_window ) def waveform_to_spectrogram(self, waveform: torch.Tensor) -> torch.Tensor: """Convert waveform to complex spectrogram""" # waveform: [channels, samples] or [samples] if waveform.dim() == 1: waveform = waveform.unsqueeze(0) # Apply STFT spec = self.stft(waveform) # [channels, freq, time, 2] (real, imag) return spec def spectrogram_to_waveform(self, spectrogram: torch.Tensor) -> torch.Tensor: """Convert complex spectrogram back to waveform""" # spectrogram: [channels, freq, time, 2] waveform = self.istft(spectrogram) return waveform def magnitude_phase_split(self, complex_spec: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """Split complex spectrogram into magnitude and phase""" # complex_spec: [channels, freq, time, 2] real = complex_spec[..., 0] imag = complex_spec[..., 1] magnitude = torch.sqrt(real**2 + imag**2) phase = torch.atan2(imag, real) return magnitude, phase def magnitude_phase_combine(self, magnitude: torch.Tensor, phase: torch.Tensor) -> torch.Tensor: """Combine magnitude and phase into complex spectrogram""" real = magnitude * torch.cos(phase) imag = magnitude * torch.sin(phase) complex_spec = torch.stack([real, imag], dim=-1) return complex_spec def create_masks(self, mixture_mag: torch.Tensor, source_mags: list) -> list: """Create soft masks for source separation""" # Add small epsilon to avoid division by zero eps = 1e-8 total_magnitude = sum(source_mags) + eps masks = [] for source_mag in source_mags: mask = source_mag / total_magnitude masks.append(mask) return masks # Data augmentation transforms class AudioAugmentation: """Audio data augmentation techniques""" @staticmethod def add_noise(waveform: torch.Tensor, noise_level: float = 0.01) -> torch.Tensor: """Add Gaussian noise to waveform""" noise = torch.randn_like(waveform) * noise_level return waveform + noise @staticmethod def time_stretch(waveform: torch.Tensor, rate: float, sample_rate: int) -> torch.Tensor: """Time stretching using librosa""" waveform_np = waveform.numpy() stretched = librosa.effects.time_stretch(waveform_np, rate=rate) return torch.from_numpy(stretched) @staticmethod def pitch_shift(waveform: torch.Tensor, n_steps: int, sample_rate: int) -> torch.Tensor: """Pitch shifting using librosa""" waveform_np = waveform.numpy() shifted = librosa.effects.pitch_shift(waveform_np, sr=sample_rate, n_steps=n_steps) return torch.from_numpy(shifted)
Model Implementation: U-Net for Source Separation
# src/models/unet.py import torch import torch.nn as nn import torch.nn.functional as F class ConvBlock(nn.Module): """Basic convolutional block with normalization and activation""" def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3): super().__init__() self.conv = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), nn.Conv2d(out_channels, out_channels, kernel_size, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) ) def forward(self, x): return self.conv(x) class DownBlock(nn.Module): """Encoder block with max pooling""" def __init__(self, in_channels: int, out_channels: int): super().__init__() self.conv = ConvBlock(in_channels, out_channels) self.pool = nn.MaxPool2d(2) def forward(self, x): skip = self.conv(x) x = self.pool(skip) return x, skip class UpBlock(nn.Module): """Decoder block with skip connections""" def __init__(self, in_channels: int, out_channels: int): super().__init__() self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, 2, stride=2) self.conv = ConvBlock(in_channels, out_channels) def forward(self, x, skip): x = self.up(x) # Handle size mismatch diffY = skip.size()[2] - x.size()[2] diffX = skip.size()[3] - x.size()[3] x = F.pad(x, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]) x = torch.cat([skip, x], dim=1) return self.conv(x) class UNet(nn.Module): """U-Net for music source separation""" def __init__( self, in_channels: int = 1, n_sources: int = 2, base_channels: int = 64, n_layers: int = 4 ): super().__init__() self.n_sources = n_sources # Encoder path self.encoder = nn.ModuleList() channels = [in_channels] + [base_channels * (2**i) for i in range(n_layers)] for i in range(n_layers): self.encoder.append(DownBlock(channels[i], channels[i+1])) # Bottleneck self.bottleneck = ConvBlock(channels[-1], channels[-1] * 2) # Decoder path self.decoder = nn.ModuleList() channels.reverse() channels[0] = channels[0] * 2 # Account for bottleneck for i in range(n_layers): self.decoder.append(UpBlock(channels[i], channels[i+1])) # Final output layers - separate head for each source self.output_layers = nn.ModuleList([ nn.Conv2d(channels[-1], 1, 1) for _ in range(n_sources) ]) # Activation for masks self.activation = nn.Sigmoid() def forward(self, x): # x: [batch, 1, freq, time] - magnitude spectrogram # Encoder path skips = [] for encoder in self.encoder: x, skip = encoder(x) skips.append(skip) # Bottleneck x = self.bottleneck(x) # Decoder path skips.reverse() for decoder, skip in zip(self.decoder, skips): x = decoder(x, skip) # Generate masks for each source masks = [] for output_layer in self.output_layers: mask = self.activation(output_layer(x)) masks.append(mask) # Stack masks: [batch, n_sources, freq, time] masks = torch.cat(masks, dim=1) # Normalize masks to sum to 1 masks = masks / (torch.sum(masks, dim=1, keepdim=True) + 1e-8) return masks class MagnitudeUNet(nn.Module): """Complete model for magnitude spectrogram separation""" def __init__( self, n_fft: int = 2048, hop_length: int = 512, n_sources: int = 2, **unet_kwargs ): super().__init__() self.n_fft = n_fft self.hop_length = hop_length self.n_sources = n_sources # STFT parameters self.stft = torchaudio.transforms.Spectrogram( n_fft=n_fft, hop_length=hop_length, power=None # Complex spectrogram ) self.istft = torchaudio.transforms.InverseSpectrogram( n_fft=n_fft, hop_length=hop_length ) # U-Net for mask prediction self.unet = UNet(in_channels=1, n_sources=n_sources, **unet_kwargs) def forward(self, waveform): # waveform: [batch, samples] batch_size = waveform.size(0) # Convert to spectrogram complex_spec = self.stft(waveform) # [batch, freq, time, 2] magnitude = torch.sqrt(complex_spec[..., 0]**2 + complex_spec[..., 1]**2) phase = torch.atan2(complex_spec[..., 1], complex_spec[..., 0]) # Predict masks: [batch, n_sources, freq, time] masks = self.unet(magnitude.unsqueeze(1)) # Apply masks to magnitude separated_mags = masks * magnitude.unsqueeze(1) # [batch, n_sources, freq, time] # Reconstruct with original phase separated_specs = [] for i in range(self.n_sources): mag = separated_mags[:, i] # [batch, freq, time] real = mag * torch.cos(phase) imag = mag * torch.sin(phase) complex_spec = torch.stack([real, imag], dim=-1) separated_specs.append(complex_spec) # Convert back to waveform separated_waveforms = [] for spec in separated_specs: waveform = self.istft(spec) separated_waveforms.append(waveform) return torch.stack(separated_waveforms, dim=1) # [batch, n_sources, samples]
Training Pipeline: From Data to Model
# src/training/trainer.py import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader from pathlib import Path import wandb from tqdm import tqdm import json class SourceSeparationTrainer: """Training pipeline for source separation models""" def __init__( self, model: nn.Module, train_loader: DataLoader, val_loader: DataLoader, config: dict, device: str = "cuda" ): self.model = model.to(device) self.train_loader = train_loader self.val_loader = val_loader self.config = config self.device = device # Training setup self.optimizer = optim.Adam( model.parameters(), lr=config["learning_rate"], weight_decay=config.get("weight_decay", 1e-5) ) self.scheduler = optim.lr_scheduler.ReduceLROnPlateau( self.optimizer, mode="min", patience=config.get("lr_patience", 5), factor=0.5 ) self.criterion = nn.L1Loss() # Mean Absolute Error # Tracking self.best_val_loss = float("inf") self.train_losses = [] self.val_losses = [] # Experiment tracking if config.get("use_wandb", False): wandb.init(project="music-separation", config=config) def train_epoch(self, epoch: int) -> float: """Train for one epoch""" self.model.train() total_loss = 0.0 pbar = tqdm(self.train_loader, desc=f"Epoch {epoch}") for batch_idx, batch in enumerate(pbar): # Move data to device mixture = batch["mixture"].to(self.device) vocals = batch["vocals"].to(self.device) accompaniment = batch["accompaniment"].to(self.device) # Target: [batch, n_sources, samples] target = torch.stack([vocals, accompaniment], dim=1) # Forward pass self.optimizer.zero_grad() output = self.model(mixture) # [batch, n_sources, samples] # Compute loss loss = self.criterion(output, target) # Backward pass loss.backward() # Gradient clipping if self.config.get("grad_clip", False): torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.config["grad_clip"] ) self.optimizer.step() # Update metrics total_loss += loss.item() # Update progress bar pbar.set_postfix({ "loss": f"{loss.item():.4f}", "avg_loss": f"{total_loss/(batch_idx+1):.4f}" }) # Log to wandb if self.config.get("use_wandb", False): wandb.log({ "train/batch_loss": loss.item(), "train/learning_rate": self.optimizer.param_groups[0]["lr"] }) avg_loss = total_loss / len(self.train_loader) return avg_loss def validate(self, epoch: int) -> float: """Validate the model""" self.model.eval() total_loss = 0.0 with torch.no_grad(): for batch in tqdm(self.val_loader, desc="Validation"): # Move data to device mixture = batch["mixture"].to(self.device) vocals = batch["vocals"].to(self.device) accompaniment = batch["accompaniment"].to(self.device) target = torch.stack([vocals, accompaniment], dim=1) # Forward pass output = self.model(mixture) loss = self.criterion(output, target) total_loss += loss.item() avg_loss = total_loss / len(self.val_loader) # Log validation metrics if self.config.get("use_wandb", False): wandb.log({ "val/loss": avg_loss, "epoch": epoch }) return avg_loss def save_checkpoint(self, epoch: int, val_loss: float, is_best: bool = False): """Save model checkpoint""" checkpoint = { "epoch": epoch, "model_state_dict": self.model.state_dict(), "optimizer_state_dict": self.optimizer.state_dict(), "scheduler_state_dict": self.scheduler.state_dict(), "val_loss": val_loss, "config": self.config, "train_losses": self.train_losses, "val_losses": self.val_losses } # Save latest checkpoint checkpoint_dir = Path(self.config["checkpoint_dir"]) checkpoint_dir.mkdir(parents=True, exist_ok=True) torch.save(checkpoint, checkpoint_dir / "latest_checkpoint.pt") # Save best model if is_best: torch.save(checkpoint, checkpoint_dir / "best_model.pt") print(f"Saved new best model with val_loss: {val_loss:.4f}") def train(self, num_epochs: int): """Complete training loop""" print(f"Starting training for {num_epochs} epochs...") for epoch in range(1, num_epochs + 1): # Training phase train_loss = self.train_epoch(epoch) self.train_losses.append(train_loss) # Validation phase val_loss = self.validate(epoch) self.val_losses.append(val_loss) # Learning rate scheduling self.scheduler.step(val_loss) # Check for best model is_best = val_loss < self.best_val_loss if is_best: self.best_val_loss = val_loss # Save checkpoint if epoch % self.config.get("save_every", 5) == 0: self.save_checkpoint(epoch, val_loss, is_best) print(f"Epoch {epoch}: Train Loss: {train_loss:.4f}, " f"Val Loss: {val_loss:.4f}") print("Training completed!") return self.train_losses, self.val_losses # Loss functions class MultiScaleLoss(nn.Module): """Multi-scale L1 loss for better detail preservation""" def __init__(self, scales: list = [2048, 1024, 512]): super().__init__() self.scales = scales self.criterion = nn.L1Loss() def forward(self, prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor: total_loss = 0.0 # Original scale loss total_loss += self.criterion(prediction, target) # Multi-scale losses for scale in self.scales: # Average pooling to create different scales pred_scaled = F.avg_pool1d(prediction, scale, stride=scale//2) target_scaled = F.avg_pool1d(target, scale, stride=scale//2) scale_loss = self.criterion(pred_scaled, target_scaled) total_loss += 0.3 * scale_loss # Weight for multi-scale terms return total_loss
# train.py import yaml import torch from torch.utils.data import DataLoader from src.data.dataset import MusicSeparationDataset from src.models.unet import MagnitudeUNet from src.training.trainer import SourceSeparationTrainer def main(): # Load configuration with open("configs/training_config.yaml", "r") as f: config = yaml.safe_load(f) # Setup device device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") # Create datasets train_dataset = MusicSeparationDataset( data_dir=config["data_dir"], split="train", **config["dataset"] ) val_dataset = MusicSeparationDataset( data_dir=config["data_dir"], split="val", **config["dataset"] ) # Create data loaders train_loader = DataLoader( train_dataset, batch_size=config["batch_size"], shuffle=True, num_workers=config["num_workers"], pin_memory=True ) val_loader = DataLoader( val_dataset, batch_size=config["batch_size"], shuffle=False, num_workers=config["num_workers"], pin_memory=True ) # Create model model = MagnitudeUNet(**config["model"]) print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}") # Create trainer trainer = SourceSeparationTrainer( model=model, train_loader=train_loader, val_loader=val_loader, config=config, device=device ) # Start training trainer.train(config["num_epochs"]) if __name__ == "__main__": main()
Evaluation and Metrics
# src/evaluation/metrics.py import torch import numpy as np from typing import Dict, List import musdb import museval class SeparationMetrics: """Comprehensive evaluation metrics for source separation""" @staticmethod def sdr_loss(prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ Signal-to-Distortion Ratio loss (differentiable) Higher values = better separation """ # Ensure same length min_len = min(prediction.size(-1), target.size(-1)) prediction = prediction[..., :min_len] target = target[..., :min_len] # Compute SDR target_energy = torch.sum(target**2, dim=-1) + 1e-8 noise_energy = torch.sum((prediction - target)**2, dim=-1) + 1e-8 sdr = 10 * torch.log10(target_energy / noise_energy) return -sdr.mean() # Negative for loss (we want to minimize) @staticmethod def compute_sdr_sir_sar( prediction: np.ndarray, target: np.ndarray, mixture: np.ndarray ) -> Dict[str, float]: """ Compute SDR, SIR, SAR using museval library Args: prediction: Estimated source [samples] target: True source [samples] mixture: Original mixture [samples] Returns: Dictionary with SDR, SIR, SAR values """ # Ensure same length min_len = min(len(prediction), len(target), len(mixture)) prediction = prediction[:min_len] target = target[:min_len] mixture = mixture[:min_len] # Stack for museval (expects [sources, samples]) estimates = np.array([prediction]) references = np.array([target]) # Compute metrics scores = museval.evaluate(references, estimates) return { "SDR": float(scores[0][0]["SDR"]), "SIR": float(scores[0][0]["SIR"]), "SAR": float(scores[0][0]["SAR"]) } @staticmethod def batch_evaluate( model: torch.nn.Module, dataloader, device: str = "cuda" ) -> Dict[str, List[float]]: """Evaluate model on entire dataset""" model.eval() all_metrics = {"SDR": [], "SIR": [], "SAR": []} with torch.no_grad(): for batch in tqdm(dataloader, desc="Evaluating"): mixture = batch["mixture"].to(device) vocals = batch["vocals"].cpu().numpy() accompaniment = batch["accompaniment"].cpu().numpy() # Get predictions output = model(mixture) # [batch, 2, samples] pred_vocals = output[:, 0].cpu().numpy() pred_accompaniment = output[:, 1].cpu().numpy() # Compute metrics for each sample in batch for i in range(mixture.size(0)): # Vocals metrics vocals_metrics = SeparationMetrics.compute_sdr_sir_sar( pred_vocals[i], vocals[i], mixture[i].cpu().numpy() ) # Accompaniment metrics accomp_metrics = SeparationMetrics.compute_sdr_sir_sar( pred_accompaniment[i], accompaniment[i], mixture[i].cpu().numpy() ) # Average metrics across sources for metric in ["SDR", "SIR", "SAR"]: avg_metric = (vocals_metrics[metric] + accomp_metrics[metric]) / 2 all_metrics[metric].append(avg_metric) # Compute summary statistics summary = {} for metric, values in all_metrics.items(): summary[f"{metric}_mean"] = np.mean(values) summary[f"{metric}_std"] = np.std(values) summary[f"{metric}_median"] = np.median(values) return summary, all_metrics class PerceptualEvaluator: """Tools for perceptual quality assessment""" @staticmethod def spectral_convergence(prediction: torch.Tensor, target: torch.Tensor) -> float: """Measure spectral similarity""" # Compute spectrograms pred_spec = torch.stft(prediction, n_fft=2048, return_complex=True) target_spec = torch.stft(target, n_fft=2048, return_complex=True) # Magnitude spectrograms pred_mag = torch.abs(pred_spec) target_mag = torch.abs(target_spec) # Spectral convergence numerator = torch.norm(target_mag - pred_mag, p="fro") denominator = torch.norm(target_mag, p="fro") return (numerator / (denominator + 1e-8)).item() @staticmethod def magnitude_loss(prediction: torch.Tensor, target: torch.Tensor) -> float: """Log magnitude loss""" pred_spec = torch.stft(prediction, n_fft=2048, return_complex=True) target_spec = torch.stft(target, n_fft=2048, return_complex=True) pred_mag = torch.abs(pred_spec) target_mag = torch.abs(target_spec) # Log magnitude pred_log = torch.log(pred_mag + 1e-8) target_log = torch.log(target_mag + 1e-8) return torch.nn.functional.l1_loss(pred_log, target_log).item()
# benchmark.py - Evaluate trained model import torch import yaml import json from pathlib import Path from torch.utils.data import DataLoader from src.data.dataset import MusicSeparationDataset from src.models.unet import MagnitudeUNet from src.evaluation.metrics import SeparationMetrics def benchmark_model(model_path: str, config_path: str, data_dir: str): """Run comprehensive benchmark on trained model""" # Load config and model with open(config_path, "r") as f: config = yaml.safe_load(f) device = "cuda" if torch.cuda.is_available() else "cpu" # Load trained model model = MagnitudeUNet(**config["model"]) checkpoint = torch.load(model_path, map_location=device) model.load_state_dict(checkpoint["model_state_dict"]) model = model.to(device) # Create test dataset test_dataset = MusicSeparationDataset( data_dir=data_dir, split="test", **config["dataset"] ) test_loader = DataLoader( test_dataset, batch_size=1, # Process one song at a time for detailed analysis shuffle=False, num_workers=4 ) # Run evaluation print("Running evaluation...") summary, detailed_metrics = SeparationMetrics.batch_evaluate( model, test_loader, device ) # Print results print("\n" + "="*50) print("BENCHMARK RESULTS") print("="*50) for metric, value in summary.items(): print(f"{metric}: {value:.3f}") # Save detailed results results = { "summary": summary, "detailed_metrics": {k: [float(x) for x in v] for k, v in detailed_metrics.items()}, "model_path": model_path, "config": config } results_path = Path(model_path).parent / "benchmark_results.json" with open(results_path, "w") as f: json.dump(results, f, indent=2) print(f"\nDetailed results saved to: {results_path}") return summary, detailed_metrics if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument("--model", required=True, help="Path to trained model") parser.add_argument("--config", required=True, help="Path to config file") parser.add_argument("--data", required=True, help="Path to test data") args = parser.parse_args() benchmark_model(args.model, args.config, args.data)
Deployment: Production-Ready API
# src/inference/api.py import torch import torchaudio from fastapi import FastAPI, File, UploadFile, HTTPException from fastapi.responses import StreamingResponse import io import tempfile import asyncio from pathlib import Path import logging from typing import Optional from src.models.unet import MagnitudeUNet from src.inference.separator import AudioSeparator # Setup logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Global model instance separator: Optional[AudioSeparator] = None app = FastAPI( title="Music Source Separation API", description="AI-powered vocal separation from music", version="1.0.0" ) @app.on_event("startup") async def startup_event(): """Load model on startup""" global separator try: model_path = "models/best_model.pt" separator = AudioSeparator(model_path) logger.info("Model loaded successfully") except Exception as e: logger.error(f"Failed to load model: {e}") raise @app.get("/health") async def health_check(): """Health check endpoint""" return { "status": "healthy", "model_loaded": separator is not None, "gpu_available": torch.cuda.is_available() } @app.post("/separate") async def separate_audio( audio_file: UploadFile = File(...), output_format: str = "wav" ): """ Separate vocals from music Args: audio_file: Input audio file (wav, mp3, etc.) output_format: Output format (wav, mp3) Returns: ZIP file containing separated stems """ if separator is None: raise HTTPException(status_code=500, detail="Model not loaded") # Validate file format allowed_formats = {".wav", ".mp3", ".flac", ".m4a"} file_suffix = Path(audio_file.filename).suffix.lower() if file_suffix not in allowed_formats: raise HTTPException( status_code=400, detail=f"Unsupported format: {file_suffix}" ) try: # Save uploaded file temporarily with tempfile.NamedTemporaryFile(suffix=file_suffix, delete=False) as tmp_file: content = await audio_file.read() tmp_file.write(content) tmp_path = tmp_file.name # Perform separation vocals, accompaniment = await asyncio.get_event_loop().run_in_executor( None, separator.separate_file, tmp_path ) # Create output streams vocals_buffer = io.BytesIO() accomp_buffer = io.BytesIO() # Save separated audio torchaudio.save(vocals_buffer, vocals, separator.sample_rate, format=output_format) torchaudio.save(accomp_buffer, accompaniment, separator.sample_rate, format=output_format) # Create ZIP response import zipfile zip_buffer = io.BytesIO() with zipfile.ZipFile(zip_buffer, 'w') as zip_file: vocals_buffer.seek(0) accomp_buffer.seek(0) zip_file.writestr(f"vocals.{output_format}", vocals_buffer.read()) zip_file.writestr(f"accompaniment.{output_format}", accomp_buffer.read()) zip_buffer.seek(0) # Cleanup Path(tmp_path).unlink() return StreamingResponse( io.BytesIO(zip_buffer.read()), media_type="application/zip", headers={"Content-Disposition": "attachment; filename=separated_audio.zip"} ) except Exception as e: logger.error(f"Separation failed: {e}") raise HTTPException(status_code=500, detail=f"Processing failed: {str(e)}") @app.post("/separate/preview") async def separate_preview( audio_file: UploadFile = File(...), duration: float = 30.0 ): """ Quick preview separation (first N seconds) """ if separator is None: raise HTTPException(status_code=500, detail="Model not loaded") try: # Process only first N seconds for preview with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file: content = await audio_file.read() tmp_file.write(content) tmp_path = tmp_file.name # Load and trim audio waveform, sr = torchaudio.load(tmp_path) max_samples = int(duration * sr) if waveform.size(1) > max_samples: waveform = waveform[:, :max_samples] # Perform separation vocals, accompaniment = await asyncio.get_event_loop().run_in_executor( None, separator.separate_tensor, waveform ) # Return JSON with basic info return { "duration_processed": waveform.size(1) / sr, "sample_rate": sr, "vocals_energy": float(torch.mean(torch.abs(vocals))), "accompaniment_energy": float(torch.mean(torch.abs(accompaniment))), "separation_quality": "estimated_good" # Could implement quality metric } except Exception as e: logger.error(f"Preview failed: {e}") raise HTTPException(status_code=500, detail=f"Preview failed: {str(e)}") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)
# Dockerfile FROM python:3.9-slim # Install system dependencies RUN apt-get update && apt-get install -y \ gcc \ g++ \ libsndfile1 \ ffmpeg \ && rm -rf /var/lib/apt/lists/* # Set working directory WORKDIR /app # Copy requirements and install Python dependencies COPY requirements.txt . RUN pip install --no-cache-dir -r requirements.txt # Copy source code COPY src/ src/ COPY configs/ configs/ COPY models/ models/ # Set environment variables ENV PYTHONPATH=/app ENV MODEL_PATH=/app/models/best_model.pt # Expose port EXPOSE 8000 # Health check HEALTHCHECK --interval=30s --timeout=30s --start-period=5s --retries=3 \ CMD curl -f http://localhost:8000/health || exit 1 # Run API server CMD ["uvicorn", "src.inference.api:app", "--host", "0.0.0.0", "--port", "8000"]
# docker-compose.yml version: '3.8' services: music-separation-api: build: . ports: - "8000:8000" volumes: - ./models:/app/models - ./logs:/app/logs environment: - MODEL_PATH=/app/models/best_model.pt - LOG_LEVEL=INFO restart: unless-stopped deploy: resources: reservations: devices: - driver: nvidia count: 1 capabilities: [gpu] nginx: image: nginx:alpine ports: - "80:80" volumes: - ./nginx.conf:/etc/nginx/nginx.conf depends_on: - music-separation-api restart: unless-stopped # Build and run: # docker-compose up --build -d
Production Best Practices
Model Versioning and Rollback
Implement proper model versioning with Blue-Green deployments. Always keep the previous working model available for instant rollback. Use semantic versioning for model releases.
Monitoring and Observability
Monitor inference latency, GPU utilization, and model accuracy drift. Use tools like Prometheus + Grafana for metrics, and implement proper logging with structured data for debugging.
Resource Optimization
Use model quantization and ONNX Runtime for production inference. Implement request batching and GPU memory management. Consider using TensorRT for NVIDIA GPUs to achieve 2-5x speedup.
Security and Rate Limiting
Implement proper authentication, input validation, and rate limiting. Use HTTPS in production and validate all audio inputs for malicious content. Consider implementing API keys for access control.
Testing Strategy
Implement unit tests for data processing, integration tests for the API, and regression tests for model accuracy. Use golden datasets for consistent evaluation across model versions.
Scalability Considerations
Design for horizontal scaling with load balancers and multiple GPU instances. Implement proper queuing for long-running separation tasks. Consider using cloud services like AWS Batch or Kubernetes for orchestration.
Next Steps and Extensions
Model Improvements
- • Implement Demucs waveform model
- • Add 4-stem separation (drums, bass, vocals, other)
- • Experiment with Transformer architectures
- • Try diffusion models for generation
Production Features
- • Real-time streaming separation
- • Web interface for file uploads
- • Batch processing capabilities
- • Integration with cloud storage
Papers to Read
- • "Music Source Separation in the Waveform Domain"
- • "Conv-TasNet: Surpassing Ideal Time-Frequency Magnitude Masking"
- • "Hybrid Transformers for Music Source Separation"
Code Repositories
- • facebook/demucs - Official Demucs implementation
- • sigsep/open-unmix-pytorch - Open-Unmix baseline
- • asteroid-team/asteroid - Toolkit for audio separation