# title: ENA dataset utility functions
# author: Taewook Kang, Kyubyung Kang
# date: 2024.3.27
# license: MIT
# reference: https://pyautocad.readthedocs.io/en/latest/_modules/pyautocad/api.html
# version
#   0.1. 2024.3.27. create file
# 
import json, os, re, logging, numpy as np
from transformers import BertTokenizer

def load_train_chunk_data(data_dir, sort_fname=False):
	geom_list = []
	fnames = os.listdir(data_dir)
	if sort_fname:
		fnames.sort(key=lambda x: int(re.search(r'\d+', x).group()))
	xsec_count = 0
	for file_name in fnames:
		if file_name.endswith('.json') == False:
			continue
		with open(os.path.join(data_dir, file_name), 'r') as f:
			chunk = json.load(f)
			for xsec in chunk:
				xsec_count += 1
				geom = xsec['geom']
				for g in geom:
					g['station'] = xsec['station']
					features = g['earthwork_feature']
					if len(features) == 0:
						continue
					geom_list.append(g)
	print(f'Loaded {xsec_count} cross sections')
	return geom_list

def update_feature_dims_token(geom_list): 
	tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True) # Load the BERT tokenizer

	feature_dims = []
	max_token = 0
	padding_token_id = tokenizer.add_tokens(['padding'])
	for geom in geom_list:
		label = geom['label']
		geom['feature_dims'] = []
		for feature in geom['earthwork_feature']:
			# token = tokenizer.tokenize(feature)
			token_ids = tokenizer.convert_tokens_to_ids(feature)
			geom['feature_dims'].append(token_ids)

			word, count = extract_word_and_count(feature)
			if word in tokens:
				continue
			feature_dims.append(word)
		
		max_token = max(max_token, len(geom['feature_dims']))

	for geom in geom_list:
		label = geom['label']
		geom['feature_dims'] += [padding_token_id] * (max_token - len(geom['feature_dims']))    

	print(f'Max token length: {max_token}')
	return feature_dims

def extract_word_and_count(s):
	match = re.match(r'(\w+)(?:\((\d+)\))?', s)
	if match:
		word, count = match.groups()
		count = int(count) if count else 1
		return word, count
	
	return None, None

def update_feature_dims_freq(geom_list, augument=False):
	feature_dims = []
	for geom in geom_list:
		label = geom['label']
		geom['feature_dims'] = []
		for feature in geom['earthwork_feature']:
			word, count = extract_word_and_count(feature)
			if word is None or count is None:
				continue
			if word in feature_dims:
				continue
			feature_dims.append(word)
	
	feature_dims.sort()

	max_feature_dims_count = [0.0] * len(feature_dims)
	for geom in geom_list:
		label = geom['label']
		geom['feature_dims'] = [0.0] * len(feature_dims)
		geom['feature_text'] = ''
		# geom['feature_angle_dims'] = [0.0] * len(feature_dims)

		for feature in geom['earthwork_feature']:
			word, count = extract_word_and_count(feature)
			if word is None or count is None:
				continue
			geom['feature_text'] += f'{word}({count}) '
			index = feature_dims.index(word)

			geom['feature_dims'][index] = count
			max_feature_dims_count[index] = max(max_feature_dims_count[index], count)

	# normalize feature_dims by usng max_feature_dims_count
	for geom in geom_list:
		label = geom['label']
		for i in range(len(geom['feature_dims'])):
			geom['feature_dims'][i] /= max_feature_dims_count[i]

	# augument feature_dims dataset
	if augument:
		for geom in geom_list:
			label = geom['label']
			geom['feature_dims_aug'] = []
			for i in range(len(geom['feature_dims'])):
				geom['feature_dims_aug'].append(geom['feature_dims'][i])
				geom['feature_dims_aug'].append(geom['feature_dims'][i] * geom['feature_dims'][i])

	print(f'feature dims({len(feature_dims)}): {feature_dims}')
	return feature_dims

def update_onehot_encoding(geom_list):
	label_kinds = []
	for geom in geom_list:
		label = geom['label']
		if label not in label_kinds:
			label_kinds.append(label)

	from collections import Counter # from sklearn.preprocessing import OneHotEncoder
	for geom in geom_list: # count label's kind of train_labels. Initialize the one-hot encoder
		label = geom['label']

		label_counts = Counter(label_kinds)
		onehot = np.zeros(len(label_kinds))
		onehot[label_kinds.index(label)] = 1.0
		geom['label_onehot'] = onehot
	return label_kinds