Spaces:
Runtime error
Runtime error
| import pdb | |
| import sys | |
| WORD_POS = 1 | |
| TAG_POS = 2 | |
| MASK_TAG = "__entity__" | |
| INPUT_MASK_TAG = ":__entity__" | |
| RESET_POS_TAG='RESET' | |
| noun_tags = ['NFP','JJ','NN','FW','NNS','NNPS','JJS','JJR','NNP','POS','CD'] | |
| cap_tags = ['NFP','JJ','NN','FW','NNS','NNPS','JJS','JJR','NNP','PRP'] | |
| def detect_masked_positions(terms_arr): | |
| sentence_arr,span_arr = generate_masked_sentences(terms_arr) | |
| new_sent_arr = [] | |
| for i in range(len(terms_arr)): | |
| new_sent_arr.append(terms_arr[i][WORD_POS]) | |
| return new_sent_arr,sentence_arr,span_arr | |
| def generate_masked_sentences(terms_arr): | |
| size = len(terms_arr) | |
| sentence_arr = [] | |
| span_arr = [] | |
| i = 0 | |
| hack_for_no_nouns_case(terms_arr) | |
| while (i < size): | |
| term_info = terms_arr[i] | |
| if (term_info[TAG_POS] in noun_tags): | |
| skip = gen_sentence(sentence_arr,terms_arr,i) | |
| i += skip | |
| for j in range(skip): | |
| span_arr.append(1) | |
| else: | |
| i += 1 | |
| span_arr.append(0) | |
| #print(sentence_arr) | |
| return sentence_arr,span_arr | |
| def hack_for_no_nouns_case(terms_arr): | |
| ''' | |
| This is just a hack for case user enters a sentence with no entity to be tagged specifically and the sentence has no nouns | |
| Happens for odd inputs like a single word like "eg" etc. | |
| Just make the first term as a noun to proceed. | |
| ''' | |
| size = len(terms_arr) | |
| i = 0 | |
| found = False | |
| while (i < size): | |
| term_info = terms_arr[i] | |
| if (term_info[TAG_POS] in noun_tags): | |
| found = True | |
| break | |
| else: | |
| i += 1 | |
| if (not found and len(terms_arr) >= 1): | |
| term_info = terms_arr[0] | |
| term_info[TAG_POS] = noun_tags[0] | |
| def gen_sentence(sentence_arr,terms_arr,index): | |
| size = len(terms_arr) | |
| new_sent = [] | |
| for prefix,term in enumerate(terms_arr[:index]): | |
| new_sent.append(term[WORD_POS]) | |
| i = index | |
| skip = 0 | |
| while (i < size): | |
| if (terms_arr[i][TAG_POS] in noun_tags): | |
| skip += 1 | |
| i += 1 | |
| else: | |
| break | |
| new_sent.append(MASK_TAG) | |
| i = index + skip | |
| while (i < size): | |
| new_sent.append(terms_arr[i][WORD_POS]) | |
| i += 1 | |
| assert(skip != 0) | |
| sentence_arr.append(new_sent) | |
| return skip | |
| def capitalize(terms_arr): | |
| for i,term_tag in enumerate(terms_arr): | |
| #print(term_tag) | |
| if (term_tag[TAG_POS] in cap_tags): | |
| word = term_tag[WORD_POS][0].upper() + term_tag[WORD_POS][1:] | |
| term_tag[WORD_POS] = word | |
| #print(terms_arr) | |
| def set_POS_based_on_entities(sent): | |
| terms_arr = [] | |
| sent_arr = sent.split() | |
| for i,word in enumerate(sent_arr): | |
| #print(term_tag) | |
| term_tag = ['-']*5 | |
| if (word.endswith(INPUT_MASK_TAG)): | |
| term_tag[TAG_POS] = noun_tags[0] | |
| term_tag[WORD_POS] = word.replace(INPUT_MASK_TAG,"") | |
| else: | |
| term_tag[TAG_POS] = RESET_POS_TAG | |
| term_tag[WORD_POS] = word | |
| terms_arr.append(term_tag) | |
| return terms_arr | |
| #print(terms_arr) | |
| def filter_common_noun_spans(span_arr,masked_sent_arr,terms_arr,common_descs): | |
| ret_span_arr = span_arr.copy() | |
| ret_masked_sent_arr = [] | |
| sent_index = 0 | |
| loop_span_index = 0 | |
| while (loop_span_index < len(span_arr)): | |
| span_val = span_arr[loop_span_index] | |
| orig_index = loop_span_index | |
| if (span_val == 1): | |
| curr_index = orig_index | |
| is_all_common = True | |
| while (curr_index < len(span_arr) and span_arr[curr_index] == 1): | |
| term = terms_arr[curr_index] | |
| if (term[WORD_POS].lower() not in common_descs): | |
| is_all_common = False | |
| curr_index += 1 | |
| loop_span_index = curr_index #note the loop scan index is updated | |
| if (is_all_common): | |
| curr_index = orig_index | |
| print("Filtering common span: ",end='') | |
| while (curr_index < len(span_arr) and span_arr[curr_index] == 1): | |
| print(terms_arr[curr_index][WORD_POS],' ',end='') | |
| ret_span_arr[curr_index] = 0 | |
| curr_index += 1 | |
| print() | |
| sent_index += 1 # we are skipping a span | |
| else: | |
| ret_masked_sent_arr.append(masked_sent_arr[sent_index]) | |
| sent_index += 1 | |
| else: | |
| loop_span_index += 1 | |
| return ret_masked_sent_arr,ret_span_arr | |
| def normalize_casing(sent): | |
| sent_arr = sent.split() | |
| ret_sent_arr = [] | |
| for i,word in enumerate(sent_arr): | |
| if (len(word) > 1): | |
| norm_word = word[0] + word[1:].lower() | |
| else: | |
| norm_word = word[0] | |
| ret_sent_arr.append(norm_word) | |
| return ' '.join(ret_sent_arr) | |