Fine-tuned Gemma-3 4B for Multi-Task Customer Service Complaint Analysis
This repository contains a google/gemma-3-4b-it
model that has been fine-tuned using QLoRA for a comprehensive, multi-task customer service application. The model was trained on a synthetic dataset of fashion-related customer complaints to perform both causal language modeling (generating a structured JSON response) and several classification tasks simultaneously via specialized classification heads.
This model is designed to act as an "agent" that can ingest a customer complaint and its surrounding context, then output a complete analysis covering multiple business-critical dimensions.
Model Capabilities
This model is trained to perform 8 classification tasks simultaneously based on the input complaint:
is_actionable
: Determines if the complaint requires a direct action (boolean).complaint_category
: Classifies the complaint into one of 11 categories (e.g., "Sizing Issue", "Damaged Item").decision_recommendation
: Recommends a course of action from 11 options (e.g., "Full_Refund_With_Return").info_complete
: Assesses if all necessary information is present to resolve the issue (boolean).tone
: Classifies the required tone for a formal response (e.g., "Empathetic_Standard").refund_percentage
: Suggests a specific refund percentage (0-100).sentiment
: Detects the customer's sentiment (e.g., "negative", "very_negative").aggression
: Detects the level of aggression in the customer's message.
How to Use (for Classification)
This model uses custom classification heads and requires the GemmaComplaintResolver
wrapper class from the training notebook to be used correctly.
import torch
from transformers import AutoTokenizer, AutoConfig
from peft import PeftModel
from huggingface_hub import hf_hub_download
import os
# You must have the GemmaComplaintResolver class definition in your environment.
# Assuming it's defined as it was in the training notebook...
# --- Configuration ---
repo_id = "ShovalBenjer/gemma-3-4b-fashion-multitask_A4000_v7"
device = "cuda" if torch.cuda.is_available() else "cpu"
# --- 1. Load Tokenizer and Model Config ---
tokenizer = AutoTokenizer.from_pretrained(repo_id)
config = AutoConfig.from_pretrained("google/gemma-3-4b-it", trust_remote_code=True)
# Define the label structure the model was trained with
num_labels_dict = {
"is_actionable": 2, "complaint_category": 11, "decision_recommendation": 11,
"info_complete": 2, "tone": 7, "refund_percentage": 13,
"sentiment": 6, "aggression": 5
}
# --- 2. Instantiate the Custom Model Wrapper ---
# IMPORTANT: This assumes the GemmaComplaintResolver class is defined.
model = GemmaComplaintResolver(
base_model_name_or_path="google/gemma-3-4b-it",
num_labels_dict=num_labels_dict,
model_config_for_base_loading=config,
)
# --- 3. Load the Fine-Tuned Weights ---
# a) Load the classification head weights
weights_path = hf_hub_download(repo_id=repo_id, filename="classification_heads.pth")
model.load_state_dict(torch.load(weights_path, map_location='cpu'), strict=False)
# b) Apply the LoRA adapter
model = PeftModel.from_pretrained(model, repo_id)
# --- 4. Prepare for Inference ---
# Cast to appropriate dtype and move to device
compute_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
model.to(dtype=compute_dtype).to(device).eval()
# --- 5. Run Inference ---
customer_complaint = "The t-shirt I ordered arrived with a huge hole in it! I'm very angry and want a full refund immediately."
# The model expects the full prompt structure used during training.
# In this notebook, the pre-processed column was 'text_for_lm'.
# The structure inside 'text_for_lm' was: <start_of_turn>user\n{complaint_details}<end_of_turn>\n<start_of_turn>model\n{json_output}<eos>
# For inference on just the classification heads, we only need the prompt part.
input_text = f"<start_of_turn>user\\n{customer_complaint}<end_of_turn>\\n<start_of_turn>model\\n"
inputs = tokenizer(input_text, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
# --- 6. Decode a Prediction ---
# Example: Get the predicted complaint category
category_logits = outputs['logits_complaint_category']
predicted_category_id = torch.argmax(category_logits, dim=-1).item()
complaint_categories = ["Sizing Issue", "Damaged Item", "Not as Described", "Shipping Problem", "Policy Inquiry", "Late Delivery", "Wrong Item Received", "Quality Issue", "Return Process Issue", "Other", "N/A"]
predicted_category = complaint_categories[predicted_category_id]
print(f"Customer Complaint: '{customer_complaint}'")
print(f"Predicted Complaint Category: {predicted_category}")