NILM Gemma 2B Fine-tuned for Appliance State Detection
This repository hosts a fine-tuned version of google/gemma-2b
specifically adapted for Non-Intrusive Load Monitoring (NILM), also known as energy disaggregation. The model identifies the operational (on/off) states of common household appliances based on a text-based representation of an aggregate electrical power signal.
The fine-tuning was performed using Parameter-Efficient Fine-Tuning (PEFT) with QLoRA, making it trainable even on resource-constrained environments like Google Colab's free tier.
Model Description
The model takes a JSON string representing a sequence of aggregate power readings (in Watts) and outputs a JSON string indicating the on/off state (1 for on, 0 for off) for a predefined set of appliances.
Predefined Appliances:
refrigerator
microwave
kettle
lights
Input Format
The input to the model should be a JSON string with a single key: "aggregate_signal"
, containing a list of numerical power values.
Example Input:
{"aggregate_signal":}
Output Format
The output from the model will be a JSON string with keys for each appliance and their predicted on/off state (1 or 0).
Example Output:
{"refrigerator": 1, "microwave": 1, "kettle": 0, "lights": 0}
How to Use
To use this model, you'll need the transformers
, peft
, bitsandbytes
, and torch
libraries.
Installation
pip install -q -U transformers peft bitsandbytes accelerate trl numpy
Inference Code
Here's how you can load the model and make predictions:
import json
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import PeftModel
# Define the base model and your fine-tuned model path on Hugging Face Hub
base_model_id = "google/gemma-2b"
hf_model_path = "louijiec/nilm-gemma-2b-finetuned" # Your model's path
# QLoRA configuration (must match the training configuration)
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
# Load the base model in 4-bit
print(f"Loading base model: {base_model_id}...")
base_model = AutoModelForCausalLM.from_pretrained(
base_model_id,
quantization_config=bnb_config,
torch_dtype=torch.bfloat16,
device_map="auto"
)
# Load the fine-tuned adapter from Hugging Face Hub
print(f"Loading PEFT adapter from Hugging Face Hub: {hf_model_path}...")
model = PeftModel.from_pretrained(base_model, hf_model_path)
model.eval() # Set model to evaluation mode
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(hf_model_path, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
print("Fine-tuned model loaded for inference.")
def predict_appliance_states(signal: list):
"""
Predicts appliance states for a given aggregate signal using the fine-tuned LLM.
"""
user_input = json.dumps({"aggregate_signal": signal})
# The prompt template should exactly match the one used during training
prompt = (
f"### System: You are an energy disaggregation assistant. "
f"Analyze the aggregate electrical signal (a sequence of power readings in Watts) "
f"and identify the operational states (on/off) of the predefined household appliances. "
f"Output the states as a JSON object, where 1 means 'on' and 0 means 'off'.\n\n"
f"### User: {user_input}\n\n"
f"### Assistant: "
)
inputs = tokenizer(prompt, return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=100, # Adjust based on expected output length
do_sample=False, # For deterministic output
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id
)
response = tokenizer.decode(outputs, skip_special_tokens=True)
# Extract only the assistant's response part
assistant_prefix = "### Assistant:"
if assistant_prefix in response:
response = response.split(assistant_prefix, 1).strip()
# Attempt to parse the JSON output
try:
predicted_states = json.loads(response)
return predicted_states
except json.JSONDecodeError:
print(f"Warning: Could not parse JSON output: {response}")
return None
# --- Example Usage ---
# Example 1: Refrigerator and Microwave running
signal_1 =
print(f"\nInput: {json.dumps({'aggregate_signal': signal_1})}")
prediction_1 = predict_appliance_states(signal_1)
print(f"Predicted: {prediction_1}")
# Example 2: Only Lights are on
signal_2 =
print(f"\nInput: {json.dumps({'aggregate_signal': signal_2})}")
prediction_2 = predict_appliance_states(signal_2)
print(f"Predicted: {prediction_2}")
# Example 3: Kettle is turned on mid-way (output reflects final state during signal)
signal_3 =
print(f"\nInput: {json.dumps({'aggregate_signal': signal_3})}")
prediction_3 = predict_appliance_states(signal_3)
print(f"Predicted: {prediction_3}")
# Example 4: All appliances off
signal_4 =
print(f"\nInput: {json.dumps({'aggregate_signal': signal_4})}")
prediction_4 = predict_appliance_states(signal_4)
print(f"Predicted: {prediction_4}")
Training Details
Base Model
The model is based on google/gemma-2b
, a lightweight and efficient open-source LLM.
Fine-tuning Method
Parameter-Efficient Fine-Tuning (PEFT) using QLoRA was employed to adapt the model to the NILM task. This method significantly reduces the memory and computational requirements, enabling training on consumer-grade GPUs.
QLoRA Configuration:
load_in_4bit=True
bnb_4bit_quant_type="nf4"
bnb_4bit_compute_dtype=torch.bfloat16
bnb_4bit_use_double_quant=True
LoRA Configuration:
r=16
(LoRA attention dimension)lora_alpha=16
(Scaling factor)target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
(Layers to apply LoRA to)lora_dropout=0.05
bias="none"
task_type="CAUSAL_LM"
Dataset
A synthetic dataset of 10,000 samples was generated specifically for this task. Each sample consists of an aggregate power signal (10 readings) and the corresponding on/off states of the four target appliances (refrigerator
, microwave
, kettle
, lights
). The dataset includes scenarios where appliance states change mid-way through the signal window to simulate real-world events.
The data generation process ensures variability in power consumption, accounting for baseline noise and appliance-specific power fluctuations.
Training Environment
The model was fine-tuned on a Google Colab free-tier instance, typically leveraging a T4 GPU.
Training Arguments (Illustrative)
per_device_train_batch_size=2
gradient_accumulation_steps=4
optim="paged_adamw_8bit"
logging_steps=50
learning_rate=2e-4
max_steps=500
(can be adjusted for more training)bf16=True
Limitations and Considerations
- Synthetic Data: The model was trained on synthetic data. Its performance on real-world NILM datasets might vary and could require further fine-tuning on actual sensor data.
- Fixed Appliances: The model is trained for a fixed set of four appliances. Extending it to new appliances would require further fine-tuning with a dataset including those appliances.
- JSON Output Robustness: While instruction-tuned to output JSON, LLMs can sometimes deviate from strict formatting, especially with unusual inputs. Error handling for JSON parsing is crucial in practical applications.
- Signal Length: The model was trained on signals of 10 readings. Significant deviations in input signal length might impact performance.
- Power Profiles: The synthetic data uses simplified power profiles. Real appliances have more complex and varied power signatures.
Citation
If you use this model, please consider citing the original Gemma model:
@article{google2024gemma,
title={Gemma: A Family of Lightweight, Open Models},
author={Google},
year={2024},
url={https://blog.google/technology/ai/gemma-open-models/}
}
@software{peft,
author = {Tim Dettmers and others},
title = {PEFT: Parameter-Efficient Fine-Tuning},
url = {https://github.com/huggingface/peft},
year = {2023}
}
@software{trl,
author = {Victor Sanh and others},
title = {TRL: Transformer Reinforcement Learning},
url = {https://github.com/huggingface/trl},
year = {2023}
}
- Downloads last month
- 35