175 lines
4.8 KiB
Dart
175 lines
4.8 KiB
Dart
import 'dart:io';
|
|
import 'dart:math' as math;
|
|
import 'dart:typed_data';
|
|
import 'package:tflite_flutter/tflite_flutter.dart';
|
|
import 'package:image/image.dart' as img;
|
|
import 'target_detection_service.dart';
|
|
|
|
class YOLOImpactDetectionService {
|
|
Interpreter? _interpreter;
|
|
|
|
static const String modelPath = 'assets/models/yolov11n_impact.tflite';
|
|
static const String labelsPath = 'assets/models/labels.txt';
|
|
|
|
Future<void> init() async {
|
|
if (_interpreter != null) return;
|
|
|
|
try {
|
|
// Try loading the specific YOLOv11 model first, fallback to v8 if not found
|
|
try {
|
|
_interpreter = await Interpreter.fromAsset(modelPath);
|
|
} catch (e) {
|
|
print('YOLOv11 model not found at $modelPath, trying YOLOv8 fallback');
|
|
_interpreter = await Interpreter.fromAsset(
|
|
'assets/models/yolov8n_impact.tflite',
|
|
);
|
|
}
|
|
|
|
print('YOLO Interpreter loaded successfully');
|
|
} catch (e) {
|
|
print('Error loading YOLO model: $e');
|
|
}
|
|
}
|
|
|
|
Future<List<DetectedImpactResult>> detectImpacts(String imagePath) async {
|
|
if (_interpreter == null) await init();
|
|
if (_interpreter == null) return [];
|
|
|
|
try {
|
|
final bytes = File(imagePath).readAsBytesSync();
|
|
final originalImage = img.decodeImage(bytes);
|
|
if (originalImage == null) return [];
|
|
|
|
// YOLOv8/v11 usually takes 640x640
|
|
const int inputSize = 640;
|
|
final resizedImage = img.copyResize(
|
|
originalImage,
|
|
width: inputSize,
|
|
height: inputSize,
|
|
);
|
|
|
|
// Prepare input tensor
|
|
var input = _imageToByteListFloat32(resizedImage, inputSize);
|
|
|
|
// Raw YOLO output shape usually [1, 4 + num_classes, 8400]
|
|
// For single class "impact", it's [1, 5, 8400]
|
|
var output = List<double>.filled(1 * 5 * 8400, 0).reshape([1, 5, 8400]);
|
|
|
|
_interpreter!.run(input, output);
|
|
|
|
return _processOutput(
|
|
output[0],
|
|
originalImage.width,
|
|
originalImage.height,
|
|
);
|
|
} catch (e) {
|
|
print('Error during YOLO inference: $e');
|
|
return [];
|
|
}
|
|
}
|
|
|
|
List<DetectedImpactResult> _processOutput(
|
|
List<List<double>> output,
|
|
int imgWidth,
|
|
int imgHeight,
|
|
) {
|
|
final List<_Detection> candidates = [];
|
|
const double threshold = 0.25;
|
|
|
|
// output is [5, 8400] -> [x, y, w, h, conf]
|
|
for (int i = 0; i < 8400; i++) {
|
|
final double confidence = output[4][i];
|
|
if (confidence > threshold) {
|
|
candidates.add(
|
|
_Detection(
|
|
x: output[0][i],
|
|
y: output[1][i],
|
|
w: output[2][i],
|
|
h: output[3][i],
|
|
confidence: confidence,
|
|
),
|
|
);
|
|
}
|
|
}
|
|
|
|
// Apply Non-Max Suppression (NMS)
|
|
final List<_Detection> suppressed = _nms(candidates);
|
|
|
|
return suppressed
|
|
.map(
|
|
(det) => DetectedImpactResult(
|
|
x: det.x / 640.0,
|
|
y: det.y / 640.0,
|
|
radius: 5.0,
|
|
suggestedScore: 0,
|
|
),
|
|
)
|
|
.toList();
|
|
}
|
|
|
|
List<_Detection> _nms(List<_Detection> detections) {
|
|
if (detections.isEmpty) return [];
|
|
|
|
// Sort by confidence descending
|
|
detections.sort((a, b) => b.confidence.compareTo(a.confidence));
|
|
|
|
final List<_Detection> selected = [];
|
|
final List<bool> active = List.filled(detections.length, true);
|
|
|
|
for (int i = 0; i < detections.length; i++) {
|
|
if (!active[i]) continue;
|
|
|
|
selected.add(detections[i]);
|
|
|
|
for (int j = i + 1; j < detections.length; j++) {
|
|
if (!active[j]) continue;
|
|
|
|
if (_iou(detections[i], detections[j]) > 0.45) {
|
|
active[j] = false;
|
|
}
|
|
}
|
|
}
|
|
|
|
return selected;
|
|
}
|
|
|
|
double _iou(_Detection a, _Detection b) {
|
|
final double areaA = a.w * a.h;
|
|
final double areaB = b.w * b.h;
|
|
|
|
final double x1 = math.max(a.x - a.w / 2, b.x - b.w / 2);
|
|
final double y1 = math.max(a.y - a.h / 2, b.y - b.h / 2);
|
|
final double x2 = math.min(a.x + a.w / 2, b.x + b.w / 2);
|
|
final double y2 = math.min(a.y + a.h / 2, b.y + b.h / 2);
|
|
|
|
final double intersection = math.max(0.0, x2 - x1) * math.max(0.0, y2 - y1);
|
|
return intersection / (areaA + areaB - intersection);
|
|
}
|
|
|
|
Uint8List _imageToByteListFloat32(img.Image image, int inputSize) {
|
|
var convertedBytes = Float32List(1 * inputSize * inputSize * 3);
|
|
var buffer = Float32List.view(convertedBytes.buffer);
|
|
int pixelIndex = 0;
|
|
for (int i = 0; i < inputSize; i++) {
|
|
for (int j = 0; j < inputSize; j++) {
|
|
var pixel = image.getPixel(j, i);
|
|
buffer[pixelIndex++] = (pixel.r / 255.0);
|
|
buffer[pixelIndex++] = (pixel.g / 255.0);
|
|
buffer[pixelIndex++] = (pixel.b / 255.0);
|
|
}
|
|
}
|
|
return convertedBytes.buffer.asUint8List();
|
|
}
|
|
}
|
|
|
|
class _Detection {
|
|
final double x, y, w, h, confidence;
|
|
_Detection({
|
|
required this.x,
|
|
required this.y,
|
|
required this.w,
|
|
required this.h,
|
|
required this.confidence,
|
|
});
|
|
}
|