|
|
|
""" |
|
|
|
Typical output of the script: |
|
{"topic_model":"tm-fr-all-v2.0","topic_count":100,"lang":"fr","ci_ref":"actionfem-1936-02-15-a-i0022","topics":[],"min_p":0.02} |
|
|
|
|
|
{ |
|
"topic_count": 100, |
|
"lang": "de", |
|
"topics": [ |
|
{"t": "tm-de-all-v2.0_tp02_de", "p": 0.027}, |
|
{"t": "tm-de-all-v2.0_tp11_de", "p": 0.119}, |
|
{"t": "tm-de-all-v2.0_tp26_de", "p": 0.045} |
|
], |
|
"min_p": 0.02, |
|
"ts": "2024.08.29", |
|
"id": "actionfem-1927-12-15-a-i0001", |
|
"sys_id": "tm-de-all-v2.0" |
|
} |
|
""" |
|
import datetime |
|
import logging |
|
import argparse |
|
import traceback |
|
import math |
|
import json |
|
import re |
|
import collections |
|
from typing import Generator, List, Dict, Any, Optional |
|
from smart_open import open |
|
|
|
|
|
CI_ID_REGEX = re.compile(r"^(.+?/)?([^/]+?-\d{4}-\d{2}-\d{2}-\w-i\d{4})[^/]*$") |
|
|
|
|
|
class Mallet2TopicAssignment: |
|
def __init__( |
|
self, |
|
args: Optional[argparse.Namespace] = None, |
|
topic_assignment_threshold: Optional[float] = None, |
|
lang: Optional[str] = None, |
|
topic_model: Optional[str] = None, |
|
numeric_topic_ids: Optional[bool] = None, |
|
format_type: Optional[str] = None, |
|
topic_count: Optional[int] = None, |
|
output: Optional[str] = None, |
|
) -> None: |
|
|
|
self.eps = args.topic_assignment_threshold |
|
self.lang = args.lang |
|
self.topic_model = args.topic_model |
|
self.numeric_topic_ids = args.numeric_topic_ids |
|
self.format_type = args.format_type.lower() |
|
self.topic_count = args.topic_count |
|
self.output = args.output |
|
self.args = args |
|
|
|
self.validate_options() |
|
|
|
self.precision = math.ceil(abs(math.log10(self.eps))) + 1 |
|
self.padding_length = math.ceil(math.log10(self.topic_count)) |
|
self.topic_id_format = ( |
|
f"{self.topic_model}_tp{{t:0{self.padding_length}d}}_{self.lang}" |
|
) |
|
self.last_timestamp = ( |
|
datetime.datetime.now(tz=datetime.timezone.utc) |
|
.replace(microsecond=0) |
|
.isoformat() |
|
+ "Z" |
|
) |
|
|
|
def validate_options(self) -> None: |
|
if self.eps <= 0 or self.eps >= 1: |
|
raise ValueError("topic_assignment_threshold must be between 0 and 1.") |
|
if self.format_type == "sparse" and not self.topic_count: |
|
raise ValueError( |
|
"The --topic_count option is required when using the 'sparse' format." |
|
) |
|
|
|
def read_tsv_files(self, filenames: List[str]) -> Generator[List[str], None, None]: |
|
|
|
for filename in filenames: |
|
yield from self.read_tsv_file(filename) |
|
|
|
def read_tsv_file(self, filename: str) -> Generator[List[str], None, None]: |
|
line_count = 0 |
|
with open(filename, "r", encoding="utf-8") as file: |
|
for line in file: |
|
line_count += 1 |
|
if not line.startswith("#"): |
|
yield line.strip().split("\t") |
|
if line_count % 1000 == 0: |
|
logging.info("Processed lines: %s", line_count) |
|
|
|
def convert_matrix_row(self, row: List[str]) -> Dict[str, Any]: |
|
ci_id = re.sub(CI_ID_REGEX, r"\2", row[1]) |
|
topics = row[2:] |
|
topic_count = len(topics) |
|
if self.numeric_topic_ids: |
|
topics = [ |
|
{"t": t, "p": round(fp, self.precision)} |
|
for t, p in enumerate(topics) |
|
if (fp := float(p)) >= self.eps |
|
] |
|
else: |
|
topics = [ |
|
{ |
|
"t": self.topic_id_format.format(t=t), |
|
"p": round(fp, self.precision), |
|
} |
|
for t, p in enumerate(topics) |
|
if (fp := float(p)) >= self.eps |
|
] |
|
|
|
return { |
|
"ci_id": ci_id, |
|
"model_id": self.topic_model, |
|
"lang": self.lang, |
|
"topic_count": topic_count, |
|
"topics": topics, |
|
"min_p": self.eps, |
|
"ts": self.last_timestamp, |
|
} |
|
|
|
def convert_sparse_row(self, row: List[str]) -> Dict[str, Any]: |
|
ci_id = re.sub(CI_ID_REGEX, r"\2", row[1]) |
|
topic_pairs = row[2:] |
|
topics = [] |
|
for i in range(0, len(topic_pairs), 2): |
|
t = int(topic_pairs[i]) |
|
p = float(topic_pairs[i + 1]) |
|
if p >= self.eps: |
|
if self.numeric_topic_ids: |
|
topics.append( |
|
{ |
|
"t": t, |
|
"p": round(p, math.ceil(abs(math.log10(self.eps))) + 1), |
|
} |
|
) |
|
else: |
|
topics.append( |
|
{ |
|
"t": self.topic_id_format.format(t=t), |
|
"p": round(p, math.ceil(abs(math.log10(self.eps))) + 1), |
|
} |
|
) |
|
|
|
return { |
|
"ci_id": ci_id, |
|
"model_id": self.topic_model, |
|
"lang": self.lang, |
|
"topic_count": self.topic_count, |
|
"topics": topics, |
|
"min_p": self.eps, |
|
"ts": self.last_timestamp, |
|
} |
|
|
|
def parse_mallet_files( |
|
self, filenames: List[str] |
|
) -> Generator[Dict[str, Any], None, None]: |
|
""" |
|
Process the Mallet topic word weights from multiple files and yield topic assignments in JSON format. |
|
|
|
Args: |
|
filenames (List[str]): List of paths to the input files. |
|
|
|
Yields: |
|
Dict[str, Any]: Parsed topic assignment from each line in the input files. |
|
""" |
|
ci_id_stats = collections.Counter() |
|
if self.format_type == "sparse": |
|
convert_row = self.convert_sparse_row |
|
elif self.format_type == "matrix": |
|
convert_row = self.convert_matrix_row |
|
else: |
|
raise ValueError(f"Invalid format type: {self.format_type}") |
|
|
|
for row in self.read_tsv_files(filenames): |
|
ci_id = re.sub(CI_ID_REGEX, r"\2", row[1]) |
|
if ci_id in ci_id_stats: |
|
ci_id_stats["DUPLICATE_COUNT"] += 1 |
|
continue |
|
ci_id_stats[ci_id] = 1 |
|
|
|
yield convert_row(row) |
|
|
|
logging.info("DUPLICATE-COUNT: %d", ci_id_stats["DUPLICATE_COUNT"]) |
|
|
|
def run(self) -> Optional[Generator[Dict[str, Any], None, None]]: |
|
""" |
|
Main method to process the input files based on the command line arguments. |
|
Returns a generator if output is set to '<generator>', otherwise writes to a file. |
|
|
|
Returns: |
|
Optional[Generator[Dict[str, Any], None, None]]: A generator for topic assignments |
|
if output is set to '<generator>', otherwise None. |
|
""" |
|
if self.output == "<generator>": |
|
|
|
return self.parse_mallet_files(self.args.INPUT_FILES) |
|
|
|
try: |
|
with open(self.output, "w", encoding="utf-8") as out_file: |
|
for topic_assignment in self.parse_mallet_files(self.args.INPUT_FILES): |
|
out_file.write( |
|
json.dumps( |
|
topic_assignment, ensure_ascii=False, separators=(",", ":") |
|
) |
|
+ "\n" |
|
) |
|
except Exception as e: |
|
logging.error(f"An error occurred: {e}") |
|
logging.error("Traceback: %s", traceback.format_exc()) |
|
exit(1) |
|
|
|
@staticmethod |
|
def setup_logging(options: argparse.Namespace) -> None: |
|
""" |
|
Set up logging configuration based on command line options. |
|
""" |
|
log_level = logging.DEBUG if options.debug else logging.INFO |
|
logging.basicConfig( |
|
level=log_level, filename=options.logfile if options.logfile else None |
|
) |
|
|
|
@staticmethod |
|
def main( |
|
args: Optional[List[str]], |
|
) -> Optional[Generator[Dict[str, Any], None, None]]: |
|
""" |
|
Static method serving as the entry point of the script. |
|
If the output option is set to '<generator>', it returns a Python generator |
|
for topic assignments, otherwise prints results or writes to a file. |
|
|
|
Returns: |
|
Optional[Generator[Dict[str, Any], None, None]]: Generator for topic assignments |
|
if output is set to '<generator>', otherwise None. |
|
""" |
|
parser = argparse.ArgumentParser( |
|
usage="%(prog)s [OPTIONS] INPUT [INPUT ...]", |
|
description=( |
|
"Return topic assignments from mallet textual topic modeling output." |
|
), |
|
epilog="Contact [email protected] for more information.", |
|
) |
|
|
|
parser.add_argument("--version", action="version", version="2024.10.23") |
|
parser.add_argument( |
|
"-l", "--logfile", help="Write log information to FILE", metavar="FILE" |
|
) |
|
parser.add_argument( |
|
"-q", |
|
"--quiet", |
|
action="store_true", |
|
help="Do not print status messages to stderr", |
|
) |
|
parser.add_argument( |
|
"-d", "--debug", action="store_true", help="Print debug information" |
|
) |
|
parser.add_argument( |
|
"-L", |
|
"--lang", |
|
"--language", |
|
default="und", |
|
help="ISO 639 language code two-letter or 'und' for undefined", |
|
) |
|
parser.add_argument( |
|
"-M", |
|
"--topic_model", |
|
default="tm000", |
|
help="Topic model identifier, e.g., tm001", |
|
) |
|
parser.add_argument( |
|
"-N", |
|
"--numeric_topic_ids", |
|
action="store_true", |
|
help="Use numeric topic IDs in the topic assignment", |
|
) |
|
parser.add_argument( |
|
"-T", |
|
"--topic_assignment_threshold", |
|
type=float, |
|
default=0.02, |
|
help="Minimum probability for inclusion in the output", |
|
) |
|
parser.add_argument( |
|
"-F", |
|
"--format_type", |
|
choices=["matrix", "sparse"], |
|
default="matrix", |
|
help="Format of the input file: 'matrix' or 'sparse'", |
|
) |
|
parser.add_argument( |
|
"-C", |
|
"--topic_count", |
|
type=int, |
|
help="Needed for formatting ", |
|
required=True, |
|
) |
|
parser.add_argument( |
|
"-o", |
|
"--output", |
|
help=( |
|
"Path to the output file (%(default)s). If set to '<generator>' it will" |
|
" return a generator that can be used to enumerate all results in a" |
|
" flexible way. " |
|
), |
|
default="/dev/stdout", |
|
) |
|
|
|
parser.add_argument( |
|
"--level", |
|
default="INFO", |
|
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], |
|
help="Set the logging level. Default: %(default)s", |
|
) |
|
parser.add_argument( |
|
"INPUT_FILES", nargs="+", help="One or more input files to process." |
|
) |
|
|
|
options = parser.parse_args(args=args) |
|
|
|
|
|
Mallet2TopicAssignment.setup_logging(options) |
|
|
|
|
|
if options.format_type == "sparse" and not options.topic_count: |
|
parser.error( |
|
"The --topic_count option is required when using the 'sparse' format" |
|
) |
|
|
|
|
|
app = Mallet2TopicAssignment(args=options) |
|
|
|
|
|
if options.output == "<generator>": |
|
return app.run() |
|
|
|
|
|
app.run() |
|
return None |
|
|
|
|
|
if __name__ == "__main__": |
|
Mallet2TopicAssignment.main() |
|
|