""" ShootTracker — Microservice IA FastAPI + YOLOv8 + OpenCV Endpoints : POST /analyze — Analyse d'une photo de cible GET /health — Santé du service GET / — Infos """ import os import logging from contextlib import asynccontextmanager from typing import Optional from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, field_validator from analyzer import TargetAnalyzer, AnalysisResult # ─── Logging ────────────────────────────────────────────────────────────────── logging.basicConfig( level=logging.INFO, format='%(asctime)s [%(levelname)s] %(name)s: %(message)s', datefmt='%H:%M:%S' ) logger = logging.getLogger(__name__) # ─── Modèle global ──────────────────────────────────────────────────────────── analyzer: Optional[TargetAnalyzer] = None @asynccontextmanager async def lifespan(app: FastAPI): """Charge le modèle au démarrage.""" global analyzer model_path = os.getenv('YOLO_MODEL_PATH', 'yolov8n.pt') logger.info(f"Chargement du modèle YOLOv8 : {model_path}") analyzer = TargetAnalyzer(model_path=model_path) logger.info("✓ Service IA prêt") yield logger.info("Service IA arrêté") # ─── Application FastAPI ────────────────────────────────────────────────────── app = FastAPI( title="ShootTracker AI Service", description="Analyse d'impacts de tir avec YOLOv8 + OpenCV", version="1.0.0", lifespan=lifespan, ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # ─── Schémas Pydantic ───────────────────────────────────────────────────────── class AnalyzeRequest(BaseModel): image_base64: str target_type: str previous_image_base64: Optional[str] = None @field_validator('target_type') @classmethod def validate_target_type(cls, v: str) -> str: allowed = ['issf', 'silhouette', 'libre'] if v not in allowed: raise ValueError(f"target_type doit être parmi : {allowed}") return v @field_validator('image_base64') @classmethod def validate_image(cls, v: str) -> str: if not v or len(v) < 100: raise ValueError("image_base64 invalide ou trop courte") return v class ImpactOut(BaseModel): x: float y: float zone: Optional[int] points: int confidence: float class AnalyzeResponse(BaseModel): impacts: list[ImpactOut] total_detected: int score_zone: int score_groupement: int score_total: int dispersion_radius: float center_x: float center_y: float annotated_image_base64: Optional[str] error: Optional[str] warning: Optional[str] # ─── Endpoints ──────────────────────────────────────────────────────────────── @app.get("/", tags=["Info"]) async def root(): return { "service": "ShootTracker AI", "version": "1.0.0", "status": "running", "model": os.getenv('YOLO_MODEL_PATH', 'yolov8n.pt'), } @app.get("/health", tags=["Info"]) async def health(): return { "status": "ok", "model_loaded": analyzer is not None and analyzer.model is not None, "service": "shoottracker-ai", } @app.post("/analyze", response_model=AnalyzeResponse, tags=["Analyse"]) async def analyze_target(req: AnalyzeRequest): """ Analyse une photo de cible de tir. - Détecte les impacts avec YOLOv8 - Analyse les zones selon le type de cible - Si image_précédente fournie : ne compte que les NOUVEAUX impacts - Calcule le score et le groupement - Retourne la photo annotée en base64 """ if analyzer is None: raise HTTPException(status_code=503, detail="Service IA non initialisé") logger.info(f"Analyse demandée : target_type={req.target_type}, prev={'oui' if req.previous_image_base64 else 'non'}") try: result: AnalysisResult = analyzer.analyze( image_b64=req.image_base64, target_type=req.target_type, previous_image_b64=req.previous_image_base64, ) except Exception as e: logger.exception(f"Erreur inattendue: {e}") raise HTTPException(status_code=500, detail=f"Erreur interne : {str(e)}") if result.error and not result.annotated_image_base64: # Erreur critique (image illisible) raise HTTPException(status_code=422, detail=result.error) return AnalyzeResponse( impacts=[ImpactOut(**imp) for imp in result.impacts], total_detected=result.total_detected, score_zone=result.score_zone, score_groupement=result.score_groupement, score_total=result.score_total, dispersion_radius=result.dispersion_radius, center_x=result.center_x, center_y=result.center_y, annotated_image_base64=result.annotated_image_base64, error=result.error, warning=result.warning, ) if __name__ == "__main__": import uvicorn port = int(os.getenv("PORT", 8000)) uvicorn.run("main:app", host="0.0.0.0", port=port, reload=False, workers=1)