Files
impact/lib/services/yolo_impact_detection_service.dart
2026-03-12 22:03:40 +01:00

175 lines
4.8 KiB
Dart

import 'dart:io';
import 'dart:math' as math;
import 'dart:typed_data';
import 'package:tflite_flutter/tflite_flutter.dart';
import 'package:image/image.dart' as img;
import 'target_detection_service.dart';
class YOLOImpactDetectionService {
Interpreter? _interpreter;
static const String modelPath = 'assets/models/yolov11n_impact.tflite';
static const String labelsPath = 'assets/models/labels.txt';
Future<void> init() async {
if (_interpreter != null) return;
try {
// Try loading the specific YOLOv11 model first, fallback to v8 if not found
try {
_interpreter = await Interpreter.fromAsset(modelPath);
} catch (e) {
print('YOLOv11 model not found at $modelPath, trying YOLOv8 fallback');
_interpreter = await Interpreter.fromAsset(
'assets/models/yolov8n_impact.tflite',
);
}
print('YOLO Interpreter loaded successfully');
} catch (e) {
print('Error loading YOLO model: $e');
}
}
Future<List<DetectedImpactResult>> detectImpacts(String imagePath) async {
if (_interpreter == null) await init();
if (_interpreter == null) return [];
try {
final bytes = File(imagePath).readAsBytesSync();
final originalImage = img.decodeImage(bytes);
if (originalImage == null) return [];
// YOLOv8/v11 usually takes 640x640
const int inputSize = 640;
final resizedImage = img.copyResize(
originalImage,
width: inputSize,
height: inputSize,
);
// Prepare input tensor
var input = _imageToByteListFloat32(resizedImage, inputSize);
// Raw YOLO output shape usually [1, 4 + num_classes, 8400]
// For single class "impact", it's [1, 5, 8400]
var output = List<double>.filled(1 * 5 * 8400, 0).reshape([1, 5, 8400]);
_interpreter!.run(input, output);
return _processOutput(
output[0],
originalImage.width,
originalImage.height,
);
} catch (e) {
print('Error during YOLO inference: $e');
return [];
}
}
List<DetectedImpactResult> _processOutput(
List<List<double>> output,
int imgWidth,
int imgHeight,
) {
final List<_Detection> candidates = [];
const double threshold = 0.25;
// output is [5, 8400] -> [x, y, w, h, conf]
for (int i = 0; i < 8400; i++) {
final double confidence = output[4][i];
if (confidence > threshold) {
candidates.add(
_Detection(
x: output[0][i],
y: output[1][i],
w: output[2][i],
h: output[3][i],
confidence: confidence,
),
);
}
}
// Apply Non-Max Suppression (NMS)
final List<_Detection> suppressed = _nms(candidates);
return suppressed
.map(
(det) => DetectedImpactResult(
x: det.x / 640.0,
y: det.y / 640.0,
radius: 5.0,
suggestedScore: 0,
),
)
.toList();
}
List<_Detection> _nms(List<_Detection> detections) {
if (detections.isEmpty) return [];
// Sort by confidence descending
detections.sort((a, b) => b.confidence.compareTo(a.confidence));
final List<_Detection> selected = [];
final List<bool> active = List.filled(detections.length, true);
for (int i = 0; i < detections.length; i++) {
if (!active[i]) continue;
selected.add(detections[i]);
for (int j = i + 1; j < detections.length; j++) {
if (!active[j]) continue;
if (_iou(detections[i], detections[j]) > 0.45) {
active[j] = false;
}
}
}
return selected;
}
double _iou(_Detection a, _Detection b) {
final double areaA = a.w * a.h;
final double areaB = b.w * b.h;
final double x1 = math.max(a.x - a.w / 2, b.x - b.w / 2);
final double y1 = math.max(a.y - a.h / 2, b.y - b.h / 2);
final double x2 = math.min(a.x + a.w / 2, b.x + b.w / 2);
final double y2 = math.min(a.y + a.h / 2, b.y + b.h / 2);
final double intersection = math.max(0.0, x2 - x1) * math.max(0.0, y2 - y1);
return intersection / (areaA + areaB - intersection);
}
Uint8List _imageToByteListFloat32(img.Image image, int inputSize) {
var convertedBytes = Float32List(1 * inputSize * inputSize * 3);
var buffer = Float32List.view(convertedBytes.buffer);
int pixelIndex = 0;
for (int i = 0; i < inputSize; i++) {
for (int j = 0; j < inputSize; j++) {
var pixel = image.getPixel(j, i);
buffer[pixelIndex++] = (pixel.r / 255.0);
buffer[pixelIndex++] = (pixel.g / 255.0);
buffer[pixelIndex++] = (pixel.b / 255.0);
}
}
return convertedBytes.buffer.asUint8List();
}
}
class _Detection {
final double x, y, w, h, confidence;
_Detection({
required this.x,
required this.y,
required this.w,
required this.h,
required this.confidence,
});
}