wi-lab commited on
Commit
713dc9d
·
verified ·
1 Parent(s): 30e95b9

upload side scripts

Browse files
Files changed (3) hide show
  1. utils/beamforming.py +67 -0
  2. utils/pretraining.py +150 -0
  3. utils/res1dcnn.py +88 -0
utils/beamforming.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #%% PACKAGES & MODEULS
2
+ import numpy as np
3
+ import torch
4
+ from input_preprocess import DeepMIMO_data_gen, deepmimo_data_cleaning, tokenizer
5
+ from inference import lwm_inference, create_raw_dataset
6
+ from lwm_model import lwm
7
+
8
+ #%% DEEPMIMO DATA GENERATION
9
+ scenario_names = np.array([
10
+ "city_18_denver", "city_15_indianapolis", "city_19_oklahoma",
11
+ "city_12_fortworth", "city_11_santaclara", "city_7_sandiego"
12
+ ])
13
+
14
+ bf_scenario_idx = 3
15
+ scenario_idxs = np.array([bf_scenario_idx])
16
+ selected_scenario_names = scenario_names[scenario_idxs]
17
+
18
+ deepmimo_data = [DeepMIMO_data_gen(scenario_name) for scenario_name in selected_scenario_names]
19
+ cleaned_deepmimo_data = [deepmimo_data_cleaning(deepmimo_data[scenario_idx]) for scenario_idx in range(len(deepmimo_data))]
20
+
21
+ #%% FUNCTION FOR MRT BEAMFORMING
22
+ def compute_mrt_beamforming(channel_data, snr_db=None):
23
+
24
+ channel_data = torch.tensor(channel_data[0])
25
+ mrt_vectors = []
26
+ snr_linear = 10 ** (snr_db / 10) if snr_db is not None else None
27
+
28
+ for idx in range(channel_data.shape[0]):
29
+ channel = channel_data[idx, 0, :, :] # Shape: (32, 32)
30
+
31
+ if snr_db is not None:
32
+ # Add complex Gaussian noise to the channel
33
+ noise_power = torch.mean(torch.abs(channel) ** 2) / snr_linear
34
+ noise = torch.sqrt(noise_power / 2) * (
35
+ torch.randn_like(channel) + 1j * torch.randn_like(channel)
36
+ )
37
+ channel = channel + noise
38
+
39
+ # Compute MRT beamforming vector for each user
40
+ h_avg = torch.mean(channel, dim=1, keepdim=True) # Shape: (32, 1)
41
+ h_conj = torch.conj(h_avg) # Conjugate of averaged channel vector
42
+ mrt_vector = h_conj / torch.norm(h_conj, dim=0, keepdim=True) # Normalize
43
+
44
+ mrt_vectors.append(mrt_vector)
45
+
46
+ return torch.stack(mrt_vectors, dim=0) # Shape: (N, 32, 1)
47
+
48
+ #%% GENERATE BEAMFORMING VECTORS
49
+ beamforming_vectors = compute_mrt_beamforming(cleaned_deepmimo_data)
50
+
51
+ #%% GENERATE LWM EMBEDDINGS FROM MASKED INPUT CHANNELS
52
+ preprocessed_chs = tokenizer(
53
+ selected_scenario_names=selected_scenario_names,
54
+ manual_data=None,
55
+ gen_raw=False) # gen_raw=False masks 15% of the input patches, and LWM will act as a denoiser
56
+
57
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
58
+ print(f"Loading the LWM model on {device} ...")
59
+ model = lwm.from_pretrained(device=device)
60
+
61
+ input_types = ['cls_emb', 'channel_emb', 'raw']
62
+ selected_input_type = input_types[1]
63
+
64
+ if selected_input_type in ['cls_emb', 'channel_emb']:
65
+ dataset = lwm_inference(preprocessed_chs, selected_input_type, model, device)
66
+ else:
67
+ dataset = create_raw_dataset(preprocessed_chs, device)
utils/pretraining.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #%% PACKAGES & MODULES
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.optim as optim
5
+ from torch.optim.lr_scheduler import StepLR
6
+ from inference import prepare_for_lwm
7
+ from input_preprocess import tokenizer
8
+ from lwm_model import lwm
9
+ import numpy as np
10
+
11
+ #%% PARAMETERS
12
+ n_epochs = 100
13
+ n_layers = 12
14
+ n_heads = 12
15
+ d_model = 64
16
+ d_ff = d_model * 4
17
+ d_k = d_model // n_heads
18
+ d_v = d_model // n_heads
19
+ dropout = 0.1
20
+ max_len = 129
21
+ element_length = 16
22
+ batch_size = 64
23
+ train_ratio = 0.7
24
+ val_ratio = 0.2
25
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
26
+
27
+ #%% PRE-TRAINING DATA GENERATION
28
+ # The following DeepMIMO scenarios are not enough for pre-training a
29
+ # Transformer-based foundation model like LWM. Add more scenarios for
30
+ # more effective pre-training. The instruction for reproducing the actual
31
+ # dataset used for pre-training LWM can be found in the Huggingface forum.
32
+ scenario_names = np.array([
33
+ "city_18_denver", "city_15_indianapolis", "city_19_oklahoma",
34
+ "city_12_fortworth", "city_11_santaclara", "city_7_sandiego"
35
+ ])
36
+
37
+ scenario_idxs = np.array([0, 1, 2, 3, 4, 5])
38
+ selected_scenario_names = scenario_names[scenario_idxs]
39
+
40
+ preprocessed_chs = tokenizer(
41
+ selected_scenario_names=selected_scenario_names,
42
+ manual_data=None,
43
+ gen_raw=False)
44
+
45
+ #%% DATALOADER
46
+ train_size = int(train_ratio * len(preprocessed_chs))
47
+ val_size = int(val_ratio * len(preprocessed_chs))
48
+ test_size = len(preprocessed_chs) - val_size - train_size
49
+
50
+ train_data, val_data, test_data = torch.utils.data.random_split(
51
+ preprocessed_chs, [train_size, val_size, test_size]
52
+ )
53
+
54
+ train_loader = prepare_for_lwm(train_data, device, batch_size=batch_size, shuffle=True)
55
+ val_loader = prepare_for_lwm(val_data, device, batch_size=batch_size, shuffle=True)
56
+ test_loader = prepare_for_lwm(test_data, device, batch_size=batch_size, shuffle=True)
57
+
58
+ # %% Model
59
+ load_model = False
60
+
61
+ model = lwm()
62
+ model.to(device)
63
+
64
+ if load_model:
65
+ model_name = 'models/pretrained_model.pth'
66
+ model.load_state_dict(torch.load(model_name))
67
+ print(f"Model loaded from {model_name}")
68
+
69
+ # Loss function
70
+ criterionMLM = nn.MSELoss()
71
+
72
+ # %% Optimizer and Scheduler
73
+ adaptive_lr = False
74
+
75
+ optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)
76
+ scheduler = (
77
+ optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min')
78
+ if adaptive_lr
79
+ else StepLR(optimizer, step_size=10, gamma=0.9)
80
+ )
81
+
82
+ # %% Training
83
+ training_loss = []
84
+ validation_loss = []
85
+
86
+ def train(model, dataloader, optimizer, scheduler=None, device="cuda"):
87
+
88
+ model.train()
89
+ running_loss = 0.0
90
+ criterionMCM = nn.MSELoss()
91
+
92
+ for idx, batch in enumerate(dataloader):
93
+ input_ids = batch[0].to(device)
94
+ masked_tokens = batch[1].to(device)
95
+ masked_pos = batch[2].to(device)
96
+
97
+ optimizer.zero_grad()
98
+
99
+ logits_lm, _ = model(input_ids, masked_pos)
100
+ loss_lm = criterionMCM(logits_lm, masked_tokens)
101
+ loss = loss_lm / torch.var(masked_tokens)
102
+
103
+ loss.backward()
104
+ optimizer.step()
105
+
106
+ if scheduler is not None:
107
+ scheduler.step()
108
+
109
+ running_loss += loss.item()
110
+
111
+ average_loss = running_loss / len(dataloader)
112
+
113
+ return average_loss
114
+
115
+ def validate(model, dataloader, device="cuda"):
116
+ model.eval()
117
+ running_loss = 0.0
118
+ criterionMCM = nn.MSELoss()
119
+
120
+ with torch.no_grad():
121
+ for idx, batch in enumerate(dataloader):
122
+ input_ids = batch[0].to(device)
123
+ masked_tokens = batch[1].to(device)
124
+ masked_pos = batch[2].to(device)
125
+
126
+ logits_lm, _ = model(input_ids, masked_pos)
127
+
128
+ loss_lm = criterionMCM(logits_lm, masked_tokens)
129
+ loss = loss_lm / torch.var(masked_tokens)
130
+
131
+ running_loss += loss.item()
132
+
133
+ average_loss = running_loss / len(dataloader)
134
+
135
+ return average_loss
136
+
137
+ # %% Training Loop
138
+ for epoch in range(n_epochs):
139
+ print(f"Epoch {epoch + 1}/{n_epochs}")
140
+
141
+ # Training step
142
+ train_loss = train(model, train_loader, optimizer, scheduler, device)
143
+ training_loss.append(train_loss)
144
+ print(f"Training Loss: {train_loss:.4f}")
145
+
146
+ # Validation step
147
+ if val_loader is not None:
148
+ val_loss = validate(model, val_loader, device)
149
+ validation_loss.append(val_loss)
150
+ print(f"Validation Loss: {val_loss:.4f}")
utils/res1dcnn.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.optim import Adam
5
+ from torch.optim.lr_scheduler import MultiStepLR
6
+
7
+ class ResidualBlock(nn.Module):
8
+
9
+ def __init__(self, in_channels, out_channels):
10
+ super(ResidualBlock, self).__init__()
11
+ self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1)
12
+ self.bn1 = nn.BatchNorm1d(out_channels)
13
+ self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=3, padding=1)
14
+ self.bn2 = nn.BatchNorm1d(out_channels)
15
+
16
+ # Shortcut connection to match dimensions when needed
17
+ self.shortcut = nn.Sequential()
18
+ if in_channels != out_channels:
19
+ self.shortcut = nn.Sequential(
20
+ nn.Conv1d(in_channels, out_channels, kernel_size=1),
21
+ nn.BatchNorm1d(out_channels)
22
+ )
23
+
24
+ def forward(self, x):
25
+ residual = x
26
+ x = F.relu(self.bn1(self.conv1(x)))
27
+ x = self.bn2(self.conv2(x))
28
+ x += self.shortcut(residual)
29
+ x = F.relu(x)
30
+ return x
31
+
32
+
33
+ class ResNet1DCNN(nn.Module):
34
+
35
+ def __init__(self, input_channels, sequence_length, num_classes):
36
+ super(ResNet1DCNN, self).__init__()
37
+
38
+ # Initial convolution layer
39
+ self.conv1 = nn.Conv1d(input_channels, 32, kernel_size=7, stride=2, padding=3)
40
+ self.bn1 = nn.BatchNorm1d(32)
41
+ self.maxpool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
42
+
43
+ # Residual layers
44
+ self.layer1 = self._make_layer(32, 32, 2)
45
+ self.layer2 = self._make_layer(32, 64, 3)
46
+ self.layer3 = self._make_layer(64, 128, 4)
47
+
48
+ # Calculate the size of the flattened features
49
+ with torch.no_grad():
50
+ dummy_input = torch.zeros(1, input_channels, sequence_length)
51
+ dummy_output = self.compute_conv_output(dummy_input)
52
+ self.flatten_size = dummy_output.numel()
53
+
54
+ # Fully connected layers
55
+ self.fc1 = nn.Linear(self.flatten_size, 128)
56
+ self.bn_fc1 = nn.BatchNorm1d(128)
57
+ self.fc2 = nn.Linear(128, num_classes)
58
+
59
+ self.dropout = nn.Dropout(0.5)
60
+
61
+ def _make_layer(self, in_channels, out_channels, num_blocks):
62
+
63
+ layers = [ResidualBlock(in_channels, out_channels)]
64
+ for _ in range(1, num_blocks):
65
+ layers.append(ResidualBlock(out_channels, out_channels))
66
+ return nn.Sequential(*layers)
67
+
68
+ def compute_conv_output(self, x):
69
+
70
+ x = self.maxpool(F.relu(self.bn1(self.conv1(x))))
71
+ x = self.layer1(x)
72
+ x = self.layer2(x)
73
+ x = self.layer3(x)
74
+ x = F.adaptive_avg_pool1d(x, 1)
75
+ return x
76
+
77
+ def forward(self, x):
78
+
79
+ x = x.transpose(1, 2)
80
+
81
+ x = self.compute_conv_output(x)
82
+
83
+ x = x.view(x.size(0), -1)
84
+ x = F.relu(self.bn_fc1(self.fc1(x)))
85
+ x = self.dropout(x)
86
+ x = self.fc2(x)
87
+
88
+ return x