upload side scripts
Browse files- utils/beamforming.py +67 -0
- utils/pretraining.py +150 -0
- utils/res1dcnn.py +88 -0
utils/beamforming.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#%% PACKAGES & MODEULS
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from input_preprocess import DeepMIMO_data_gen, deepmimo_data_cleaning, tokenizer
|
5 |
+
from inference import lwm_inference, create_raw_dataset
|
6 |
+
from lwm_model import lwm
|
7 |
+
|
8 |
+
#%% DEEPMIMO DATA GENERATION
|
9 |
+
scenario_names = np.array([
|
10 |
+
"city_18_denver", "city_15_indianapolis", "city_19_oklahoma",
|
11 |
+
"city_12_fortworth", "city_11_santaclara", "city_7_sandiego"
|
12 |
+
])
|
13 |
+
|
14 |
+
bf_scenario_idx = 3
|
15 |
+
scenario_idxs = np.array([bf_scenario_idx])
|
16 |
+
selected_scenario_names = scenario_names[scenario_idxs]
|
17 |
+
|
18 |
+
deepmimo_data = [DeepMIMO_data_gen(scenario_name) for scenario_name in selected_scenario_names]
|
19 |
+
cleaned_deepmimo_data = [deepmimo_data_cleaning(deepmimo_data[scenario_idx]) for scenario_idx in range(len(deepmimo_data))]
|
20 |
+
|
21 |
+
#%% FUNCTION FOR MRT BEAMFORMING
|
22 |
+
def compute_mrt_beamforming(channel_data, snr_db=None):
|
23 |
+
|
24 |
+
channel_data = torch.tensor(channel_data[0])
|
25 |
+
mrt_vectors = []
|
26 |
+
snr_linear = 10 ** (snr_db / 10) if snr_db is not None else None
|
27 |
+
|
28 |
+
for idx in range(channel_data.shape[0]):
|
29 |
+
channel = channel_data[idx, 0, :, :] # Shape: (32, 32)
|
30 |
+
|
31 |
+
if snr_db is not None:
|
32 |
+
# Add complex Gaussian noise to the channel
|
33 |
+
noise_power = torch.mean(torch.abs(channel) ** 2) / snr_linear
|
34 |
+
noise = torch.sqrt(noise_power / 2) * (
|
35 |
+
torch.randn_like(channel) + 1j * torch.randn_like(channel)
|
36 |
+
)
|
37 |
+
channel = channel + noise
|
38 |
+
|
39 |
+
# Compute MRT beamforming vector for each user
|
40 |
+
h_avg = torch.mean(channel, dim=1, keepdim=True) # Shape: (32, 1)
|
41 |
+
h_conj = torch.conj(h_avg) # Conjugate of averaged channel vector
|
42 |
+
mrt_vector = h_conj / torch.norm(h_conj, dim=0, keepdim=True) # Normalize
|
43 |
+
|
44 |
+
mrt_vectors.append(mrt_vector)
|
45 |
+
|
46 |
+
return torch.stack(mrt_vectors, dim=0) # Shape: (N, 32, 1)
|
47 |
+
|
48 |
+
#%% GENERATE BEAMFORMING VECTORS
|
49 |
+
beamforming_vectors = compute_mrt_beamforming(cleaned_deepmimo_data)
|
50 |
+
|
51 |
+
#%% GENERATE LWM EMBEDDINGS FROM MASKED INPUT CHANNELS
|
52 |
+
preprocessed_chs = tokenizer(
|
53 |
+
selected_scenario_names=selected_scenario_names,
|
54 |
+
manual_data=None,
|
55 |
+
gen_raw=False) # gen_raw=False masks 15% of the input patches, and LWM will act as a denoiser
|
56 |
+
|
57 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
58 |
+
print(f"Loading the LWM model on {device} ...")
|
59 |
+
model = lwm.from_pretrained(device=device)
|
60 |
+
|
61 |
+
input_types = ['cls_emb', 'channel_emb', 'raw']
|
62 |
+
selected_input_type = input_types[1]
|
63 |
+
|
64 |
+
if selected_input_type in ['cls_emb', 'channel_emb']:
|
65 |
+
dataset = lwm_inference(preprocessed_chs, selected_input_type, model, device)
|
66 |
+
else:
|
67 |
+
dataset = create_raw_dataset(preprocessed_chs, device)
|
utils/pretraining.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#%% PACKAGES & MODULES
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.optim as optim
|
5 |
+
from torch.optim.lr_scheduler import StepLR
|
6 |
+
from inference import prepare_for_lwm
|
7 |
+
from input_preprocess import tokenizer
|
8 |
+
from lwm_model import lwm
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
#%% PARAMETERS
|
12 |
+
n_epochs = 100
|
13 |
+
n_layers = 12
|
14 |
+
n_heads = 12
|
15 |
+
d_model = 64
|
16 |
+
d_ff = d_model * 4
|
17 |
+
d_k = d_model // n_heads
|
18 |
+
d_v = d_model // n_heads
|
19 |
+
dropout = 0.1
|
20 |
+
max_len = 129
|
21 |
+
element_length = 16
|
22 |
+
batch_size = 64
|
23 |
+
train_ratio = 0.7
|
24 |
+
val_ratio = 0.2
|
25 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
26 |
+
|
27 |
+
#%% PRE-TRAINING DATA GENERATION
|
28 |
+
# The following DeepMIMO scenarios are not enough for pre-training a
|
29 |
+
# Transformer-based foundation model like LWM. Add more scenarios for
|
30 |
+
# more effective pre-training. The instruction for reproducing the actual
|
31 |
+
# dataset used for pre-training LWM can be found in the Huggingface forum.
|
32 |
+
scenario_names = np.array([
|
33 |
+
"city_18_denver", "city_15_indianapolis", "city_19_oklahoma",
|
34 |
+
"city_12_fortworth", "city_11_santaclara", "city_7_sandiego"
|
35 |
+
])
|
36 |
+
|
37 |
+
scenario_idxs = np.array([0, 1, 2, 3, 4, 5])
|
38 |
+
selected_scenario_names = scenario_names[scenario_idxs]
|
39 |
+
|
40 |
+
preprocessed_chs = tokenizer(
|
41 |
+
selected_scenario_names=selected_scenario_names,
|
42 |
+
manual_data=None,
|
43 |
+
gen_raw=False)
|
44 |
+
|
45 |
+
#%% DATALOADER
|
46 |
+
train_size = int(train_ratio * len(preprocessed_chs))
|
47 |
+
val_size = int(val_ratio * len(preprocessed_chs))
|
48 |
+
test_size = len(preprocessed_chs) - val_size - train_size
|
49 |
+
|
50 |
+
train_data, val_data, test_data = torch.utils.data.random_split(
|
51 |
+
preprocessed_chs, [train_size, val_size, test_size]
|
52 |
+
)
|
53 |
+
|
54 |
+
train_loader = prepare_for_lwm(train_data, device, batch_size=batch_size, shuffle=True)
|
55 |
+
val_loader = prepare_for_lwm(val_data, device, batch_size=batch_size, shuffle=True)
|
56 |
+
test_loader = prepare_for_lwm(test_data, device, batch_size=batch_size, shuffle=True)
|
57 |
+
|
58 |
+
# %% Model
|
59 |
+
load_model = False
|
60 |
+
|
61 |
+
model = lwm()
|
62 |
+
model.to(device)
|
63 |
+
|
64 |
+
if load_model:
|
65 |
+
model_name = 'models/pretrained_model.pth'
|
66 |
+
model.load_state_dict(torch.load(model_name))
|
67 |
+
print(f"Model loaded from {model_name}")
|
68 |
+
|
69 |
+
# Loss function
|
70 |
+
criterionMLM = nn.MSELoss()
|
71 |
+
|
72 |
+
# %% Optimizer and Scheduler
|
73 |
+
adaptive_lr = False
|
74 |
+
|
75 |
+
optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)
|
76 |
+
scheduler = (
|
77 |
+
optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min')
|
78 |
+
if adaptive_lr
|
79 |
+
else StepLR(optimizer, step_size=10, gamma=0.9)
|
80 |
+
)
|
81 |
+
|
82 |
+
# %% Training
|
83 |
+
training_loss = []
|
84 |
+
validation_loss = []
|
85 |
+
|
86 |
+
def train(model, dataloader, optimizer, scheduler=None, device="cuda"):
|
87 |
+
|
88 |
+
model.train()
|
89 |
+
running_loss = 0.0
|
90 |
+
criterionMCM = nn.MSELoss()
|
91 |
+
|
92 |
+
for idx, batch in enumerate(dataloader):
|
93 |
+
input_ids = batch[0].to(device)
|
94 |
+
masked_tokens = batch[1].to(device)
|
95 |
+
masked_pos = batch[2].to(device)
|
96 |
+
|
97 |
+
optimizer.zero_grad()
|
98 |
+
|
99 |
+
logits_lm, _ = model(input_ids, masked_pos)
|
100 |
+
loss_lm = criterionMCM(logits_lm, masked_tokens)
|
101 |
+
loss = loss_lm / torch.var(masked_tokens)
|
102 |
+
|
103 |
+
loss.backward()
|
104 |
+
optimizer.step()
|
105 |
+
|
106 |
+
if scheduler is not None:
|
107 |
+
scheduler.step()
|
108 |
+
|
109 |
+
running_loss += loss.item()
|
110 |
+
|
111 |
+
average_loss = running_loss / len(dataloader)
|
112 |
+
|
113 |
+
return average_loss
|
114 |
+
|
115 |
+
def validate(model, dataloader, device="cuda"):
|
116 |
+
model.eval()
|
117 |
+
running_loss = 0.0
|
118 |
+
criterionMCM = nn.MSELoss()
|
119 |
+
|
120 |
+
with torch.no_grad():
|
121 |
+
for idx, batch in enumerate(dataloader):
|
122 |
+
input_ids = batch[0].to(device)
|
123 |
+
masked_tokens = batch[1].to(device)
|
124 |
+
masked_pos = batch[2].to(device)
|
125 |
+
|
126 |
+
logits_lm, _ = model(input_ids, masked_pos)
|
127 |
+
|
128 |
+
loss_lm = criterionMCM(logits_lm, masked_tokens)
|
129 |
+
loss = loss_lm / torch.var(masked_tokens)
|
130 |
+
|
131 |
+
running_loss += loss.item()
|
132 |
+
|
133 |
+
average_loss = running_loss / len(dataloader)
|
134 |
+
|
135 |
+
return average_loss
|
136 |
+
|
137 |
+
# %% Training Loop
|
138 |
+
for epoch in range(n_epochs):
|
139 |
+
print(f"Epoch {epoch + 1}/{n_epochs}")
|
140 |
+
|
141 |
+
# Training step
|
142 |
+
train_loss = train(model, train_loader, optimizer, scheduler, device)
|
143 |
+
training_loss.append(train_loss)
|
144 |
+
print(f"Training Loss: {train_loss:.4f}")
|
145 |
+
|
146 |
+
# Validation step
|
147 |
+
if val_loader is not None:
|
148 |
+
val_loss = validate(model, val_loader, device)
|
149 |
+
validation_loss.append(val_loss)
|
150 |
+
print(f"Validation Loss: {val_loss:.4f}")
|
utils/res1dcnn.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch.optim import Adam
|
5 |
+
from torch.optim.lr_scheduler import MultiStepLR
|
6 |
+
|
7 |
+
class ResidualBlock(nn.Module):
|
8 |
+
|
9 |
+
def __init__(self, in_channels, out_channels):
|
10 |
+
super(ResidualBlock, self).__init__()
|
11 |
+
self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1)
|
12 |
+
self.bn1 = nn.BatchNorm1d(out_channels)
|
13 |
+
self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=3, padding=1)
|
14 |
+
self.bn2 = nn.BatchNorm1d(out_channels)
|
15 |
+
|
16 |
+
# Shortcut connection to match dimensions when needed
|
17 |
+
self.shortcut = nn.Sequential()
|
18 |
+
if in_channels != out_channels:
|
19 |
+
self.shortcut = nn.Sequential(
|
20 |
+
nn.Conv1d(in_channels, out_channels, kernel_size=1),
|
21 |
+
nn.BatchNorm1d(out_channels)
|
22 |
+
)
|
23 |
+
|
24 |
+
def forward(self, x):
|
25 |
+
residual = x
|
26 |
+
x = F.relu(self.bn1(self.conv1(x)))
|
27 |
+
x = self.bn2(self.conv2(x))
|
28 |
+
x += self.shortcut(residual)
|
29 |
+
x = F.relu(x)
|
30 |
+
return x
|
31 |
+
|
32 |
+
|
33 |
+
class ResNet1DCNN(nn.Module):
|
34 |
+
|
35 |
+
def __init__(self, input_channels, sequence_length, num_classes):
|
36 |
+
super(ResNet1DCNN, self).__init__()
|
37 |
+
|
38 |
+
# Initial convolution layer
|
39 |
+
self.conv1 = nn.Conv1d(input_channels, 32, kernel_size=7, stride=2, padding=3)
|
40 |
+
self.bn1 = nn.BatchNorm1d(32)
|
41 |
+
self.maxpool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
|
42 |
+
|
43 |
+
# Residual layers
|
44 |
+
self.layer1 = self._make_layer(32, 32, 2)
|
45 |
+
self.layer2 = self._make_layer(32, 64, 3)
|
46 |
+
self.layer3 = self._make_layer(64, 128, 4)
|
47 |
+
|
48 |
+
# Calculate the size of the flattened features
|
49 |
+
with torch.no_grad():
|
50 |
+
dummy_input = torch.zeros(1, input_channels, sequence_length)
|
51 |
+
dummy_output = self.compute_conv_output(dummy_input)
|
52 |
+
self.flatten_size = dummy_output.numel()
|
53 |
+
|
54 |
+
# Fully connected layers
|
55 |
+
self.fc1 = nn.Linear(self.flatten_size, 128)
|
56 |
+
self.bn_fc1 = nn.BatchNorm1d(128)
|
57 |
+
self.fc2 = nn.Linear(128, num_classes)
|
58 |
+
|
59 |
+
self.dropout = nn.Dropout(0.5)
|
60 |
+
|
61 |
+
def _make_layer(self, in_channels, out_channels, num_blocks):
|
62 |
+
|
63 |
+
layers = [ResidualBlock(in_channels, out_channels)]
|
64 |
+
for _ in range(1, num_blocks):
|
65 |
+
layers.append(ResidualBlock(out_channels, out_channels))
|
66 |
+
return nn.Sequential(*layers)
|
67 |
+
|
68 |
+
def compute_conv_output(self, x):
|
69 |
+
|
70 |
+
x = self.maxpool(F.relu(self.bn1(self.conv1(x))))
|
71 |
+
x = self.layer1(x)
|
72 |
+
x = self.layer2(x)
|
73 |
+
x = self.layer3(x)
|
74 |
+
x = F.adaptive_avg_pool1d(x, 1)
|
75 |
+
return x
|
76 |
+
|
77 |
+
def forward(self, x):
|
78 |
+
|
79 |
+
x = x.transpose(1, 2)
|
80 |
+
|
81 |
+
x = self.compute_conv_output(x)
|
82 |
+
|
83 |
+
x = x.view(x.size(0), -1)
|
84 |
+
x = F.relu(self.bn_fc1(self.fc1(x)))
|
85 |
+
x = self.dropout(x)
|
86 |
+
x = self.fc2(x)
|
87 |
+
|
88 |
+
return x
|