Update src/streamlit_app.py
Browse files- src/streamlit_app.py +98 -97
src/streamlit_app.py
CHANGED
@@ -1,25 +1,39 @@
|
|
1 |
import streamlit as st
|
2 |
import torch
|
3 |
import torch.nn.functional as F
|
4 |
-
from torch.nn.functional import softmax
|
5 |
from transformers import AutoTokenizer, AutoModelForTokenClassification
|
6 |
import pandas as pd
|
7 |
import trafilatura
|
8 |
|
9 |
-
#
|
10 |
st.set_page_config(layout="wide", page_title="LinkBERT")
|
11 |
|
12 |
-
# Load
|
13 |
-
|
14 |
-
|
15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
model.eval()
|
17 |
|
18 |
# Functions
|
19 |
-
|
20 |
def tokenize_with_indices(text: str):
|
21 |
-
encoded = tokenizer.encode_plus(
|
22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
def fetch_and_extract_content(url: str):
|
25 |
downloaded = trafilatura.fetch_url(url)
|
@@ -29,109 +43,102 @@ def fetch_and_extract_content(url: str):
|
|
29 |
return None
|
30 |
|
31 |
def process_text(inputs: str, confidence_threshold: float):
|
32 |
-
max_chunk_length = 512 - 2
|
33 |
words = inputs.split()
|
34 |
chunk_texts = []
|
35 |
-
current_chunk = []
|
36 |
-
current_length = 0
|
37 |
for word in words:
|
38 |
-
|
|
|
39 |
chunk_texts.append(" ".join(current_chunk))
|
40 |
current_chunk = [word]
|
41 |
-
current_length =
|
42 |
-
|
43 |
else:
|
44 |
current_chunk.append(word)
|
45 |
-
current_length +=
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
'Prediction': [],
|
51 |
-
'Confidence': [],
|
52 |
-
'Start': [],
|
53 |
-
'End': []
|
54 |
-
}
|
55 |
reconstructed_text = ""
|
56 |
original_position_offset = 0
|
57 |
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
outputs = model(input_ids_tensor)
|
63 |
-
logits = outputs.logits
|
64 |
-
predictions = torch.argmax(logits, dim=-1).squeeze().tolist()
|
65 |
-
softmax_scores = F.softmax(logits, dim=-1).squeeze().tolist()
|
66 |
-
|
67 |
-
word_info = {}
|
68 |
-
|
69 |
-
for idx, (start, end) in enumerate(token_offsets):
|
70 |
-
if idx == 0 or idx == len(token_offsets) - 1:
|
71 |
-
continue
|
72 |
-
|
73 |
-
word_start = start
|
74 |
-
while word_start > 0 and chunk[word_start-1] != ' ':
|
75 |
-
word_start -= 1
|
76 |
-
|
77 |
-
if word_start not in word_info:
|
78 |
-
word_info[word_start] = {'prediction': 0, 'confidence': 0.0, 'subtokens': []}
|
79 |
-
|
80 |
-
confidence_percentage = softmax_scores[idx][predictions[idx]] * 100
|
81 |
-
|
82 |
-
if predictions[idx] == 1 and confidence_percentage >= confidence_threshold:
|
83 |
-
word_info[word_start]['prediction'] = 1
|
84 |
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
for
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
|
111 |
df_tokens = pd.DataFrame(df_data)
|
112 |
return reconstructed_text, df_tokens
|
113 |
|
114 |
-
#
|
115 |
-
|
116 |
-
st.title('LinkBERT')
|
117 |
st.markdown("""
|
118 |
-
LinkBERT
|
119 |
""")
|
120 |
|
121 |
-
confidence_threshold = st.slider(
|
122 |
|
123 |
tab1, tab2 = st.tabs(["Text Input", "URL Input"])
|
124 |
|
125 |
with tab1:
|
126 |
user_input = st.text_area("Enter text to process:")
|
127 |
-
if st.button(
|
128 |
highlighted_text, df_tokens = process_text(user_input, confidence_threshold)
|
129 |
st.markdown(highlighted_text, unsafe_allow_html=True)
|
130 |
st.dataframe(df_tokens)
|
131 |
|
132 |
with tab2:
|
133 |
url_input = st.text_input("Enter URL to process:")
|
134 |
-
if st.button(
|
135 |
content = fetch_and_extract_content(url_input)
|
136 |
if content:
|
137 |
highlighted_text, df_tokens = process_text(content, confidence_threshold)
|
@@ -140,28 +147,22 @@ with tab2:
|
|
140 |
else:
|
141 |
st.error("Could not fetch content from the URL. Please check the URL and try again.")
|
142 |
|
143 |
-
# Additional information at the end
|
144 |
st.divider()
|
145 |
st.markdown("""
|
146 |
-
|
147 |
## Applications of LinkBERT
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
- **Anchor Text
|
152 |
-
- **
|
153 |
-
- **Link Placement Guide:** Offers guidance to link builders by suggesting optimal placement for links within content.
|
154 |
-
- **Anchor Text Idea Generator:** Provides creative anchor text suggestions to enrich content and improve SEO strategies.
|
155 |
-
- **Spam and Inorganic SEO Detection:** Helps identify unnatural link patterns, contributing to the detection of spam and inorganic SEO tactics.
|
156 |
|
157 |
## Training and Performance
|
158 |
-
|
159 |
LinkBERT was fine-tuned on a dataset of organic web content and editorial links.
|
160 |
|
161 |
[Watch the video](https://www.youtube.com/watch?v=A0ZulyVqjZo)
|
162 |
-
|
163 |
# Engage Our Team
|
164 |
Interested in using this in an automated pipeline for bulk link prediction?
|
165 |
|
166 |
-
Please [book an appointment](https://dejanmarketing.com/conference/)
|
167 |
-
""")
|
|
|
1 |
import streamlit as st
|
2 |
import torch
|
3 |
import torch.nn.functional as F
|
|
|
4 |
from transformers import AutoTokenizer, AutoModelForTokenClassification
|
5 |
import pandas as pd
|
6 |
import trafilatura
|
7 |
|
8 |
+
# Streamlit config
|
9 |
st.set_page_config(layout="wide", page_title="LinkBERT")
|
10 |
|
11 |
+
# Load tokenizer & model (avoid meta-tensor .to() issue)
|
12 |
+
MODEL_ID = "dejanseo/LinkBERT-XL"
|
13 |
+
|
14 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
|
15 |
+
|
16 |
+
load_kwargs = {}
|
17 |
+
if torch.cuda.is_available():
|
18 |
+
# Load directly onto GPU(s); do NOT call .to(...) afterward
|
19 |
+
load_kwargs.update(dict(device_map="auto", torch_dtype=torch.float16))
|
20 |
+
else:
|
21 |
+
# CPU load without meta tensors
|
22 |
+
load_kwargs.update(dict(device_map=None))
|
23 |
+
|
24 |
+
model = AutoModelForTokenClassification.from_pretrained(MODEL_ID, **load_kwargs)
|
25 |
model.eval()
|
26 |
|
27 |
# Functions
|
|
|
28 |
def tokenize_with_indices(text: str):
|
29 |
+
encoded = tokenizer.encode_plus(
|
30 |
+
text,
|
31 |
+
return_offsets_mapping=True,
|
32 |
+
add_special_tokens=True,
|
33 |
+
truncation=True,
|
34 |
+
max_length=512
|
35 |
+
)
|
36 |
+
return encoded["input_ids"], encoded["offset_mapping"]
|
37 |
|
38 |
def fetch_and_extract_content(url: str):
|
39 |
downloaded = trafilatura.fetch_url(url)
|
|
|
43 |
return None
|
44 |
|
45 |
def process_text(inputs: str, confidence_threshold: float):
|
46 |
+
max_chunk_length = 512 - 2 # safe window for special tokens
|
47 |
words = inputs.split()
|
48 |
chunk_texts = []
|
49 |
+
current_chunk, current_length = [], 0
|
|
|
50 |
for word in words:
|
51 |
+
tok_len = len(tokenizer.tokenize(word))
|
52 |
+
if tok_len + current_length > max_chunk_length:
|
53 |
chunk_texts.append(" ".join(current_chunk))
|
54 |
current_chunk = [word]
|
55 |
+
current_length = tok_len
|
|
|
56 |
else:
|
57 |
current_chunk.append(word)
|
58 |
+
current_length += tok_len
|
59 |
+
if current_chunk:
|
60 |
+
chunk_texts.append(" ".join(current_chunk))
|
61 |
+
|
62 |
+
df_data = {"Word": [], "Prediction": [], "Confidence": [], "Start": [], "End": []}
|
|
|
|
|
|
|
|
|
|
|
63 |
reconstructed_text = ""
|
64 |
original_position_offset = 0
|
65 |
|
66 |
+
with torch.no_grad():
|
67 |
+
for chunk in chunk_texts:
|
68 |
+
input_ids, token_offsets = tokenize_with_indices(chunk)
|
69 |
+
input_ids_tensor = torch.tensor(input_ids).unsqueeze(0).to(model.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
|
71 |
+
outputs = model(input_ids_tensor)
|
72 |
+
logits = outputs.logits # [1, seq_len, num_labels]
|
73 |
+
predictions = torch.argmax(logits, dim=-1).squeeze(0).tolist()
|
74 |
+
softmax_scores = F.softmax(logits, dim=-1).squeeze(0).tolist()
|
75 |
+
|
76 |
+
word_info = {}
|
77 |
+
for idx, (start, end) in enumerate(token_offsets):
|
78 |
+
if idx == 0 or idx == len(token_offsets) - 1:
|
79 |
+
continue # skip specials
|
80 |
+
|
81 |
+
word_start = start
|
82 |
+
while word_start > 0 and chunk[word_start - 1] != ' ':
|
83 |
+
word_start -= 1
|
84 |
+
|
85 |
+
if word_start not in word_info:
|
86 |
+
word_info[word_start] = {"prediction": 0, "confidence": 0.0, "subtokens": []}
|
87 |
+
|
88 |
+
conf_pct = softmax_scores[idx][predictions[idx]] * 100.0
|
89 |
+
if predictions[idx] == 1 and conf_pct >= confidence_threshold:
|
90 |
+
word_info[word_start]["prediction"] = 1
|
91 |
+
word_info[word_start]["confidence"] = max(word_info[word_start]["confidence"], conf_pct)
|
92 |
+
word_info[word_start]["subtokens"].append((start, end, chunk[start:end]))
|
93 |
+
|
94 |
+
last_end = 0
|
95 |
+
for word_start in sorted(word_info.keys()):
|
96 |
+
word_data = word_info[word_start]
|
97 |
+
for subtoken_start, subtoken_end, subtoken_text in word_data["subtokens"]:
|
98 |
+
escaped = subtoken_text.replace("$", "\\$")
|
99 |
+
if last_end < subtoken_start:
|
100 |
+
reconstructed_text += chunk[last_end:subtoken_start]
|
101 |
+
if word_data["prediction"] == 1:
|
102 |
+
reconstructed_text += (
|
103 |
+
f"<span style='background-color: rgba(0, 255, 0); display: inline;'>{escaped}</span>"
|
104 |
+
)
|
105 |
+
else:
|
106 |
+
reconstructed_text += escaped
|
107 |
+
last_end = subtoken_end
|
108 |
+
|
109 |
+
df_data["Word"].append(escaped)
|
110 |
+
df_data["Prediction"].append(word_data["prediction"])
|
111 |
+
df_data["Confidence"].append(word_info[word_start]["confidence"])
|
112 |
+
df_data["Start"].append(subtoken_start + original_position_offset)
|
113 |
+
df_data["End"].append(subtoken_end + original_position_offset)
|
114 |
+
|
115 |
+
original_position_offset += len(chunk) + 1
|
116 |
+
|
117 |
+
reconstructed_text += chunk[last_end:].replace("$", "\\$")
|
118 |
|
119 |
df_tokens = pd.DataFrame(df_data)
|
120 |
return reconstructed_text, df_tokens
|
121 |
|
122 |
+
# UI
|
123 |
+
st.title("LinkBERT")
|
|
|
124 |
st.markdown("""
|
125 |
+
LinkBERT predicts natural link placement within web content. Enter text or a URL for extraction. Increase the threshold to reduce link predictions.
|
126 |
""")
|
127 |
|
128 |
+
confidence_threshold = st.slider("Confidence Threshold", 50, 100, 50)
|
129 |
|
130 |
tab1, tab2 = st.tabs(["Text Input", "URL Input"])
|
131 |
|
132 |
with tab1:
|
133 |
user_input = st.text_area("Enter text to process:")
|
134 |
+
if st.button("Process Text"):
|
135 |
highlighted_text, df_tokens = process_text(user_input, confidence_threshold)
|
136 |
st.markdown(highlighted_text, unsafe_allow_html=True)
|
137 |
st.dataframe(df_tokens)
|
138 |
|
139 |
with tab2:
|
140 |
url_input = st.text_input("Enter URL to process:")
|
141 |
+
if st.button("Fetch and Process"):
|
142 |
content = fetch_and_extract_content(url_input)
|
143 |
if content:
|
144 |
highlighted_text, df_tokens = process_text(content, confidence_threshold)
|
|
|
147 |
else:
|
148 |
st.error("Could not fetch content from the URL. Please check the URL and try again.")
|
149 |
|
|
|
150 |
st.divider()
|
151 |
st.markdown("""
|
|
|
152 |
## Applications of LinkBERT
|
153 |
+
- **Anchor Text Suggestion**
|
154 |
+
- **Evaluation of Existing Links**
|
155 |
+
- **Link Placement Guide**
|
156 |
+
- **Anchor Text Idea Generator**
|
157 |
+
- **Spam and Inorganic SEO Detection**
|
|
|
|
|
|
|
158 |
|
159 |
## Training and Performance
|
|
|
160 |
LinkBERT was fine-tuned on a dataset of organic web content and editorial links.
|
161 |
|
162 |
[Watch the video](https://www.youtube.com/watch?v=A0ZulyVqjZo)
|
163 |
+
|
164 |
# Engage Our Team
|
165 |
Interested in using this in an automated pipeline for bulk link prediction?
|
166 |
|
167 |
+
Please [book an appointment](https://dejanmarketing.com/conference/).
|
168 |
+
""")
|