dejanseo commited on
Commit
88da8fc
·
verified ·
1 Parent(s): d7b5eae

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +98 -97
src/streamlit_app.py CHANGED
@@ -1,25 +1,39 @@
1
  import streamlit as st
2
  import torch
3
  import torch.nn.functional as F
4
- from torch.nn.functional import softmax
5
  from transformers import AutoTokenizer, AutoModelForTokenClassification
6
  import pandas as pd
7
  import trafilatura
8
 
9
- # Set Streamlit configuration
10
  st.set_page_config(layout="wide", page_title="LinkBERT")
11
 
12
- # Load model and tokenizer (correct for XLM-RoBERTa Large)
13
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
- tokenizer = AutoTokenizer.from_pretrained("dejanseo/LinkBERT-XL")
15
- model = AutoModelForTokenClassification.from_pretrained("dejanseo/LinkBERT-XL").to(device)
 
 
 
 
 
 
 
 
 
 
16
  model.eval()
17
 
18
  # Functions
19
-
20
  def tokenize_with_indices(text: str):
21
- encoded = tokenizer.encode_plus(text, return_offsets_mapping=True, add_special_tokens=True)
22
- return encoded['input_ids'], encoded['offset_mapping']
 
 
 
 
 
 
23
 
24
  def fetch_and_extract_content(url: str):
25
  downloaded = trafilatura.fetch_url(url)
@@ -29,109 +43,102 @@ def fetch_and_extract_content(url: str):
29
  return None
30
 
31
  def process_text(inputs: str, confidence_threshold: float):
32
- max_chunk_length = 512 - 2
33
  words = inputs.split()
34
  chunk_texts = []
35
- current_chunk = []
36
- current_length = 0
37
  for word in words:
38
- if len(tokenizer.tokenize(word)) + current_length > max_chunk_length:
 
39
  chunk_texts.append(" ".join(current_chunk))
40
  current_chunk = [word]
41
- current_length = len(tokenizer.tokenize(word))
42
-
43
  else:
44
  current_chunk.append(word)
45
- current_length += len(tokenizer.tokenize(word))
46
- chunk_texts.append(" ".join(current_chunk))
47
-
48
- df_data = {
49
- 'Word': [],
50
- 'Prediction': [],
51
- 'Confidence': [],
52
- 'Start': [],
53
- 'End': []
54
- }
55
  reconstructed_text = ""
56
  original_position_offset = 0
57
 
58
- for chunk in chunk_texts:
59
- input_ids, token_offsets = tokenize_with_indices(chunk)
60
- input_ids_tensor = torch.tensor(input_ids).unsqueeze(0).to(device)
61
- with torch.no_grad():
62
- outputs = model(input_ids_tensor)
63
- logits = outputs.logits
64
- predictions = torch.argmax(logits, dim=-1).squeeze().tolist()
65
- softmax_scores = F.softmax(logits, dim=-1).squeeze().tolist()
66
-
67
- word_info = {}
68
-
69
- for idx, (start, end) in enumerate(token_offsets):
70
- if idx == 0 or idx == len(token_offsets) - 1:
71
- continue
72
-
73
- word_start = start
74
- while word_start > 0 and chunk[word_start-1] != ' ':
75
- word_start -= 1
76
-
77
- if word_start not in word_info:
78
- word_info[word_start] = {'prediction': 0, 'confidence': 0.0, 'subtokens': []}
79
-
80
- confidence_percentage = softmax_scores[idx][predictions[idx]] * 100
81
-
82
- if predictions[idx] == 1 and confidence_percentage >= confidence_threshold:
83
- word_info[word_start]['prediction'] = 1
84
 
85
- word_info[word_start]['confidence'] = max(word_info[word_start]['confidence'], confidence_percentage)
86
- word_info[word_start]['subtokens'].append((start, end, chunk[start:end]))
87
-
88
- last_end = 0
89
- for word_start in sorted(word_info.keys()):
90
- word_data = word_info[word_start]
91
- for subtoken_start, subtoken_end, subtoken_text in word_data['subtokens']:
92
- escaped_subtoken_text = subtoken_text.replace('$', '\\$')
93
- if last_end < subtoken_start:
94
- reconstructed_text += chunk[last_end:subtoken_start]
95
- if word_data['prediction'] == 1:
96
- reconstructed_text += f"<span style='background-color: rgba(0, 255, 0); display: inline;'>{escaped_subtoken_text}</span>"
97
- else:
98
- reconstructed_text += escaped_subtoken_text
99
- last_end = subtoken_end
100
-
101
- df_data['Word'].append(escaped_subtoken_text)
102
- df_data['Prediction'].append(word_data['prediction'])
103
- df_data['Confidence'].append(word_info[word_start]['confidence'])
104
- df_data['Start'].append(subtoken_start + original_position_offset)
105
- df_data['End'].append(subtoken_end + original_position_offset)
106
-
107
- original_position_offset += len(chunk) + 1
108
-
109
- reconstructed_text += chunk[last_end:].replace('$', '\\$')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
  df_tokens = pd.DataFrame(df_data)
