Leire Aguirre
		
	commited on
		
		
					Commit 
							
							·
						
						16179ad
	
1
								Parent(s):
							
							a4f672b
								
update color scheme and include tooltip
Browse files- src/memory.js +38 -24
    	
        src/memory.js
    CHANGED
    
    | @@ -220,27 +220,50 @@ export function updateGraph() { | |
| 220 |  | 
| 221 | 
             
                const color = d => {
         | 
| 222 | 
             
                    switch (d.data.name) {
         | 
| 223 | 
            -
                        case 'Parameters': return '# | 
| 224 | 
            -
                        case 'Gradients': return '# | 
| 225 | 
            -
                        case 'OptimizerAverages': return '# | 
| 226 | 
            -
                        case 'activationMemory': return '# | 
| 227 | 
            -
                        case 'fixed100GB': return '# | 
| 228 | 
            -
                        case 'Attention': return '# | 
| 229 | 
            -
                        case 'Feedforward': return '# | 
| 230 | 
            -
                        case 'LayerNorm': return '# | 
| 231 | 
            -
                        case 'Dropout': return '# | 
| 232 | 
            -
                        case 'Projection': return '# | 
| 233 | 
            -
                        case 'Cross Entropy': return '# | 
| 234 | 
            -
                        case 'Total': return '# | 
| 235 | 
            -
                        case 'root': return '# | 
| 236 | 
             
                        default: return '#a0c4ff';  // Lighter Blue (for unexpected cases)
         | 
| 237 | 
             
                    }
         | 
| 238 | 
             
                };
         | 
| 239 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 240 | 
             
                const cell = svg.selectAll("g")
         | 
| 241 | 
             
                    .data(root.descendants())
         | 
| 242 | 
             
                    .join("g")
         | 
| 243 | 
            -
                    .attr("transform", d => `translate(${d.x0},${d.y0})`) | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 244 |  | 
| 245 | 
             
                cell.append("rect")
         | 
| 246 | 
             
                    .attr("width", d => d.x1 - d.x0)
         | 
| @@ -276,15 +299,6 @@ export function updateGraph() { | |
| 276 | 
             
                        }
         | 
| 277 | 
             
                    });
         | 
| 278 |  | 
| 279 | 
            -
                // Add invisible rect for better hover area
         | 
| 280 | 
            -
                cell.append("rect")
         | 
| 281 | 
            -
                    .attr("width", d => d.x1 - d.x0)
         | 
| 282 | 
            -
                    .attr("height", d => d.y1 - d.y0)
         | 
| 283 | 
            -
                    .attr("fill", "none")
         | 
| 284 | 
            -
                    .attr("pointer-events", "all")
         | 
| 285 | 
            -
                    .append("title")
         | 
| 286 | 
            -
                    .text(d => `${d.data.name}: ${formatBytes(d.value)}`);
         | 
| 287 | 
            -
             | 
| 288 | 
             
                // Adjust legend positioning
         | 
| 289 | 
             
                const legendData = root.children[0].children.concat(root.children[0]);
         | 
| 290 | 
             
                const legend = svg.append("g")
         | 
