File size: 4,535 Bytes
5129aaa
 
196b164
 
 
 
5129aaa
 
 
 
196b164
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5129aaa
196b164
 
5129aaa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196b164
 
 
 
 
 
 
 
 
 
 
5129aaa
 
 
 
 
 
 
 
 
 
 
 
196b164
 
 
 
5129aaa
 
196b164
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5129aaa
 
196b164
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5129aaa
 
 
196b164
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
from datasets import load_dataset as _load_dataset
from os import environ
from PIL import Image
import numpy as np
import json

from pyarrow.parquet import ParquetFile
from pyarrow import Table as pa_Table
from datasets import Dataset

DATASET = "satellogic/EarthView"

sets = {
    "satellogic": {
        "shards" : 3676,
    },
    "sentinel_1": {
        "shards" : 1763,
    },
    "neon": {
        "config" : "default",
        "shards" : 607,
        "path"   : "data",
    }
}

def get_subsets():
    return sets.keys()

def get_nshards(subset):
    return sets[subset]["shards"]

def get_path(subset):
    return sets[subset].get("path", subset)

def get_config(subset):
    return sets[subset].get("config", subset)

def load_dataset(subset, dataset="satellogic/EarthView", split="train", shards = None, streaming=True, **kwargs):
    config = get_config(subset)
    nshards = get_nshards(subset)
    path   = get_path(subset)
    if shards is None:
        data_files = None
    else:
        data_files = [f"{path}/{split}-{shard:05d}-of-{nshards:05d}.parquet" for shard in shards]
        data_files = {split: data_files}

    ds = _load_dataset(
        path=dataset,
        name=config,
        save_infos=True,
        split=split,
        data_files=data_files,
        streaming=streaming,
        token=environ.get("HF_TOKEN", None),
        **kwargs)

    return ds    

def load_parquet(subset_or_filename, batch_size=100):
    if subset_or_filename in get_subsets():
        pqfile = ParquetFile(f"dataset/{subset_or_filename}/sample.parquet")
    else:
        pqfile = subset_or_filename

    batch  = pqfile.iter_batches(batch_size=batch_size)
    return Dataset(pa_Table.from_batches(batch))

def item_to_images(subset, item):
    """
    Converts the images within an item (arrays), as retrieved from the dataset to proper PIL.Image

    subset: The name of the Subset, one of "satellogic", "default", "sentinel-1"
    item: The item as retrieved from the subset

    returns the item, with arrays converted to PIL.Image
    """
    metadata = item["metadata"]
    if type(metadata) == str:
        metadata = json.loads(metadata)

    item = {
        k: np.asarray(v).astype("uint8")
            for k,v in item.items()
                if k != "metadata"
    }
    item["metadata"] = metadata
    
    if subset == "satellogic":
        # item["rgb"] = [
        #     Image.fromarray(np.average(image.transpose(1,2,0), 2).astype("uint8"))
        #         for image in item["rgb"]
        # ]
        rgbs = []
        for rgb in item["rgb"]:
            rgbs.append(Image.fromarray(rgb.transpose(1,2,0)))
            # rgbs.append(Image.fromarray(rgb[0,:,:]))      # Red
            # rgbs.append(Image.fromarray(rgb[1,:,:]))      # Green
            # rgbs.append(Image.fromarray(rgb[2,:,:]))      # Blue
        item["rgb"] = rgbs
        item["1m"] = [
            Image.fromarray(image[0,:,:])
                for image in item["1m"]
        ]
        count = len(item["1m"])
    elif subset == "sentinel_1":
        # Mapping of V and H to RGB. May not be correct
        # https://gis.stackexchange.com/questions/400726/creating-composite-rgb-images-from-sentinel-1-channels
        i10m = item["10m"]
        i10m = np.concatenate(
            (   i10m,
                np.expand_dims(
                    i10m[:,0,:,:]/(i10m[:,1,:,:]+0.01)*256,
                    1
                ).astype("uint8")
            ),
            1
        )
        item["10m"] = [
            Image.fromarray(image.transpose(1,2,0))
                for image in i10m
        ]
        count = len(item["10m"])
    elif subset == "neon":
        item["rgb"] = [
            Image.fromarray(image.transpose(1,2,0))
                for image in item["rgb"]
        ]
        item["chm"] = [
            Image.fromarray(image[0])
                for image in item["chm"]
        ]

        # The next is a very arbitrary conversion from the 369 hyperspectral data to RGB
        # It just averages each 1/3 of the bads and assigns it to a channel
        item["1m"] = [
            Image.fromarray(
                np.concatenate((
                    np.expand_dims(np.average(image[:124],0),2),
                    np.expand_dims(np.average(image[124:247],0),2),
                    np.expand_dims(np.average(image[247:],0),2))
                ,2).astype("uint8"))
                    for image in item["1m"]
        ]
        count = len(item["rgb"])
    
    item["metadata"]["count"] = count
    return item