289 lines
11 KiB
Python
289 lines
11 KiB
Python
"""
|
|
Translation module with extensible backend support.
|
|
|
|
To add a new translation provider:
|
|
1. Create a class that inherits from TranslationBackend
|
|
2. Implement the translate() method
|
|
3. Register it in the TranslationService class
|
|
"""
|
|
from abc import ABC, abstractmethod
|
|
from typing import Optional
|
|
import os
|
|
|
|
|
|
class TranslationBackend(ABC):
|
|
"""Abstract base class for translation backends."""
|
|
|
|
@abstractmethod
|
|
def translate(self, text: str, target_lang: str, source_lang: str = "en") -> Optional[str]:
|
|
"""
|
|
Translate text from source language to target language.
|
|
|
|
Args:
|
|
text: Text to translate
|
|
target_lang: Target language code (e.g., 'de', 'fr', 'es')
|
|
source_lang: Source language code (default: 'en')
|
|
|
|
Returns:
|
|
Translated text, or None if translation fails
|
|
"""
|
|
pass
|
|
|
|
|
|
class HuggingFaceTranslator(TranslationBackend):
|
|
"""
|
|
Translation backend using HuggingFace transformers.
|
|
|
|
Uses Helsinki-NLP models for translation.
|
|
"""
|
|
|
|
def __init__(self):
|
|
self._model = None
|
|
self._tokenizer = None
|
|
self._model_name = None
|
|
self._device = 'cpu' # Default device
|
|
|
|
def _load_model(self, target_lang: str):
|
|
"""Lazy load the translation model."""
|
|
try:
|
|
from transformers import MarianMTModel, MarianTokenizer
|
|
import torch
|
|
except ImportError as e:
|
|
raise ImportError(
|
|
"transformers library not installed. "
|
|
"Install with: pip install transformers torch"
|
|
) from e
|
|
|
|
try:
|
|
# Check for SentencePiece (required by MarianTokenizer)
|
|
import sentencepiece
|
|
except ImportError:
|
|
raise ImportError(
|
|
"SentencePiece library not installed. "
|
|
"Install with: pip install sentencepiece"
|
|
)
|
|
|
|
# Map language codes to model names
|
|
model_map = {
|
|
'de': 'Helsinki-NLP/opus-mt-en-de',
|
|
'fr': 'Helsinki-NLP/opus-mt-en-fr',
|
|
'es': 'Helsinki-NLP/opus-mt-en-es',
|
|
'it': 'Helsinki-NLP/opus-mt-en-it',
|
|
'pt': 'Helsinki-NLP/opus-mt-en-pt',
|
|
'ru': 'Helsinki-NLP/opus-mt-en-ru',
|
|
}
|
|
|
|
model_name = model_map.get(target_lang)
|
|
if not model_name:
|
|
raise ValueError(f"No model available for language: {target_lang}")
|
|
|
|
# Only reload if language changed
|
|
if self._model_name != model_name:
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
|
|
# Load tokenizer first (doesn't need device)
|
|
self._tokenizer = MarianTokenizer.from_pretrained(model_name)
|
|
|
|
# Load model - try to load directly to device to avoid meta tensor issues
|
|
try:
|
|
# For CPU, load normally
|
|
if device == 'cpu':
|
|
self._model = MarianMTModel.from_pretrained(model_name)
|
|
self._model.eval()
|
|
else:
|
|
# For CUDA, try loading with device_map or load then move
|
|
try:
|
|
# Try loading with device_map if supported
|
|
self._model = MarianMTModel.from_pretrained(
|
|
model_name,
|
|
device_map='auto'
|
|
)
|
|
self._model.eval()
|
|
# Update device based on where model actually ended up
|
|
actual_device = next(self._model.parameters()).device.type
|
|
device = actual_device if actual_device in ['cuda', 'cpu'] else 'cpu'
|
|
except (TypeError, ValueError):
|
|
# Fallback: load to CPU first, then move
|
|
self._model = MarianMTModel.from_pretrained(model_name)
|
|
self._model.eval()
|
|
try:
|
|
self._model = self._model.to(device)
|
|
except Exception:
|
|
# If moving fails, keep on CPU
|
|
device = 'cpu'
|
|
except Exception as e:
|
|
# Ultimate fallback: load to CPU
|
|
print(f"Warning: Error loading model to {device}, using CPU: {e}")
|
|
self._model = MarianMTModel.from_pretrained(model_name)
|
|
self._model.eval()
|
|
device = 'cpu'
|
|
|
|
self._model_name = model_name
|
|
self._device = device
|
|
|
|
def translate(self, text: str, target_lang: str, source_lang: str = "en") -> Optional[str]:
|
|
"""Translate using HuggingFace model."""
|
|
if not text:
|
|
return ""
|
|
|
|
try:
|
|
self._load_model(target_lang)
|
|
import torch
|
|
|
|
# Split text into paragraphs first, then sentences
|
|
paragraphs = text.split('\n\n')
|
|
translated_paragraphs = []
|
|
|
|
for para in paragraphs:
|
|
if not para.strip():
|
|
translated_paragraphs.append(para)
|
|
continue
|
|
|
|
# Split into sentences (simple approach)
|
|
sentences = para.split('\n')
|
|
translated_sentences = []
|
|
|
|
for sentence in sentences:
|
|
if not sentence.strip():
|
|
translated_sentences.append(sentence)
|
|
continue
|
|
|
|
try:
|
|
# Tokenize and move to device
|
|
inputs = self._tokenizer(
|
|
[sentence],
|
|
return_tensors="pt",
|
|
padding=True,
|
|
truncation=True,
|
|
max_length=512
|
|
).to(self._device)
|
|
|
|
# Generate translation
|
|
with torch.no_grad():
|
|
translated = self._model.generate(**inputs, max_length=512)
|
|
|
|
translated_text = self._tokenizer.decode(translated[0], skip_special_tokens=True)
|
|
translated_sentences.append(translated_text)
|
|
except Exception as e:
|
|
print(f"Error translating sentence: {e}")
|
|
translated_sentences.append(sentence) # Fallback to original
|
|
|
|
translated_paragraphs.append('\n'.join(translated_sentences))
|
|
|
|
return '\n\n'.join(translated_paragraphs)
|
|
|
|
except Exception as e:
|
|
import traceback
|
|
print(f"Translation error: {e}")
|
|
print(traceback.format_exc())
|
|
return None
|
|
|
|
|
|
class GoogleTranslateBackend(TranslationBackend):
|
|
"""
|
|
Translation backend using Google Translate API.
|
|
|
|
Requires GOOGLE_TRANSLATE_API_KEY environment variable.
|
|
"""
|
|
|
|
def __init__(self):
|
|
self.api_key = os.getenv('GOOGLE_TRANSLATE_API_KEY')
|
|
if not self.api_key:
|
|
raise ValueError("GOOGLE_TRANSLATE_API_KEY environment variable not set")
|
|
|
|
def translate(self, text: str, target_lang: str, source_lang: str = "en") -> Optional[str]:
|
|
"""Translate using Google Translate API."""
|
|
try:
|
|
from googletrans import Translator
|
|
translator = Translator()
|
|
result = translator.translate(text, dest=target_lang, src=source_lang)
|
|
return result.text
|
|
except ImportError:
|
|
raise ImportError("googletrans library not installed. Install with: pip install googletrans==4.0.0rc1")
|
|
except Exception as e:
|
|
print(f"Google Translate error: {e}")
|
|
return None
|
|
|
|
|
|
class DeepLTranslator(TranslationBackend):
|
|
"""
|
|
Translation backend using DeepL API.
|
|
|
|
Requires DEEPL_API_KEY environment variable.
|
|
"""
|
|
|
|
def __init__(self):
|
|
self.api_key = os.getenv('DEEPL_API_KEY')
|
|
if not self.api_key:
|
|
raise ValueError("DEEPL_API_KEY environment variable not set")
|
|
|
|
def translate(self, text: str, target_lang: str, source_lang: str = "en") -> Optional[str]:
|
|
"""Translate using DeepL API."""
|
|
try:
|
|
import deepl
|
|
translator = deepl.Translator(self.api_key)
|
|
result = translator.translate_text(text, target_lang=target_lang.upper(), source_lang=source_lang.upper())
|
|
return result.text
|
|
except ImportError:
|
|
raise ImportError("deepl library not installed. Install with: pip install deepl")
|
|
except Exception as e:
|
|
print(f"DeepL translation error: {e}")
|
|
return None
|
|
|
|
|
|
class TranslationService:
|
|
"""
|
|
Translation service that manages multiple translation backends.
|
|
|
|
Automatically selects the best available backend based on configuration.
|
|
"""
|
|
|
|
def __init__(self, backend: Optional[str] = None):
|
|
"""
|
|
Initialize translation service.
|
|
|
|
Args:
|
|
backend: Backend name ('huggingface', 'google', 'deepl').
|
|
If None, auto-selects based on availability.
|
|
"""
|
|
self.backend_name = backend or self._auto_select_backend()
|
|
self.backend = self._create_backend(self.backend_name)
|
|
|
|
def _auto_select_backend(self) -> str:
|
|
"""Auto-select the best available backend."""
|
|
# Priority: DeepL > Google > HuggingFace
|
|
if os.getenv('DEEPL_API_KEY'):
|
|
return 'deepl'
|
|
elif os.getenv('GOOGLE_TRANSLATE_API_KEY'):
|
|
return 'google'
|
|
else:
|
|
return 'huggingface' # Default to local model
|
|
|
|
def _create_backend(self, backend_name: str) -> TranslationBackend:
|
|
"""Create a translation backend instance."""
|
|
backends = {
|
|
'huggingface': HuggingFaceTranslator,
|
|
'google': GoogleTranslateBackend,
|
|
'deepl': DeepLTranslator,
|
|
}
|
|
|
|
backend_class = backends.get(backend_name.lower())
|
|
if not backend_class:
|
|
raise ValueError(f"Unknown backend: {backend_name}")
|
|
|
|
try:
|
|
return backend_class()
|
|
except Exception as e:
|
|
# Fallback to HuggingFace if other backends fail
|
|
if backend_name != 'huggingface':
|
|
print(f"Failed to initialize {backend_name}, falling back to HuggingFace: {e}")
|
|
return HuggingFaceTranslator()
|
|
raise
|
|
|
|
def translate(self, text: str, target_lang: str, source_lang: str = "en") -> Optional[str]:
|
|
"""Translate text using the configured backend."""
|
|
if not text:
|
|
return ""
|
|
return self.backend.translate(text, target_lang, source_lang)
|
|
|