Spaces:
Running
Running
| # Copyright 2020 Mobvoi Inc. (authors: Fangjun Kuang) | |
| # | |
| # See ../../../LICENSE for clarification regarding multiple authors | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from dataclasses import dataclass | |
| from dataclasses import field | |
| from typing import Dict | |
| from typing import Generic | |
| from typing import List | |
| from typing import Optional | |
| from typing import TypeVar | |
| from typing import Union | |
| Symbol = TypeVar('Symbol') | |
| # SymbolTable is copied from | |
| # https://github.com/k2-fsa/k2/blob/master/k2/python/k2/symbol_table.py | |
| ''' | |
| SymbolTable: map symbol to id | |
| ''' | |
| class SymbolTable(Generic[Symbol]): | |
| '''SymbolTable that maps symbol IDs, found on the FSA arcs to | |
| actual objects. These objects can be arbitrary Python objects | |
| that can serve as keys in a dictionary (i.e. they need to be | |
| hashable and immutable). | |
| The SymbolTable can only be read to/written from disk if the | |
| symbols are strings. | |
| ''' | |
| _id2sym: Dict[int, Symbol] = field(default_factory=dict) | |
| '''Map an integer to a symbol. | |
| ''' | |
| _sym2id: Dict[Symbol, int] = field(default_factory=dict) | |
| '''Map a symbol to an integer. | |
| ''' | |
| _next_available_id: int = 1 | |
| '''A helper internal field that helps adding new symbols | |
| to the table efficiently. | |
| ''' | |
| eps: Symbol = '<eps>' | |
| '''Null symbol, always mapped to index 0. | |
| ''' | |
| def __post_init__(self): | |
| assert all(self._sym2id[sym] == idx for idx, sym in self._id2sym.items()) | |
| assert all(self._id2sym[idx] == sym for sym, idx in self._sym2id.items()) | |
| assert 0 not in self._id2sym or self._id2sym[0] == self.eps | |
| self._next_available_id = max(self._id2sym, default=0) + 1 | |
| self._id2sym.setdefault(0, self.eps) | |
| self._sym2id.setdefault(self.eps, 0) | |
| def from_str(s: str) -> 'SymbolTable': | |
| '''Build a symbol table from a string. | |
| The string consists of lines. Every line has two fields separated | |
| by space(s), tab(s) or both. The first field is the symbol and the | |
| second the integer id of the symbol. | |
| Args: | |
| s: | |
| The input string with the format described above. | |
| Returns: | |
| An instance of :class:`SymbolTable`. | |
| ''' | |
| id2sym: Dict[int, str] = dict() | |
| sym2id: Dict[str, int] = dict() | |
| for line in s.split('\n'): | |
| fields = line.split() | |
| if len(fields) == 0: | |
| continue # skip empty lines | |
| assert len(fields) == 2, \ | |
| f'Expect a line with 2 fields. Given: {len(fields)}' | |
| sym, idx = fields[0], int(fields[1]) | |
| assert sym not in sym2id, f'Duplicated symbol {sym}' | |
| assert idx not in id2sym, f'Duplicated id {idx}' | |
| id2sym[idx] = sym | |
| sym2id[sym] = idx | |
| eps = id2sym.get(0, '<eps>') | |
| return SymbolTable(_id2sym=id2sym, _sym2id=sym2id, eps=eps) | |
| def from_file(filename: str) -> 'SymbolTable': | |
| '''Build a symbol table from file. | |
| Every line in the symbol table file has two fields separated by | |
| space(s), tab(s) or both. The following is an example file: | |
| .. code-block:: | |
| <eps> 0 | |
| a 1 | |
| b 2 | |
| c 3 | |
| Args: | |
| filename: | |
| Name of the symbol table file. Its format is documented above. | |
| Returns: | |
| An instance of :class:`SymbolTable`. | |
| ''' | |
| with open(filename, 'r', encoding='utf-8') as f: | |
| return SymbolTable.from_str(f.read().strip()) | |
| def to_str(self) -> str: | |
| ''' | |
| Returns: | |
| Return a string representation of this object. You can pass | |
| it to the method ``from_str`` to recreate an identical object. | |
| ''' | |
| s = '' | |
| for idx, symbol in sorted(self._id2sym.items()): | |
| s += f'{symbol} {idx}\n' | |
| return s | |
| def to_file(self, filename: str): | |
| '''Serialize the SymbolTable to a file. | |
| Every line in the symbol table file has two fields separated by | |
| space(s), tab(s) or both. The following is an example file: | |
| .. code-block:: | |
| <eps> 0 | |
| a 1 | |
| b 2 | |
| c 3 | |
| Args: | |
| filename: | |
| Name of the symbol table file. Its format is documented above. | |
| ''' | |
| with open(filename, 'w') as f: | |
| for idx, symbol in sorted(self._id2sym.items()): | |
| print(symbol, idx, file=f) | |
| def add(self, symbol: Symbol, index: Optional[int] = None) -> int: | |
| '''Add a new symbol to the SymbolTable. | |
| Args: | |
| symbol: | |
| The symbol to be added. | |
| index: | |
| Optional int id to which the symbol should be assigned. | |
| If it is not available, a ValueError will be raised. | |
| Returns: | |
| The int id to which the symbol has been assigned. | |
| ''' | |
| # Already in the table? Return its ID. | |
| if symbol in self._sym2id: | |
| return self._sym2id[symbol] | |
| # Specific ID not provided - use next available. | |
| if index is None: | |
| index = self._next_available_id | |
| # Specific ID provided but not available. | |
| if index in self._id2sym: | |
| raise ValueError(f"Cannot assign id '{index}' to '{symbol}' - " | |
| f"already occupied by {self._id2sym[index]}") | |
| self._sym2id[symbol] = index | |
| self._id2sym[index] = symbol | |
| # Update next available ID if needed | |
| if self._next_available_id <= index: | |
| self._next_available_id = index + 1 | |
| return index | |
| def get(self, k: Union[int, Symbol]) -> Union[Symbol, int]: | |
| '''Get a symbol for an id or get an id for a symbol | |
| Args: | |
| k: | |
| If it is an id, it tries to find the symbol corresponding | |
| to the id; if it is a symbol, it tries to find the id | |
| corresponding to the symbol. | |
| Returns: | |
| An id or a symbol depending on the given `k`. | |
| ''' | |
| if isinstance(k, int): | |
| return self._id2sym[k] | |
| else: | |
| return self._sym2id[k] | |
| def merge(self, other: 'SymbolTable') -> 'SymbolTable': | |
| '''Create a union of two SymbolTables. | |
| Raises an AssertionError if the same IDs are occupied by | |
| different symbols. | |
| Args: | |
| other: | |
| A symbol table to merge with ``self``. | |
| Returns: | |
| A new symbol table. | |
| ''' | |
| self._check_compatible(other) | |
| return SymbolTable( | |
| _id2sym={**self._id2sym, **other._id2sym}, | |
| _sym2id={**self._sym2id, **other._sym2id}, | |
| eps=self.eps | |
| ) | |
| def _check_compatible(self, other: 'SymbolTable') -> None: | |
| # Epsilon compatibility | |
| assert self.eps == other.eps, f'Mismatched epsilon symbol: ' \ | |
| f'{self.eps} != {other.eps}' | |
| # IDs compatibility | |
| common_ids = set(self._id2sym).intersection(other._id2sym) | |
| for idx in common_ids: | |
| assert self[idx] == other[idx], f'ID conflict for id: {idx}, ' \ | |
| f'self[idx] = "{self[idx]}", ' \ | |
| f'other[idx] = "{other[idx]}"' | |
| # Symbols compatibility | |
| common_symbols = set(self._sym2id).intersection(other._sym2id) | |
| for sym in common_symbols: | |
| assert self[sym] == other[sym], f'ID conflict for id: {sym}, ' \ | |
| f'self[sym] = "{self[sym]}", ' \ | |
| f'other[sym] = "{other[sym]}"' | |
| def __getitem__(self, item: Union[int, Symbol]) -> Union[Symbol, int]: | |
| return self.get(item) | |
| def __contains__(self, item: Union[int, Symbol]) -> bool: | |
| if isinstance(item, int): | |
| return item in self._id2sym | |
| else: | |
| return item in self._sym2id | |
| def __len__(self) -> int: | |
| return len(self._id2sym) | |
| def __eq__(self, other: 'SymbolTable') -> bool: | |
| if len(self) != len(other): | |
| return False | |
| for s in self.symbols: | |
| if self[s] != other[s]: | |
| return False | |
| return True | |
| def ids(self) -> List[int]: | |
| '''Returns a list of integer IDs corresponding to the symbols. | |
| ''' | |
| ans = list(self._id2sym.keys()) | |
| ans.sort() | |
| return ans | |
| def symbols(self) -> List[Symbol]: | |
| '''Returns a list of symbols (e.g., strings) corresponding to | |
| the integer IDs. | |
| ''' | |
| ans = list(self._sym2id.keys()) | |
| ans.sort() | |
| return ans | |