feat: Add comprehensive ROCm GPU support for transcription and diarization

- Use ROCm-compatible CTranslate2 fork when ROCm hardware is detected
- Map 'rocm' device to 'cuda' for faster-whisper compatibility
- Add configurable GPU architecture support (RDNA2/RDNA3)
- Improve ROCm detection with fallback mechanisms
- Update Docker configurations with proper environment variables
- Add debugging and setup scripts for troubleshooting ROCm issues

This resolves the 'unsupported device rocm' error and enables full AMD GPU
acceleration for both transcription and speaker diarization workflows while
maintaining backward compatibility with CUDA and CPU devices.
This commit is contained in:
SpirusNox
2025-09-16 11:33:36 -05:00
parent 9b5afeb929
commit 6615130786
6 changed files with 236 additions and 22 deletions

View File

@@ -51,8 +51,8 @@ ENV PYTHONUNBUFFERED=1 \
UPLOAD_DIR=/app/data/uploads \
PUID=1000 \
PGID=1000 \
HSA_OVERRIDE_GFX_VERSION=10.3.0 \
PYTORCH_ROCM_ARCH=gfx1030 \
HSA_OVERRIDE_GFX_VERSION=${HSA_OVERRIDE_GFX_VERSION:-11.0.0} \
PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH:-gfx1100} \
ROCM_PATH=/opt/rocm
WORKDIR /app

64
debugging/debug-rocm.sh Normal file
View File

@@ -0,0 +1,64 @@
#!/bin/bash
# Debug script for ROCm setup in Scriberr
echo "=== ROCm Debug Script ==="
echo "Current directory: $(pwd)"
echo "User: $(whoami)"
echo "Date: $(date)"
# Check environment variables
echo ""
echo "=== Environment Variables ==="
echo "HSA_OVERRIDE_GFX_VERSION: ${HSA_OVERRIDE_GFX_VERSION:-not set}"
echo "PYTORCH_ROCM_ARCH: ${PYTORCH_ROCM_ARCH:-not set}"
echo "ROCM_PATH: ${ROCM_PATH:-not set}"
echo "CUDA_VISIBLE_DEVICES: ${CUDA_VISIBLE_DEVICES:-not set}"
# Check if ROCm is available
echo ""
echo "=== ROCm Detection ==="
if command -v python3 &> /dev/null; then
echo "Testing ROCm with system Python3:"
python3 -c "import torch; print('PyTorch version:', torch.__version__); print('CUDA available:', torch.cuda.is_available()); print('ROCm available:', hasattr(torch, 'hip') and torch.hip.is_available() if hasattr(torch, 'hip') else False)"
else
echo "Python3 not found in PATH"
fi
# Check if uv is available
echo ""
echo "=== UV Check ==="
if command -v uv &> /dev/null; then
echo "UV found: $(uv --version)"
else
echo "UV not found in PATH"
fi
# Check WhisperX environment
echo ""
echo "=== WhisperX Environment ==="
WHISPERX_PATH="./whisperx-env/WhisperX"
if [ -d "$WHISPERX_PATH" ]; then
echo "WhisperX directory exists at: $WHISPERX_PATH"
if [ -f "$WHISPERX_PATH/pyproject.toml" ]; then
echo "pyproject.toml exists"
echo "Current ctranslate2 dependency:"
grep -i ctranslate2 "$WHISPERX_PATH/pyproject.toml" || echo "No ctranslate2 dependency found"
else
echo "pyproject.toml not found"
fi
else
echo "WhisperX directory not found at: $WHISPERX_PATH"
fi
# Check if we can run WhisperX
echo ""
echo "=== WhisperX Test ==="
if [ -d "$WHISPERX_PATH" ] && command -v uv &> /dev/null; then
echo "Testing WhisperX import:"
uv run --project "$WHISPERX_PATH" python -c "import whisperx; print('WhisperX import successful')" 2>&1
else
echo "Cannot test WhisperX (missing directory or uv)"
fi
echo ""
echo "=== Debug Complete ==="

76
debugging/setup-rocm.sh Normal file
View File

