170 lines
5.7 KiB
Python
170 lines
5.7 KiB
Python
"""
|
|
ShootTracker — Microservice IA
|
|
FastAPI + YOLOv8 + OpenCV
|
|
|
|
Endpoints :
|
|
POST /analyze — Analyse d'une photo de cible
|
|
GET /health — Santé du service
|
|
GET / — Infos
|
|
"""
|
|
|
|
import os
|
|
import logging
|
|
from contextlib import asynccontextmanager
|
|
from typing import Optional
|
|
|
|
from fastapi import FastAPI, HTTPException
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from pydantic import BaseModel, field_validator
|
|
|
|
from analyzer import TargetAnalyzer, AnalysisResult
|
|
|
|
# ─── Logging ──────────────────────────────────────────────────────────────────
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s [%(levelname)s] %(name)s: %(message)s',
|
|
datefmt='%H:%M:%S'
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# ─── Modèle global ────────────────────────────────────────────────────────────
|
|
analyzer: Optional[TargetAnalyzer] = None
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
"""Charge le modèle au démarrage."""
|
|
global analyzer
|
|
model_path = os.getenv('YOLO_MODEL_PATH', 'yolov8n.pt')
|
|
logger.info(f"Chargement du modèle YOLOv8 : {model_path}")
|
|
analyzer = TargetAnalyzer(model_path=model_path)
|
|
logger.info("✓ Service IA prêt")
|
|
yield
|
|
logger.info("Service IA arrêté")
|
|
|
|
# ─── Application FastAPI ──────────────────────────────────────────────────────
|
|
app = FastAPI(
|
|
title="ShootTracker AI Service",
|
|
description="Analyse d'impacts de tir avec YOLOv8 + OpenCV",
|
|
version="1.0.0",
|
|
lifespan=lifespan,
|
|
)
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
# ─── Schémas Pydantic ─────────────────────────────────────────────────────────
|
|
|
|
class AnalyzeRequest(BaseModel):
|
|
image_base64: str
|
|
target_type: str
|
|
previous_image_base64: Optional[str] = None
|
|
|
|
@field_validator('target_type')
|
|
@classmethod
|
|
def validate_target_type(cls, v: str) -> str:
|
|
allowed = ['issf', 'silhouette', 'libre']
|
|
if v not in allowed:
|
|
raise ValueError(f"target_type doit être parmi : {allowed}")
|
|
return v
|
|
|
|
@field_validator('image_base64')
|
|
@classmethod
|
|
def validate_image(cls, v: str) -> str:
|
|
if not v or len(v) < 100:
|
|
raise ValueError("image_base64 invalide ou trop courte")
|
|
return v
|
|
|
|
class ImpactOut(BaseModel):
|
|
x: float
|
|
y: float
|
|
zone: Optional[int]
|
|
points: int
|
|
confidence: float
|
|
|
|
class AnalyzeResponse(BaseModel):
|
|
impacts: list[ImpactOut]
|
|
total_detected: int
|
|
score_zone: int
|
|
score_groupement: int
|
|
score_total: int
|
|
dispersion_radius: float
|
|
center_x: float
|
|
center_y: float
|
|
annotated_image_base64: Optional[str]
|
|
error: Optional[str]
|
|
warning: Optional[str]
|
|
|
|
# ─── Endpoints ────────────────────────────────────────────────────────────────
|
|
|
|
@app.get("/", tags=["Info"])
|
|
async def root():
|
|
return {
|
|
"service": "ShootTracker AI",
|
|
"version": "1.0.0",
|
|
"status": "running",
|
|
"model": os.getenv('YOLO_MODEL_PATH', 'yolov8n.pt'),
|
|
}
|
|
|
|
@app.get("/health", tags=["Info"])
|
|
async def health():
|
|
return {
|
|
"status": "ok",
|
|
"model_loaded": analyzer is not None and analyzer.model is not None,
|
|
"service": "shoottracker-ai",
|
|
}
|
|
|
|
@app.post("/analyze", response_model=AnalyzeResponse, tags=["Analyse"])
|
|
async def analyze_target(req: AnalyzeRequest):
|
|
"""
|
|
Analyse une photo de cible de tir.
|
|
|
|
- Détecte les impacts avec YOLOv8
|
|
- Analyse les zones selon le type de cible
|
|
- Si image_précédente fournie : ne compte que les NOUVEAUX impacts
|
|
- Calcule le score et le groupement
|
|
- Retourne la photo annotée en base64
|
|
"""
|
|
if analyzer is None:
|
|
raise HTTPException(status_code=503, detail="Service IA non initialisé")
|
|
|
|
logger.info(f"Analyse demandée : target_type={req.target_type}, prev={'oui' if req.previous_image_base64 else 'non'}")
|
|
|
|
try:
|
|
result: AnalysisResult = analyzer.analyze(
|
|
image_b64=req.image_base64,
|
|
target_type=req.target_type,
|
|
previous_image_b64=req.previous_image_base64,
|
|
)
|
|
except Exception as e:
|
|
logger.exception(f"Erreur inattendue: {e}")
|
|
raise HTTPException(status_code=500, detail=f"Erreur interne : {str(e)}")
|
|
|
|
if result.error and not result.annotated_image_base64:
|
|
# Erreur critique (image illisible)
|
|
raise HTTPException(status_code=422, detail=result.error)
|
|
|
|
return AnalyzeResponse(
|
|
impacts=[ImpactOut(**imp) for imp in result.impacts],
|
|
total_detected=result.total_detected,
|
|
score_zone=result.score_zone,
|
|
score_groupement=result.score_groupement,
|
|
score_total=result.score_total,
|
|
dispersion_radius=result.dispersion_radius,
|
|
center_x=result.center_x,
|
|
center_y=result.center_y,
|
|
annotated_image_base64=result.annotated_image_base64,
|
|
error=result.error,
|
|
warning=result.warning,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
port = int(os.getenv("PORT", 8000))
|
|
uvicorn.run("main:app", host="0.0.0.0", port=port, reload=False, workers=1)
|