mallet-topic-inferencer / lib /mallet2topic_assignment_jsonl.py
Simon Clematide
Initial commit with models, scripts, and JAR files
fc83ec7
#!/usr/bin/env python3
"""
Typical output of the script:
{"topic_model":"tm-fr-all-v2.0","topic_count":100,"lang":"fr","ci_ref":"actionfem-1936-02-15-a-i0022","topics":[],"min_p":0.02}
{
"topic_count": 100,
"lang": "de",
"topics": [
{"t": "tm-de-all-v2.0_tp02_de", "p": 0.027},
{"t": "tm-de-all-v2.0_tp11_de", "p": 0.119},
{"t": "tm-de-all-v2.0_tp26_de", "p": 0.045}
],
"min_p": 0.02,
"ts": "2024.08.29",
"id": "actionfem-1927-12-15-a-i0001",
"sys_id": "tm-de-all-v2.0"
}
"""
import datetime
import logging
import argparse
import traceback
import math
import json
import re
import collections
from typing import Generator, List, Dict, Any, Optional
from smart_open import open
CI_ID_REGEX = re.compile(r"^(.+?/)?([^/]+?-\d{4}-\d{2}-\d{2}-\w-i\d{4})[^/]*$")
class Mallet2TopicAssignment:
def __init__(
self,
args: Optional[argparse.Namespace] = None,
topic_assignment_threshold: Optional[float] = None,
lang: Optional[str] = None,
topic_model: Optional[str] = None,
numeric_topic_ids: Optional[bool] = None,
format_type: Optional[str] = None,
topic_count: Optional[int] = None,
output: Optional[str] = None,
) -> None:
self.eps = args.topic_assignment_threshold
self.lang = args.lang
self.topic_model = args.topic_model
self.numeric_topic_ids = args.numeric_topic_ids
self.format_type = args.format_type.lower() # Normalize case
self.topic_count = args.topic_count
self.output = args.output
self.args = args # Ensure we keep the args namespace
self.validate_options()
self.precision = math.ceil(abs(math.log10(self.eps))) + 1
self.padding_length = math.ceil(math.log10(self.topic_count))
self.topic_id_format = (
f"{self.topic_model}_tp{{t:0{self.padding_length}d}}_{self.lang}"
)
self.last_timestamp = (
datetime.datetime.now(tz=datetime.timezone.utc)
.replace(microsecond=0)
.isoformat()
+ "Z"
)
def validate_options(self) -> None:
if self.eps <= 0 or self.eps >= 1:
raise ValueError("topic_assignment_threshold must be between 0 and 1.")
if self.format_type == "sparse" and not self.topic_count:
raise ValueError(
"The --topic_count option is required when using the 'sparse' format."
)
def read_tsv_files(self, filenames: List[str]) -> Generator[List[str], None, None]:
for filename in filenames:
yield from self.read_tsv_file(filename)
def read_tsv_file(self, filename: str) -> Generator[List[str], None, None]:
line_count = 0
with open(filename, "r", encoding="utf-8") as file:
for line in file:
line_count += 1
if not line.startswith("#"):
yield line.strip().split("\t")
if line_count % 1000 == 0:
logging.info("Processed lines: %s", line_count)
def convert_matrix_row(self, row: List[str]) -> Dict[str, Any]:
ci_id = re.sub(CI_ID_REGEX, r"\2", row[1])
topics = row[2:]
topic_count = len(topics)
if self.numeric_topic_ids:
topics = [
{"t": t, "p": round(fp, self.precision)}
for t, p in enumerate(topics)
if (fp := float(p)) >= self.eps
]
else:
topics = [
{
"t": self.topic_id_format.format(t=t),
"p": round(fp, self.precision),
}
for t, p in enumerate(topics)
if (fp := float(p)) >= self.eps
]
return {
"ci_id": ci_id,
"model_id": self.topic_model,
"lang": self.lang,
"topic_count": topic_count,
"topics": topics,
"min_p": self.eps,
"ts": self.last_timestamp,
}
def convert_sparse_row(self, row: List[str]) -> Dict[str, Any]:
ci_id = re.sub(CI_ID_REGEX, r"\2", row[1])
topic_pairs = row[2:]
topics = []
for i in range(0, len(topic_pairs), 2):
t = int(topic_pairs[i])
p = float(topic_pairs[i + 1])
if p >= self.eps:
if self.numeric_topic_ids:
topics.append(
{
"t": t,
"p": round(p, math.ceil(abs(math.log10(self.eps))) + 1),
}
)
else:
topics.append(
{
"t": self.topic_id_format.format(t=t),
"p": round(p, math.ceil(abs(math.log10(self.eps))) + 1),
}
)
return {
"ci_id": ci_id,
"model_id": self.topic_model,
"lang": self.lang,
"topic_count": self.topic_count,
"topics": topics,
"min_p": self.eps,
"ts": self.last_timestamp,
}
def parse_mallet_files(
self, filenames: List[str]
) -> Generator[Dict[str, Any], None, None]:
"""
Process the Mallet topic word weights from multiple files and yield topic assignments in JSON format.
Args:
filenames (List[str]): List of paths to the input files.
Yields:
Dict[str, Any]: Parsed topic assignment from each line in the input files.
"""
ci_id_stats = collections.Counter()
if self.format_type == "sparse":
convert_row = self.convert_sparse_row
elif self.format_type == "matrix":
convert_row = self.convert_matrix_row
else:
raise ValueError(f"Invalid format type: {self.format_type}")
for row in self.read_tsv_files(filenames):
ci_id = re.sub(CI_ID_REGEX, r"\2", row[1])
if ci_id in ci_id_stats:
ci_id_stats["DUPLICATE_COUNT"] += 1
continue
ci_id_stats[ci_id] = 1
yield convert_row(row)
logging.info("DUPLICATE-COUNT: %d", ci_id_stats["DUPLICATE_COUNT"])
def run(self) -> Optional[Generator[Dict[str, Any], None, None]]:
"""
Main method to process the input files based on the command line arguments.
Returns a generator if output is set to '<generator>', otherwise writes to a file.
Returns:
Optional[Generator[Dict[str, Any], None, None]]: A generator for topic assignments
if output is set to '<generator>', otherwise None.
"""
if self.output == "<generator>":
# Return a generator if the output is set to '<generator>'
return self.parse_mallet_files(self.args.INPUT_FILES)
try:
with open(self.output, "w", encoding="utf-8") as out_file:
for topic_assignment in self.parse_mallet_files(self.args.INPUT_FILES):
out_file.write(
json.dumps(
topic_assignment, ensure_ascii=False, separators=(",", ":")
)
+ "\n"
)
except Exception as e:
logging.error(f"An error occurred: {e}")
logging.error("Traceback: %s", traceback.format_exc())
exit(1)
@staticmethod
def setup_logging(options: argparse.Namespace) -> None:
"""
Set up logging configuration based on command line options.
"""
log_level = logging.DEBUG if options.debug else logging.INFO
logging.basicConfig(
level=log_level, filename=options.logfile if options.logfile else None
)
@staticmethod
def main(
args: Optional[List[str]],
) -> Optional[Generator[Dict[str, Any], None, None]]:
"""
Static method serving as the entry point of the script.
If the output option is set to '<generator>', it returns a Python generator
for topic assignments, otherwise prints results or writes to a file.
Returns:
Optional[Generator[Dict[str, Any], None, None]]: Generator for topic assignments
if output is set to '<generator>', otherwise None.
"""
parser = argparse.ArgumentParser(
usage="%(prog)s [OPTIONS] INPUT [INPUT ...]",
description=(
"Return topic assignments from mallet textual topic modeling output."
),
epilog="Contact [email protected] for more information.",
)
parser.add_argument("--version", action="version", version="2024.10.23")
parser.add_argument(
"-l", "--logfile", help="Write log information to FILE", metavar="FILE"
)
parser.add_argument(
"-q",
"--quiet",
action="store_true",
help="Do not print status messages to stderr",
)
parser.add_argument(
"-d", "--debug", action="store_true", help="Print debug information"
)
parser.add_argument(
"-L",
"--lang",
"--language",
default="und",
help="ISO 639 language code two-letter or 'und' for undefined",
)
parser.add_argument(
"-M",
"--topic_model",
default="tm000",
help="Topic model identifier, e.g., tm001",
)
parser.add_argument(
"-N",
"--numeric_topic_ids",
action="store_true",
help="Use numeric topic IDs in the topic assignment",
)
parser.add_argument(
"-T",
"--topic_assignment_threshold",
type=float,
default=0.02,
help="Minimum probability for inclusion in the output",
)
parser.add_argument(
"-F",
"--format_type",
choices=["matrix", "sparse"],
default="matrix",
help="Format of the input file: 'matrix' or 'sparse'",
)
parser.add_argument(
"-C",
"--topic_count",
type=int,
help="Needed for formatting ",
required=True,
)
parser.add_argument(
"-o",
"--output",
help=(
"Path to the output file (%(default)s). If set to '<generator>' it will"
" return a generator that can be used to enumerate all results in a"
" flexible way. "
),
default="/dev/stdout",
)
parser.add_argument(
"--level",
default="INFO",
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
help="Set the logging level. Default: %(default)s",
)
parser.add_argument(
"INPUT_FILES", nargs="+", help="One or more input files to process."
)
options = parser.parse_args(args=args)
# Configure logging
Mallet2TopicAssignment.setup_logging(options)
# Validate specific arguments
if options.format_type == "sparse" and not options.topic_count:
parser.error(
"The --topic_count option is required when using the 'sparse' format"
)
# Create the application instance
app = Mallet2TopicAssignment(args=options)
# Check if output is set to '<generator>' and return a generator if so
if options.output == "<generator>":
return app.run()
# Otherwise, run normally (output to file or stdout)
app.run()
return None
if __name__ == "__main__":
Mallet2TopicAssignment.main()