112
  return reconstructed_text, df_tokens
113
 
114
- # Streamlit Interface
115
-
116
- st.title('LinkBERT')
117
  st.markdown("""
118
- LinkBERT is a model developed by [Dejan Marketing](https://dejanmarketing.com/) designed to predict natural link placement within web content. You can either enter plain text or the URL for automated plain text extraction. To reduce the number of link predictions increase the threshold slider value.
119
  """)
120
 
121
- confidence_threshold = st.slider('Confidence Threshold', 50, 100, 50)
122
 
123
  tab1, tab2 = st.tabs(["Text Input", "URL Input"])
124
 
125
  with tab1:
126
  user_input = st.text_area("Enter text to process:")
127
- if st.button('Process Text'):
128
  highlighted_text, df_tokens = process_text(user_input, confidence_threshold)
129
  st.markdown(highlighted_text, unsafe_allow_html=True)
130
  st.dataframe(df_tokens)
131
 
132
  with tab2:
133
  url_input = st.text_input("Enter URL to process:")
134
- if st.button('Fetch and Process'):
135
  content = fetch_and_extract_content(url_input)
136
  if content:
137
  highlighted_text, df_tokens = process_text(content, confidence_threshold)
@@ -140,28 +147,22 @@ with tab2:
140
  else:
141
  st.error("Could not fetch content from the URL. Please check the URL and try again.")
142
 
143
- # Additional information at the end
144
  st.divider()
145
  st.markdown("""
146
-
147
  ## Applications of LinkBERT
148
-
149
- LinkBERT's applications are vast and diverse, tailored to enhance both the efficiency and quality of web content creation and analysis:
150
-
151
- - **Anchor Text Suggestion:** Acts as a mechanism during internal link optimization, suggesting potential anchor texts to web authors.
152
- - **Evaluation of Existing Links:** Assesses the naturalness of link placements within existing content, aiding in the refinement of web pages.
153
- - **Link Placement Guide:** Offers guidance to link builders by suggesting optimal placement for links within content.
154
- - **Anchor Text Idea Generator:** Provides creative anchor text suggestions to enrich content and improve SEO strategies.
155
- - **Spam and Inorganic SEO Detection:** Helps identify unnatural link patterns, contributing to the detection of spam and inorganic SEO tactics.
156
 
157
  ## Training and Performance
158
-
159
  LinkBERT was fine-tuned on a dataset of organic web content and editorial links.
160
 
161
  [Watch the video](https://www.youtube.com/watch?v=A0ZulyVqjZo)
162
-
163
  # Engage Our Team
164
  Interested in using this in an automated pipeline for bulk link prediction?
165
 
166
- Please [book an appointment](https://dejanmarketing.com/conference/) to discuss your needs.
167
- """)
 
1
  import streamlit as st
2
  import torch
3
  import torch.nn.functional as F
 
4
  from transformers import AutoTokenizer, AutoModelForTokenClassification
5
  import pandas as pd
6
  import trafilatura
7
 
8
+ # Streamlit config
9
  st.set_page_config(layout="wide", page_title="LinkBERT")
10
 
11
+ # Load tokenizer & model (avoid meta-tensor .to() issue)
12
+ MODEL_ID = "dejanseo/LinkBERT-XL"
13
+
14
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
15
+
16
+ load_kwargs = {}
17
+ if torch.cuda.is_available():
18
+ # Load directly onto GPU(s); do NOT call .to(...) afterward
19
+ load_kwargs.update(dict(device_map="auto", torch_dtype=torch.float16))
20
+ else:
21
+ # CPU load without meta tensors
22
+ load_kwargs.update(dict(device_map=None))
23
+
24
+ model = AutoModelForTokenClassification.from_pretrained(MODEL_ID, **load_kwargs)
25
  model.eval()
26
 
27
  # Functions
 
28
  def tokenize_with_indices(text: str):
29
+ encoded = tokenizer.encode_plus(
30
+ text,
31
+ return_offsets_mapping=True,
32
+ add_special_tokens=True,
33
+ truncation=True,
34
+ max_length=512
35
+ )
36
+ return encoded["input_ids"], encoded["offset_mapping"]
37
 
38
  def fetch_and_extract_content(url: str):
39
  downloaded = trafilatura.fetch_url(url)
 
43
  return None
44
 
45
  def process_text(inputs: str, confidence_threshold: float):
46
+ max_chunk_length = 512 - 2 # safe window for special tokens
47
  words = inputs.split()
48
  chunk_texts = []
49
+ current_chunk, current_length = [], 0
 
50
  for word in words:
51
+ tok_len = len(tokenizer.tokenize(word))
52
+ if tok_len + current_length > max_chunk_length:
53
  chunk_texts.append(" ".join(current_chunk))
54
  current_chunk = [word]
55
+ current_length = tok_len
 
56
  else:
57
  current_chunk.append(word)