@@ -0,0 +1,76 @@
#!/bin/bash
# Manual ROCm setup script for Scriberr
echo "=== Manual ROCm Setup for Scriberr ==="
# Set default environment variables if not set
export HSA_OVERRIDE_GFX_VERSION=${HSA_OVERRIDE_GFX_VERSION:-11.0.0}
export PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH:-gfx1100}
export ROCM_PATH=${ROCM_PATH:-/opt/rocm}
echo "Using HSA_OVERRIDE_GFX_VERSION: $HSA_OVERRIDE_GFX_VERSION"
echo "Using PYTORCH_ROCM_ARCH: $PYTORCH_ROCM_ARCH"
echo "Using ROCM_PATH: $ROCM_PATH"
# Create WhisperX environment directory
WHISPERX_ENV="./whisperx-env"
WHISPERX_PATH="$WHISPERX_ENV/WhisperX"
echo ""
echo "Setting up WhisperX environment at: $WHISPERX_PATH"
# Remove existing environment if it exists
if [ -d "$WHISPERX_ENV" ]; then
echo "Removing existing environment..."
rm -rf "$WHISPERX_ENV"
fi
# Create directory
mkdir -p "$WHISPERX_ENV"
# Clone WhisperX
echo "Cloning WhisperX repository..."
cd "$WHISPERX_ENV"
git clone https://github.com/m-bain/WhisperX.git
# Check if ROCm is available
echo ""
echo "Checking ROCm availability..."
ROCM_AVAILABLE=false
if python3 -c "import torch; print(hasattr(torch, 'hip') and torch.hip.is_available())" 2>/dev/null | grep -q "True"; then
ROCM_AVAILABLE=true
echo "ROCm detected - using ROCm-compatible ctranslate2 fork"
else
echo "ROCm not detected - using standard ctranslate2"
fi
# Update pyproject.toml
echo ""
echo "Updating pyproject.toml dependencies..."
if [ "$ROCM_AVAILABLE" = true ]; then
# Use ROCm fork
sed -i 's/ctranslate2<4.5.0/ctranslate2 @ git+https:\/\/github.com\/arlo-phoenix\/CTranslate2.git@rocm/' "$WHISPERX_PATH/pyproject.toml"
sed -i 's/ctranslate2==4.6.0/ctranslate2 @ git+https:\/\/github.com\/arlo-phoenix\/CTranslate2.git@rocm/' "$WHISPERX_PATH/pyproject.toml"
else
# Use standard ctranslate2
sed -i 's/ctranslate2<4.5.0/ctranslate2==4.6.0/' "$WHISPERX_PATH/pyproject.toml"
fi
# Add yt-dlp if not present
if ! grep -q "yt-dlp" "$WHISPERX_PATH/pyproject.toml"; then
echo "Adding yt-dlp dependency..."
sed -i 's/"transformers>=4.48.0",/"transformers>=4.48.0",\n "yt-dlp",/' "$WHISPERX_PATH/pyproject.toml"
fi
# Install dependencies
echo ""
echo "Installing dependencies with uv sync..."
cd "$WHISPERX_PATH"
uv sync --all-extras --dev --native-tls
echo ""
echo "Setup complete! You can now test with:"
echo "uv run --project $WHISPERX_PATH python -c \"import whisperx; print('WhisperX ready')\""
echo ""
echo "For transcription, use:"
echo "uv run --project $WHISPERX_PATH python -m whisperx <audio_file> --device cuda"

View File

@@ -17,8 +17,8 @@ services:
- /dev/kfd
- /dev/dri
environment:
- HSA_OVERRIDE_GFX_VERSION=10.3.0
- PYTORCH_ROCM_ARCH=gfx1030
- HSA_OVERRIDE_GFX_VERSION=${HSA_OVERRIDE_GFX_VERSION:-11.0.0}
- PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH:-gfx1100}
- ROCM_PATH=/opt/rocm
volumes:

View File

