mirror of
https://github.com/rishikanthc/Scriberr.git
synced 2026-06-28 14:55:46 +00:00
fix auto device detection in voxtral
This commit is contained in:
@@ -35,8 +35,9 @@ def transcribe_audio(
|
||||
Dictionary containing transcription results
|
||||
"""
|
||||
# Determine device
|
||||
if device == "auto":
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
# if device == "auto":
|
||||
# device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
print(f"Loading Voxtral model on {device}...", file=sys.stderr)
|
||||
|
||||
@@ -57,9 +58,7 @@ def transcribe_audio(
|
||||
|
||||
# Prepare transcription request using the proper method
|
||||
inputs = processor.apply_transcription_request(
|
||||
language=language,
|
||||
audio=audio_path,
|
||||
model_id=model_id
|
||||
language=language, audio=audio_path, model_id=model_id
|
||||
)
|
||||
|
||||
# Move inputs to device with correct dtype
|
||||
@@ -76,8 +75,7 @@ def transcribe_audio(
|
||||
|
||||
# Decode only the newly generated tokens (skip the input prompt)
|
||||
decoded_outputs = processor.batch_decode(
|
||||
outputs[:, inputs.input_ids.shape[1]:],
|
||||
skip_special_tokens=True
|
||||
outputs[:, inputs.input_ids.shape[1] :], skip_special_tokens=True
|
||||
)
|
||||
|
||||
transcription_text = decoded_outputs[0]
|
||||
@@ -94,7 +92,7 @@ def transcribe_audio(
|
||||
"start": 0.0,
|
||||
"end": 0.0, # Duration unknown without audio analysis
|
||||
"text": transcription_text,
|
||||
"words": [] # Voxtral doesn't provide word-level timestamps
|
||||
"words": [], # Voxtral doesn't provide word-level timestamps
|
||||
}
|
||||
],
|
||||
"language": language,
|
||||
@@ -104,7 +102,7 @@ def transcribe_audio(
|
||||
|
||||
# Write output
|
||||
output_file = Path(output_path)
|
||||
with output_file.open('w', encoding='utf-8') as f:
|
||||
with output_file.open("w", encoding="utf-8") as f:
|
||||
json.dump(result, f, ensure_ascii=False, indent=2)
|
||||
|
||||
print(f"Results written to {output_path}", file=sys.stderr)
|
||||
@@ -116,40 +114,29 @@ def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Transcribe audio using Voxtral-mini model"
|
||||
)
|
||||
parser.add_argument("audio_path", type=str, help="Path to input audio file")
|
||||
parser.add_argument("output_path", type=str, help="Path to output JSON file")
|
||||
parser.add_argument(
|
||||
"audio_path",
|
||||
type=str,
|
||||
help="Path to input audio file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"output_path",
|
||||
type=str,
|
||||
help="Path to output JSON file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--language",
|
||||
type=str,
|
||||
default="en",
|
||||
help="Language code (default: en)"
|
||||
"--language", type=str, default="en", help="Language code (default: en)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-id",
|
||||
type=str,
|
||||
default="mistralai/Voxtral-mini",
|
||||
help="HuggingFace model ID (default: mistralai/Voxtral-mini)"
|
||||
help="HuggingFace model ID (default: mistralai/Voxtral-mini)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="auto",
|
||||
choices=["cpu", "cuda", "auto"],
|
||||
help="Device to use (default: auto)"
|
||||
help="Device to use (default: auto)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-new-tokens",
|
||||
type=int,
|
||||
default=8192,
|
||||
help="Maximum number of tokens to generate (default: 8192)"
|
||||
help="Maximum number of tokens to generate (default: 8192)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
@@ -166,6 +153,7 @@ def main():
|
||||
except Exception as e:
|
||||
print(f"Error: {e}", file=sys.stderr)
|
||||
import traceback
|
||||
|
||||
traceback.print_exc(file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
@@ -29,11 +29,13 @@ def split_audio_file(audio_path, chunk_duration_secs=1500):
|
||||
end_sample = min(start_sample + chunk_samples, len(audio))
|
||||
chunk_audio = audio[start_sample:end_sample]
|
||||
start_time = start_sample / sr
|
||||
chunks.append({
|
||||
'audio': chunk_audio,
|
||||
'start_time': start_time,
|
||||
'duration': len(chunk_audio) / sr
|
||||
})
|
||||
chunks.append(
|
||||
{
|
||||
"audio": chunk_audio,
|
||||
"start_time": start_time,
|
||||
"duration": len(chunk_audio) / sr,
|
||||
}
|
||||
)
|
||||
|
||||
return chunks, sr
|
||||
|
||||
@@ -51,8 +53,9 @@ def transcribe_buffered(
|
||||
Transcribe long audio by splitting into chunks and merging results.
|
||||
"""
|
||||
# Determine device
|
||||
if device == "auto":
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
# if device == "auto":
|
||||
# device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
print(f"Loading Voxtral model on {device}...", file=sys.stderr)
|
||||
|
||||
@@ -77,18 +80,19 @@ def transcribe_buffered(
|
||||
full_text = []
|
||||
|
||||
for i, chunk_info in enumerate(chunks):
|
||||
print(f"Transcribing chunk {i+1}/{len(chunks)} (duration: {chunk_info['duration']:.1f}s)...", file=sys.stderr)
|
||||
print(
|
||||
f"Transcribing chunk {i + 1}/{len(chunks)} (duration: {chunk_info['duration']:.1f}s)...",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
# Save chunk to temporary file
|
||||
chunk_path = f"/tmp/voxtral_chunk_{i}.wav"
|
||||
sf.write(chunk_path, chunk_info['audio'], sr)
|
||||
sf.write(chunk_path, chunk_info["audio"], sr)
|
||||
|
||||
try:
|
||||
# Prepare transcription request for this chunk
|
||||
inputs = processor.apply_transcription_request(
|
||||
language=language,
|
||||
audio=chunk_path,
|
||||
model_id=model_id
|
||||
language=language, audio=chunk_path, model_id=model_id
|
||||
)
|
||||
|
||||
# Move inputs to device with correct dtype
|
||||
@@ -103,14 +107,15 @@ def transcribe_buffered(
|
||||
|
||||
# Decode only the newly generated tokens (skip the input prompt)
|
||||
decoded_outputs = processor.batch_decode(
|
||||
outputs[:, inputs.input_ids.shape[1]:],
|
||||
skip_special_tokens=True
|
||||
outputs[:, inputs.input_ids.shape[1] :], skip_special_tokens=True
|
||||
)
|
||||
|
||||
chunk_text = decoded_outputs[0]
|
||||
full_text.append(chunk_text)
|
||||
|
||||
print(f"Chunk {i+1} complete: {len(chunk_text)} characters", file=sys.stderr)
|
||||
print(
|
||||
f"Chunk {i + 1} complete: {len(chunk_text)} characters", file=sys.stderr
|
||||
)
|
||||
|
||||
finally:
|
||||
# Clean up temp file
|
||||
@@ -119,7 +124,9 @@ def transcribe_buffered(
|
||||
|
||||
# Concatenate all chunks
|
||||
final_text = " ".join(full_text)
|
||||
print(f"Transcription complete: {len(final_text)} characters total", file=sys.stderr)
|
||||
print(
|
||||
f"Transcription complete: {len(final_text)} characters total", file=sys.stderr
|
||||
)
|
||||
|
||||
# Prepare output in Scriberr format
|
||||
# Note: Voxtral doesn't provide word-level timestamps
|
||||
@@ -131,7 +138,7 @@ def transcribe_buffered(
|
||||
"start": 0.0,
|
||||
"end": 0.0, # Duration unknown without audio analysis
|
||||
"text": final_text,
|
||||
"words": [] # Voxtral doesn't provide word-level timestamps
|
||||
"words": [], # Voxtral doesn't provide word-level timestamps
|
||||
}
|
||||
],
|
||||
"language": language,
|
||||
@@ -144,7 +151,7 @@ def transcribe_buffered(
|
||||
|
||||
# Write output
|
||||
output_file_path = Path(output_file)
|
||||
with output_file_path.open('w', encoding='utf-8') as f:
|
||||
with output_file_path.open("w", encoding="utf-8") as f:
|
||||
json.dump(result, f, ensure_ascii=False, indent=2)
|
||||
|
||||
print(f"Results written to {output_file}", file=sys.stderr)
|
||||
@@ -156,46 +163,35 @@ def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Transcribe long audio using Voxtral with chunking"
|
||||
)
|
||||
parser.add_argument("audio_path", type=str, help="Path to input audio file")
|
||||
parser.add_argument("output_path", type=str, help="Path to output JSON file")
|
||||
parser.add_argument(
|
||||
"audio_path",
|
||||
type=str,
|
||||
help="Path to input audio file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"output_path",
|
||||
type=str,
|
||||
help="Path to output JSON file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--language",
|
||||
type=str,
|
||||
default="en",
|
||||
help="Language code (default: en)"
|
||||
"--language", type=str, default="en", help="Language code (default: en)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-id",
|
||||
type=str,
|
||||
default="mistralai/Voxtral-mini",
|
||||
help="HuggingFace model ID (default: mistralai/Voxtral-mini)"
|
||||
help="HuggingFace model ID (default: mistralai/Voxtral-mini)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="auto",
|
||||
choices=["cpu", "cuda", "auto"],
|
||||
help="Device to use (default: auto)"
|
||||
help="Device to use (default: auto)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-new-tokens",
|
||||
type=int,
|
||||
default=8192,
|
||||
help="Maximum number of tokens to generate per chunk (default: 8192)"
|
||||
help="Maximum number of tokens to generate per chunk (default: 8192)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--chunk-len",
|
||||
type=float,
|
||||
default=1500,
|
||||
help="Chunk duration in seconds (default: 1500 = 25 minutes)"
|
||||
help="Chunk duration in seconds (default: 1500 = 25 minutes)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
@@ -217,6 +213,7 @@ def main():
|
||||
except Exception as e:
|
||||
print(f"Error: {e}", file=sys.stderr)
|
||||
import traceback
|
||||
|
||||
traceback.print_exc(file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user