mallet-topic-inferencer / lib /mallet_topic_inferencer.py
Simon Clematide
Initial commit with models, scripts, and JAR files
fc83ec7
#!/usr/bin/python3
"""
DOCUMENTATION: This script performs vectorization and topic inference using Mallet models. It accepts a raw JSONL file,
identifies the language of the text, and applies the corresponding Mallet model for topic inference. It also supports
other input formats through a flexible InputReader abstraction (e.g., CSV, JSONL).
The benefit of this script with respect to the Mallet CLI is that it can handle
multiple languages in a single run without calling Mallet multiple times.
Classes:
- MalletVectorizer: Handles text-to-Mallet vectorization.
- LanguageInferencer: Performs topic inference using a Mallet inferencer and
the vectorizer.
- InputReader (abstract class): Defines the interface for reading input
documents.
- JsonlInputReader: Reads input from JSONL files.
- CsvInputReader: Reads input from CSV files (Mallet format).
- MalletTopicInferencer: Coordinates the process, identifies language, and manages
inference.
USAGE:
python mallet_topic_inferencer.py --input input.jsonl --output output.txt
--logfile logfile.log --input-format jsonl --level INFO --num_iterations 1000
--languages de,en --de_inferencer models/de.inferencer --de_pipe models/de.pipe
"""
import collections
import traceback
import jpype
import jpype.imports
import spacy
# from jpype.types import JString
import os
import logging
import argparse
import json
import csv
import tempfile
from typing import List, Dict, Generator, Tuple, Optional, Set
from abc import ABC, abstractmethod
import mallet2topic_assignment_jsonl as m2taj
from smart_open import open
log = logging.getLogger(__name__)
def save_text_as_csv(text: str) -> str:
"""
Save the given text as a temporary CSV file with an arbitrary ID and return the file name.
Args:
text (str): The text to be saved in the CSV file.
Returns:
str: The name of the temporary CSV file.
"""
# Create a temporary file with .csv suffix
temp_csv_file = tempfile.NamedTemporaryFile(
delete=False, mode="w", suffix=".csv", newline="", encoding="utf-8"
)
# Write the text to the CSV file with an arbitrary ID
csv_writer = csv.writer(temp_csv_file, delimiter="\t")
csv_writer.writerow(["ID", "DUMMYCLASS", "TEXT"]) # Header
csv_writer.writerow(["USERINPUT-2024-10-24-a-i0042", "dummy_class", text])
# Close the file to ensure all data is written
temp_csv_file.close()
return temp_csv_file.name
class Lemmatizer:
def __init__(
self,
languages_dict: Dict[str, str],
lang_lemmatization_dict: Dict[str, Dict[str, str]],
):
"""
Initializes the linguistic lemmatizer with specified languages and lemmatization dictionary.
Args:
languages (List[str]): List of language codes to load processing pipelines for.
lemmatization_dict (Dict[str, str]): Dictionary mapping tokens to their lemmas.
"""
self.languages_dict = languages_dict
self.lemmatization_dict = lang_lemmatization_dict
self.language_processors = self._load_language_processors(languages_dict)
def _load_language_processors(
self, languages_dict
) -> Dict[str, spacy.language.Language]:
"""
Loads spacy language processors for the specified languages.
Returns:
Dict[str, spacy.language.Language]: Dictionary mapping language codes to spacy NLP pipelines.
"""
processors = {}
for lang in languages_dict:
processors[lang] = spacy.load(
languages_dict[lang], disable=["parser", "ner"]
)
processors[lang].add_pipe("sentencizer")
return processors
def analyze_text(self, text: str, lang: str) -> List[str]:
"""
Analyzes text, performing tokenization, POS tagging, and lemma mapping.
Args:
text (str): Text to process.
lang (str): Language code for the text.
Returns:
List[str]: List of tokens that have matching entries in the lemmatization dictionary.
"""
if lang not in self.language_processors:
raise ValueError(f"No processing pipeline for language '{lang}'")
nlp = self.language_processors[lang]
doc = nlp(text)
token2lemma = self.lemmatization_dict[lang]
matched_tokens = [
lemma for tok in doc if (lemma := token2lemma.get(tok.text.lower()))
]
return matched_tokens
# ==================== Vectorization ====================
class MalletVectorizer:
"""
Handles the vectorization of multiple documents into a format usable by Mallet using the pipe file from the model.
"""
def __init__(self, language: str, pipe_file: str) -> None:
# noinspection PyUnresolvedReferences
from cc.mallet.classify.tui import Csv2Vectors # type: ignore # Import after JVM is started
self.vectorizer = Csv2Vectors()
self.pipe_file = pipe_file
self.language = language
def run_csv2vectors(
self,
input_file: str,
output_file: Optional[str] = None,
delete_input_file_after: bool = True,
) -> str:
"""
Run Csv2Vectors to vectorize the input file.
Simple java-internal command line interface to the Csv2Vectors class in Mallet.
Args:
input_file: Path to the csv input file to be vectorized.
output_file: Path where the output .mallet file should be saved.
"""
if not output_file:
output_file = input_file + ".mallet"
# Arguments for Csv2Vectors java main class
arguments = [
"--input",
input_file,
"--output",
output_file,
"--keep-sequence", # Keep sequence for feature extraction
"--encoding",
"UTF-8",
"--use-pipe-from",
self.pipe_file,
]
logging.info("Calling mallet Csv2Vector: %s", arguments)
self.vectorizer.main(arguments)
logging.debug("Csv2Vector call finished.")
if log.getEffectiveLevel() != logging.DEBUG and delete_input_file_after:
os.remove(input_file)
logging.info("Cleaning up input file: %s", input_file)
return output_file
class LanguageInferencer:
"""
A class to manage Mallet inferencing for a specific language.
Loads the inferencer and pipe file during initialization.
"""
def __init__(self, language: str, inferencer_file: str, pipe_file: str) -> None:
# noinspection PyUnresolvedReferences
from cc.mallet.topics.tui import InferTopics # type: ignore # Import after JVM is started
self.language = language
self.inferencer_file = inferencer_file
self.inferencer = InferTopics()
self.pipe_file = pipe_file
self.vectorizer = MalletVectorizer(language=language, pipe_file=self.pipe_file)
if not os.path.exists(self.inferencer_file):
raise FileNotFoundError(
f"Inferencer file not found: {self.inferencer_file}"
)
def run_csv2topics(
self, csv_file: str, delete_mallet_file_after: bool = True
) -> Dict[str, str]:
"""
Perform topic inference on a single input file.
The input file should be in the format expected by Mallet.
Returns a dictionary of document_id -> topic distributions.
"""
# Vectorize the input file and write to a temporary file
mallet_file = self.vectorizer.run_csv2vectors(csv_file)
topics_file = mallet_file + ".doctopics"
arguments = [
"--input",
mallet_file,
"--inferencer",
self.inferencer_file,
"--output-doc-topics",
topics_file,
"--random-seed",
"42",
]
logging.info("Calling mallet InferTopics: %s", arguments)
self.inferencer.main(arguments)
logging.debug("InferTopics call finished.")
if log.getEffectiveLevel() != logging.DEBUG and delete_mallet_file_after:
os.remove(mallet_file)
logging.info("Cleaning up input file: %s", mallet_file)
return topics_file
# ==================== Input Reader Abstraction ====================
class InputReader(ABC):
"""
Abstract base class for input readers.
Subclasses should implement the `read_documents` method to yield documents.
"""
@abstractmethod
def read_documents(self) -> Generator[Tuple[str, str], None, None]:
"""
Yields a tuple of (document_id, text).
Each implementation should handle its specific input format.
"""
pass
class JsonlInputReader(InputReader):
"""
Reads input from a JSONL file, where each line contains a JSON object
with at least "id" and "text" fields.
"""
def __init__(self, input_file: str) -> None:
self.input_file = input_file
def read_documents(self) -> Generator[Tuple[str, str], None, None]:
with open(self.input_file, "r", encoding="utf-8") as f:
for line in f:
data = json.loads(line)
document_id = data.get("id", "unknown_id")
text = data.get("text", "")
yield document_id, text
class CsvInputReader(InputReader):
"""
Reads input from a CSV file in Mallet's format (document ID, dummy class, text).
Assumes that the CSV has three columns: "id", "dummyclass", and "text".
"""
def __init__(self, input_file: str) -> None:
self.input_file = input_file
def read_documents(self) -> Generator[Tuple[str, str], None, None]:
with open(self.input_file, mode="r", encoding="utf-8") as f:
csv_reader = csv.reader(f, delimiter="\t")
for row in csv_reader:
if len(row) < 3:
continue
document_id, text = row[0], row[2]
yield document_id, text.lower()
# ==================== Main Application ====================
class MalletTopicInferencer:
"""
MalletTopicInferencer class coordinates the process of reading input documents, identifying their language, and performing topic inference using Mallet models.
"""
def __init__(self, args: argparse.Namespace) -> None:
self.args = args
self.languages = set(args.languages)
self.language_inferencers: Optional[Dict[str, LanguageInferencer]] = None
self.language_lemmatizations: Optional[Dict[str, Dict[str, str]]] = None
self.input_reader = None
self.inference_results: List[Dict[str, str]] = []
self.language_dict: Dict[str, str] = {}
self.seen_languages: Set[str] = set()
self.stats = collections.Counter()
def initialize(self) -> None:
"""Initialize the inferencers after JVM startup."""
self.language_inferencers = self.init_language_inferencers(self.args)
self.input_reader = self.build_input_reader(self.args)
self.language_lemmatizations = self.init_language_lemmatizations(self.args)
if self.args.language_file:
self.language_dict = self.read_language_file(self.args.language_file)
@staticmethod
def start_jvm() -> None:
"""Start the Java Virtual Machine if not already started."""
if not jpype.isJVMStarted():
current_dir = os.getcwd()
source_dir = os.path.dirname(os.path.abspath(__file__))
# Construct classpath relative to the current directory
classpath = [
os.path.join(current_dir, "mallet/lib/mallet-deps.jar"),
os.path.join(current_dir, "mallet/lib/mallet.jar"),
]
# Check if the files exist in the current directory
if not all(os.path.exists(path) for path in classpath):
# If not, construct classpath relative to the source directory
classpath = [
os.path.join(source_dir, "mallet/lib/mallet-deps.jar"),
os.path.join(source_dir, "mallet/lib/mallet.jar"),
]
jpype.startJVM(classpath=classpath)
log.info(f"JVM started successfully with classpath {classpath}.")
else:
log.warning("JVM already running.")
def read_language_file(self, language_file: str) -> Dict[str, str]:
"""Read the language file (JSONL) and return a dictionary of document_id -> language."""
language_dict = {}
with open(language_file, "r", encoding="utf-8") as f:
for line in f:
data = json.loads(line)
doc_id = data.get("doc_id")
language = data.get("language")
if doc_id and language:
language_dict[doc_id] = language
return language_dict
@staticmethod
def load_lemmatization_file(
lemmatization_file_path: str,
bidi: bool = False,
lowercase: bool = True,
ignore_pos: bool = True,
) -> Dict[str, str]:
"""
Load lemmatization data from the file.
:param lemmatization_file_path: Path to the lemmatization file.
:return: A dictionary mapping tokens to their corresponding lemmas.
"""
token2lemma = {}
n = 0
with open(lemmatization_file_path, "r", "utf-8") as file:
for line in file:
token, _, lemma = line.strip().split("\t")
if lowercase:
token2lemma[token.lower()] = lemma.lower()
else:
token2lemma[token] = lemma
n += 1
logging.info(
"Read %d lemmatization entries from %s", n, lemmatization_file_path
)
return token2lemma
def init_language_lemmatizations(
self, args: argparse.Namespace
) -> Dict[str, Dict[str, str]]:
"""Build a mapping of languages to their respective lemmatization dictionaries."""
language_lemmatizations: Dict[str, Dict[str, str]] = {}
for language in args.languages:
lemmatization_key = f"{language}_lemmatization"
if getattr(args, lemmatization_key, None):
lemmatization_file = getattr(args, lemmatization_key)
language_lemmatizations[language] = self.load_lemmatization_file(
lemmatization_file
)
else:
log.info(
f"Lemmatization file for language: {language} not provided by"
" arguments. Skipping."
)
return language_lemmatizations
def identify_language(self, document_id: str, text: str) -> str:
"""Identify the language of the text using the language file or a dummy method."""
# Check if the document ID is in the language dictionary
if document_id in self.language_dict:
return self.language_dict[document_id]
# Placeholder: Assume German ("de") for now if not found in the dictionary
return "de"
def init_language_inferencers(
self, args: argparse.Namespace
) -> Dict[str, LanguageInferencer]:
"""Build a mapping of languages to their respective inferencers
Includes the vectorizer pipe for each language as well.
"""
language_inferencers: Dict[str, LanguageInferencer] = {}
for language in args.languages:
inferencer_key = f"{language}_inferencer"
pipe_key = f"{language}_pipe"
if getattr(args, inferencer_key, None) and getattr(args, pipe_key, None):
language_inferencers[language] = LanguageInferencer(
language=language,
inferencer_file=getattr(args, inferencer_key),
pipe_file=getattr(args, pipe_key),
)
else:
log.info(
f"Inferencer or pipe file for language: {language} not provided by"
" arguments. Skipping."
)
return language_inferencers
def build_input_reader(self, args: argparse.Namespace) -> InputReader:
"""Select the appropriate input reader based on the input format."""
if args.input_format == "jsonl":
return JsonlInputReader(args.input)
elif args.input_format == "csv":
return CsvInputReader(args.input)
else:
raise ValueError(f"Unsupported input format: {args.input_format}")
def process_input_file(self) -> None:
"""Process the input file, identify language, and apply the appropriate Mallet model"""
temp_files_by_language = self.write_language_specific_csv_files()
doctopics_files = self.run_topic_inference(temp_files_by_language)
logging.info(doctopics_files)
if self.args.output_format == "csv":
self.merge_inference_results(doctopics_files)
elif self.args.output_format == "jsonl":
self.merge_inference_results_jsonl(doctopics_files)
def merge_inference_results_jsonl(self, doctopics_files_by_language):
args = ["--output", "<generator>"]
m2ta_converters = {}
for lang, doctopics_file in doctopics_files_by_language.items():
topic_model_id = self.args.__dict__[f"{lang}_model_id"]
if "{lang}" in topic_model_id:
topic_model_id.format(lang=lang)
args += [
"--topic_model",
topic_model_id,
"--topic_count",
str(self.args.__dict__[f"{lang}_topic_count"]),
"--lang",
lang,
doctopics_file, # input comes last!
]
m2ta_converters[lang] = m2taj.Mallet2TopicAssignment.main(args)
for lang, m2ta_converter in m2ta_converters.items():
with open(self.args.output, "w", encoding="utf-8") as out_f:
for row in m2ta_converter:
self.stats["content_items"] += 1
print(
json.dumps(row, ensure_ascii=False, separators=(",", ":")),
file=out_f,
)
def merge_inference_results(
self, doctopics_files_by_language: Dict[str, str]
) -> None:
"""Merge the inference results from multiple languages into a single output file."""
logging.info(
"Saving CSV inference results into file %s from multiple languages: %s",
self.args.output,
doctopics_files_by_language,
)
with open(self.args.output, "w", encoding="utf-8") as out_f:
for language, doctopics_file in doctopics_files_by_language.items():
with open(doctopics_file, "r", encoding="utf-8") as f:
for line in f:
if line.startswith("#"):
continue
doc_id, topic_dist = line.strip().split("\t", 1)
print(
doc_id + "__" + language,
topic_dist,
sep="\t",
end="\n",
file=out_f,
)
def write_language_specific_csv_files(self) -> Dict[str, str]:
"""Read documents and write to language-specific temporary files"""
tsv_files_by_language = {}
for document_id, text in self.input_reader.read_documents():
language_code = self.identify_language(document_id, text)
self.stats["LANGUAGE: " + language_code] += 1
if language_code not in self.languages:
continue
if language_code not in tsv_files_by_language:
tsv_files_by_language[language_code] = tempfile.NamedTemporaryFile(
delete=False,
mode="w",
suffix=f".{language_code}.tsv",
encoding="utf-8",
)
logging.info(
"Writing documents for language: %s in temp file: %s",
language_code,
tsv_files_by_language[language_code].name,
)
print(
document_id,
language_code,
text,
sep="\t",
end="\n",
file=tsv_files_by_language[language_code],
)
# Close all temporary files
for temp_file in tsv_files_by_language.values():
temp_file.close()
# noinspection PyShadowingNames
result = {
lang: temp_file.name for lang, temp_file in tsv_files_by_language.items()
}
return result
def run_topic_inference(
self, language_specific_csv_files: Dict[str, str]
) -> Dict[str, str]:
"""Run inference for each language"""
doctopics_files_by_language = {}
for language_code, csv_file in language_specific_csv_files.items():
inferencer = self.language_inferencers.get(language_code)
if not inferencer:
log.error(f"No inferencer found for language: {language_code}")
continue
doctopics_file = inferencer.run_csv2topics(csv_file)
doctopics_files_by_language[language_code] = doctopics_file
# Clean up the temporary vectorized file if logging level is not DEBUG
if log.getEffectiveLevel() != logging.DEBUG:
logging.info("Cleaning language specific csv file: %s", csv_file)
os.remove(csv_file)
logging.debug("Resulting doctopic files: %s", doctopics_files_by_language)
return doctopics_files_by_language
def write_results_to_output(self) -> None:
"""Write the final merged inference results to the output file."""
with open(self.args.output, "w", encoding="utf-8") as out_file:
for result in self.inference_results:
out_file.write(json.dumps(result) + "\n")
log.info(f"All inferences merged and written to {self.args.output}")
def run(self) -> None:
"""Main execution method."""
try:
self.start_jvm()
self.initialize()
self.process_input_file()
# self.write_results_to_output()
except Exception as e:
log.error(f"An error occurred: {e}")
log.error("Traceback: %s", traceback.format_exc())
finally:
jpype.shutdownJVM()
log.info("JVM shutdown.")
for key, value in sorted(self.stats.items()):
log.info(f"STATS: {key}: {value}")
if __name__ == "__main__":
languages = ["de", "fr", "lb"] # You can add more languages as needed
parser = argparse.ArgumentParser(description="Mallet Topic Inference in Python")
parser.add_argument("--logfile", help="Path to log file", default=None)
parser.add_argument(
"--level",
default="DEBUG",
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
help="Logging level $(default)s",
)
parser.add_argument(
"--input", help="Path to input file (%(default)s)", required=True
)
parser.add_argument(
"--input-format",
choices=["jsonl", "csv"],
default="jsonl",
help="Format of the input file",
)
parser.add_argument(
"--output-format",
choices=["jsonl", "csv"],
help=(
"Format of the output file: csv: raw Mallet output with docids patched into"
" numericID-LANG, jsonl: impresso JSONL format"
),
)
parser.add_argument(
"--output",
help="Path to final output file. (%(default)s)",
default="out.jsonl",
)
parser.add_argument(
"--languages",
nargs="+",
default=languages,
help="List of languages to support (%(default)s)",
)
parser.add_argument(
"--language-file",
help="Path to JSONL containing document_id to language mappings",
required=False,
)
parser.add_argument("--model_dir", help="Path to model directory", required=True)
# Dynamically generate arguments for each language's inferencer and pipe files
for lang in languages:
parser.add_argument(
f"--{lang}_inferencer",
help=f"Path to {lang} inferencer file",
)
parser.add_argument(f"--{lang}_pipe", help=f"Path to {lang} pipe file")
parser.add_argument(
f"--{lang}_lemmatization", help=f"Path to {lang} lemmatization file"
)
# Dynamically generate arguments for each language's inferencer and pipe files
for lang in languages:
parser.add_argument(
f"--{lang}_model_id",
default=f"tm-{lang}-all-v2.0",
help="Model ID can take a {lang} format placeholder (%(default)s)",
)
for lang in languages:
parser.add_argument(
f"--{lang}_topic_count",
default=100,
help="Number of topics of model (%(default)s)",
)
args = parser.parse_args()
logging.basicConfig(
filename=args.logfile,
level=args.level,
format="%(asctime)-15s %(filename)s:%(lineno)d %(levelname)s: %(message)s",
force=True,
)
# Automatically construct file paths if not explicitly specified
for lang in args.languages:
model_id = getattr(args, f"{lang}_model_id")
model_dir = args.model_dir
pipe_path = os.path.join(model_dir, f"{model_id}.pipe")
inferencer_path = os.path.join(model_dir, f"{model_id}.inferencer")
lemmatization_path = os.path.join(
model_dir, f"{model_id}.vocab.lemmatization.tsv.gz"
)
if not getattr(args, f"{lang}_pipe") and os.path.exists(pipe_path):
logging.info("Automatically setting pipe path to %s", pipe_path)
setattr(args, f"{lang}_pipe", pipe_path)
if not getattr(args, f"{lang}_inferencer") and os.path.exists(inferencer_path):
logging.info("Automatically setting inferencer path to %s", inferencer_path)
setattr(args, f"{lang}_inferencer", inferencer_path)
if not getattr(args, f"{lang}_lemmatization") and os.path.exists(
lemmatization_path
):
logging.info(
"Automatically setting lemmatization path to %s", lemmatization_path
)
setattr(args, f"{lang}_lemmatization", lemmatization_path)
if not args.output_format:
if "jsonl" in args.output:
args.output_format = "jsonl"
else:
args.output_format = "csv"
logging.warning("Unspecified output format set to %s", args.output_format)
for lang in args.languages:
if not getattr(args, f"{lang}_inferencer") or not getattr(args, f"{lang}_pipe"):
logging.warning(
"Inferencer or pipe file not provided for language: %s. Ignoring"
" content items for this language.",
lang,
)
args.languages.remove(lang)
logging.info(
"Performing monolingual topic inference for the following languages: %s",
args.languages,
)
logging.info("Arguments: %s", args)
app = MalletTopicInferencer(args)
app.run()