wi-lab commited on
Commit
f9384cb
·
verified ·
1 Parent(s): 80d1f41

Upload 4 files

Browse files
Files changed (4) hide show
  1. inference.py +84 -84
  2. input_preprocess.py +373 -374
  3. lwm_model.py +134 -134
  4. utils.py +323 -0
inference.py CHANGED
@@ -1,85 +1,85 @@
1
- # -*- coding: utf-8 -*-
2
- """
3
- Created on Sun Sep 15 18:27:17 2024
4
-
5
- @author: salikha4
6
- """
7
-
8
- import os
9
- import csv
10
- import json
11
- import shutil
12
- import random
13
- import argparse
14
- from datetime import datetime
15
- import pandas as pd
16
- import time
17
- import torch
18
- import torch.nn as nn
19
- import torch.nn.functional as F
20
- from torch.utils.data import Dataset, DataLoader, TensorDataset
21
- from torch.optim import Adam
22
- import numpy as np
23
- import warnings
24
- warnings.filterwarnings('ignore')
25
-
26
- def lwm_inference(preprocessed_chs, input_type, lwm_model, device):
27
-
28
- dataset = prepare_for_lwm(preprocessed_chs, device)
29
- # Process data through LWM
30
- lwm_loss, embedding_data = evaluate(lwm_model, dataset)
31
- print(f'LWM loss: {lwm_loss:.4f}')
32
-
33
- if input_type == 'cls_emb':
34
- embedding_data = embedding_data[:, 0]
35
- elif input_type == 'channel_emb':
36
- embedding_data = embedding_data[:, 1:]
37
-
38
- dataset = embedding_data.float()
39
- return dataset
40
-
41
- def prepare_for_lwm(data, device, batch_size=64, shuffle=False):
42
-
43
- input_ids, masked_tokens, masked_pos = zip(*data)
44
-
45
- input_ids_tensor = torch.tensor(input_ids, device=device).float()
46
- masked_tokens_tensor = torch.tensor(masked_tokens, device=device).float()
47
- masked_pos_tensor = torch.tensor(masked_pos, device=device).long()
48
-
49
- dataset = TensorDataset(input_ids_tensor, masked_tokens_tensor, masked_pos_tensor)
50
-
51
- return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
52
-
53
- def evaluate(model, dataloader):
54
-
55
- model.eval()
56
- running_loss = 0.0
57
- outputs = []
58
- criterionMCM = nn.MSELoss()
59
-
60
- with torch.no_grad():
61
- for idx, batch in enumerate(dataloader):
62
- input_ids = batch[0]
63
- masked_tokens = batch[1]
64
- masked_pos = batch[2]
65
-
66
- logits_lm, output = model(input_ids, masked_pos)
67
-
68
- output_batch_preproc = output
69
- outputs.append(output_batch_preproc)
70
-
71
- loss_lm = criterionMCM(logits_lm, masked_tokens)
72
- loss = loss_lm / torch.var(masked_tokens)
73
- running_loss += loss.item()
74
-
75
- average_loss = running_loss / len(dataloader)
76
- output_total = torch.cat(outputs, dim=0)
77
-
78
- return average_loss, output_total
79
-
80
- def create_raw_dataset(data, device):
81
- """Create a dataset for raw channel data."""
82
- input_ids, _, _ = zip(*data)
83
- input_data = torch.tensor(input_ids, device=device)[:, 1:]
84
- return input_data.float()
85
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Sun Sep 15 18:27:17 2024
4
+
5
+ @author: salikha4
6
+ """
7
+
8
+ import os
9
+ import csv
10
+ import json
11
+ import shutil
12
+ import random
13
+ import argparse
14
+ from datetime import datetime
15
+ import pandas as pd
16
+ import time
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+ from torch.utils.data import Dataset, DataLoader, TensorDataset
21
+ from torch.optim import Adam
22
+ import numpy as np
23
+ import warnings
24
+ warnings.filterwarnings('ignore')
25
+
26
+ def lwm_inference(preprocessed_chs, input_type, lwm_model, device):
27
+
28
+ dataset = prepare_for_lwm(preprocessed_chs, device)
29
+ # Process data through LWM
30
+ lwm_loss, embedding_data = evaluate(lwm_model, dataset)
31
+ # print(f'LWM loss: {lwm_loss:.4f}')
32
+
33
+ if input_type == 'cls_emb':
34
+ embedding_data = embedding_data[:, 0]
35
+ elif input_type == 'channel_emb':
36
+ embedding_data = embedding_data[:, 1:]
37
+
38
+ dataset = embedding_data.float()
39
+ return dataset
40
+
41
+ def prepare_for_lwm(data, device, batch_size=64, shuffle=False):
42
+
43
+ input_ids, masked_tokens, masked_pos = zip(*data)
44
+
45
+ input_ids_tensor = torch.tensor(input_ids, device=device).float()
46
+ masked_tokens_tensor = torch.tensor(masked_tokens, device=device).float()
47
+ masked_pos_tensor = torch.tensor(masked_pos, device=device).long()
48
+
49
+ dataset = TensorDataset(input_ids_tensor, masked_tokens_tensor, masked_pos_tensor)
50
+
51
+ return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
52
+
53
+ def evaluate(model, dataloader):
54
+
55
+ model.eval()
56
+ running_loss = 0.0
57
+ outputs = []
58
+ criterionMCM = nn.MSELoss()
59
+
60
+ with torch.no_grad():
61
+ for idx, batch in enumerate(dataloader):
62
+ input_ids = batch[0]
63
+ masked_tokens = batch[1]
64
+ masked_pos = batch[2]
65
+
66
+ logits_lm, output = model(input_ids, masked_pos)
67
+
68
+ output_batch_preproc = output
69
+ outputs.append(output_batch_preproc)
70
+
71
+ loss_lm = criterionMCM(logits_lm, masked_tokens)
72
+ loss = loss_lm / torch.var(masked_tokens)
73
+ running_loss += loss.item()
74
+
75
+ average_loss = running_loss / len(dataloader)
76
+ output_total = torch.cat(outputs, dim=0)
77
+
78
+ return average_loss, output_total
79
+
80
+ def create_raw_dataset(data, device):
81
+ """Create a dataset for raw channel data."""
82
+ input_ids, _, _ = zip(*data)
83
+ input_data = torch.tensor(input_ids, device=device)[:, 1:]
84
+ return input_data.float()
85
 
input_preprocess.py CHANGED
@@ -1,374 +1,373 @@
1
- # -*- coding: utf-8 -*-
2
- """
3
- Created on Fri Sep 13 16:13:29 2024
4
-
5
- This script generates preprocessed data from wireless communication scenarios,
6
- including token generation, patch creation, and data sampling for machine learning models.
7
-
8
- @author: salikha4
9
- """
10
-
11
- import numpy as np
12
- import os
13
- from tqdm import tqdm
14
- import time
15
- import pickle
16
- import DeepMIMOv3
17
- import torch
18
-
19
- #%% Scenarios List
20
- def scenarios_list():
21
- """Returns an array of available scenarios."""
22
- return np.array([
23
- 'city_18_denver', 'city_15_indianapolis', 'city_19_oklahoma',
24
- 'city_12_fortworth', 'city_11_santaclara', 'city_7_sandiego'
25
- ])
26
-
27
- #%% Token Generation
28
- def tokenizer(selected_scenario_names=None, manual_data=None, gen_raw=True):
29
- """
30
- Generates tokens by preparing and preprocessing the dataset.
31
-
32
- Args:
33
- scenario_idxs (list): Indices of the scenarios.
34
- patch_gen (bool): Whether to generate patches. Defaults to True.
35
- patch_size (int): Size of each patch. Defaults to 16.
36
- gen_deepMIMO_data (bool): Whether to generate DeepMIMO data. Defaults to False.
37
- gen_raw (bool): Whether to generate raw data. Defaults to False.
38
- save_data (bool): Whether to save the preprocessed data. Defaults to False.
39
-
40
- Returns:
41
- preprocessed_data, sequence_length, element_length: Preprocessed data and related dimensions.
42
- """
43
-
44
- if manual_data is not None:
45
- patches = patch_maker(np.expand_dims(np.array(manual_data), axis=1))
46
- #patches = patch_maker(torch.tensor(manual_data, dtype=torch.complex64).unsqueeze(1))
47
- else:
48
- # Patch generation or loading
49
- if isinstance(selected_scenario_names, str):
50
- selected_scenario_names = [selected_scenario_names]
51
- deepmimo_data = [DeepMIMO_data_gen(scenario_name) for scenario_name in selected_scenario_names]
52
-
53
- n_scenarios = len(selected_scenario_names)
54
-
55
- cleaned_deepmimo_data = [deepmimo_data_cleaning(deepmimo_data[scenario_idx]) for scenario_idx in range(n_scenarios)]
56
-
57
- patches = [patch_maker(cleaned_deepmimo_data[scenario_idx]) for scenario_idx in range(n_scenarios)]
58
- patches = np.vstack(patches)
59
-
60
- # Define dimensions
61
- patch_size = patches.shape[2]
62
- n_patches = patches.shape[1]
63
- n_masks_half = int(0.15 * n_patches / 2)
64
- # sequence_length = n_patches + 1
65
- # element_length = patch_size
66
-
67
- word2id = {'[CLS]': 0.2 * np.ones((patch_size)), '[MASK]': 0.1 * np.ones((patch_size))}
68
-
69
- # Generate preprocessed channels
70
- preprocessed_data = []
71
- for user_idx in tqdm(range(len(patches)), desc="Processing items"):
72
- sample = make_sample(user_idx, patches, word2id, n_patches, n_masks_half, patch_size, gen_raw=gen_raw)
73
- preprocessed_data.append(sample)
74
-
75
- return preprocessed_data
76
-
77
- #%%
78
- def deepmimo_data_cleaning(deepmimo_data):
79
- idxs = np.where(deepmimo_data['user']['LoS'] != -1)[0]
80
- cleaned_deepmimo_data = deepmimo_data['user']['channel'][idxs]
81
- return np.array(cleaned_deepmimo_data) * 1e6
82
-
83
- #%% Patch Creation
84
- def patch_maker(original_ch, patch_size=16, norm_factor=1e6):
85
- """
86
- Creates patches from the dataset based on the scenario.
87
-
88
- Args:-
89
- patch_size (int): Size of each patch.
90
- scenario (str): Selected scenario for data generation.
91
- gen_deepMIMO_data (bool): Whether to generate DeepMIMO data.
92
- norm_factor (int): Normalization factor for channels.
93
-
94
- Returns:
95
- patch (numpy array): Generated patches.
96
- """
97
- # idxs = np.where(data['user']['LoS'] != -1)[0]
98
-
99
- # # Reshaping and normalizing channels
100
- # original_ch = data['user']['channel'][idxs]
101
- flat_channels = original_ch.reshape((original_ch.shape[0], -1)).astype(np.csingle)
102
- flat_channels_complex = np.hstack((flat_channels.real, flat_channels.imag))
103
-
104
- # Create patches
105
- n_patches = flat_channels_complex.shape[1] // patch_size
106
- patch = np.zeros((len(flat_channels_complex), n_patches, patch_size))
107
- for idx in range(n_patches):
108
- patch[:, idx, :] = flat_channels_complex[:, idx * patch_size:(idx + 1) * patch_size]
109
-
110
- return patch
111
-
112
-
113
- #%% Data Generation for Scenario Areas
114
- def DeepMIMO_data_gen(scenario):
115
- """
116
- Generates or loads data for a given scenario.
117
-
118
- Args:
119
- scenario (str): Scenario name.
120
- gen_deepMIMO_data (bool): Whether to generate DeepMIMO data.
121
- save_data (bool): Whether to save generated data.
122
-
123
- Returns:
124
- data (dict): Loaded or generated data.
125
- """
126
- import DeepMIMOv3
127
-
128
- parameters, row_column_users, n_ant_bs, n_ant_ue, n_subcarriers = get_parameters(scenario)
129
-
130
- deepMIMO_dataset = DeepMIMOv3.generate_data(parameters)
131
- uniform_idxs = uniform_sampling(deepMIMO_dataset, [1, 1], len(parameters['user_rows']),
132
- users_per_row=row_column_users[scenario]['n_per_row'])
133
- data = select_by_idx(deepMIMO_dataset, uniform_idxs)[0]
134
-
135
- return data
136
-
137
- #%%%
138
- def get_parameters(scenario):
139
-
140
- n_ant_bs = 32 #32
141
- n_ant_ue = 1
142
- n_subcarriers = 32 #32
143
- scs = 30e3
144
-
145
- row_column_users = {
146
- 'city_18_denver': {
147
- 'n_rows': 85,
148
- 'n_per_row': 82
149
- },
150
- 'city_15_indianapolis': {
151
- 'n_rows': 80,
152
- 'n_per_row': 79
153
- },
154
- 'city_19_oklahoma': {
155
- 'n_rows': 82,
156
- 'n_per_row': 75
157
- },
158
- 'city_12_fortworth': {
159
- 'n_rows': 86,
160
- 'n_per_row': 72
161
- },
162
- 'city_11_santaclara': {
163
- 'n_rows': 47,
164
- 'n_per_row': 114
165
- },
166
- 'city_7_sandiego': {
167
- 'n_rows': 71,
168
- 'n_per_row': 83
169
- }}
170
-
171
- parameters = DeepMIMOv3.default_params()
172
- parameters['dataset_folder'] = './scenarios'
173
- parameters['scenario'] = scenario
174
-
175
- if scenario == 'O1_3p5':
176
- parameters['active_BS'] = np.array([4])
177
- elif scenario in ['city_18_denver', 'city_15_indianapolis']:
178
- parameters['active_BS'] = np.array([3])
179
- else:
180
- parameters['active_BS'] = np.array([1])
181
-
182
- if scenario == 'Boston5G_3p5':
183
- parameters['user_rows'] = np.arange(row_column_users[scenario]['n_rows'][0],
184
- row_column_users[scenario]['n_rows'][1])
185
- else:
186
- parameters['user_rows'] = np.arange(row_column_users[scenario]['n_rows'])
187
- parameters['bs_antenna']['shape'] = np.array([n_ant_bs, 1]) # Horizontal, Vertical
188
- parameters['bs_antenna']['rotation'] = np.array([0,0,-135]) # (x,y,z)
189
- parameters['ue_antenna']['shape'] = np.array([n_ant_ue, 1])
190
- parameters['enable_BS2BS'] = False
191
- parameters['OFDM']['subcarriers'] = n_subcarriers
192
- parameters['OFDM']['selected_subcarriers'] = np.arange(n_subcarriers)
193
-
194
- parameters['OFDM']['bandwidth'] = scs * n_subcarriers / 1e9
195
- parameters['num_paths'] = 20
196
-
197
- return parameters, row_column_users, n_ant_bs, n_ant_ue, n_subcarriers
198
-
199
- #%% Sample Generation
200
- def make_sample(user_idx, patch, word2id, n_patches, n_masks, patch_size, gen_raw=False):
201
- """
202
- Generates a sample for each user, including masking and tokenizing.
203
-
204
- Args:
205
- user_idx (int): Index of the user.
206
- patch (numpy array): Patches data.
207
- word2id (dict): Dictionary for special tokens.
208
- n_patches (int): Number of patches.
209
- n_masks (int): Number of masks.
210
- patch_size (int): Size of each patch.
211
- gen_raw (bool): Whether to generate raw tokens.
212
-
213
- Returns:
214
- sample (list): Generated sample for the user.
215
- """
216
-
217
- tokens = patch[user_idx]
218
- input_ids = np.vstack((word2id['[CLS]'], tokens))
219
-
220
- real_tokens_size = int(n_patches / 2)
221
- masks_pos_real = np.random.choice(range(0, real_tokens_size), size=n_masks, replace=False)
222
- masks_pos_imag = masks_pos_real + real_tokens_size
223
- masked_pos = np.hstack((masks_pos_real, masks_pos_imag)) + 1
224
-
225
- masked_tokens = []
226
- for pos in masked_pos:
227
- original_masked_tokens = input_ids[pos].copy()
228
- masked_tokens.append(original_masked_tokens)
229
- if not gen_raw:
230
- rnd_num = np.random.rand()
231
- if rnd_num < 0.1:
232
- input_ids[pos] = np.random.rand(patch_size)
233
- elif rnd_num < 0.9:
234
- input_ids[pos] = word2id['[MASK]']
235
-
236
- return [input_ids, masked_tokens, masked_pos]
237
-
238
-
239
- #%% Sampling and Data Selection
240
- def uniform_sampling(dataset, sampling_div, n_rows, users_per_row):
241
- """
242
- Performs uniform sampling on the dataset.
243
-
244
- Args:
245
- dataset (dict): DeepMIMO dataset.
246
- sampling_div (list): Step sizes along [x, y] dimensions.
247
- n_rows (int): Number of rows for user selection.
248
- users_per_row (int): Number of users per row.
249
-
250
- Returns:
251
- uniform_idxs (numpy array): Indices of the selected samples.
252
- """
253
- cols = np.arange(users_per_row, step=sampling_div[0])
254
- rows = np.arange(n_rows, step=sampling_div[1])
255
- uniform_idxs = np.array([j + i * users_per_row for i in rows for j in cols])
256
-
257
- return uniform_idxs
258
-
259
- def select_by_idx(dataset, idxs):
260
- """
261
- Selects a subset of the dataset based on the provided indices.
262
-
263
- Args:
264
- dataset (dict): Dataset to trim.
265
- idxs (numpy array): Indices of users to select.
266
-
267
- Returns:
268
- dataset_t (list): Trimmed dataset based on selected indices.
269
- """
270
- dataset_t = [] # Trimmed dataset
271
- for bs_idx in range(len(dataset)):
272
- dataset_t.append({})
273
- for key in dataset[bs_idx].keys():
274
- dataset_t[bs_idx]['location'] = dataset[bs_idx]['location']
275
- dataset_t[bs_idx]['user'] = {k: dataset[bs_idx]['user'][k][idxs] for k in dataset[bs_idx]['user']}
276
-
277
- return dataset_t
278
-
279
- #%% Save and Load Utilities
280
- def save_var(var, path):
281
- """
282
- Saves a variable to a pickle file.
283
-
284
- Args:
285
- var (object): Variable to be saved.
286
- path (str): Path to save the file.
287
-
288
- Returns:
289
- None
290
- """
291
- path_full = path if path.endswith('.p') else (path + '.pickle')
292
- with open(path_full, 'wb') as handle:
293
- pickle.dump(var, handle)
294
-
295
- def load_var(path):
296
- """
297
- Loads a variable from a pickle file.
298
-
299
- Args:
300
- path (str): Path of the file to load.
301
-
302
- Returns:
303
- var (object): Loaded variable.
304
- """
305
- path_full = path if path.endswith('.p') else (path + '.pickle')
306
- with open(path_full, 'rb') as handle:
307
- var = pickle.load(handle)
308
-
309
- return var
310
-
311
- #%% Label Generation
312
- def label_gen(task, data, scenario, n_beams=64):
313
-
314
- idxs = np.where(data['user']['LoS'] != -1)[0]
315
-
316
- if task == 'LoS/NLoS Classification':
317
- label = data['user']['LoS'][idxs]
318
- elif task == 'Beam Prediction':
319
- parameters, row_column_users, n_ant_bs, n_ant_ue, n_subcarriers = get_parameters(scenario)
320
- n_users = len(data['user']['channel'])
321
- n_subbands = 1
322
- fov = 120
323
-
324
- # Setup Beamformers
325
- beam_angles = np.around(np.arange(-fov/2, fov/2+.1, fov/(n_beams-1)), 2)
326
-
327
- F1 = np.array([steering_vec(parameters['bs_antenna']['shape'],
328
- phi=azi*np.pi/180,
329
- kd=2*np.pi*parameters['bs_antenna']['spacing']).squeeze()
330
- for azi in beam_angles])
331
-
332
- full_dbm = np.zeros((n_beams, n_subbands, n_users), dtype=float)
333
- for ue_idx in tqdm(range(n_users), desc='Computing the channel for each user'):
334
- if data['user']['LoS'][ue_idx] == -1:
335
- full_dbm[:,:,ue_idx] = np.nan
336
- else:
337
- chs = F1 @ data['user']['channel'][ue_idx]
338
- full_linear = np.abs(np.mean(chs.squeeze().reshape((n_beams, n_subbands, -1)), axis=-1))
339
- full_dbm[:,:,ue_idx] = np.around(20*np.log10(full_linear) + 30, 1)
340
-
341
- best_beams = np.argmax(np.mean(full_dbm,axis=1), axis=0)
342
- best_beams = best_beams.astype(float)
343
- best_beams[np.isnan(full_dbm[0,0,:])] = np.nan
344
- # max_bf_pwr = np.max(np.mean(full_dbm,axis=1), axis=0)
345
-
346
- label = best_beams[idxs]
347
-
348
- return label.astype(int)
349
-
350
- def steering_vec(array, phi=0, theta=0, kd=np.pi):
351
- idxs = DeepMIMOv3.ant_indices(array)
352
- resp = DeepMIMOv3.array_response(idxs, phi, theta+np.pi/2, kd)
353
- return resp / np.linalg.norm(resp)
354
-
355
- def label_prepend(deepmimo_data, preprocessed_chs, task, scenario_idxs, n_beams=64):
356
- labels = []
357
- for scenario_idx in scenario_idxs:
358
- scenario_name = scenarios_list()[scenario_idx]
359
- # data = DeepMIMO_data_gen(scenario_name)
360
- data = deepmimo_data[scenario_idx]
361
- labels.extend(label_gen(task, data, scenario_name, n_beams=n_beams))
362
-
363
- preprocessed_chs = [preprocessed_chs[i] + [labels[i]] for i in range(len(preprocessed_chs))]
364
-
365
- return preprocessed_chs
366
-
367
- def create_labels(task, scenario_names, n_beams=64):
368
- labels = []
369
- if isinstance(scenario_names, str):
370
- scenario_names = [scenario_names]
371
- for scenario_name in scenario_names:
372
- data = DeepMIMO_data_gen(scenario_name)
373
- labels.extend(label_gen(task, data, scenario_name, n_beams=n_beams))
374
- return torch.tensor(labels)
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Fri Sep 13 16:13:29 2024
4
+
5
+ This script generates preprocessed data from wireless communication scenarios,
6
+ including token generation, patch creation, and data sampling for machine learning models.
7
+
8
+ @author: salikha4
9
+ """
10
+
11
+ import numpy as np
12
+ import os
13
+ from tqdm import tqdm
14
+ import time
15
+ import pickle
16
+ import DeepMIMOv3
17
+ import torch
18
+ from utils import plot_coverage, generate_gaussian_noise
19
+ #%% Scenarios List
20
+ def scenarios_list():
21
+ """Returns an array of available scenarios."""
22
+ return np.array([
23
+ 'city_18_denver', 'city_15_indianapolis', 'city_19_oklahoma',
24
+ 'city_12_fortworth', 'city_11_santaclara', 'city_7_sandiego'
25
+ ])
26
+
27
+ #%% Token Generation
28
+ def tokenizer(selected_scenario_names=None, manual_data=None, gen_raw=True, snr_db=None):
29
+ """
30
+ Generates tokens by preparing and preprocessing the dataset.
31
+
32
+ Args:
33
+ scenario_idxs (list): Indices of the scenarios.
34
+ patch_gen (bool): Whether to generate patches. Defaults to True.
35
+ patch_size (int): Size of each patch. Defaults to 16.
36
+ gen_deepMIMO_data (bool): Whether to generate DeepMIMO data. Defaults to False.
37
+ gen_raw (bool): Whether to generate raw data. Defaults to False.
38
+ save_data (bool): Whether to save the preprocessed data. Defaults to False.
39
+
40
+ Returns:
41
+ preprocessed_data, sequence_length, element_length: Preprocessed data and related dimensions.
42
+ """
43
+
44
+ if manual_data is not None:
45
+ patches = patch_maker(np.expand_dims(np.array(manual_data), axis=1), snr_db=snr_db)
46
+ else:
47
+ # Patch generation or loading
48
+ if isinstance(selected_scenario_names, str):
49
+ selected_scenario_names = [selected_scenario_names]
50
+ deepmimo_data = [DeepMIMO_data_gen(scenario_name) for scenario_name in selected_scenario_names]
51
+ n_scenarios = len(selected_scenario_names)
52
+
53
+ cleaned_deepmimo_data = [deepmimo_data_cleaning(deepmimo_data[scenario_idx]) for scenario_idx in range(n_scenarios)]
54
+
55
+ patches = [patch_maker(cleaned_deepmimo_data[scenario_idx], snr_db=snr_db) for scenario_idx in range(n_scenarios)]
56
+ patches = np.vstack(patches)
57
+
58
+ # Define dimensions
59
+ patch_size = patches.shape[2]
60
+ n_patches = patches.shape[1]
61
+ n_masks_half = int(0.15 * n_patches / 2)
62
+
63
+ word2id = {'[CLS]': 0.2 * np.ones((patch_size)), '[MASK]': 0.1 * np.ones((patch_size))}
64
+
65
+ # Generate preprocessed channels
66
+ preprocessed_data = []
67
+ for user_idx in tqdm(range(len(patches)), desc="Processing items"):
68
+ sample = make_sample(user_idx, patches, word2id, n_patches, n_masks_half, patch_size, gen_raw=gen_raw)
69
+ preprocessed_data.append(sample)
70
+
71
+ return preprocessed_data
72
+
73
+ #%%
74
+ def deepmimo_data_cleaning(deepmimo_data):
75
+ idxs = np.where(deepmimo_data['user']['LoS'] != -1)[0]
76
+ cleaned_deepmimo_data = deepmimo_data['user']['channel'][idxs]
77
+ return np.array(cleaned_deepmimo_data) * 1e6
78
+
79
+ #%% Patch Creation
80
+ def patch_maker(original_ch, patch_size=16, norm_factor=1e6, snr_db=None):
81
+ """
82
+ Creates patches from the dataset based on the scenario.
83
+
84
+ Args:-
85
+ patch_size (int): Size of each patch.
86
+ scenario (str): Selected scenario for data generation.
87
+ gen_deepMIMO_data (bool): Whether to generate DeepMIMO data.
88
+ norm_factor (int): Normalization factor for channels.
89
+
90
+ Returns:
91
+ patch (numpy array): Generated patches.
92
+ """
93
+ flat_channels = original_ch.reshape((original_ch.shape[0], -1)).astype(np.csingle)
94
+ if snr_db is not None:
95
+ flat_channels += generate_gaussian_noise(flat_channels, snr_db)
96
+
97
+ flat_channels_complex = np.hstack((flat_channels.real, flat_channels.imag))
98
+
99
+ # Create patches
100
+ n_patches = flat_channels_complex.shape[1] // patch_size
101
+ patch = np.zeros((len(flat_channels_complex), n_patches, patch_size))
102
+ for idx in range(n_patches):
103
+ patch[:, idx, :] = flat_channels_complex[:, idx * patch_size:(idx + 1) * patch_size]
104
+
105
+ return patch
106
+
107
+ #%% Data Generation for Scenario Areas
108
+ def DeepMIMO_data_gen(scenario):
109
+ """
110
+ Generates or loads data for a given scenario.
111
+
112
+ Args:
113
+ scenario (str): Scenario name.
114
+ gen_deepMIMO_data (bool): Whether to generate DeepMIMO data.
115
+ save_data (bool): Whether to save generated data.
116
+
117
+ Returns:
118
+ data (dict): Loaded or generated data.
119
+ """
120
+ import DeepMIMOv3
121
+
122
+ parameters, row_column_users, n_ant_bs, n_ant_ue, n_subcarriers = get_parameters(scenario)
123
+
124
+ deepMIMO_dataset = DeepMIMOv3.generate_data(parameters)
125
+ uniform_idxs = uniform_sampling(deepMIMO_dataset, [1, 1], len(parameters['user_rows']),
126
+ users_per_row=row_column_users[scenario]['n_per_row'])
127
+ data = select_by_idx(deepMIMO_dataset, uniform_idxs)[0]
128
+
129
+ return data
130
+
131
+ #%%%
132
+ def get_parameters(scenario):
133
+
134
+ n_ant_bs = 32
135
+ n_ant_ue = 1
136
+ n_subcarriers = 32
137
+ scs = 30e3
138
+
139
+ row_column_users = {
140
+ 'city_18_denver': {
141
+ 'n_rows': 85,
142
+ 'n_per_row': 82
143
+ },
144
+ 'city_15_indianapolis': {
145
+ 'n_rows': 80,
146
+ 'n_per_row': 79
147
+ },
148
+ 'city_19_oklahoma': {
149
+ 'n_rows': 82,
150
+ 'n_per_row': 75
151
+ },
152
+ 'city_12_fortworth': {
153
+ 'n_rows': 86,
154
+ 'n_per_row': 72
155
+ },
156
+ 'city_11_santaclara': {
157
+ 'n_rows': 47,
158
+ 'n_per_row': 114
159
+ },
160
+ 'city_7_sandiego': {
161
+ 'n_rows': 71,
162
+ 'n_per_row': 83
163
+ }}
164
+
165
+ parameters = DeepMIMOv3.default_params()
166
+ parameters['dataset_folder'] = './scenarios'
167
+ parameters['scenario'] = scenario
168
+
169
+ if scenario == 'O1_3p5':
170
+ parameters['active_BS'] = np.array([4])
171
+ elif scenario in ['city_18_denver', 'city_15_indianapolis']:
172
+ parameters['active_BS'] = np.array([3])
173
+ else:
174
+ parameters['active_BS'] = np.array([1])
175
+
176
+ if scenario == 'Boston5G_3p5':
177
+ parameters['user_rows'] = np.arange(row_column_users[scenario]['n_rows'][0],
178
+ row_column_users[scenario]['n_rows'][1])
179
+ else:
180
+ parameters['user_rows'] = np.arange(row_column_users[scenario]['n_rows'])
181
+ parameters['bs_antenna']['shape'] = np.array([n_ant_bs, 1]) # Horizontal, Vertical
182
+ parameters['bs_antenna']['rotation'] = np.array([0,0,-135]) # (x,y,z)
183
+ parameters['ue_antenna']['shape'] = np.array([n_ant_ue, 1])
184
+ parameters['enable_BS2BS'] = False
185
+ parameters['OFDM']['subcarriers'] = n_subcarriers
186
+ parameters['OFDM']['selected_subcarriers'] = np.arange(n_subcarriers)
187
+
188
+ parameters['OFDM']['bandwidth'] = scs * n_subcarriers / 1e9
189
+ parameters['num_paths'] = 20
190
+
191
+ return parameters, row_column_users, n_ant_bs, n_ant_ue, n_subcarriers
192
+
193
+ #%% Sample Generation
194
+ def make_sample(user_idx, patch, word2id, n_patches, n_masks, patch_size, gen_raw=False):
195
+ """
196
+ Generates a sample for each user, including masking and tokenizing.
197
+
198
+ Args:
199
+ user_idx (int): Index of the user.
200
+ patch (numpy array): Patches data.
201
+ word2id (dict): Dictionary for special tokens.
202
+ n_patches (int): Number of patches.
203
+ n_masks (int): Number of masks.
204
+ patch_size (int): Size of each patch.
205
+ gen_raw (bool): Whether to generate raw tokens.
206
+
207
+ Returns:
208
+ sample (list): Generated sample for the user.
209
+ """
210
+
211
+ tokens = patch[user_idx]
212
+ input_ids = np.vstack((word2id['[CLS]'], tokens))
213
+
214
+ real_tokens_size = int(n_patches / 2)
215
+ masks_pos_real = np.random.choice(range(0, real_tokens_size), size=n_masks, replace=False)
216
+ masks_pos_imag = masks_pos_real + real_tokens_size
217
+ masked_pos = np.hstack((masks_pos_real, masks_pos_imag)) + 1
218
+
219
+ masked_tokens = []
220
+ for pos in masked_pos:
221
+ original_masked_tokens = input_ids[pos].copy()
222
+ masked_tokens.append(original_masked_tokens)
223
+ if not gen_raw:
224
+ rnd_num = np.random.rand()
225
+ if rnd_num < 0.1:
226
+ input_ids[pos] = np.random.rand(patch_size)
227
+ elif rnd_num < 0.9:
228
+ input_ids[pos] = word2id['[MASK]']
229
+
230
+ return [input_ids, masked_tokens, masked_pos]
231
+
232
+
233
+ #%% Sampling and Data Selection
234
+ def uniform_sampling(dataset, sampling_div, n_rows, users_per_row):
235
+ """
236
+ Performs uniform sampling on the dataset.
237
+
238
+ Args:
239
+ dataset (dict): DeepMIMO dataset.
240
+ sampling_div (list): Step sizes along [x, y] dimensions.
241
+ n_rows (int): Number of rows for user selection.
242
+ users_per_row (int): Number of users per row.
243
+
244
+ Returns:
245
+ uniform_idxs (numpy array): Indices of the selected samples.
246
+ """
247
+ cols = np.arange(users_per_row, step=sampling_div[0])
248
+ rows = np.arange(n_rows, step=sampling_div[1])
249
+ uniform_idxs = np.array([j + i * users_per_row for i in rows for j in cols])
250
+
251
+ return uniform_idxs
252
+
253
+ def select_by_idx(dataset, idxs):
254
+ """
255
+ Selects a subset of the dataset based on the provided indices.
256
+
257
+ Args:
258
+ dataset (dict): Dataset to trim.
259
+ idxs (numpy array): Indices of users to select.
260
+
261
+ Returns:
262
+ dataset_t (list): Trimmed dataset based on selected indices.
263
+ """
264
+ dataset_t = [] # Trimmed dataset
265
+ for bs_idx in range(len(dataset)):
266
+ dataset_t.append({})
267
+ for key in dataset[bs_idx].keys():
268
+ dataset_t[bs_idx]['location'] = dataset[bs_idx]['location']
269
+ dataset_t[bs_idx]['user'] = {k: dataset[bs_idx]['user'][k][idxs] for k in dataset[bs_idx]['user']}
270
+
271
+ return dataset_t
272
+
273
+ #%% Save and Load Utilities
274
+ def save_var(var, path):
275
+ """
276
+ Saves a variable to a pickle file.
277
+
278
+ Args:
279
+ var (object): Variable to be saved.
280
+ path (str): Path to save the file.
281
+
282
+ Returns:
283
+ None
284
+ """
285
+ path_full = path if path.endswith('.p') else (path + '.pickle')
286
+ with open(path_full, 'wb') as handle:
287
+ pickle.dump(var, handle)
288
+
289
+ def load_var(path):
290
+ """
291
+ Loads a variable from a pickle file.
292
+
293
+ Args:
294
+ path (str): Path of the file to load.
295
+
296
+ Returns:
297
+ var (object): Loaded variable.
298
+ """
299
+ path_full = path if path.endswith('.p') else (path + '.pickle')
300
+ with open(path_full, 'rb') as handle:
301
+ var = pickle.load(handle)
302
+
303
+ return var
304
+
305
+ #%% Label Generation
306
+ def label_gen(task, data, scenario, n_beams=64):
307
+
308
+ idxs = np.where(data['user']['LoS'] != -1)[0]
309
+
310
+ if task == 'LoS/NLoS Classification':
311
+ label = data['user']['LoS'][idxs]
312
+
313
+ losChs = np.where(data['user']['LoS'] == -1, np.nan, data['user']['LoS'])
314
+ plot_coverage(data['user']['location'], losChs)
315
+
316
+ elif task == 'Beam Prediction':
317
+ parameters, row_column_users = get_parameters(scenario)[:2]
318
+ n_users = len(data['user']['channel'])
319
+ n_subbands = 1
320
+ fov = 180
321
+
322
+ # Setup Beamformers
323
+ beam_angles = np.around(np.arange(-fov/2, fov/2+.1, fov/(n_beams-1)), 2)
324
+
325
+ F1 = np.array([steering_vec(parameters['bs_antenna']['shape'],
326
+ phi=azi*np.pi/180,
327
+ kd=2*np.pi*parameters['bs_antenna']['spacing']).squeeze()
328
+ for azi in beam_angles])
329
+
330
+ full_dbm = np.zeros((n_beams, n_subbands, n_users), dtype=float)
331
+ for ue_idx in tqdm(range(n_users), desc='Computing the channel for each user'):
332
+ if data['user']['LoS'][ue_idx] == -1:
333
+ full_dbm[:,:,ue_idx] = np.nan
334
+ else:
335
+ chs = F1 @ data['user']['channel'][ue_idx]
336
+ full_linear = np.abs(np.mean(chs.squeeze().reshape((n_beams, n_subbands, -1)), axis=-1))
337
+ full_dbm[:,:,ue_idx] = np.around(20*np.log10(full_linear) + 30, 1)
338
+
339
+ best_beams = np.argmax(np.mean(full_dbm,axis=1), axis=0)
340
+ best_beams = best_beams.astype(float)
341
+ best_beams[np.isnan(full_dbm[0,0,:])] = np.nan
342
+
343
+ plot_coverage(data['user']['location'], best_beams)
344
+
345
+ label = best_beams[idxs]
346
+
347
+ return label.astype(int)
348
+
349
+ def steering_vec(array, phi=0, theta=0, kd=np.pi):
350
+ idxs = DeepMIMOv3.ant_indices(array)
351
+ resp = DeepMIMOv3.array_response(idxs, phi, theta+np.pi/2, kd)
352
+ return resp / np.linalg.norm(resp)
353
+
354
+ def label_prepend(deepmimo_data, preprocessed_chs, task, scenario_idxs, n_beams=64):
355
+ labels = []
356
+ for scenario_idx in scenario_idxs:
357
+ scenario_name = scenarios_list()[scenario_idx]
358
+ data = deepmimo_data[scenario_idx]
359
+ labels.extend(label_gen(task, data, scenario_name, n_beams=n_beams))
360
+
361
+ preprocessed_chs = [preprocessed_chs[i] + [labels[i]] for i in range(len(preprocessed_chs))]
362
+
363
+ return preprocessed_chs
364
+
365
+ def create_labels(task, scenario_names, n_beams=64):
366
+ labels = []
367
+ if isinstance(scenario_names, str):
368
+ scenario_names = [scenario_names]
369
+ for scenario_name in scenario_names:
370
+ data = DeepMIMO_data_gen(scenario_name)
371
+ labels.extend(label_gen(task, data, scenario_name, n_beams=n_beams))
372
+ return torch.tensor(labels).long()
373
+ #%%
 
