feat: ShootTracker SQLite+JWT+YOLOv8
This commit is contained in:
@@ -0,0 +1,424 @@
|
||||
"""
|
||||
ShootTracker — Analyseur d'impacts de tir
|
||||
YOLOv8 + OpenCV
|
||||
|
||||
Responsabilités :
|
||||
1. Détecter les impacts sur la photo de cible (YOLOv8)
|
||||
2. Analyser les zones selon le type de cible (OpenCV)
|
||||
3. Gérer la soustraction d'image pour les impacts cumulés
|
||||
4. Calculer score, groupement et dispersion
|
||||
"""
|
||||
|
||||
import base64
|
||||
import io
|
||||
import math
|
||||
import logging
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ─── Structures ───────────────────────────────────────────────────────────────
|
||||
|
||||
@dataclass
|
||||
class ImpactPoint:
|
||||
x: float # coordonnée normalisée [0,1] depuis le centre
|
||||
y: float # coordonnée normalisée [0,1] depuis le centre
|
||||
zone: Optional[int] = None
|
||||
points: int = 0
|
||||
confidence: float = 1.0
|
||||
|
||||
@dataclass
|
||||
class AnalysisResult:
|
||||
impacts: list = field(default_factory=list)
|
||||
total_detected: int = 0
|
||||
score_zone: int = 0
|
||||
score_groupement: int = 0
|
||||
score_total: int = 0
|
||||
dispersion_radius: float = 0.0
|
||||
center_x: float = 0.5
|
||||
center_y: float = 0.5
|
||||
annotated_image_base64: Optional[str] = None
|
||||
error: Optional[str] = None
|
||||
warning: Optional[str] = None
|
||||
|
||||
|
||||
# ─── Constantes de cibles ─────────────────────────────────────────────────────
|
||||
|
||||
# Cible ISSF : 10 zones concentriques
|
||||
# zones[i] = rayon relatif de la zone (10 à l'intérieur, 1 à l'extérieur)
|
||||
# Normalisé sur 0.5 (demi-image = bord de la cible)
|
||||
ISSF_ZONES = {
|
||||
10: 0.05, 9: 0.10, 8: 0.16, 7: 0.22, 6: 0.28,
|
||||
5: 0.33, 4: 0.38, 3: 0.43, 2: 0.47, 1: 0.50
|
||||
}
|
||||
|
||||
# Couleurs OpenCV BGR pour l'annotation
|
||||
ZONE_COLORS_BGR = {
|
||||
10: (0, 200, 0), # Vert
|
||||
9: (80, 220, 0),
|
||||
8: (140, 220, 0),
|
||||
7: (0, 200, 100),
|
||||
6: (0, 200, 200),
|
||||
5: (0, 150, 255),
|
||||
4: (0, 80, 255),
|
||||
3: (0, 0, 255),
|
||||
2: (0, 0, 180),
|
||||
1: (0, 0, 100),
|
||||
}
|
||||
DEFAULT_COLOR_BGR = (128, 128, 128)
|
||||
|
||||
# ─── Classe principale ────────────────────────────────────────────────────────
|
||||
|
||||
class TargetAnalyzer:
|
||||
def __init__(self, model_path: str = 'yolov8n.pt'):
|
||||
self.model = None
|
||||
self.model_path = model_path
|
||||
self._load_model()
|
||||
|
||||
def _load_model(self):
|
||||
"""Charge le modèle YOLOv8."""
|
||||
try:
|
||||
from ultralytics import YOLO
|
||||
self.model = YOLO(self.model_path)
|
||||
logger.info(f"Modèle YOLOv8 chargé : {self.model_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Impossible de charger YOLOv8 : {e}")
|
||||
self.model = None
|
||||
|
||||
def analyze(
|
||||
self,
|
||||
image_b64: str,
|
||||
target_type: str,
|
||||
previous_image_b64: Optional[str] = None
|
||||
) -> AnalysisResult:
|
||||
"""
|
||||
Point d'entrée principal.
|
||||
|
||||
Args:
|
||||
image_b64: Image base64 de la cible
|
||||
target_type: 'issf' | 'silhouette' | 'libre'
|
||||
previous_image_b64: Image précédente de la même cible (pour soustraction)
|
||||
"""
|
||||
result = AnalysisResult()
|
||||
try:
|
||||
# 1. Décoder les images
|
||||
img_cv = self._decode_b64_to_cv2(image_b64)
|
||||
if img_cv is None:
|
||||
result.error = "Image illisible ou corrompue"
|
||||
return result
|
||||
|
||||
prev_cv = None
|
||||
if previous_image_b64:
|
||||
prev_cv = self._decode_b64_to_cv2(previous_image_b64)
|
||||
|
||||
# 2. Soustraction d'image si image précédente fournie
|
||||
working_img = img_cv.copy()
|
||||
if prev_cv is not None:
|
||||
working_img = self._subtract_images(img_cv, prev_cv)
|
||||
|
||||
# 3. Détection des impacts
|
||||
if self.model:
|
||||
impacts_raw = self._detect_with_yolo(working_img)
|
||||
else:
|
||||
# Fallback OpenCV si YOLOv8 non disponible
|
||||
impacts_raw = self._detect_with_opencv(working_img)
|
||||
|
||||
# 4. Analyse des zones selon le type de cible
|
||||
h, w = img_cv.shape[:2]
|
||||
scored_impacts = self._score_impacts(impacts_raw, w, h, target_type)
|
||||
result.impacts = [vars(imp) for imp in scored_impacts]
|
||||
result.total_detected = len(scored_impacts)
|
||||
|
||||
# 5. Calcul scores
|
||||
result.score_zone = sum(imp.points for imp in scored_impacts)
|
||||
|
||||
# 6. Calcul groupement (dispersion)
|
||||
if len(scored_impacts) >= 2:
|
||||
center_x, center_y, radius = self._compute_dispersion(scored_impacts)
|
||||
result.center_x = center_x
|
||||
result.center_y = center_y
|
||||
result.dispersion_radius = radius
|
||||
result.score_groupement = self._groupement_score(radius, result.score_zone)
|
||||
elif len(scored_impacts) == 1:
|
||||
result.center_x = scored_impacts[0].x
|
||||
result.center_y = scored_impacts[0].y
|
||||
result.dispersion_radius = 0.0
|
||||
result.score_groupement = round(result.score_zone * 0.10)
|
||||
|
||||
result.score_total = result.score_zone + result.score_groupement
|
||||
|
||||
# 7. Annotation de l'image
|
||||
annotated = self._annotate_image(img_cv, scored_impacts, target_type)
|
||||
result.annotated_image_base64 = self._encode_cv2_to_b64(annotated)
|
||||
|
||||
if result.total_detected == 0:
|
||||
result.warning = "Aucun impact détecté. Vérifiez la qualité de l'image."
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Erreur analyze: {e}")
|
||||
result.error = f"Erreur d'analyse : {str(e)}"
|
||||
|
||||
return result
|
||||
|
||||
# ─── Décodage / Encodage ──────────────────────────────────────────────────
|
||||
|
||||
def _decode_b64_to_cv2(self, b64_str: str) -> Optional[np.ndarray]:
|
||||
try:
|
||||
# Supprimer le préfixe data URL si présent
|
||||
if ',' in b64_str:
|
||||
b64_str = b64_str.split(',', 1)[1]
|
||||
data = base64.b64decode(b64_str)
|
||||
pil_img = Image.open(io.BytesIO(data)).convert('RGB')
|
||||
# Redimensionner si trop grande (max 2000px)
|
||||
max_dim = 2000
|
||||
w, h = pil_img.size
|
||||
if max(w, h) > max_dim:
|
||||
ratio = max_dim / max(w, h)
|
||||
pil_img = pil_img.resize((int(w * ratio), int(h * ratio)), Image.LANCZOS)
|
||||
return cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
|
||||
except Exception as e:
|
||||
logger.error(f"Décodage image échoué: {e}")
|
||||
return None
|
||||
|
||||
def _encode_cv2_to_b64(self, img: np.ndarray) -> str:
|
||||
_, buffer = cv2.imencode('.jpg', img, [cv2.IMWRITE_JPEG_QUALITY, 85])
|
||||
return 'data:image/jpeg;base64,' + base64.b64encode(buffer.tobytes()).decode('utf-8')
|
||||
|
||||
# ─── Détection YOLOv8 ─────────────────────────────────────────────────────
|
||||
|
||||
def _detect_with_yolo(self, img: np.ndarray) -> list[tuple[float, float, float]]:
|
||||
"""
|
||||
Détecte les impacts avec YOLOv8.
|
||||
Utilise yolov8n.pt (modèle généraliste) en cherchant des objets ronds/petits.
|
||||
Pour un fine-tuning ultérieur, entraîner sur des photos de cibles annotées.
|
||||
|
||||
Retourne une liste de (cx_norm, cy_norm, confidence) en coordonnées normalisées.
|
||||
"""
|
||||
h, w = img.shape[:2]
|
||||
try:
|
||||
results = self.model(img, conf=0.15, verbose=False)
|
||||
detections = []
|
||||
for r in results:
|
||||
for box in r.boxes:
|
||||
# Toutes les détections (le modèle de base peut détecter des "trous")
|
||||
x1, y1, x2, y2 = box.xyxy[0].tolist()
|
||||
cx = (x1 + x2) / 2 / w
|
||||
cy = (y1 + y2) / 2 / h
|
||||
conf = float(box.conf[0])
|
||||
# Filtrer les boîtes trop grandes (pas des impacts)
|
||||
box_w = (x2 - x1) / w
|
||||
box_h = (y2 - y1) / h
|
||||
if box_w < 0.15 and box_h < 0.15: # Max 15% de l'image
|
||||
detections.append((cx, cy, conf))
|
||||
logger.info(f"YOLOv8 détecté {len(detections)} impacts potentiels")
|
||||
return detections
|
||||
except Exception as e:
|
||||
logger.warning(f"YOLOv8 échoué, fallback OpenCV: {e}")
|
||||
return self._detect_with_opencv(img)
|
||||
|
||||
# ─── Détection OpenCV (fallback) ──────────────────────────────────────────
|
||||
|
||||
def _detect_with_opencv(self, img: np.ndarray) -> list[tuple[float, float, float]]:
|
||||
"""
|
||||
Détection des impacts par analyse d'image OpenCV.
|
||||
Recherche des zones sombres circulaires (trous dans la cible).
|
||||
"""
|
||||
h, w = img.shape[:2]
|
||||
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
# Amélioration du contraste
|
||||
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
|
||||
gray = clahe.apply(gray)
|
||||
|
||||
# Détecter les bords
|
||||
blurred = cv2.GaussianBlur(gray, (5, 5), 0)
|
||||
edges = cv2.Canny(blurred, 30, 100)
|
||||
|
||||
# Détection de cercles (Hough)
|
||||
min_r = max(3, int(min(w, h) * 0.005)) # Rayon min ~ 0.5%
|
||||
max_r = max(15, int(min(w, h) * 0.03)) # Rayon max ~ 3%
|
||||
circles = cv2.HoughCircles(
|
||||
gray, cv2.HOUGH_GRADIENT, dp=1.2,
|
||||
minDist=int(min(w, h) * 0.02),
|
||||
param1=100, param2=20,
|
||||
minRadius=min_r, maxRadius=max_r
|
||||
)
|
||||
|
||||
detections = []
|
||||
if circles is not None:
|
||||
circles = np.round(circles[0, :]).astype(int)
|
||||
for (cx, cy, r) in circles[:30]: # Max 30 impacts
|
||||
detections.append((cx / w, cy / h, 0.8))
|
||||
|
||||
# Si peu de cercles détectés, essayer par zones sombres
|
||||
if len(detections) < 2:
|
||||
detections.extend(self._detect_dark_spots(gray, w, h))
|
||||
|
||||
logger.info(f"OpenCV détecté {len(detections)} impacts")
|
||||
return detections
|
||||
|
||||
def _detect_dark_spots(self, gray: np.ndarray, w: int, h: int) -> list[tuple[float, float, float]]:
|
||||
"""Détecte les zones sombres (trous) dans l'image."""
|
||||
_, thresh = cv2.threshold(gray, 80, 255, cv2.THRESH_BINARY_INV)
|
||||
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
|
||||
opened = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, kernel)
|
||||
contours, _ = cv2.findContours(opened, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
|
||||
detections = []
|
||||
min_area = (w * h) * 0.0001 # 0.01% de l'image
|
||||
max_area = (w * h) * 0.005 # 0.5% de l'image
|
||||
|
||||
for cnt in contours:
|
||||
area = cv2.contourArea(cnt)
|
||||
if min_area <= area <= max_area:
|
||||
M = cv2.moments(cnt)
|
||||
if M['m00'] > 0:
|
||||
cx = M['m10'] / M['m00'] / w
|
||||
cy = M['m01'] / M['m00'] / h
|
||||
detections.append((cx, cy, 0.6))
|
||||
|
||||
return detections[:25] # Max 25
|
||||
|
||||
# ─── Soustraction d'image ────────────────────────────────────────────────
|
||||
|
||||
def _subtract_images(self, current: np.ndarray, previous: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Soustrait l'image précédente pour isoler les NOUVEAUX impacts.
|
||||
"""
|
||||
# Redimensionner si nécessaire
|
||||
if current.shape[:2] != previous.shape[:2]:
|
||||
previous = cv2.resize(previous, (current.shape[1], current.shape[0]))
|
||||
|
||||
# Conversion en niveaux de gris
|
||||
curr_gray = cv2.cvtColor(current, cv2.COLOR_BGR2GRAY)
|
||||
prev_gray = cv2.cvtColor(previous, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
# Soustraction absolue
|
||||
diff = cv2.absdiff(curr_gray, prev_gray)
|
||||
_, mask = cv2.threshold(diff, 25, 255, cv2.THRESH_BINARY)
|
||||
|
||||
# Appliquer le masque sur l'image courante
|
||||
result = current.copy()
|
||||
background = np.full_like(current, 200) # Fond gris clair
|
||||
background[mask == 255] = current[mask == 255]
|
||||
|
||||
return background
|
||||
|
||||
# ─── Scoring par zones ───────────────────────────────────────────────────
|
||||
|
||||
def _score_impacts(
|
||||
self, raw_detections: list[tuple[float, float, float]],
|
||||
w: int, h: int, target_type: str
|
||||
) -> list[ImpactPoint]:
|
||||
"""Assigne une zone et des points à chaque impact détecté."""
|
||||
impacts = []
|
||||
for (cx, cy, conf) in raw_detections:
|
||||
imp = ImpactPoint(x=cx, y=cy, confidence=conf)
|
||||
if target_type == 'issf':
|
||||
imp.zone, imp.points = self._get_issf_zone(cx, cy)
|
||||
else:
|
||||
# Silhouette / libre : pas de scoring par zones
|
||||
imp.zone = None
|
||||
imp.points = 1 # 1 point par tir
|
||||
impacts.append(imp)
|
||||
return impacts
|
||||
|
||||
def _get_issf_zone(self, cx: float, cy: float) -> tuple[Optional[int], int]:
|
||||
"""
|
||||
Calcule la zone ISSF (1-10) pour un impact à (cx, cy) normalisé.
|
||||
Assume que le centre de la cible est à (0.5, 0.5).
|
||||
"""
|
||||
dx = cx - 0.5
|
||||
dy = cy - 0.5
|
||||
dist = math.sqrt(dx*dx + dy*dy) # Distance au centre [0, ~0.7]
|
||||
|
||||
for zone in range(10, 0, -1): # De 10 (centre) à 1 (bord)
|
||||
if dist <= ISSF_ZONES[zone]:
|
||||
return zone, zone
|
||||
return None, 0 # Hors cible
|
||||
|
||||
# ─── Calcul groupement ───────────────────────────────────────────────────
|
||||
|
||||
def _compute_dispersion(self, impacts: list[ImpactPoint]) -> tuple[float, float, float]:
|
||||
"""Calcule le centre et le rayon de dispersion (en pixels normalisés * 1000)."""
|
||||
xs = [imp.x for imp in impacts]
|
||||
ys = [imp.y for imp in impacts]
|
||||
cx = sum(xs) / len(xs)
|
||||
cy = sum(ys) / len(ys)
|
||||
# Rayon moyen des distances au centre
|
||||
dists = [math.sqrt((x - cx)**2 + (y - cy)**2) for x, y in zip(xs, ys)]
|
||||
radius = max(dists) if dists else 0.0
|
||||
# Convertir en "pixels" (on multiplie par 1000 pour un nombre lisible)
|
||||
return cx, cy, round(radius * 1000, 2)
|
||||
|
||||
def _groupement_score(self, radius: float, zone_score: int) -> int:
|
||||
"""
|
||||
Bonus de groupement : jusqu'à 10% du score de zones.
|
||||
Radius : normalisé * 1000. Moins c'est grand, meilleur c'est.
|
||||
"""
|
||||
if radius <= 0: return 0
|
||||
# Excellent groupement (radius < 10) → 10%, mauvais (>100) → 0%
|
||||
if radius <= 10: ratio = 0.10
|
||||
elif radius <= 30: ratio = 0.07
|
||||
elif radius <= 60: ratio = 0.04
|
||||
elif radius <= 100: ratio = 0.02
|
||||
else: ratio = 0.0
|
||||
return round(zone_score * ratio)
|
||||
|
||||
# ─── Annotation de l'image ───────────────────────────────────────────────
|
||||
|
||||
def _annotate_image(
|
||||
self, img: np.ndarray,
|
||||
impacts: list[ImpactPoint],
|
||||
target_type: str
|
||||
) -> np.ndarray:
|
||||
"""Dessine les impacts annotés sur l'image."""
|
||||
annotated = img.copy()
|
||||
h, w = annotated.shape[:2]
|
||||
dot_radius = max(8, int(min(w, h) * 0.012))
|
||||
|
||||
for imp in impacts:
|
||||
px = int(imp.x * w)
|
||||
py = int(imp.y * h)
|
||||
color = ZONE_COLORS_BGR.get(imp.zone or 0, DEFAULT_COLOR_BGR)
|
||||
|
||||
# Cercle extérieur (blanc)
|
||||
cv2.circle(annotated, (px, py), dot_radius + 2, (255, 255, 255), 2)
|
||||
# Cercle de couleur par zone
|
||||
cv2.circle(annotated, (px, py), dot_radius, color, -1)
|
||||
# Texte de la zone
|
||||
if imp.zone:
|
||||
text = str(imp.zone)
|
||||
font_scale = max(0.4, dot_radius / 20)
|
||||
thickness = max(1, dot_radius // 10)
|
||||
(tw, th), _ = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness)
|
||||
cv2.putText(
|
||||
annotated, text,
|
||||
(px - tw // 2, py + th // 2),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, font_scale,
|
||||
(255, 255, 255), thickness, cv2.LINE_AA
|
||||
)
|
||||
|
||||
# Dessiner le centre de gravité si plusieurs impacts
|
||||
if len(impacts) >= 2:
|
||||
cx_px = int(sum(i.x for i in impacts) / len(impacts) * w)
|
||||
cy_px = int(sum(i.y for i in impacts) / len(impacts) * h)
|
||||
cv2.drawMarker(
|
||||
annotated, (cx_px, cy_px), (0, 255, 255),
|
||||
cv2.MARKER_CROSS, 20, 2
|
||||
)
|
||||
|
||||
# Watermark
|
||||
cv2.putText(
|
||||
annotated, 'ShootTracker AI',
|
||||
(10, h - 10),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.5,
|
||||
(100, 100, 100), 1, cv2.LINE_AA
|
||||
)
|
||||
|
||||
return annotated
|
||||
Reference in New Issue
Block a user