Source code for wefe.word_embedding_model

"""A Word Embedding contanier based on gensim BaseKeyedVectors."""

from collections.abc import Sequence
from typing import Any

import gensim
import numpy as np
from numpy.typing import NDArray
import semantic_version

GENSIM_VERSION = semantic_version.Version.coerce(gensim.__version__)
GENSIM_V4_OR_GREATER = GENSIM_VERSION.major >= 4  # type: ignore

if GENSIM_V4_OR_GREATER:
    from gensim.models import KeyedVectors as BaseKeyedVectors
else:
    # In older versions, BaseKeyedVectors is in a different location.
    from gensim.models.keyedvectors import BaseKeyedVectors  # type: ignore

# --- Type Aliases ---
# Using NDArray for better type hinting with NumPy arrays.
EmbeddingDict = dict[str, NDArray[np.float64]]
EmbeddingSets = dict[str, EmbeddingDict]


[docs] class WordEmbeddingModel: """A wrapper for Word Embedding pre-trained models. It can hold gensim's KeyedVectors or gensim's api loaded models. It includes the name of the model and some vocab prefix if needed. """
[docs] def __init__( self, wv: BaseKeyedVectors, name: str | None = None, vocab_prefix: str | None = None, ) -> None: """Initialize the word embedding model. Parameters ---------- wv : BaseKeyedVectors. An instance of word embedding loaded through gensim KeyedVector interface or gensim's api. name : str, optional The name of the model, by default ''. vocab_prefix : str, optional. A prefix that will be concatenated with all word in the model vocab, by default None. Raises ------ TypeError if word_embedding is not a KeyedVectors instance. TypeError if model_name is not None and not an instance of str. TypeError if vocab_prefix is not None and not an instance of str. Examples -------- >>> from gensim.test.utils import common_texts >>> from gensim.models import Word2Vec >>> from wefe.word_embedding_model import WordEmbeddingModel >>> dummy_model = Word2Vec(common_texts, window=5, ... min_count=1, workers=1).wv >>> model = WordEmbeddingModel(dummy_model, 'Dummy model dim=10', ... vocab_prefix='/en/') >>> print(model.name) Dummy model dim=10 >>> print(model.vocab_prefix) /en/ """ # Type checking if not isinstance(wv, BaseKeyedVectors): raise TypeError( f"wv must be an instance of gensim's BaseKeyedVectors, " f"but got {type(wv)}." ) if name is not None and not isinstance(name, str): raise TypeError(f"name must be a string or None, but got {type(name)}.") if vocab_prefix is not None and not isinstance(vocab_prefix, str): raise TypeError( f"vocab_prefix must be a string or None, but got {type(vocab_prefix)}." ) # Assign the attributes self.wv = wv if GENSIM_V4_OR_GREATER: self.vocab = self.wv.key_to_index else: self.vocab = self.wv.vocab self.vocab_prefix = vocab_prefix if name is None: self.name = "Unnamed model" else: self.name = name
def __eq__(self, other: Any) -> bool: """Check if other is the same WordEmbeddingModel that self. Parameters ---------- other : Any Some object Returns ------- bool True if other is a WordEmbeddingModel that have the same model, model_name and vocab_prefix . False in any other case """ if not isinstance(other, WordEmbeddingModel): return False if self.wv != other.wv: return False if self.name != other.name: return False return self.vocab_prefix == other.vocab_prefix def __len__(self) -> int: """Return the number of words in the vocabulary.""" return len(self.vocab) def __getitem__(self, key: str) -> NDArray[np.float64]: """Retrieve the embedding for a word. Parameters ---------- key : str A word Returns ------- np.ndarray The embedding associated with the word. Raises ------ KeyError If the word is not in the vocabulary. """ if not isinstance(key, str): raise TypeError(f"key must be a string, but got {type(key)}.") if key not in self.vocab: raise KeyError(f"word '{key}' not in model vocab.") return self.wv[key] def __contains__(self, key: str) -> bool: """Check if a word is in the model's vocabulary. Parameters ---------- key: str Some word. Returns ------- bool True if the word exists in the model's vocabulary. """ return key in self.vocab def __repr__(self) -> str: """Generate a string representation of the WordEmbeddingModel. Format: <WordEmbeddingModel named {name} with {n_embeddings} of {dims} dims> Returns ------- str The generated representation. """ try: if self.name == "Unnamed model" and self.vocab_prefix is not None: return ( "<WordEmbeddingModel 'Unnamed model' " f"with {self.wv.vectors.shape[0]}" f" word embeddings of {self.wv.vectors.shape[1]} dims" f" and '{self.vocab_prefix}' as word prefix>" ) if self.name == "Unnamed model": return ( "<WordEmbeddingModel 'Unnamed model' " f"with {self.wv.vectors.shape[0]}" f" word embeddings of {self.wv.vectors.shape[1]} dims>" ) if self.vocab_prefix is not None: return ( f"<WordEmbeddingModel named '{self.name}' with " f"{self.wv.vectors.shape[0]}" f" word embeddings of {self.wv.vectors.shape[1]} dims" f" and '{self.vocab_prefix}' as word prefix>" ) return ( f"<WordEmbeddingModel named '{self.name}' " f"with {self.wv.vectors.shape[0]}" f" word embeddings of {self.wv.vectors.shape[1]} dims>" ) except AttributeError: # it can happen if some of the attributes (name or vocab_prefix) are not # defined. return "<WordEmbeddingModel with wrong __repr__>"
[docs] def get(self, word: str, default: Any | None = None) -> NDArray[np.float64] | None: """Retrieve a word's embedding, returning a default value if not found.""" return self.wv[word] if word in self else default
[docs] def normalize(self) -> None: """Normalize the word vectors to unit L2 length. This method uses the underlying gensim functionality to perform L2 normalization. The model's vectors are modified in-place. **Warning**: This is a destructive operation. Raises ------ AttributeError If the underlying model does not support normalization. """ # Gensim 4+ has a more direct way to get normalized vectors. # To maintain the "inplace" behavior, we re-assign the vectors. if hasattr(self.wv, "get_normed_vectors"): self.wv.vectors = self.wv.get_normed_vectors() # Ensure the norms are also updated for similarity calculations if GENSIM_V4_OR_GREATER: self.wv.fill_norms(force=True) elif hasattr(self.wv, "init_sims"): self.wv.init_sims(replace=True) else: raise AttributeError( "The underlying gensim model does not have a " "known normalization method ('get_normed_vectors' or 'init_sims')." )
[docs] def update(self, word: str, embedding: NDArray[np.float64]) -> None: """Update the value of an embedding of the model. If the method is executed with a word that is not in the vocabulary, an exception will be raised. Parameters ---------- word : str The word to update. It must already exist in the vocabulary. embedding : NDArray[np.float64] The new embedding for the word. Must match the model's vector size and dtype. Raises ------ TypeError if word is not a1 string. TypeError if embedding is not an np.array. ValueError if word is not in the model's vocabulary. ValueError if the embedding is not the same size as the size of the model's embeddings. ValueError if the dtype of the embedding values is not the same as the model's embeddings. """ if not isinstance(word, str): raise TypeError(f"Word must be a string, but got {type(word)}.") if word not in self: raise ValueError(f"Word '{word}' is not in the model's vocabulary.") if not isinstance(embedding, np.ndarray): raise TypeError( f"Embedding must be a NumPy array, but got {type(embedding)}." ) embedding_size = embedding.shape[0] if self.wv.vector_size != embedding_size: raise ValueError( f"The size of '{word}' embedding ({embedding_size}) is different from " f"the size of the embeddings in the model ({self.wv.vector_size})." ) if not np.issubdtype(self.wv.vectors.dtype, embedding.dtype): raise ValueError( f"embedding dtype ({embedding.dtype}) is not the same of model's dtype " f"({self.wv.vectors.dtype})." ) if GENSIM_V4_OR_GREATER: idx = self.wv.key_to_index[word] else: idx = self.wv.vocab[word].index self.wv.vectors[idx] = embedding.astype(self.wv.vectors.dtype)
[docs] def batch_update( self, words: Sequence[str], embeddings: Sequence[np.ndarray] | np.ndarray, ) -> None: """Update a batch of embeddings in the model. This method updates the embeddings for a given sequence of words efficiently by leveraging NumPy's advanced indexing. All validations (word existence, embedding shape, and data type) are performed collectively before any modifications are applied to the model. This ensures atomicity: either all updates succeed, or none do. Parameters ---------- words : Sequence[str] A sequence (list, tuple, or np.ndarray) containing the words whose representations will be updated. All words must already exist in the model's vocabulary and must be strings. embeddings : Union[Sequence[np.ndarray], np.ndarray] A sequence (list or tuple) of NumPy arrays, or a 2D NumPy array, that contains all the new embeddings. Each embedding must be a 1D NumPy array with the same size and data type as the model's embeddings. The length of `embeddings` must match the length of `words`. Raises ------ TypeError If `words` is not a sequence of strings, or if `embeddings` is not a sequence of NumPy arrays or a single NumPy array. Also, if individual elements within `words` are not strings, or elements within `embeddings` are not NumPy arrays. ValueError If `words` and `embeddings` do not have the same number of elements. If any word in `words` is not found in the model's vocabulary. If any embedding has a different dimension than the model's embeddings. If any embedding has a data type incompatible with the model's embeddings. Examples -------- >>> from gensim.test.utils import common_texts >>> from gensim.models import Word2Vec >>> from wefe.word_embedding_model import WordEmbeddingModel >>> import numpy as np >>> # Create a dummy WordEmbeddingModel >>> kv_model = Word2Vec(common_texts, vector_size=10, min_count=1).wv >>> model = WordEmbeddingModel(kv_model, 'Dummy Model') >>> original_embedding_the = model['the'] >>> original_embedding_system = model['system'] >>> # Prepare words and new embeddings >>> words_to_update = ['the', 'system'] >>> new_embeddings = [ ... np.zeros(10, dtype=model.wv.vectors.dtype), ... np.ones(10, dtype=model.wv.vectors.dtype) ... ] >>> # Update embeddings >>> model.batch_update(words_to_update, new_embeddings) >>> # Verify updates >>> assert np.all(model['the'] == np.zeros(10)) >>> assert np.all(model['system'] == np.ones(10)) >>> # Test with missing word (will raise error) >>> try: ... model.batch_update(['nonexistent_word'], [np.zeros(10)]) ... except ValueError as e: ... print(e) The following words are not in the model's vocabulary: nonexistent_word. """ # Initial type and length validation for the input containers if not isinstance(words, list | tuple | np.ndarray): raise TypeError( f"words argument should be a list, tuple or np.array of strings, " f"but got {type(words)}." ) if not isinstance(embeddings, list | tuple | np.ndarray): raise TypeError( "embeddings argument should be a list, tuple or np.array of " f"NumPy arrays, but got {type(embeddings)}." ) if len(words) != len(embeddings): raise ValueError( "words and embeddings must have the same number of elements, " f"but got {len(words)} words and {len(embeddings)} embeddings." ) # 1. Validate 'words' elements and collect their indices missing_words = [] word_indices = [] for word in words: if not isinstance(word, str): raise TypeError( f"All elements in 'words' must be strings, but found a " f"{type(word)}." ) if word not in self.vocab: missing_words.append(word) else: # Get the index based on gensim version if GENSIM_V4_OR_GREATER: word_indices.append(self.wv.key_to_index[word]) else: word_indices.append(self.wv.vocab[word].index) if missing_words: raise ValueError( f"The following words are not in the model's vocabulary: " f"{', '.join(missing_words)}." ) # Convert collected indices to a NumPy array for advanced indexing np_word_indices = np.array(word_indices, dtype=int) # Define expected properties for the embeddings based on the model expected_vector_size = self.wv.vector_size model_dtype = self.wv.vectors.dtype # 2. Validate and prepare 'embeddings' for batch update embeddings_to_update: np.ndarray # If 'embeddings' is already a 2D NumPy array, perform checks directly on it. if isinstance(embeddings, np.ndarray): if embeddings.ndim != 2 or embeddings.shape[1] != expected_vector_size: raise ValueError( f"Input embeddings array has shape {embeddings.shape}, " f"but expected a 2D array with {expected_vector_size} columns " f"(model's vector size {expected_vector_size})." ) if not np.can_cast( embeddings.dtype, model_dtype ): # Check if source dtype can be cast to target dtype raise ValueError( f"Input embeddings array dtype ({embeddings.dtype}) cannot be " f"safely cast to model's dtype ({model_dtype})." ) # Ensure correct dtype for assignment; copy=False avoids copy if # already correct embeddings_to_update = embeddings.astype(model_dtype, copy=False) else: # Handle Sequence of np.ndarray temp_embeddings_list = [] for i, emb in enumerate(embeddings): # Ensure each element is a NumPy array if not isinstance(emb, np.ndarray): raise TypeError( f"Embedding at index {i} ('{words[i]}') is not a NumPy array, " f"but got {type(emb)}." ) # Ensure each embedding has the correct dimension (1D) and size if emb.ndim != 1 or emb.shape[0] != expected_vector_size: raise ValueError( f"Embedding at index {i} ('{words[i]}') has shape {emb.shape} " f"which is different from the model's embedding size " f"({expected_vector_size},)." ) # Ensure data type compatibility if not np.can_cast( emb.dtype, model_dtype ): # Check if source dtype can be cast to target dtype raise ValueError( f"Embedding at index {i} ('{words[i]}') with dtype " f"({emb.dtype}) " f"cannot be safely cast to model's dtype ({model_dtype})." ) temp_embeddings_list.append(emb) # Collect validated embeddings # Convert the list of validated embeddings to a single 2D NumPy array embeddings_to_update = np.array(temp_embeddings_list, dtype=model_dtype) # 3. Perform the batch update using advanced indexing self.wv.vectors[np_word_indices] = embeddings_to_update