Source code for wefe.metrics.base_metric

from abc import ABC, abstractmethod
import numpy as np
from typing import Any, Callable, Dict, List, Union, Tuple
from ..query import Query
from ..word_embedding_model import WordEmbeddingModel


[docs]class BaseMetric(ABC): """ A base class to implement any metric following the framework described by WEFE. It contains the name of the metric, the templates (cardinalities) that it supports and the abstract function run_query, which must be implemented by any metric that extends this class. """ # A tuple that indicates the cardinality of target and attribute sets metric_template: Tuple[Union[int, str], Union[int, str]] # The name of the metric metric_name: str # The initials or short name of the metric metric_short_name: str def _check_input( self, query: Query, word_embedding: WordEmbeddingModel, ) -> None: """Check if the Query and WordEmbeddingModel parameters are valid when executing run_query. Parameters ---------- query : Query The query that the method will execute. word_embedding : A word embedding model. Raises ------ TypeError if query is not instance of Query. TypeError if word_embedding is not instance of . TypeError if lost_vocabulary_threshold is not a float number. TypeError if warn_filtered_words is not a bool. Exception if the metric require different number of target sets than the delivered query Exception if the metric require different number of attribute sets than the delivered query """ # check if the query passed is a instance of Query if not isinstance(query, Query): raise TypeError('query should be a Query instance, got {}'.format(query)) # check if the word_embedding is a instance of if not isinstance(word_embedding, WordEmbeddingModel): raise TypeError('word_embedding should be a WordEmbeddingModel instance, ' 'got: {}'.format(word_embedding)) # templates: # check the cardinality of the target sets of the provided query if self.metric_template[0] != 'n' and query.template[0] != self.metric_template[ 0]: raise Exception('The cardinality of the set of target words of the \'{}\' ' 'query does not match with the cardinality required by {}. ' 'Provided query: {}, metric: {}.'.format( query.query_name, self.metric_name, query.template[0], self.metric_template[0])) # check the cardinality of the attribute sets of the provided query if self.metric_template[1] != 'n' and query.template[1] != self.metric_template[ 1]: raise Exception('The cardinality of the set of attribute words of the ' '\'{}\' query does not match with the cardinality ' 'required by {}. Provided query: {}, metric: {}.'.format( query.query_name, self.metric_name, query.template[1], self.metric_template[1]))
[docs] @abstractmethod def run_query(self, query: Query, word_embedding: WordEmbeddingModel, lost_vocabulary_threshold: float = 0.2, preprocessor_options: Dict[str, Union[bool, str, Callable, None]] = { 'strip_accents': False, 'lowercase': False, 'preprocessor': None, }, secondary_preprocessor_options: Union[Dict[str, Union[bool, str, Callable, None]], None] = None, warn_not_found_words: bool = False, *args: Any, **kwargs: Any) -> Dict[str, Any]: self._check_input(query=query, word_embedding=word_embedding) return {}