|
class LocalClassifier { |
|
constructor() { |
|
this.weights = new Map(); |
|
this.biases = new Map(); |
|
this.learningRate = 0.01; |
|
this.featureDim = 512; |
|
this.isInitialized = false; |
|
} |
|
|
|
initialize(featureDim = 512) { |
|
this.featureDim = featureDim; |
|
this.isInitialized = true; |
|
} |
|
|
|
|
|
trainOnFeedback(features, tag, feedback) { |
|
if (!this.isInitialized) { |
|
this.initialize(); |
|
} |
|
|
|
|
|
let target; |
|
switch (feedback) { |
|
case 'positive': |
|
target = 1.0; |
|
break; |
|
case 'negative': |
|
target = 0.0; |
|
break; |
|
case 'custom': |
|
target = 1.0; |
|
break; |
|
default: |
|
return; |
|
} |
|
|
|
|
|
if (!this.weights.has(tag)) { |
|
this.weights.set(tag, new Array(this.featureDim).fill(0).map(() => |
|
(Math.random() - 0.5) * 0.01 |
|
)); |
|
this.biases.set(tag, 0); |
|
} |
|
|
|
const weights = this.weights.get(tag); |
|
const bias = this.biases.get(tag); |
|
|
|
|
|
let logit = bias; |
|
for (let i = 0; i < features.length; i++) { |
|
logit += weights[i] * features[i]; |
|
} |
|
|
|
|
|
const prediction = 1 / (1 + Math.exp(-logit)); |
|
|
|
|
|
const error = prediction - target; |
|
|
|
|
|
for (let i = 0; i < features.length; i++) { |
|
weights[i] -= this.learningRate * error * features[i]; |
|
} |
|
this.biases.set(tag, bias - this.learningRate * error); |
|
|
|
|
|
this.weights.set(tag, weights); |
|
} |
|
|
|
|
|
predict(features, tag) { |
|
if (!this.weights.has(tag)) { |
|
return null; |
|
} |
|
|
|
const weights = this.weights.get(tag); |
|
const bias = this.biases.get(tag); |
|
|
|
let logit = bias; |
|
for (let i = 0; i < Math.min(features.length, weights.length); i++) { |
|
logit += weights[i] * features[i]; |
|
} |
|
|
|
|
|
return 1 / (1 + Math.exp(-logit)); |
|
} |
|
|
|
|
|
predictAll(features, candidateTags) { |
|
const predictions = []; |
|
|
|
for (const tag of candidateTags) { |
|
const confidence = this.predict(features, tag); |
|
if (confidence !== null) { |
|
predictions.push({ tag, confidence }); |
|
} |
|
} |
|
|
|
return predictions.sort((a, b) => b.confidence - a.confidence); |
|
} |
|
|
|
|
|
retrainOnBatch(feedbackData) { |
|
for (const item of feedbackData) { |
|
if (item.audioFeatures && item.correctedTags) { |
|
|
|
const features = this.extractSimpleFeatures(item.audioFeatures); |
|
|
|
|
|
for (const tagData of item.correctedTags) { |
|
this.trainOnFeedback(features, tagData.tag, tagData.feedback); |
|
} |
|
} |
|
} |
|
} |
|
|
|
|
|
extractSimpleFeatures(audioFeatures) { |
|
|
|
|
|
const features = new Array(this.featureDim).fill(0); |
|
|
|
if (audioFeatures) { |
|
|
|
features[0] = audioFeatures.duration / 60; |
|
features[1] = audioFeatures.sampleRate / 48000; |
|
features[2] = audioFeatures.numberOfChannels; |
|
|
|
|
|
const seed = this.simpleHash(JSON.stringify(audioFeatures)); |
|
for (let i = 3; i < this.featureDim; i++) { |
|
features[i] = this.seededRandom(seed + i) * 0.1; |
|
} |
|
} |
|
|
|
return features; |
|
} |
|
|
|
|
|
simpleHash(str) { |
|
let hash = 0; |
|
for (let i = 0; i < str.length; i++) { |
|
const char = str.charCodeAt(i); |
|
hash = ((hash << 5) - hash) + char; |
|
hash = hash & hash; |
|
} |
|
return Math.abs(hash); |
|
} |
|
|
|
|
|
seededRandom(seed) { |
|
const x = Math.sin(seed) * 10000; |
|
return x - Math.floor(x); |
|
} |
|
|
|
|
|
saveModel() { |
|
const modelData = { |
|
weights: Object.fromEntries(this.weights), |
|
biases: Object.fromEntries(this.biases), |
|
featureDim: this.featureDim, |
|
learningRate: this.learningRate |
|
}; |
|
|
|
localStorage.setItem('clipTaggerModel', JSON.stringify(modelData)); |
|
} |
|
|
|
|
|
loadModel() { |
|
const saved = localStorage.getItem('clipTaggerModel'); |
|
if (saved) { |
|
try { |
|
const modelData = JSON.parse(saved); |
|
this.weights = new Map(Object.entries(modelData.weights)); |
|
this.biases = new Map(Object.entries(modelData.biases)); |
|
this.featureDim = modelData.featureDim || 512; |
|
this.learningRate = modelData.learningRate || 0.01; |
|
this.isInitialized = true; |
|
return true; |
|
} catch (error) { |
|
console.error('Error loading model:', error); |
|
} |
|
} |
|
return false; |
|
} |
|
|
|
|
|
getModelStats() { |
|
return { |
|
trainedTags: this.weights.size, |
|
featureDim: this.featureDim, |
|
learningRate: this.learningRate, |
|
tags: Array.from(this.weights.keys()) |
|
}; |
|
} |
|
|
|
|
|
clearModel() { |
|
this.weights.clear(); |
|
this.biases.clear(); |
|
localStorage.removeItem('clipTaggerModel'); |
|
} |
|
} |
|
|
|
export default LocalClassifier; |