58
+ current_length += tok_len
59
+ if current_chunk:
60
+ chunk_texts.append(" ".join(current_chunk))
61
+
62
+ df_data = {"Word": [], "Prediction": [], "Confidence": [], "Start": [], "End": []}
 
 
 
 
 
63
  reconstructed_text = ""
64
  original_position_offset = 0
65
 
66
+ with torch.no_grad():
67
+ for chunk in chunk_texts:
68
+ input_ids, token_offsets = tokenize_with_indices(chunk)
69
+ input_ids_tensor = torch.tensor(input_ids).unsqueeze(0).to(model.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
+ outputs = model(input_ids_tensor)
72
+ logits = outputs.logits # [1, seq_len, num_labels]
73
+ predictions = torch.argmax(logits, dim=-1).squeeze(0).tolist()
74
+ softmax_scores = F.softmax(logits, dim=-1).squeeze(0).tolist()
75
+
76
+ word_info = {}
77
+ for idx, (start, end) in enumerate(token_offsets):
78
+ if idx == 0 or idx == len(token_offsets) - 1:
79
+ continue # skip specials
80
+
81
+ word_start = start
82
+ while word_start > 0 and chunk[word_start - 1] != ' ':
83
+ word_start -= 1
84
+
85
+ if word_start not in word_info:
86
+ word_info[word_start] = {"prediction": 0, "confidence": 0.0, "subtokens": []}
87
+
88
+ conf_pct = softmax_scores[idx][predictions[idx]] * 100.0
89
+ if predictions[idx] == 1 and conf_pct >= confidence_threshold:
90
+ word_info[word_start]["prediction"] = 1
91
+ word_info[word_start]["confidence"] = max(word_info[word_start]["confidence"], conf_pct)
92
+ word_info[word_start]["subtokens"].append((start, end, chunk[start:end]))
93
+
94
+ last_end = 0
95
+ for word_start in sorted(word_info.keys()):
96
+ word_data = word_info[word_start]
97
+ for subtoken_start, subtoken_end, subtoken_text in word_data["subtokens"]:
98
+ escaped = subtoken_text.replace("$", "\\$")
99
+ if last_end < subtoken_start:
100
+ reconstructed_text += chunk[last_end:subtoken_start]
101
+ if word_data["prediction"] == 1:
102
+ reconstructed_text += (
103
+ f"<span style='background-color: rgba(0, 255, 0); display: inline;'>{escaped}</span>"
104
+ )
105
+ else:
106
+ reconstructed_text += escaped
107
+ last_end = subtoken_end
108
+
109
+ df_data["Word"].append(escaped)
110
+ df_data["Prediction"].append(word_data["prediction"])
111
+ df_data["Confidence"].append(word_info[word_start]["confidence"])
112
+ df_data["Start"].append(subtoken_start + original_position_offset)
113
+ df_data["End"].append(subtoken_end + original_position_offset)
114
+
115
+ original_position_offset += len(chunk) + 1
116
+
117
+ reconstructed_text += chunk[last_end:].replace("$", "\\$")
118
 
119
  df_tokens = pd.DataFrame(df_data)
120
  return reconstructed_text, df_tokens
121
 
122
+ # UI
123
+ st.title("LinkBERT")
 
124
  st.markdown("""
125
+ LinkBERT predicts natural link placement within web content. Enter text or a URL for extraction. Increase the threshold to reduce link predictions.
126
  """)
127
 
128
+ confidence_threshold = st.slider("Confidence Threshold", 50, 100, 50)
129
 
130
  tab1, tab2 = st.tabs(["Text Input", "URL Input"])
131
 
132
  with tab1:
133
  user_input = st.text_area("Enter text to process:")
134
+ if st.button("Process Text"):
135
  highlighted_text, df_tokens = process_text(user_input, confidence_threshold)
136
  st.markdown(highlighted_text, unsafe_allow_html=True)
137
  st.dataframe(df_tokens)
138
 
139
  with tab2:
140
  url_input = st.text_input("Enter URL to process:")
141
+ if st.button("Fetch and Process"):
142
  content = fetch_and_extract_content(url_input)
143
  if content:
144
  highlighted_text, df_tokens = process_text(content, confidence_threshold)
 
147
  else:
148
  st.error("Could not fetch content from the URL. Please check the URL and try again.")
149
 
 
150
  st.divider()
151
  st.markdown("""
 
152
  ## Applications of LinkBERT
153
+ - **Anchor Text Suggestion**
154
+ - **Evaluation of Existing Links**
155
+ - **Link Placement Guide**
156
+ - **Anchor Text Idea Generator**
157
+ - **Spam and Inorganic SEO Detection**
 
 
 
158
 
159
  ## Training and Performance
 
160
  LinkBERT was fine-tuned on a dataset of organic web content and editorial links.
161
 
162
  [Watch the video](https://www.youtube.com/watch?v=A0ZulyVqjZo)
163
+
164
  # Engage Our Team
165
  Interested in using this in an automated pipeline for bulk link prediction?
166
 
167
+ Please [book an appointment](https://dejanmarketing.com/conference/).
168
+ """)