fix auto device detection in voxtral

This commit is contained in:
rishikanthc
2025-12-31 14:56:52 -08:00
parent 7911f6b283
commit c317ca609d
2 changed files with 47 additions and 62 deletions

View File

@@ -35,8 +35,9 @@ def transcribe_audio(
Dictionary containing transcription results
"""
# Determine device
if device == "auto":
device = "cuda" if torch.cuda.is_available() else "cpu"
# if device == "auto":
# device = "cuda" if torch.cuda.is_available() else "cpu"
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Loading Voxtral model on {device}...", file=sys.stderr)
@@ -57,9 +58,7 @@ def transcribe_audio(
# Prepare transcription request using the proper method
inputs = processor.apply_transcription_request(
language=language,
audio=audio_path,
model_id=model_id
language=language, audio=audio_path, model_id=model_id
)
# Move inputs to device with correct dtype
@@ -76,8 +75,7 @@ def transcribe_audio(
# Decode only the newly generated tokens (skip the input prompt)
decoded_outputs = processor.batch_decode(
outputs[:, inputs.input_ids.shape[1]:],
skip_special_tokens=True
outputs[:, inputs.input_ids.shape[1] :], skip_special_tokens=True
)
transcription_text = decoded_outputs[0]
@@ -94,7 +92,7 @@ def transcribe_audio(
"start": 0.0,
"end": 0.0, # Duration unknown without audio analysis
"text": transcription_text,
"words": [] # Voxtral doesn't provide word-level timestamps
"words": [], # Voxtral doesn't provide word-level timestamps
}
],
"language": language,
@@ -104,7 +102,7 @@ def transcribe_audio(
# Write output
output_file = Path(output_path)
with output_file.open('w', encoding='utf-8') as f:
with output_file.open("w", encoding="utf-8") as f:
json.dump(result, f, ensure_ascii=False, indent=2)
print(f"Results written to {output_path}", file=sys.stderr)
@@ -116,40 +114,29 @@ def main():
parser = argparse.ArgumentParser(
description="Transcribe audio using Voxtral-mini model"
)
parser.add_argument("audio_path", type=str, help="Path to input audio file")
parser.add_argument("output_path", type=str, help="Path to output JSON file")
parser.add_argument(
"audio_path",
type=str,
help="Path to input audio file"
)
parser.add_argument(
"output_path",
type=str,
help="Path to output JSON file"
)
parser.add_argument(
"--language",
type=str,
default="en",
help="Language code (default: en)"
"--language", type=str, default="en", help="Language code (default: en)"
)
parser.add_argument(
"--model-id",
type=str,
default="mistralai/Voxtral-mini",
help="HuggingFace model ID (default: mistralai/Voxtral-mini)"
help="HuggingFace model ID (default: mistralai/Voxtral-mini)",
)
parser.add_argument(
"--device",
type=str,
default="auto",
choices=["cpu", "cuda", "auto"],
help="Device to use (default: auto)"
help="Device to use (default: auto)",
)
parser.add_argument(
"--max-new-tokens",
type=int,
default=8192,
help="Maximum number of tokens to generate (default: 8192)"
help="Maximum number of tokens to generate (default: 8192)",
)
args = parser.parse_args()
@@ -166,6 +153,7 @@ def main():
except Exception as e:
print(f"Error: {e}", file=sys.stderr)
import traceback
traceback.print_exc(file=sys.stderr)
sys.exit(1)

View File

@@ -29,11 +29,13 @@ def split_audio_file(audio_path, chunk_duration_secs=1500):
end_sample = min(start_sample + chunk_samples, len(audio))
chunk_audio = audio[start_sample:end_sample]
start_time = start_sample / sr
chunks.append({
'audio': chunk_audio,
'start_time': start_time,
'duration': len(chunk_audio) / sr
})
chunks.append(
{
"audio": chunk_audio,
"start_time": start_time,
"duration": len(chunk_audio) / sr,
}
)
return chunks, sr
@@ -51,8 +53,9 @@ def transcribe_buffered(
Transcribe long audio by splitting into chunks and merging results.
"""
# Determine device
if device == "auto":
device = "cuda" if torch.cuda.is_available() else "cpu"
# if device == "auto":
# device = "cuda" if torch.cuda.is_available() else "cpu"
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Loading Voxtral model on {device}...", file=sys.stderr)
@@ -77,18 +80,19 @@ def transcribe_buffered(
full_text = []
for i, chunk_info in enumerate(chunks):
print(f"Transcribing chunk {i+1}/{len(chunks)} (duration: {chunk_info['duration']:.1f}s)...", file=sys.stderr)
print(
f"Transcribing chunk {i + 1}/{len(chunks)} (duration: {chunk_info['duration']:.1f}s)...",
file=sys.stderr,
)
# Save chunk to temporary file
chunk_path = f"/tmp/voxtral_chunk_{i}.wav"
sf.write(chunk_path, chunk_info['audio'], sr)
sf.write(chunk_path, chunk_info["audio"], sr)
try:
# Prepare transcription request for this chunk
inputs = processor.apply_transcription_request(
language=language,
audio=chunk_path,
model_id=model_id
language=language, audio=chunk_path, model_id=model_id
)
# Move inputs to device with correct dtype
@@ -103,14 +107,15 @@ def transcribe_buffered(
# Decode only the newly generated tokens (skip the input prompt)
decoded_outputs = processor.batch_decode(
outputs[:, inputs.input_ids.shape[1]:],
skip_special_tokens=True
outputs[:, inputs.input_ids.shape[1] :], skip_special_tokens=True
)
chunk_text = decoded_outputs[0]
full_text.append(chunk_text)
print(f"Chunk {i+1} complete: {len(chunk_text)} characters", file=sys.stderr)
print(
f"Chunk {i + 1} complete: {len(chunk_text)} characters", file=sys.stderr
)
finally:
# Clean up temp file
@@ -119,7 +124,9 @@ def transcribe_buffered(
# Concatenate all chunks
final_text = " ".join(full_text)
print(f"Transcription complete: {len(final_text)} characters total", file=sys.stderr)
print(
f"Transcription complete: {len(final_text)} characters total", file=sys.stderr
)
# Prepare output in Scriberr format
# Note: Voxtral doesn't provide word-level timestamps
@@ -131,7 +138,7 @@ def transcribe_buffered(
"start": 0.0,
"end": 0.0, # Duration unknown without audio analysis
"text": final_text,
"words": [] # Voxtral doesn't provide word-level timestamps
"words": [], # Voxtral doesn't provide word-level timestamps
}
],
"language": language,
@@ -144,7 +151,7 @@ def transcribe_buffered(
# Write output
output_file_path = Path(output_file)
with output_file_path.open('w', encoding='utf-8') as f:
with output_file_path.open("w", encoding="utf-8") as f:
json.dump(result, f, ensure_ascii=False, indent=2)
print(f"Results written to {output_file}", file=sys.stderr)
@@ -156,46 +163,35 @@ def main():
parser = argparse.ArgumentParser(
description="Transcribe long audio using Voxtral with chunking"
)
parser.add_argument("audio_path", type=str, help="Path to input audio file")
parser.add_argument("output_path", type=str, help="Path to output JSON file")
parser.add_argument(
"audio_path",
type=str,
help="Path to input audio file"
)
parser.add_argument(
"output_path",
type=str,
help="Path to output JSON file"
)
parser.add_argument(
"--language",
type=str,
default="en",
help="Language code (default: en)"
"--language", type=str, default="en", help="Language code (default: en)"
)
parser.add_argument(
"--model-id",
type=str,
default="mistralai/Voxtral-mini",
help="HuggingFace model ID (default: mistralai/Voxtral-mini)"
help="HuggingFace model ID (default: mistralai/Voxtral-mini)",
)
parser.add_argument(
"--device",
type=str,
default="auto",
choices=["cpu", "cuda", "auto"],
help="Device to use (default: auto)"
help="Device to use (default: auto)",
)
parser.add_argument(
"--max-new-tokens",
type=int,
default=8192,
help="Maximum number of tokens to generate per chunk (default: 8192)"
help="Maximum number of tokens to generate per chunk (default: 8192)",
)
parser.add_argument(
"--chunk-len",
type=float,
default=1500,
help="Chunk duration in seconds (default: 1500 = 25 minutes)"
help="Chunk duration in seconds (default: 1500 = 25 minutes)",
)
args = parser.parse_args()
@@ -217,6 +213,7 @@ def main():
except Exception as e:
print(f"Error: {e}", file=sys.stderr)
import traceback
traceback.print_exc(file=sys.stderr)
sys.exit(1)