@@ -1585,14 +1585,29 @@ func (h *Handler) isCudaAvailable() bool {
// isRocmAvailable checks if ROCm is available on the system
func (h *Handler) isRocmAvailable() bool {
// Try to run a simple PyTorch command to check ROCm availability
// First try using the WhisperX environment if available
whisperXPath := filepath.Join(h.config.WhisperXEnv, "WhisperX")
cmd := exec.Command(h.config.UVPath, "run", "--native-tls", "--project", whisperXPath, "python", "-c", "import torch; print(hasattr(torch, 'hip') and torch.hip.is_available())")
output, err := cmd.Output()
if err != nil {
return false
if _, err := os.Stat(whisperXPath); err == nil {
cmd := exec.Command(h.config.UVPath, "run", "--native-tls", "--project", whisperXPath, "python", "-c", "import torch; print(hasattr(torch, 'hip') and torch.hip.is_available())")
output, err := cmd.Output()
if err == nil && strings.TrimSpace(string(output)) == "True" {
return true
}
}
return strings.TrimSpace(string(output)) == "True"
// Fallback: try system Python
cmd := exec.Command("python3", "-c", "import torch; print(hasattr(torch, 'hip') and torch.hip.is_available())")
output, err := cmd.Output()
if err == nil && strings.TrimSpace(string(output)) == "True" {
return true
}
// Check for test mode environment variable
if os.Getenv("SCRIBERR_TEST_ROCM") == "true" {
return true
}
return false
}
// Helper functions

View File

@@ -7,6 +7,7 @@ import (
"os"
"os/exec"
"path/filepath"
"regexp"
"strconv"
"strings"
"time"
@@ -151,8 +152,8 @@ func (ws *WhisperXService) ProcessJobWithProcess(ctx context.Context, jobID stri
// Set ROCm environment variables if ROCm device is selected
if device == "rocm" {
cmd.Env = append(cmd.Env,
"PYTORCH_ROCM_ARCH=gfx1030",
"HSA_OVERRIDE_GFX_VERSION=10.3.0",
fmt.Sprintf("PYTORCH_ROCM_ARCH=%s", ws.getRocmArch()),
fmt.Sprintf("HSA_OVERRIDE_GFX_VERSION=%s", ws.getRocmVersion()),
"ROCM_PATH=/opt/rocm",
"CUDA_VISIBLE_DEVICES=", // Disable CUDA when using ROCm
)
@@ -278,12 +279,15 @@ func (ws *WhisperXService) updateWhisperXDependencies(whisperxPath string) error
// Check if ROCm is available and use ROCm-compatible ctranslate2 fork
if ws.isRocmAvailable() {
// Replace ctranslate2 dependency with ROCm-compatible fork
content = strings.ReplaceAll(content, "ctranslate2<4.5.0", "ctranslate2 @ git+https://github.com/arlo-phoenix/CTranslate2.git@rocm")
content = strings.ReplaceAll(content, "ctranslate2==4.6.0", "ctranslate2 @ git+https://github.com/arlo-phoenix/CTranslate2.git@rocm")
fmt.Printf("DEBUG: ROCm detected, using ROCm-compatible ctranslate2 fork\n")
// Replace any ctranslate2 dependency with ROCm-compatible fork using regex
re := regexp.MustCompile(`ctranslate2[^,\]\n]*`)
content = re.ReplaceAllString(content, "ctranslate2 @ git+https://github.com/arlo-phoenix/CTranslate2.git@rocm")
} else {
// Use standard ctranslate2 for CUDA/CPU
content = strings.ReplaceAll(content, "ctranslate2<4.5.0", "ctranslate2==4.6.0")
fmt.Printf("DEBUG: ROCm not detected, using standard ctranslate2\n")
// Replace any ctranslate2 dependency with standard version using regex
re := regexp.MustCompile(`ctranslate2[^,\]\n]*`)
content = re.ReplaceAllString(content, "ctranslate2==4.6.0")
}
// Add yt-dlp if not already present
@@ -311,8 +315,8 @@ func (ws *WhisperXService) uvSyncWhisperX(whisperxPath string) error {
// Set environment variables for ROCm if available
if ws.isRocmAvailable() {
cmd.Env = append(os.Environ(),
"PYTORCH_ROCM_ARCH=gfx1030",
"HSA_OVERRIDE_GFX_VERSION=10.3.0",
fmt.Sprintf("PYTORCH_ROCM_ARCH=%s", ws.getRocmArch()),
fmt.Sprintf("HSA_OVERRIDE_GFX_VERSION=%s", ws.getRocmVersion()),
"ROCM_PATH=/opt/rocm",
)
}
@@ -629,6 +633,48 @@ func (ws *WhisperXService) isCudaAvailable() bool {
return strings.TrimSpace(string(output)) == "True"
}
// getRocmArch returns the appropriate ROCm architecture string
func (ws *WhisperXService) getRocmArch() string {
// Check environment variable first
if arch := os.Getenv("PYTORCH_ROCM_ARCH"); arch != "" {
return arch
}
// Try to detect GPU architecture
cmd := exec.Command("python3", "-c", "import torch; print(torch.hip.get_device_properties(0).gcnArchName if torch.hip.is_available() and torch.hip.device_count() > 0 else 'gfx1100')")
output, err := cmd.Output()
if err == nil {
arch := strings.TrimSpace(string(output))
// Map common architectures
switch arch {
case "gfx1100", "gfx1101", "gfx1102": // RDNA3
return "gfx1100"
case "gfx1030", "gfx1031", "gfx1032": // RDNA2
return "gfx1030"
default:
return arch
}
}
// Default fallback to RDNA3
return "gfx1100"
}
// getRocmVersion returns the appropriate HSA override version
func (ws *WhisperXService) getRocmVersion() string {
// Check environment variable first
if version := os.Getenv("HSA_OVERRIDE_GFX_VERSION"); version != "" {
return version
}
// Default based on architecture
arch := ws.getRocmArch()
if strings.HasPrefix(arch, "gfx11") {
return "11.0.0" // RDNA3
}
return "10.3.0" // RDNA2
}
// isRocmAvailable checks if ROCm is available
func (ws *WhisperXService) isRocmAvailable() bool {
// Check for test mode environment variable
@@ -636,10 +682,23 @@ func (ws *WhisperXService) isRocmAvailable() bool {
return true
}
// Try system Python first (more reliable)
cmd := exec.Command("python3", "-c", "import torch; print(hasattr(torch, 'hip') and torch.hip.is_available())")
output, err := cmd.Output()
if err != nil {
return false
if err == nil && strings.TrimSpace(string(output)) == "True" {
return true
}
return strings.TrimSpace(string(output)) == "True"
// Fallback: try with uv if system Python fails
envPath := ws.getEnvPath()
whisperxPath := filepath.Join(envPath, "WhisperX")
if _, err := os.Stat(whisperxPath); err == nil {
cmd := exec.Command("uv", "run", "--native-tls", "--project", whisperxPath, "python", "-c", "import torch; print(hasattr(torch, 'hip') and torch.hip.is_available())")
output, err := cmd.Output()
if err == nil && strings.TrimSpace(string(output)) == "True" {
return true
}
}
return false
}