lwm_model.py CHANGED
@@ -1,134 +1,134 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- import numpy as np
5
-
6
- ELEMENT_LENGTH = 16
7
- D_MODEL = 64
8
- MAX_LEN = 129
9
- N_LAYERS = 12
10
- N_HEADS = 12
11
- D_FF = D_MODEL * 4
12
- D_K = D_MODEL // N_HEADS
13
- D_V = D_MODEL // N_HEADS
14
- DROPOUT = 0.1
15
-
16
- class LayerNormalization(nn.Module):
17
- def __init__(self, d_model: int, eps: float = 1e-6) -> None:
18
- super().__init__()
19
- self.eps = eps
20
- self.alpha = nn.Parameter(torch.ones(d_model))
21
- self.bias = nn.Parameter(torch.zeros(d_model))
22
-
23
- def forward(self, x):
24
- mean = x.mean(dim=-1, keepdim=True)
25
- std = x.std(dim=-1, keepdim=True)
26
- return self.alpha * (x - mean) / (std + self.eps) + self.bias
27
-
28
- class Embedding(nn.Module):
29
- def __init__(self, element_length, d_model, max_len):
30
- super().__init__()
31
- self.element_length = element_length
32
- self.d_model = d_model
33
- self.proj = nn.Linear(element_length, d_model)
34
- self.pos_embed = nn.Embedding(max_len, d_model)
35
- self.norm = LayerNormalization(d_model)
36
-
37
- def forward(self, x):
38
- seq_len = x.size(1)
39
- pos = torch.arange(seq_len, dtype=torch.long, device=x.device)
40
- pos = pos.unsqueeze(0).expand_as(x[:, :, 0])
41
- tok_emb = self.proj(x.float())
42
- embedding = tok_emb + self.pos_embed(pos)
43
- return self.norm(embedding)
44
-
45
- class ScaledDotProductAttention(nn.Module):
46
- def __init__(self):
47
- super().__init__()
48
-
49
- def forward(self, Q, K, V):
50
- scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(D_K)
51
- attn = F.softmax(scores, dim=-1)
52
- context = torch.matmul(attn, V)
53
- return context, attn
54
-
55
- class MultiHeadAttention(nn.Module):
56
- def __init__(self):
57
- super().__init__()
58
- self.W_Q = nn.Linear(D_MODEL, D_K * N_HEADS)
59
- self.W_K = nn.Linear(D_MODEL, D_K * N_HEADS)
60
- self.W_V = nn.Linear(D_MODEL, D_V * N_HEADS)
61
- self.linear = nn.Linear(N_HEADS * D_V, D_MODEL)
62
- self.norm = LayerNormalization(D_MODEL)
63
- self.dropout = nn.Dropout(DROPOUT)
64
-
65
- def forward(self, Q, K, V):
66
- residual, batch_size = Q, Q.size(0)
67
- q_s = self.W_Q(Q).view(batch_size, -1, N_HEADS, D_K).transpose(1, 2)
68
- k_s = self.W_K(K).view(batch_size, -1, N_HEADS, D_K).transpose(1, 2)
69
- v_s = self.W_V(V).view(batch_size, -1, N_HEADS, D_V).transpose(1, 2)
70
-
71
- context, attn = ScaledDotProductAttention()(q_s, k_s, v_s)
72
- output = context.transpose(1, 2).contiguous().view(batch_size, -1, N_HEADS * D_V)
73
- output = self.linear(output)
74
- return residual + self.dropout(output), attn
75
-
76
- class PoswiseFeedForwardNet(nn.Module):
77
- def __init__(self):
78
- super().__init__()
79
- self.fc1 = nn.Linear(D_MODEL, D_FF)
80
- self.fc2 = nn.Linear(D_FF, D_MODEL)
81
- self.dropout = nn.Dropout(DROPOUT)
82
- self.norm = LayerNormalization(D_MODEL)
83
-
84
- def forward(self, x):
85
- output = self.fc2(self.dropout(F.relu(self.fc1(x))))
86
- return x + self.dropout(output)
87
-
88
- class EncoderLayer(nn.Module):
89
- def __init__(self):
90
- super().__init__()
91
- self.enc_self_attn = MultiHeadAttention()
92
- self.pos_ffn = PoswiseFeedForwardNet()
93
- self.norm = LayerNormalization(D_MODEL)
94
-
95
- def forward(self, enc_inputs):
96
- attn_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs)
97
- attn_outputs = self.norm(attn_outputs)
98
- enc_outputs = self.pos_ffn(attn_outputs)
99
- return enc_outputs, attn
100
-
101
- class lwm(torch.nn.Module):
102
- def __init__(self, element_length=16, d_model=64, max_len=129, n_layers=12):
103
- super().__init__()
104
- self.embedding = Embedding(element_length, d_model, max_len)
105
- self.layers = nn.ModuleList([EncoderLayer() for _ in range(n_layers)])
106
- self.linear = nn.Linear(d_model, d_model)
107
- self.norm = LayerNormalization(d_model)
108
-
109
- embed_weight = self.embedding.proj.weight
110
- d_model, n_dim = embed_weight.size()
111
- self.decoder = nn.Linear(d_model, n_dim, bias=False)
112
- self.decoder_bias = nn.Parameter(torch.zeros(n_dim))
113
-
114
- @classmethod
115
- def from_pretrained(cls, ckpt_name='model_weights.pth', device='cuda', use_auth_token=None):
116
- model = cls().to(device)
117
-
118
- ckpt_path = ckpt_name
119
- model.load_state_dict(torch.load(ckpt_path, map_location=device))
120
- print(f"Model loaded successfully from {ckpt_path} to {device}")
121
-
122
- return model
123
-
124
- def forward(self, input_ids, masked_pos):
125
- output = self.embedding(input_ids)
126
- for layer in self.layers:
127
- output, _ = layer(output)
128
-
129
- masked_pos = masked_pos.long()[:, :, None].expand(-1, -1, output.size(-1))
130
- h_masked = torch.gather(output, 1, masked_pos)
131
- h_masked = self.norm(F.relu(self.linear(h_masked)))
132
- logits_lm = self.decoder(h_masked) + self.decoder_bias
133
-
134
- return logits_lm, output
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+
6
+ ELEMENT_LENGTH = 16
7
+ D_MODEL = 64
8
+ MAX_LEN = 129
9
+ N_LAYERS = 12
10
+ N_HEADS = 12
11
+ D_FF = D_MODEL * 4
12
+ D_K = D_MODEL // N_HEADS
13
+ D_V = D_MODEL // N_HEADS
14
+ DROPOUT = 0.1
15
+
16
+ class LayerNormalization(nn.Module):
17
+ def __init__(self, d_model: int, eps: float = 1e-6) -> None:
18
+ super().__init__()
19
+ self.eps = eps
20
+ self.alpha = nn.Parameter(torch.ones(d_model))
21
+ self.bias = nn.Parameter(torch.zeros(d_model))
22
+
23
+ def forward(self, x):
24
+ mean = x.mean(dim=-1, keepdim=True)
25
+ std = x.std(dim=-1, keepdim=True)
26
+ return self.alpha * (x - mean) / (std + self.eps) + self.bias
27
+
28
+ class Embedding(nn.Module):
29
+ def __init__(self, element_length, d_model, max_len):
30
+ super().__init__()
31
+ self.element_length = element_length
32
+ self.d_model = d_model
33
+ self.proj = nn.Linear(element_length, d_model)
34
+ self.pos_embed = nn.Embedding(max_len, d_model)
35
+ self.norm = LayerNormalization(d_model)
36
+
37
+ def forward(self, x):
38
+ seq_len = x.size(1)
39
+ pos = torch.arange(seq_len, dtype=torch.long, device=x.device)
40
+ pos = pos.unsqueeze(0).expand_as(x[:, :, 0])
41
+ tok_emb = self.proj(x.float())
42
+ embedding = tok_emb + self.pos_embed(pos)
43
+ return self.norm(embedding)
44
+
45
+ class ScaledDotProductAttention(nn.Module):
46
+ def __init__(self):
47
+ super().__init__()
48
+
49
+ def forward(self, Q, K, V):
50
+ scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(D_K)
51
+ attn = F.softmax(scores, dim=-1)
52
+ context = torch.matmul(attn, V)
53
+ return context, attn
54
+
55
+ class MultiHeadAttention(nn.Module):
56
+ def __init__(self):
57
+ super().__init__()
58
+ self.W_Q = nn.Linear(D_MODEL, D_K * N_HEADS)
59
+ self.W_K = nn.Linear(D_MODEL, D_K * N_HEADS)
60
+ self.W_V = nn.Linear(D_MODEL, D_V * N_HEADS)
61
+ self.linear = nn.Linear(N_HEADS * D_V, D_MODEL)
62
+ self.norm = LayerNormalization(D_MODEL)
63
+ self.dropout = nn.Dropout(DROPOUT)
64
+
65
+ def forward(self, Q, K, V):
66
+ residual, batch_size = Q, Q.size(0)
67
+ q_s = self.W_Q(Q).view(batch_size, -1, N_HEADS, D_K).transpose(1, 2)
68
+ k_s = self.W_K(K).view(batch_size, -1, N_HEADS, D_K).transpose(1, 2)
69
+ v_s = self.W_V(V).view(batch_size, -1, N_HEADS, D_V).transpose(1, 2)
70
+
71
+ context, attn = ScaledDotProductAttention()(q_s, k_s, v_s)
72
+ output = context.transpose(1, 2).contiguous().view(batch_size, -1, N_HEADS * D_V)
73
+ output = self.linear(output)
74
+ return residual + self.dropout(output), attn
75
+
76
+ class PoswiseFeedForwardNet(nn.Module):
77
+ def __init__(self):
78
+ super().__init__()
79
+ self.fc1 = nn.Linear(D_MODEL, D_FF)
80
+ self.fc2 = nn.Linear(D_FF, D_MODEL)
81
+ self.dropout = nn.Dropout(DROPOUT)
82
+ self.norm = LayerNormalization(D_MODEL)
83
+
84
+ def forward(self, x):
85
+ output = self.fc2(self.dropout(F.relu(self.fc1(x))))
86
+ return x + self.dropout(output)
87
+
88
+ class EncoderLayer(nn.Module):
89
+ def __init__(self):
90
+ super().__init__()
91
+ self.enc_self_attn = MultiHeadAttention()
92
+ self.pos_ffn = PoswiseFeedForwardNet()
93
+ self.norm = LayerNormalization(D_MODEL)
94
+
95
+ def forward(self, enc_inputs):
96
+ attn_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs)
97
+ attn_outputs = self.norm(attn_outputs)
98
+ enc_outputs = self.pos_ffn(attn_outputs)
99
+ return enc_outputs, attn
100
+
101
+ class lwm(torch.nn.Module):
102
+ def __init__(self, element_length=16, d_model=64, max_len=129, n_layers=12):
103
+ super().__init__()
104
+ self.embedding = Embedding(element_length, d_model, max_len)
105
+ self.layers = nn.ModuleList([EncoderLayer() for _ in range(n_layers)])
106
+ self.linear = nn.Linear(d_model, d_model)
107
+ self.norm = LayerNormalization(d_model)
108
+
109
+ embed_weight = self.embedding.proj.weight
110
+ d_model, n_dim = embed_weight.size()
111
+ self.decoder = nn.Linear(d_model, n_dim, bias=False)
112
+ self.decoder_bias = nn.Parameter(torch.zeros(n_dim))
113
+
114
+ @classmethod
115
+ def from_pretrained(cls, ckpt_name='model_weights.pth', device='cuda', use_auth_token=None):
116
+ model = cls().to(device)
117
+
118
+ ckpt_path = ckpt_name
119
+ model.load_state_dict(torch.load(ckpt_path, map_location=device))
120
+ print(f"Model loaded successfully from {ckpt_path} to {device}")
121
+
122
+ return model
123
+
124
+ def forward(self, input_ids, masked_pos):
125
+ output = self.embedding(input_ids)
126
+ for layer in self.layers:
127
+ output, _ = layer(output)
128
+
129
+ masked_pos = masked_pos.long()[:, :, None].expand(-1, -1, output.size(-1))
130
+ h_masked = torch.gather(output, 1, masked_pos)
131
+ h_masked = self.norm(F.relu(self.linear(h_masked)))
132
+ logits_lm = self.decoder(h_masked) + self.decoder_bias
133
+
134
+ return logits_lm, output
utils.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ from sklearn.decomposition import PCA
5
+ from sklearn.manifold import TSNE
6
+ import umap
7
+ #%%
8
+ def plot_dimensionality_reduction(tensor, method='all', labels=None, input_type='Unknown', task='Unknown'):
9
+ """
10
+ Plots 2D projections of high-dimensional data using PCA, t-SNE, or UMAP.
11
+
12
+ Parameters:
13
+ tensor (torch.Tensor): Input data of shape (n_samples, n_features).
14
+ method (str or list): One of ['pca', 'tsne', 'umap'] or 'all' for all three.
15
+ labels (array-like): Optional labels for coloring the scatter plot.
16
+ input_type (str): Type of input data for title.
17
+ task (str): Task description for title.
18
+ """
19
+ tensor = tensor.view(tensor.size(0), -1)
20
+ # Convert to numpy if it's a PyTorch tensor
21
+ if isinstance(tensor, torch.Tensor):
22
+ tensor = tensor.cpu().numpy()
23
+
24
+ methods = []
25
+ if method == 'all':
26
+ methods = ['pca', 'tsne', 'umap']
27
+ elif isinstance(method, str):
28
+ methods = [method]
29
+ elif isinstance(method, list):
30
+ methods = method
31
+
32
+ plt.figure(figsize=(6 * len(methods), 5))
33
+ plt.suptitle(f"Input: {input_type}, Task: {task}", fontsize=16)
34
+
35
+ for i, m in enumerate(methods):
36
+ if m == 'pca':
37
+ reducer = PCA(n_components=2)
38
+ title = 'PCA'
39
+ elif m == 'tsne':
40
+ reducer = TSNE(n_components=2, perplexity=2, random_state=42)
41
+ title = 't-SNE'
42
+ elif m == 'umap':
43
+ reducer = umap.UMAP(n_components=2, random_state=42)
44
+ title = 'UMAP'
45
+ else:
46
+ raise ValueError(f"Unknown method: {m}")
47
+
48
+ reduced_data = reducer.fit_transform(tensor)
49
+
50
+ plt.subplot(1, len(methods), i + 1)
51
+
52
+ if labels is not None:
53
+ unique_labels = np.unique(labels)
54
+ cmap = plt.get_cmap('Spectral', len(unique_labels))
55
+
56
+ scatter = plt.scatter(reduced_data[:, 0], reduced_data[:, 1], c=labels, cmap=cmap, alpha=0.75)
57
+
58
+ cbar = plt.colorbar(scatter, ticks=unique_labels)
59
+ cbar.set_ticklabels(unique_labels)
60
+ else:
61
+ plt.scatter(reduced_data[:, 0], reduced_data[:, 1], alpha=0.75)
62
+
63
+ plt.title(title, fontsize=14)
64
+ plt.xlabel("Component 1")
65
+ plt.ylabel("Component 2")
66
+ plt.grid(True, linestyle='--', alpha=0.5)
67
+
68
+ plt.tight_layout(rect=[0, 0, 1, 0.95])
69
+ plt.show()
70
+ #%%
71
+ def plot_coverage(receivers, coverage_map, dpi=200, figsize=(6, 4), cbar_title=None, title=None,
72
+ scatter_size=12, transmitter_position=None, transmitter_orientation=None,
73
+ legend=False, limits=None, proj_3d=False, equal_aspect=False, tight_layout=True,
74
+ colormap='tab20'):
75
+ # Set up plot parameters
76
+ plot_params = {'cmap': colormap}
77
+ if limits:
78
+ plot_params['vmin'], plot_params['vmax'] = limits[0], limits[1]
79
+
80
+ # Extract coordinates
81
+ x, y = receivers[:, 0], receivers[:, 1]
82
+
83
+ # Create figure and axis
84
+ fig, ax = plt.subplots(dpi=dpi, figsize=figsize,
85
+ subplot_kw={})
86
+
87
+ # Plot the coverage map
88
+ ax.scatter(x, y, c=coverage_map, s=scatter_size, marker='s', edgecolors='black', linewidth=.15, **plot_params)
89
+
90
+ # Set axis labels
91
+ ax.set_xlabel('x (m)')
92
+ ax.set_ylabel('y (m)')
93
+
94
+ # Add legend if requested
95
+ if legend:
96
+ ax.legend(loc='upper center', ncols=10, framealpha=0.5)
97
+
98
+ # Adjust plot limits
99
+ if tight_layout:
100
+ padding = 1
101
+ mins = np.min(receivers, axis=0) - padding
102
+ maxs = np.max(receivers, axis=0) + padding
103
+
104
+ ax.set_xlim([mins[0], maxs[0]])
105
+ ax.set_ylim([mins[1], maxs[1]])
106
+
107
+ # Set equal aspect ratio for 2D plots
108
+ if equal_aspect:
109
+ ax.set_aspect('equal')
110
+
111
+ # Show plot
112
+ plt.show()
113
+ #%%
114
+ import torch
115
+ import torch.nn as nn
116
+ import torch.optim as optim
117
+ import torch.nn.functional as F
118
+ from torch.utils.data import DataLoader, TensorDataset, random_split
119
+ import numpy as np
120
+ import matplotlib.pyplot as plt
121
+ from sklearn.metrics import f1_score
122
+
123
+ # Data Preparation
124
+ def get_data_loaders(data_tensor, labels_tensor, batch_size=32, split_ratio=0.8):
125
+ dataset = TensorDataset(data_tensor, labels_tensor)
126
+
127
+ train_size = int(split_ratio * len(dataset))
128
+ test_size = len(dataset) - train_size
129
+ train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
130
+
131
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
132
+ test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
133
+
134
+ return train_loader, test_loader
135
+
136
+ class FCN(nn.Module):
137
+ def __init__(self, input_dim, num_classes):
138
+ super(FCN, self).__init__()
139
+ self.fc1 = nn.Linear(input_dim, 128)
140
+ self.bn1 = nn.BatchNorm1d(128)
141
+ self.dropout1 = nn.Dropout(0.3)
142
+
143
+ self.fc2 = nn.Linear(128, 64)
144
+ self.bn2 = nn.BatchNorm1d(64)
145
+ self.dropout2 = nn.Dropout(0.3)
146
+
147
+ self.fc3 = nn.Linear(64, num_classes)
148
+
149
+ def forward(self, x):
150
+ x = F.relu(self.bn1(self.fc1(x)))
151
+ x = self.dropout1(x)
152
+ x = F.relu(self.bn2(self.fc2(x)))
153
+ x = self.dropout2(x)
154
+ return self.fc3(x)
155
+
156
+
157
+ # Training Function
158
+ def train_model(model, train_loader, test_loader, epochs=20, lr=0.001, device="cpu", decay_step=10, decay_rate=0.5):
159
+ model.to(device)
160
+ optimizer = optim.Adam(model.parameters(), lr=lr)
161
+ scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=decay_step, gamma=decay_rate)
162
+ criterion = nn.CrossEntropyLoss()
163
+
164
+ train_losses, test_f1_scores = [], []
165
+
166
+ for epoch in range(epochs):
167
+ model.train()
168
+ epoch_loss = 0
169
+ for batch_x, batch_y in train_loader:
170
+ batch_x, batch_y = batch_x.to(device), batch_y.to(device)
171
+
172
+ optimizer.zero_grad()
173
+ outputs = model(batch_x)
174
+ loss = criterion(outputs, batch_y)
175
+ loss.backward()
176
+ optimizer.step()
177
+
178
+ epoch_loss += loss.item()
179
+
180
+ train_losses.append(epoch_loss / len(train_loader))
181
+ scheduler.step()
182
+
183
+ # Evaluate on test set
184
+ f1 = evaluate_model(model, test_loader, device)
185
+ test_f1_scores.append(f1)
186
+
187
+ print(f"Epoch [{epoch+1}/{epochs}], Loss: {train_losses[-1]:.4f}, F1-score: {f1:.4f}, LR: {scheduler.get_last_lr()[0]:.6f}")
188
+
189
+ return train_losses, np.array([test_f1_scores])
190
+
191
+ # Model Evaluation
192
+ def evaluate_model(model, test_loader, device):
193
+ model.eval()
194
+ all_preds, all_labels = [], []
195
+
196
+ with torch.no_grad():
197
+ for batch_x, batch_y in test_loader:
198
+ batch_x, batch_y = batch_x.to(device), batch_y.to(device)
199
+
200
+ outputs = model(batch_x)
201
+ preds = torch.argmax(outputs, dim=1)
202
+
203
+ all_preds.extend(preds.cpu().numpy())
204
+ all_labels.extend(batch_y.cpu().numpy())
205
+
206
+ return f1_score(all_labels, all_preds, average='weighted')
207
+
208
+ # Visualization
209
+ import matplotlib.cm as cm
210
+ def plot_metrics(test_f1_scores, input_types, n_train=None, flag=0):
211
+ """
212
+ Plots the F1-score over epochs or number of training samples.
213
+
214
+ Parameters:
215
+ test_f1_scores (list): List of F1-score values per epoch or training samples.
216
+ input_types (list): List of input type names.
217
+ n_train (list, optional): Number of training samples (used when flag=1).
218
+ flag (int): 0 for plotting F1-score over epochs, 1 for F1-score over training samples.
219
+ """
220
+ plt.figure(figsize=(7, 5), dpi=200)
221
+ colors = cm.get_cmap('Spectral', test_f1_scores.shape[0]) # Using Spectral colormap
222
+ markers = ['o', 's', 'D', '^', 'v', 'P', '*', 'X', 'h'] # Different markers for curves
223
+
224
+ for r in range(test_f1_scores.shape[0]):
225
+ color = colors(r / (test_f1_scores.shape[0] - 1)) # Normalize color index
226
+ marker = markers[r % len(markers)] # Cycle through markers
227
+ if flag == 0:
228
+ plt.plot(test_f1_scores[r], linewidth=2, marker=marker, markersize=5, markeredgewidth=1.5,
229
+ markeredgecolor=color, color=color, label=f"{input_types[r]}")
230
+ else:
231
+ plt.plot(n_train, test_f1_scores[r], linewidth=2, marker=marker, markersize=6, markeredgewidth=1.5,
232
+ markeredgecolor=color, markerfacecolor='none', color=color, label=f"{input_types[r]}")
233
+ plt.xscale('log')
234
+
235
+ x_label = "Epochs" if flag == 0 else "Number of training samples"
236
+ plt.xlabel(f"{x_label}", fontsize=12)
237
+ plt.ylabel("F1-score", fontsize=12)
238
+
239
+ plt.legend()
240
+ plt.grid(alpha=0.3)
241
+ plt.show()
242
+
243
+ #%%
244
+ def classify_by_euclidean_distance(train_loader, test_loader, device="cpu"):
245
+ """
246
+ Classifies test samples based on the Euclidean distance to the mean of training samples from each class.
247
+ Computes the F1-score for evaluation.
248
+
249
+ Parameters:
250
+ - train_loader (DataLoader): DataLoader for training data.
251
+ - test_loader (DataLoader): DataLoader for test data.
252
+ - device (str): Device to run computations on ("cpu" or "cuda").
253
+
254
+ Returns:
255
+ - predictions (torch.Tensor): Predicted class for each test sample.
256
+ - f1 (float): Weighted F1-score.
257
+ """
258
+
259
+ # Store all training data and labels
260
+ train_data_list, train_labels_list = [], []
261
+ for batch_x, batch_y in train_loader:
262
+ train_data_list.append(batch_x.to(device))
263
+ train_labels_list.append(batch_y.to(device))
264
+
265
+ train_data = torch.cat(train_data_list)
266
+ train_labels = torch.cat(train_labels_list)
267
+
268
+ unique_classes = torch.unique(train_labels)
269
+ class_means = {}
270
+
271
+ # Compute mean feature vector for each class
272
+ for cls in unique_classes:
273
+ class_means[cls.item()] = train_data[train_labels == cls].mean(dim=0)
274
+
275
+ # Convert class means to tensor for vectorized computation
276
+ class_means_tensor = torch.stack([class_means[cls.item()] for cls in unique_classes])
277
+
278
+ # Store all test data and labels
279
+ test_data_list, test_labels_list = [], []
280
+ for batch_x, batch_y in test_loader:
281
+ test_data_list.append(batch_x.to(device))
282
+ test_labels_list.append(batch_y.to(device))
283
+
284
+ test_data = torch.cat(test_data_list)
285
+ test_labels = torch.cat(test_labels_list)
286
+
287
+ # Compute Euclidean distance between each test sample and all class means
288
+ dists = torch.cdist(test_data, class_means_tensor) # Shape (n_test, n_classes)
289
+
290
+ # Assign the class with the minimum distance
291
+ predictions = unique_classes[torch.argmin(dists, dim=1)]
292
+
293
+ # Compute F1-score
294
+ f1 = f1_score(test_labels.cpu().numpy(), predictions.cpu().numpy(), average='weighted')
295
+
296
+ return f1
297
+ #%%
298
+ def generate_gaussian_noise(data, snr_db):
299
+ """
300
+ Generate complex-valued Gaussian noise given an SNR and apply it to the data.
301
+
302
+ Args:
303
+ data (np.ndarray): Input data array of shape (n_samples, n_features), assumed to be complex-valued.
304
+ snr_db (float): Signal-to-Noise Ratio in decibels (dB).
305
+
306
+ Returns:
307
+ np.ndarray: Complex-valued Gaussian noise of the same shape as data.
308
+ """
309
+ # Compute signal power
310
+ signal_power = np.mean(np.abs(data) ** 2, axis=1, keepdims=True) # Shape: (n_samples, 1)
311
+
312
+ # Compute noise power from SNR
313
+ snr_linear = 10 ** (snr_db / 10)
314
+ noise_power = signal_power / snr_linear
315
+
316
+ # Generate complex Gaussian noise (real + imaginary parts)
317
+ noise_real = np.random.randn(*data.shape) * np.sqrt(noise_power / 2)
318
+ noise_imag = np.random.randn(*data.shape) * np.sqrt(noise_power / 2)
319
+
320
+ # Combine real and imaginary parts to form complex noise
321
+ noise = noise_real + 1j * noise_imag
322
+
323
+ return noise