| @@ -302,7 +316,7 @@ export function updateGraph() { | |
| 302 | 
             
                    .attr("width", 19)
         | 
| 303 | 
             
                    .attr("height", 19)
         | 
| 304 | 
             
                    .attr("fill", d => color(d))
         | 
| 305 | 
            -
                    .attr("stroke", ' | 
| 306 | 
             
                    .attr("stroke-width", 2);
         | 
| 307 |  | 
| 308 | 
             
                legend.append("text")
         | 
|  | |
| 220 |  | 
| 221 | 
             
                const color = d => {
         | 
| 222 | 
             
                    switch (d.data.name) {
         | 
| 223 | 
            +
                        case 'Parameters': return '#117fc9';  // Blue
         | 
| 224 | 
            +
                        case 'Gradients': return '#ffad5c';  // Orange
         | 
| 225 | 
            +
                        case 'OptimizerAverages': return '#e1576b';  // Red
         | 
| 226 | 
            +
                        case 'activationMemory': return '#ffad5c';  // Orange
         | 
| 227 | 
            +
                        case 'fixed100GB': return '#80cb75';  // Green
         | 
| 228 | 
            +
                        case 'Attention': return '#e1576b';  // Red
         | 
| 229 | 
            +
                        case 'Feedforward': return '#2f94d9';  // Light Blue
         | 
| 230 | 
            +
                        case 'LayerNorm': return '#fb8b28';  // Dark Orange
         | 
| 231 | 
            +
                        case 'Dropout': return '#4ead4e';  // Dark Green
         | 
| 232 | 
            +
                        case 'Projection': return '#d94361';  // Dark Red
         | 
| 233 | 
            +
                        case 'Cross Entropy': return '#b492d3';  // Violet
         | 
| 234 | 
            +
                        case 'Total': return '#80cb75';  // Green
         | 
| 235 | 
            +
                        case 'root': return '#f3f3f3';  // Light Grey
         | 
| 236 | 
             
                        default: return '#a0c4ff';  // Lighter Blue (for unexpected cases)
         | 
| 237 | 
             
                    }
         | 
| 238 | 
             
                };
         | 
| 239 |  | 
| 240 | 
            +
                const tooltip = d3.select('body')
         | 
| 241 | 
            +
                  .append('div')
         | 
| 242 | 
            +
                  .attr('id', 'tooltip')
         | 
| 243 | 
            +
                  .style('opacity', 0)
         | 
| 244 | 
            +
                  .style('position', 'absolute')
         | 
| 245 | 
            +
                  .style('background-color', 'white')
         | 
| 246 | 
            +
                  .style('padding', '4px')
         | 
| 247 | 
            +
                  .style('font-size', '12px')
         | 
| 248 | 
            +
                  .style('border-radius', '5px')
         | 
| 249 | 
            +
                  .style('box-shadow', '0px 0px 5px 0px rgba(0,0,0,0.3)');
         | 
| 250 | 
            +
             | 
| 251 | 
            +
             | 
| 252 | 
             
                const cell = svg.selectAll("g")
         | 
| 253 | 
             
                    .data(root.descendants())
         | 
| 254 | 
             
                    .join("g")
         | 
| 255 | 
            +
                    .attr("transform", d => `translate(${d.x0},${d.y0})`)
         | 
| 256 | 
            +
                    .on('mouseover', (event, d) => {
         | 
| 257 | 
            +
                      const name = d.data.name;
         | 
| 258 | 
            +
                      const value = formatBytes(d.value);
         | 
| 259 | 
            +
                      tooltip.transition().duration(200).text(`${name}: ${value}`)
         | 
| 260 | 
            +
                    })
         | 
| 261 | 
            +
                    .on('mouseout', function() {
         | 
| 262 | 
            +
                      tooltip.style('opacity', 0)
         | 
| 263 | 
            +
                    })
         | 
| 264 | 
            +
                    .on('mousemove', function(event) {
         | 
| 265 | 
            +
                      tooltip.style('left', (event.pageX + 10) + 'px').style('top', (event.pageY + 10) + 'px').style('opacity', 1)
         | 
| 266 | 
            +
                    });
         | 
| 267 |  | 
| 268 | 
             
                cell.append("rect")
         | 
| 269 | 
             
                    .attr("width", d => d.x1 - d.x0)
         | 
|  | |
| 299 | 
             
                        }
         | 
| 300 | 
             
                    });
         | 
| 301 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 302 | 
             
                // Adjust legend positioning
         | 
| 303 | 
             
                const legendData = root.children[0].children.concat(root.children[0]);
         | 
| 304 | 
             
                const legend = svg.append("g")
         | 
|  | |
| 316 | 
             
                    .attr("width", 19)
         | 
| 317 | 
             
                    .attr("height", 19)
         | 
| 318 | 
             
                    .attr("fill", d => color(d))
         | 
| 319 | 
            +
                    .attr("stroke", '#f3f3f3')
         | 
| 320 | 
             
                    .attr("stroke-width", 2);
         | 
| 321 |  | 
| 322 | 
             
                legend.append("text")
         | 
