mirror of
https://github.com/rishikanthc/Scriberr.git
synced 2026-03-03 03:57:01 +00:00
feat: Add comprehensive ROCm GPU support for transcription and diarization
- Use ROCm-compatible CTranslate2 fork when ROCm hardware is detected - Map 'rocm' device to 'cuda' for faster-whisper compatibility - Add configurable GPU architecture support (RDNA2/RDNA3) - Improve ROCm detection with fallback mechanisms - Update Docker configurations with proper environment variables - Add debugging and setup scripts for troubleshooting ROCm issues This resolves the 'unsupported device rocm' error and enables full AMD GPU acceleration for both transcription and speaker diarization workflows while maintaining backward compatibility with CUDA and CPU devices.
This commit is contained in:
@@ -51,8 +51,8 @@ ENV PYTHONUNBUFFERED=1 \
|
||||
UPLOAD_DIR=/app/data/uploads \
|
||||
PUID=1000 \
|
||||
PGID=1000 \
|
||||
HSA_OVERRIDE_GFX_VERSION=10.3.0 \
|
||||
PYTORCH_ROCM_ARCH=gfx1030 \
|
||||
HSA_OVERRIDE_GFX_VERSION=${HSA_OVERRIDE_GFX_VERSION:-11.0.0} \
|
||||
PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH:-gfx1100} \
|
||||
ROCM_PATH=/opt/rocm
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
64
debugging/debug-rocm.sh
Normal file
64
debugging/debug-rocm.sh
Normal file
@@ -0,0 +1,64 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Debug script for ROCm setup in Scriberr
|
||||
echo "=== ROCm Debug Script ==="
|
||||
echo "Current directory: $(pwd)"
|
||||
echo "User: $(whoami)"
|
||||
echo "Date: $(date)"
|
||||
|
||||
# Check environment variables
|
||||
echo ""
|
||||
echo "=== Environment Variables ==="
|
||||
echo "HSA_OVERRIDE_GFX_VERSION: ${HSA_OVERRIDE_GFX_VERSION:-not set}"
|
||||
echo "PYTORCH_ROCM_ARCH: ${PYTORCH_ROCM_ARCH:-not set}"
|
||||
echo "ROCM_PATH: ${ROCM_PATH:-not set}"
|
||||
echo "CUDA_VISIBLE_DEVICES: ${CUDA_VISIBLE_DEVICES:-not set}"
|
||||
|
||||
# Check if ROCm is available
|
||||
echo ""
|
||||
echo "=== ROCm Detection ==="
|
||||
if command -v python3 &> /dev/null; then
|
||||
echo "Testing ROCm with system Python3:"
|
||||
python3 -c "import torch; print('PyTorch version:', torch.__version__); print('CUDA available:', torch.cuda.is_available()); print('ROCm available:', hasattr(torch, 'hip') and torch.hip.is_available() if hasattr(torch, 'hip') else False)"
|
||||
else
|
||||
echo "Python3 not found in PATH"
|
||||
fi
|
||||
|
||||
# Check if uv is available
|
||||
echo ""
|
||||
echo "=== UV Check ==="
|
||||
if command -v uv &> /dev/null; then
|
||||
echo "UV found: $(uv --version)"
|
||||
else
|
||||
echo "UV not found in PATH"
|
||||
fi
|
||||
|
||||
# Check WhisperX environment
|
||||
echo ""
|
||||
echo "=== WhisperX Environment ==="
|
||||
WHISPERX_PATH="./whisperx-env/WhisperX"
|
||||
if [ -d "$WHISPERX_PATH" ]; then
|
||||
echo "WhisperX directory exists at: $WHISPERX_PATH"
|
||||
if [ -f "$WHISPERX_PATH/pyproject.toml" ]; then
|
||||
echo "pyproject.toml exists"
|
||||
echo "Current ctranslate2 dependency:"
|
||||
grep -i ctranslate2 "$WHISPERX_PATH/pyproject.toml" || echo "No ctranslate2 dependency found"
|
||||
else
|
||||
echo "pyproject.toml not found"
|
||||
fi
|
||||
else
|
||||
echo "WhisperX directory not found at: $WHISPERX_PATH"
|
||||
fi
|
||||
|
||||
# Check if we can run WhisperX
|
||||
echo ""
|
||||
echo "=== WhisperX Test ==="
|
||||
if [ -d "$WHISPERX_PATH" ] && command -v uv &> /dev/null; then
|
||||
echo "Testing WhisperX import:"
|
||||
uv run --project "$WHISPERX_PATH" python -c "import whisperx; print('WhisperX import successful')" 2>&1
|
||||
else
|
||||
echo "Cannot test WhisperX (missing directory or uv)"
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "=== Debug Complete ==="
|
||||
76
debugging/setup-rocm.sh
Normal file
76
debugging/setup-rocm.sh
Normal file
@@ -0,0 +1,76 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Manual ROCm setup script for Scriberr
|
||||
echo "=== Manual ROCm Setup for Scriberr ==="
|
||||
|
||||
# Set default environment variables if not set
|
||||
export HSA_OVERRIDE_GFX_VERSION=${HSA_OVERRIDE_GFX_VERSION:-11.0.0}
|
||||
export PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH:-gfx1100}
|
||||
export ROCM_PATH=${ROCM_PATH:-/opt/rocm}
|
||||
|
||||
echo "Using HSA_OVERRIDE_GFX_VERSION: $HSA_OVERRIDE_GFX_VERSION"
|
||||
echo "Using PYTORCH_ROCM_ARCH: $PYTORCH_ROCM_ARCH"
|
||||
echo "Using ROCM_PATH: $ROCM_PATH"
|
||||
|
||||
# Create WhisperX environment directory
|
||||
WHISPERX_ENV="./whisperx-env"
|
||||
WHISPERX_PATH="$WHISPERX_ENV/WhisperX"
|
||||
|
||||
echo ""
|
||||
echo "Setting up WhisperX environment at: $WHISPERX_PATH"
|
||||
|
||||
# Remove existing environment if it exists
|
||||
if [ -d "$WHISPERX_ENV" ]; then
|
||||
echo "Removing existing environment..."
|
||||
rm -rf "$WHISPERX_ENV"
|
||||
fi
|
||||
|
||||
# Create directory
|
||||
mkdir -p "$WHISPERX_ENV"
|
||||
|
||||
# Clone WhisperX
|
||||
echo "Cloning WhisperX repository..."
|
||||
cd "$WHISPERX_ENV"
|
||||
git clone https://github.com/m-bain/WhisperX.git
|
||||
|
||||
# Check if ROCm is available
|
||||
echo ""
|
||||
echo "Checking ROCm availability..."
|
||||
ROCM_AVAILABLE=false
|
||||
if python3 -c "import torch; print(hasattr(torch, 'hip') and torch.hip.is_available())" 2>/dev/null | grep -q "True"; then
|
||||
ROCM_AVAILABLE=true
|
||||
echo "ROCm detected - using ROCm-compatible ctranslate2 fork"
|
||||
else
|
||||
echo "ROCm not detected - using standard ctranslate2"
|
||||
fi
|
||||
|
||||
# Update pyproject.toml
|
||||
echo ""
|
||||
echo "Updating pyproject.toml dependencies..."
|
||||
if [ "$ROCM_AVAILABLE" = true ]; then
|
||||
# Use ROCm fork
|
||||
sed -i 's/ctranslate2<4.5.0/ctranslate2 @ git+https:\/\/github.com\/arlo-phoenix\/CTranslate2.git@rocm/' "$WHISPERX_PATH/pyproject.toml"
|
||||
sed -i 's/ctranslate2==4.6.0/ctranslate2 @ git+https:\/\/github.com\/arlo-phoenix\/CTranslate2.git@rocm/' "$WHISPERX_PATH/pyproject.toml"
|
||||
else
|
||||
# Use standard ctranslate2
|
||||
sed -i 's/ctranslate2<4.5.0/ctranslate2==4.6.0/' "$WHISPERX_PATH/pyproject.toml"
|
||||
fi
|
||||
|
||||
# Add yt-dlp if not present
|
||||
if ! grep -q "yt-dlp" "$WHISPERX_PATH/pyproject.toml"; then
|
||||
echo "Adding yt-dlp dependency..."
|
||||
sed -i 's/"transformers>=4.48.0",/"transformers>=4.48.0",\n "yt-dlp",/' "$WHISPERX_PATH/pyproject.toml"
|
||||
fi
|
||||
|
||||
# Install dependencies
|
||||
echo ""
|
||||
echo "Installing dependencies with uv sync..."
|
||||
cd "$WHISPERX_PATH"
|
||||
uv sync --all-extras --dev --native-tls
|
||||
|
||||
echo ""
|
||||
echo "Setup complete! You can now test with:"
|
||||
echo "uv run --project $WHISPERX_PATH python -c \"import whisperx; print('WhisperX ready')\""
|
||||
echo ""
|
||||
echo "For transcription, use:"
|
||||
echo "uv run --project $WHISPERX_PATH python -m whisperx <audio_file> --device cuda"
|
||||
@@ -17,8 +17,8 @@ services:
|
||||
- /dev/kfd
|
||||
- /dev/dri
|
||||
environment:
|
||||
- HSA_OVERRIDE_GFX_VERSION=10.3.0
|
||||
- PYTORCH_ROCM_ARCH=gfx1030
|
||||
- HSA_OVERRIDE_GFX_VERSION=${HSA_OVERRIDE_GFX_VERSION:-11.0.0}
|
||||
- PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH:-gfx1100}
|
||||
- ROCM_PATH=/opt/rocm
|
||||
|
||||
volumes:
|
||||
|
||||
@@ -1585,14 +1585,29 @@ func (h *Handler) isCudaAvailable() bool {
|
||||
|
||||
// isRocmAvailable checks if ROCm is available on the system
|
||||
func (h *Handler) isRocmAvailable() bool {
|
||||
// Try to run a simple PyTorch command to check ROCm availability
|
||||
// First try using the WhisperX environment if available
|
||||
whisperXPath := filepath.Join(h.config.WhisperXEnv, "WhisperX")
|
||||
cmd := exec.Command(h.config.UVPath, "run", "--native-tls", "--project", whisperXPath, "python", "-c", "import torch; print(hasattr(torch, 'hip') and torch.hip.is_available())")
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
return false
|
||||
if _, err := os.Stat(whisperXPath); err == nil {
|
||||
cmd := exec.Command(h.config.UVPath, "run", "--native-tls", "--project", whisperXPath, "python", "-c", "import torch; print(hasattr(torch, 'hip') and torch.hip.is_available())")
|
||||
output, err := cmd.Output()
|
||||
if err == nil && strings.TrimSpace(string(output)) == "True" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return strings.TrimSpace(string(output)) == "True"
|
||||
|
||||
// Fallback: try system Python
|
||||
cmd := exec.Command("python3", "-c", "import torch; print(hasattr(torch, 'hip') and torch.hip.is_available())")
|
||||
output, err := cmd.Output()
|
||||
if err == nil && strings.TrimSpace(string(output)) == "True" {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check for test mode environment variable
|
||||
if os.Getenv("SCRIBERR_TEST_ROCM") == "true" {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -151,8 +152,8 @@ func (ws *WhisperXService) ProcessJobWithProcess(ctx context.Context, jobID stri
|
||||
// Set ROCm environment variables if ROCm device is selected
|
||||
if device == "rocm" {
|
||||
cmd.Env = append(cmd.Env,
|
||||
"PYTORCH_ROCM_ARCH=gfx1030",
|
||||
"HSA_OVERRIDE_GFX_VERSION=10.3.0",
|
||||
fmt.Sprintf("PYTORCH_ROCM_ARCH=%s", ws.getRocmArch()),
|
||||
fmt.Sprintf("HSA_OVERRIDE_GFX_VERSION=%s", ws.getRocmVersion()),
|
||||
"ROCM_PATH=/opt/rocm",
|
||||
"CUDA_VISIBLE_DEVICES=", // Disable CUDA when using ROCm
|
||||
)
|
||||
@@ -278,12 +279,15 @@ func (ws *WhisperXService) updateWhisperXDependencies(whisperxPath string) error
|
||||
|
||||
// Check if ROCm is available and use ROCm-compatible ctranslate2 fork
|
||||
if ws.isRocmAvailable() {
|
||||
// Replace ctranslate2 dependency with ROCm-compatible fork
|
||||
content = strings.ReplaceAll(content, "ctranslate2<4.5.0", "ctranslate2 @ git+https://github.com/arlo-phoenix/CTranslate2.git@rocm")
|
||||
content = strings.ReplaceAll(content, "ctranslate2==4.6.0", "ctranslate2 @ git+https://github.com/arlo-phoenix/CTranslate2.git@rocm")
|
||||
fmt.Printf("DEBUG: ROCm detected, using ROCm-compatible ctranslate2 fork\n")
|
||||
// Replace any ctranslate2 dependency with ROCm-compatible fork using regex
|
||||
re := regexp.MustCompile(`ctranslate2[^,\]\n]*`)
|
||||
content = re.ReplaceAllString(content, "ctranslate2 @ git+https://github.com/arlo-phoenix/CTranslate2.git@rocm")
|
||||
} else {
|
||||
// Use standard ctranslate2 for CUDA/CPU
|
||||
content = strings.ReplaceAll(content, "ctranslate2<4.5.0", "ctranslate2==4.6.0")
|
||||
fmt.Printf("DEBUG: ROCm not detected, using standard ctranslate2\n")
|
||||
// Replace any ctranslate2 dependency with standard version using regex
|
||||
re := regexp.MustCompile(`ctranslate2[^,\]\n]*`)
|
||||
content = re.ReplaceAllString(content, "ctranslate2==4.6.0")
|
||||
}
|
||||
|
||||
// Add yt-dlp if not already present
|
||||
@@ -311,8 +315,8 @@ func (ws *WhisperXService) uvSyncWhisperX(whisperxPath string) error {
|
||||
// Set environment variables for ROCm if available
|
||||
if ws.isRocmAvailable() {
|
||||
cmd.Env = append(os.Environ(),
|
||||
"PYTORCH_ROCM_ARCH=gfx1030",
|
||||
"HSA_OVERRIDE_GFX_VERSION=10.3.0",
|
||||
fmt.Sprintf("PYTORCH_ROCM_ARCH=%s", ws.getRocmArch()),
|
||||
fmt.Sprintf("HSA_OVERRIDE_GFX_VERSION=%s", ws.getRocmVersion()),
|
||||
"ROCM_PATH=/opt/rocm",
|
||||
)
|
||||
}
|
||||
@@ -629,6 +633,48 @@ func (ws *WhisperXService) isCudaAvailable() bool {
|
||||
return strings.TrimSpace(string(output)) == "True"
|
||||
}
|
||||
|
||||
// getRocmArch returns the appropriate ROCm architecture string
|
||||
func (ws *WhisperXService) getRocmArch() string {
|
||||
// Check environment variable first
|
||||
if arch := os.Getenv("PYTORCH_ROCM_ARCH"); arch != "" {
|
||||
return arch
|
||||
}
|
||||
|
||||
// Try to detect GPU architecture
|
||||
cmd := exec.Command("python3", "-c", "import torch; print(torch.hip.get_device_properties(0).gcnArchName if torch.hip.is_available() and torch.hip.device_count() > 0 else 'gfx1100')")
|
||||
output, err := cmd.Output()
|
||||
if err == nil {
|
||||
arch := strings.TrimSpace(string(output))
|
||||
// Map common architectures
|
||||
switch arch {
|
||||
case "gfx1100", "gfx1101", "gfx1102": // RDNA3
|
||||
return "gfx1100"
|
||||
case "gfx1030", "gfx1031", "gfx1032": // RDNA2
|
||||
return "gfx1030"
|
||||
default:
|
||||
return arch
|
||||
}
|
||||
}
|
||||
|
||||
// Default fallback to RDNA3
|
||||
return "gfx1100"
|
||||
}
|
||||
|
||||
// getRocmVersion returns the appropriate HSA override version
|
||||
func (ws *WhisperXService) getRocmVersion() string {
|
||||
// Check environment variable first
|
||||
if version := os.Getenv("HSA_OVERRIDE_GFX_VERSION"); version != "" {
|
||||
return version
|
||||
}
|
||||
|
||||
// Default based on architecture
|
||||
arch := ws.getRocmArch()
|
||||
if strings.HasPrefix(arch, "gfx11") {
|
||||
return "11.0.0" // RDNA3
|
||||
}
|
||||
return "10.3.0" // RDNA2
|
||||
}
|
||||
|
||||
// isRocmAvailable checks if ROCm is available
|
||||
func (ws *WhisperXService) isRocmAvailable() bool {
|
||||
// Check for test mode environment variable
|
||||
@@ -636,10 +682,23 @@ func (ws *WhisperXService) isRocmAvailable() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// Try system Python first (more reliable)
|
||||
cmd := exec.Command("python3", "-c", "import torch; print(hasattr(torch, 'hip') and torch.hip.is_available())")
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
return false
|
||||
if err == nil && strings.TrimSpace(string(output)) == "True" {
|
||||
return true
|
||||
}
|
||||
return strings.TrimSpace(string(output)) == "True"
|
||||
|
||||
// Fallback: try with uv if system Python fails
|
||||
envPath := ws.getEnvPath()
|
||||
whisperxPath := filepath.Join(envPath, "WhisperX")
|
||||
if _, err := os.Stat(whisperxPath); err == nil {
|
||||
cmd := exec.Command("uv", "run", "--native-tls", "--project", whisperxPath, "python", "-c", "import torch; print(hasattr(torch, 'hip') and torch.hip.is_available())")
|
||||
output, err := cmd.Output()
|
||||
if err == nil && strings.TrimSpace(string(output)) == "True" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user