#!/usr/bin/env python3
"""
Dataset Pipeline Builder for Computer Vision

Production-grade tool for building and managing CV dataset pipelines.
Supports format conversion, splitting, augmentation config, and validation.

Supported formats:
- COCO (JSON annotations)
- YOLO (txt per image)
- Pascal VOC (XML annotations)
- CVAT (XML export)

Usage:
    python dataset_pipeline_builder.py analyze --input /path/to/dataset
    python dataset_pipeline_builder.py convert --input /path/to/coco --output /path/to/yolo --format yolo
    python dataset_pipeline_builder.py split --input /path/to/dataset --train 0.8 --val 0.1 --test 0.1
    python dataset_pipeline_builder.py augment-config --task detection --output augmentations.yaml
    python dataset_pipeline_builder.py validate --input /path/to/dataset --format coco
"""

import os
import sys
import json
import random
import shutil
import logging
import argparse
import hashlib
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Set, Any
from datetime import datetime
from collections import defaultdict
import xml.etree.ElementTree as ET

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)


# ============================================================================
# Dataset Format Definitions
# ============================================================================

SUPPORTED_IMAGE_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp'}

COCO_CATEGORIES_TEMPLATE = {
    "info": {
        "description": "Custom Dataset",
        "version": "1.0",
        "year": datetime.now().year,
        "contributor": "Dataset Pipeline Builder",
        "date_created": datetime.now().isoformat()
    },
    "licenses": [{"id": 1, "name": "Unknown", "url": ""}],
    "images": [],
    "annotations": [],
    "categories": []
}

YOLO_DATA_YAML_TEMPLATE = """# YOLO Dataset Configuration
# Generated by Dataset Pipeline Builder

path: {dataset_path}
train: {train_path}
val: {val_path}
test: {test_path}

# Classes
nc: {num_classes}
names: {class_names}

# Optional: Download script
# download:
"""

AUGMENTATION_PRESETS = {
    'detection': {
        'light': {
            'horizontal_flip': 0.5,
            'vertical_flip': 0.0,
            'rotate': {'limit': 10, 'p': 0.3},
            'brightness_contrast': {'brightness_limit': 0.1, 'contrast_limit': 0.1, 'p': 0.3},
            'blur': {'blur_limit': 3, 'p': 0.1}
        },
        'medium': {
            'horizontal_flip': 0.5,
            'vertical_flip': 0.1,
            'rotate': {'limit': 15, 'p': 0.5},
            'scale': {'scale_limit': 0.2, 'p': 0.5},
            'brightness_contrast': {'brightness_limit': 0.2, 'contrast_limit': 0.2, 'p': 0.5},
            'hue_saturation': {'hue_shift_limit': 10, 'sat_shift_limit': 20, 'p': 0.3},
            'blur': {'blur_limit': 5, 'p': 0.2},
            'noise': {'var_limit': (10, 50), 'p': 0.2}
        },
        'heavy': {
            'horizontal_flip': 0.5,
            'vertical_flip': 0.2,
            'rotate': {'limit': 30, 'p': 0.7},
            'scale': {'scale_limit': 0.3, 'p': 0.6},
            'brightness_contrast': {'brightness_limit': 0.3, 'contrast_limit': 0.3, 'p': 0.6},
            'hue_saturation': {'hue_shift_limit': 20, 'sat_shift_limit': 30, 'p': 0.5},
            'blur': {'blur_limit': 7, 'p': 0.3},
            'noise': {'var_limit': (10, 80), 'p': 0.3},
            'mosaic': {'p': 0.5},
            'mixup': {'p': 0.3},
            'cutout': {'num_holes': 8, 'max_h_size': 32, 'max_w_size': 32, 'p': 0.3}
        }
    },
    'segmentation': {
        'light': {
            'horizontal_flip': 0.5,
            'rotate': {'limit': 10, 'p': 0.3},
            'elastic_transform': {'alpha': 50, 'sigma': 5, 'p': 0.1}
        },
        'medium': {
            'horizontal_flip': 0.5,
            'vertical_flip': 0.2,
            'rotate': {'limit': 20, 'p': 0.5},
            'scale': {'scale_limit': 0.2, 'p': 0.4},
            'elastic_transform': {'alpha': 100, 'sigma': 10, 'p': 0.3},
            'grid_distortion': {'num_steps': 5, 'distort_limit': 0.3, 'p': 0.3}
        },
        'heavy': {
            'horizontal_flip': 0.5,
            'vertical_flip': 0.3,
            'rotate': {'limit': 45, 'p': 0.7},
            'scale': {'scale_limit': 0.4, 'p': 0.6},
            'elastic_transform': {'alpha': 200, 'sigma': 20, 'p': 0.5},
            'grid_distortion': {'num_steps': 7, 'distort_limit': 0.5, 'p': 0.4},
            'optical_distortion': {'distort_limit': 0.5, 'shift_limit': 0.5, 'p': 0.3}
        }
    },
    'classification': {
        'light': {
            'horizontal_flip': 0.5,
            'rotate': {'limit': 15, 'p': 0.3},
            'brightness_contrast': {'p': 0.3}
        },
        'medium': {
            'horizontal_flip': 0.5,
            'rotate': {'limit': 30, 'p': 0.5},
            'color_jitter': {'brightness': 0.2, 'contrast': 0.2, 'saturation': 0.2, 'hue': 0.1, 'p': 0.5},
            'random_crop': {'height': 224, 'width': 224, 'p': 0.5},
            'cutout': {'num_holes': 1, 'max_h_size': 40, 'max_w_size': 40, 'p': 0.3}
        },
        'heavy': {
            'horizontal_flip': 0.5,
            'vertical_flip': 0.2,
            'rotate': {'limit': 45, 'p': 0.7},
            'color_jitter': {'brightness': 0.4, 'contrast': 0.4, 'saturation': 0.4, 'hue': 0.2, 'p': 0.7},
            'random_resized_crop': {'height': 224, 'width': 224, 'scale': (0.5, 1.0), 'p': 0.6},
            'cutout': {'num_holes': 4, 'max_h_size': 60, 'max_w_size': 60, 'p': 0.5},
            'auto_augment': {'policy': 'imagenet', 'p': 0.5},
            'rand_augment': {'num_ops': 2, 'magnitude': 9, 'p': 0.5}
        }
    }
}


# ============================================================================
# Dataset Analysis
# ============================================================================

class DatasetAnalyzer:
    """Analyze dataset structure and statistics."""

    def __init__(self, dataset_path: str):
        self.dataset_path = Path(dataset_path)
        self.stats = {}

    def analyze(self) -> Dict[str, Any]:
        """Run full dataset analysis."""
        logger.info(f"Analyzing dataset at: {self.dataset_path}")

        # Detect format
        detected_format = self._detect_format()
        self.stats['format'] = detected_format

        # Count images
        images = self._find_images()
        self.stats['total_images'] = len(images)

        # Analyze images
        self.stats['image_stats'] = self._analyze_images(images)

        # Analyze annotations based on format
        if detected_format == 'coco':
            self.stats['annotations'] = self._analyze_coco()
        elif detected_format == 'yolo':
            self.stats['annotations'] = self._analyze_yolo()
        elif detected_format == 'voc':
            self.stats['annotations'] = self._analyze_voc()
        else:
            self.stats['annotations'] = {'error': 'Unknown format'}

        # Dataset quality checks
        self.stats['quality'] = self._quality_checks()

        return self.stats

    def _detect_format(self) -> str:
        """Auto-detect dataset format."""
        # Check for COCO JSON
        for json_file in self.dataset_path.rglob('*.json'):
            try:
                with open(json_file) as f:
                    data = json.load(f)
                if 'annotations' in data and 'images' in data:
                    return 'coco'
            except:
                pass

        # Check for YOLO txt files
        txt_files = list(self.dataset_path.rglob('*.txt'))
        if txt_files:
            # Check if txt contains YOLO format (class x_center y_center width height)
            for txt_file in txt_files[:5]:
                if txt_file.name == 'classes.txt':
                    continue
                try:
                    with open(txt_file) as f:
                        line = f.readline().strip()
                    if line:
                        parts = line.split()
                        if len(parts) == 5 and all(self._is_float(p) for p in parts):
                            return 'yolo'
                except:
                    pass

        # Check for VOC XML
        xml_files = list(self.dataset_path.rglob('*.xml'))
        for xml_file in xml_files[:5]:
            try:
                tree = ET.parse(xml_file)
                root = tree.getroot()
                if root.tag == 'annotation' and root.find('object') is not None:
                    return 'voc'
            except:
                pass

        return 'unknown'

    def _is_float(self, s: str) -> bool:
        """Check if string is a float."""
        try:
            float(s)
            return True
        except ValueError:
            return False

    def _find_images(self) -> List[Path]:
        """Find all images in dataset."""
        images = []
        for ext in SUPPORTED_IMAGE_EXTENSIONS:
            images.extend(self.dataset_path.rglob(f'*{ext}'))
            images.extend(self.dataset_path.rglob(f'*{ext.upper()}'))
        return images

    def _analyze_images(self, images: List[Path]) -> Dict:
        """Analyze image files without loading them."""
        stats = {
            'count': len(images),
            'extensions': defaultdict(int),
            'sizes': [],
            'locations': defaultdict(int)
        }

        for img in images:
            stats['extensions'][img.suffix.lower()] += 1
            stats['sizes'].append(img.stat().st_size)
            # Track which subdirectory
            rel_path = img.relative_to(self.dataset_path)
            if len(rel_path.parts) > 1:
                stats['locations'][rel_path.parts[0]] += 1
            else:
                stats['locations']['root'] += 1

        if stats['sizes']:
            stats['total_size_mb'] = sum(stats['sizes']) / (1024 * 1024)
            stats['avg_size_kb'] = (sum(stats['sizes']) / len(stats['sizes'])) / 1024
            stats['min_size_kb'] = min(stats['sizes']) / 1024
            stats['max_size_kb'] = max(stats['sizes']) / 1024

        stats['extensions'] = dict(stats['extensions'])
        stats['locations'] = dict(stats['locations'])
        del stats['sizes']  # Don't include raw sizes

        return stats

    def _analyze_coco(self) -> Dict:
        """Analyze COCO format annotations."""
        stats = {
            'total_annotations': 0,
            'classes': {},
            'images_with_annotations': 0,
            'annotations_per_image': {},
            'bbox_stats': {}
        }

        # Find COCO JSON files
        for json_file in self.dataset_path.rglob('*.json'):
            try:
                with open(json_file) as f:
                    data = json.load(f)

                if 'annotations' not in data:
                    continue

                # Build category mapping
                cat_map = {}
                if 'categories' in data:
                    for cat in data['categories']:
                        cat_map[cat['id']] = cat['name']

                # Count annotations per class
                img_annotations = defaultdict(int)
                bbox_widths = []
                bbox_heights = []
                bbox_areas = []

                for ann in data['annotations']:
                    stats['total_annotations'] += 1
                    cat_id = ann.get('category_id')
                    cat_name = cat_map.get(cat_id, f'class_{cat_id}')
                    stats['classes'][cat_name] = stats['classes'].get(cat_name, 0) + 1
                    img_annotations[ann.get('image_id')] += 1

                    # Bbox stats
                    if 'bbox' in ann:
                        bbox = ann['bbox']  # [x, y, width, height]
                        if len(bbox) == 4:
                            bbox_widths.append(bbox[2])
                            bbox_heights.append(bbox[3])
                            bbox_areas.append(bbox[2] * bbox[3])

                stats['images_with_annotations'] = len(img_annotations)
                if img_annotations:
                    counts = list(img_annotations.values())
                    stats['annotations_per_image'] = {
                        'min': min(counts),
                        'max': max(counts),
                        'avg': sum(counts) / len(counts)
                    }

                if bbox_areas:
                    stats['bbox_stats'] = {
                        'avg_width': sum(bbox_widths) / len(bbox_widths),
                        'avg_height': sum(bbox_heights) / len(bbox_heights),
                        'avg_area': sum(bbox_areas) / len(bbox_areas),
                        'min_area': min(bbox_areas),
                        'max_area': max(bbox_areas)
                    }

            except Exception as e:
                logger.warning(f"Error parsing {json_file}: {e}")

        return stats

    def _analyze_yolo(self) -> Dict:
        """Analyze YOLO format annotations."""
        stats = {
            'total_annotations': 0,
            'classes': defaultdict(int),
            'images_with_annotations': 0,
            'bbox_stats': {}
        }

        # Find classes.txt if exists
        class_names = {}
        classes_file = self.dataset_path / 'classes.txt'
        if classes_file.exists():
            with open(classes_file) as f:
                for i, line in enumerate(f):
                    class_names[i] = line.strip()

        bbox_widths = []
        bbox_heights = []

        for txt_file in self.dataset_path.rglob('*.txt'):
            if txt_file.name == 'classes.txt':
                continue

            try:
                with open(txt_file) as f:
                    lines = f.readlines()

                if lines:
                    stats['images_with_annotations'] += 1

                for line in lines:
                    parts = line.strip().split()
                    if len(parts) >= 5:
                        stats['total_annotations'] += 1
                        class_id = int(parts[0])
                        class_name = class_names.get(class_id, f'class_{class_id}')
                        stats['classes'][class_name] += 1

                        # Bbox stats (normalized coords)
                        w = float(parts[3])
                        h = float(parts[4])
                        bbox_widths.append(w)
                        bbox_heights.append(h)

            except Exception as e:
                logger.warning(f"Error parsing {txt_file}: {e}")

        stats['classes'] = dict(stats['classes'])

        if bbox_widths:
            stats['bbox_stats'] = {
                'avg_width_normalized': sum(bbox_widths) / len(bbox_widths),
                'avg_height_normalized': sum(bbox_heights) / len(bbox_heights),
                'min_width_normalized': min(bbox_widths),
                'max_width_normalized': max(bbox_widths)
            }

        return stats

    def _analyze_voc(self) -> Dict:
        """Analyze Pascal VOC format annotations."""
        stats = {
            'total_annotations': 0,
            'classes': defaultdict(int),
            'images_with_annotations': 0,
            'difficulties': {'easy': 0, 'difficult': 0}
        }

        for xml_file in self.dataset_path.rglob('*.xml'):
            try:
                tree = ET.parse(xml_file)
                root = tree.getroot()

                if root.tag != 'annotation':
                    continue

                objects = root.findall('object')
                if objects:
                    stats['images_with_annotations'] += 1

                for obj in objects:
                    stats['total_annotations'] += 1
                    name = obj.find('name')
                    if name is not None:
                        stats['classes'][name.text] += 1

                    difficult = obj.find('difficult')
                    if difficult is not None and difficult.text == '1':
                        stats['difficulties']['difficult'] += 1
                    else:
                        stats['difficulties']['easy'] += 1

            except Exception as e:
                logger.warning(f"Error parsing {xml_file}: {e}")

        stats['classes'] = dict(stats['classes'])
        return stats

    def _quality_checks(self) -> Dict:
        """Run quality checks on dataset."""
        checks = {
            'issues': [],
            'warnings': [],
            'recommendations': []
        }

        # Check class imbalance
        if 'annotations' in self.stats and 'classes' in self.stats['annotations']:
            classes = self.stats['annotations']['classes']
            if classes:
                counts = list(classes.values())
                max_count = max(counts)
                min_count = min(counts)

                if max_count > 0 and min_count / max_count < 0.1:
                    checks['warnings'].append(
                        f"Severe class imbalance detected: ratio {min_count/max_count:.2%}"
                    )
                    checks['recommendations'].append(
                        "Consider oversampling minority classes or using focal loss"
                    )
                elif max_count > 0 and min_count / max_count < 0.3:
                    checks['warnings'].append(
                        f"Moderate class imbalance: ratio {min_count/max_count:.2%}"
                    )

        # Check image count
        if self.stats.get('total_images', 0) < 100:
            checks['warnings'].append(
                f"Small dataset: only {self.stats.get('total_images', 0)} images"
            )
            checks['recommendations'].append(
                "Consider data augmentation or transfer learning"
            )

        # Check for missing annotations
        if 'annotations' in self.stats:
            ann_stats = self.stats['annotations']
            total_images = self.stats.get('total_images', 0)
            images_with_ann = ann_stats.get('images_with_annotations', 0)

            if total_images > 0 and images_with_ann < total_images:
                missing = total_images - images_with_ann
                checks['warnings'].append(
                    f"{missing} images have no annotations"
                )

        return checks


# ============================================================================
# Format Conversion
# ============================================================================

class FormatConverter:
    """Convert between dataset formats."""

    def __init__(self, input_path: str, output_path: str):
        self.input_path = Path(input_path)
        self.output_path = Path(output_path)

    def convert(self, target_format: str, source_format: str = None) -> Dict:
        """Convert dataset to target format."""
        # Auto-detect source format if not specified
        if source_format is None:
            analyzer = DatasetAnalyzer(str(self.input_path))
            analyzer.analyze()
            source_format = analyzer.stats.get('format', 'unknown')

        logger.info(f"Converting from {source_format} to {target_format}")

        conversion_key = f"{source_format}_to_{target_format}"

        converters = {
            'coco_to_yolo': self._coco_to_yolo,
            'yolo_to_coco': self._yolo_to_coco,
            'voc_to_coco': self._voc_to_coco,
            'voc_to_yolo': self._voc_to_yolo,
            'coco_to_voc': self._coco_to_voc,
        }

        if conversion_key not in converters:
            return {'error': f"Unsupported conversion: {source_format} -> {target_format}"}

        return converters[conversion_key]()

    def _coco_to_yolo(self) -> Dict:
        """Convert COCO format to YOLO format."""
        results = {'converted_images': 0, 'converted_annotations': 0}

        # Find COCO JSON
        coco_files = list(self.input_path.rglob('*.json'))

        for coco_file in coco_files:
            try:
                with open(coco_file) as f:
                    coco_data = json.load(f)

                if 'annotations' not in coco_data:
                    continue

                # Create output directories
                self.output_path.mkdir(parents=True, exist_ok=True)
                labels_dir = self.output_path / 'labels'
                labels_dir.mkdir(exist_ok=True)

                # Build category and image mappings
                cat_map = {}
                for i, cat in enumerate(coco_data.get('categories', [])):
                    cat_map[cat['id']] = i

                img_map = {}
                for img in coco_data.get('images', []):
                    img_map[img['id']] = {
                        'file_name': img['file_name'],
                        'width': img['width'],
                        'height': img['height']
                    }

                # Group annotations by image
                annotations_by_image = defaultdict(list)
                for ann in coco_data['annotations']:
                    annotations_by_image[ann['image_id']].append(ann)

                # Write YOLO format labels
                for img_id, annotations in annotations_by_image.items():
                    if img_id not in img_map:
                        continue

                    img_info = img_map[img_id]
                    label_name = Path(img_info['file_name']).stem + '.txt'
                    label_path = labels_dir / label_name

                    with open(label_path, 'w') as f:
                        for ann in annotations:
                            if 'bbox' not in ann:
                                continue

                            bbox = ann['bbox']  # [x, y, width, height]
                            cat_id = cat_map.get(ann['category_id'], 0)

                            # Convert to YOLO format (normalized x_center, y_center, width, height)
                            x_center = (bbox[0] + bbox[2] / 2) / img_info['width']
                            y_center = (bbox[1] + bbox[3] / 2) / img_info['height']
                            w = bbox[2] / img_info['width']
                            h = bbox[3] / img_info['height']

                            f.write(f"{cat_id} {x_center:.6f} {y_center:.6f} {w:.6f} {h:.6f}\n")
                            results['converted_annotations'] += 1

                    results['converted_images'] += 1

                # Write classes.txt
                classes = [None] * len(cat_map)
                for cat in coco_data.get('categories', []):
                    idx = cat_map[cat['id']]
                    classes[idx] = cat['name']

                with open(self.output_path / 'classes.txt', 'w') as f:
                    for class_name in classes:
                        f.write(f"{class_name}\n")

                # Write data.yaml for YOLO training
                yaml_content = YOLO_DATA_YAML_TEMPLATE.format(
                    dataset_path=str(self.output_path.absolute()),
                    train_path='images/train',
                    val_path='images/val',
                    test_path='images/test',
                    num_classes=len(classes),
                    class_names=classes
                )
                with open(self.output_path / 'data.yaml', 'w') as f:
                    f.write(yaml_content)

            except Exception as e:
                logger.error(f"Error converting {coco_file}: {e}")

        return results

    def _yolo_to_coco(self) -> Dict:
        """Convert YOLO format to COCO format."""
        results = {'converted_images': 0, 'converted_annotations': 0}

        coco_data = COCO_CATEGORIES_TEMPLATE.copy()
        coco_data['images'] = []
        coco_data['annotations'] = []
        coco_data['categories'] = []

        # Read classes
        classes_file = self.input_path / 'classes.txt'
        class_names = []
        if classes_file.exists():
            with open(classes_file) as f:
                class_names = [line.strip() for line in f.readlines()]

        for i, name in enumerate(class_names):
            coco_data['categories'].append({
                'id': i,
                'name': name,
                'supercategory': 'object'
            })

        # Find images and labels
        images = []
        for ext in SUPPORTED_IMAGE_EXTENSIONS:
            images.extend(self.input_path.rglob(f'*{ext}'))

        annotation_id = 1
        for img_id, img_path in enumerate(images, 1):
            # Try to get image dimensions (without PIL)
            # Assume 640x640 if can't determine
            width, height = 640, 640

            coco_data['images'].append({
                'id': img_id,
                'file_name': img_path.name,
                'width': width,
                'height': height
            })
            results['converted_images'] += 1

            # Find corresponding label
            label_path = img_path.with_suffix('.txt')
            if not label_path.exists():
                # Try labels subdirectory
                label_path = img_path.parent.parent / 'labels' / (img_path.stem + '.txt')

            if label_path.exists():
                with open(label_path) as f:
                    for line in f:
                        parts = line.strip().split()
                        if len(parts) >= 5:
                            class_id = int(parts[0])
                            x_center = float(parts[1]) * width
                            y_center = float(parts[2]) * height
                            w = float(parts[3]) * width
                            h = float(parts[4]) * height

                            # Convert to COCO format [x, y, width, height]
                            x = x_center - w / 2
                            y = y_center - h / 2

                            coco_data['annotations'].append({
                                'id': annotation_id,
                                'image_id': img_id,
                                'category_id': class_id,
                                'bbox': [x, y, w, h],
                                'area': w * h,
                                'iscrowd': 0
                            })
                            annotation_id += 1
                            results['converted_annotations'] += 1

        # Write COCO JSON
        self.output_path.mkdir(parents=True, exist_ok=True)
        with open(self.output_path / 'annotations.json', 'w') as f:
            json.dump(coco_data, f, indent=2)

        return results

    def _voc_to_coco(self) -> Dict:
        """Convert Pascal VOC format to COCO format."""
        results = {'converted_images': 0, 'converted_annotations': 0}

        coco_data = COCO_CATEGORIES_TEMPLATE.copy()
        coco_data['images'] = []
        coco_data['annotations'] = []
        coco_data['categories'] = []

        class_to_id = {}
        annotation_id = 1

        for img_id, xml_file in enumerate(self.input_path.rglob('*.xml'), 1):
            try:
                tree = ET.parse(xml_file)
                root = tree.getroot()

                if root.tag != 'annotation':
                    continue

                # Get image info
                filename = root.find('filename')
                size = root.find('size')

                if filename is None or size is None:
                    continue

                width = int(size.find('width').text)
                height = int(size.find('height').text)

                coco_data['images'].append({
                    'id': img_id,
                    'file_name': filename.text,
                    'width': width,
                    'height': height
                })
                results['converted_images'] += 1

                # Convert objects
                for obj in root.findall('object'):
                    name = obj.find('name').text

                    if name not in class_to_id:
                        class_to_id[name] = len(class_to_id)
                        coco_data['categories'].append({
                            'id': class_to_id[name],
                            'name': name,
                            'supercategory': 'object'
                        })

                    bndbox = obj.find('bndbox')
                    xmin = float(bndbox.find('xmin').text)
                    ymin = float(bndbox.find('ymin').text)
                    xmax = float(bndbox.find('xmax').text)
                    ymax = float(bndbox.find('ymax').text)

                    coco_data['annotations'].append({
                        'id': annotation_id,
                        'image_id': img_id,
                        'category_id': class_to_id[name],
                        'bbox': [xmin, ymin, xmax - xmin, ymax - ymin],
                        'area': (xmax - xmin) * (ymax - ymin),
                        'iscrowd': 0
                    })
                    annotation_id += 1
                    results['converted_annotations'] += 1

            except Exception as e:
                logger.warning(f"Error parsing {xml_file}: {e}")

        # Write output
        self.output_path.mkdir(parents=True, exist_ok=True)
        with open(self.output_path / 'annotations.json', 'w') as f:
            json.dump(coco_data, f, indent=2)

        return results

    def _voc_to_yolo(self) -> Dict:
        """Convert Pascal VOC format to YOLO format."""
        # First convert to COCO, then to YOLO
        temp_coco = self.output_path / '_temp_coco'

        converter1 = FormatConverter(str(self.input_path), str(temp_coco))
        converter1._voc_to_coco()

        converter2 = FormatConverter(str(temp_coco), str(self.output_path))
        results = converter2._coco_to_yolo()

        # Clean up temp
        shutil.rmtree(temp_coco, ignore_errors=True)

        return results

    def _coco_to_voc(self) -> Dict:
        """Convert COCO format to Pascal VOC format."""
        results = {'converted_images': 0, 'converted_annotations': 0}

        self.output_path.mkdir(parents=True, exist_ok=True)
        annotations_dir = self.output_path / 'Annotations'
        annotations_dir.mkdir(exist_ok=True)

        for coco_file in self.input_path.rglob('*.json'):
            try:
                with open(coco_file) as f:
                    coco_data = json.load(f)

                if 'annotations' not in coco_data:
                    continue

                # Build mappings
                cat_map = {cat['id']: cat['name'] for cat in coco_data.get('categories', [])}
                img_map = {img['id']: img for img in coco_data.get('images', [])}

                # Group by image
                ann_by_image = defaultdict(list)
                for ann in coco_data['annotations']:
                    ann_by_image[ann['image_id']].append(ann)

                for img_id, annotations in ann_by_image.items():
                    if img_id not in img_map:
                        continue

                    img_info = img_map[img_id]

                    # Create VOC XML
                    annotation = ET.Element('annotation')

                    ET.SubElement(annotation, 'folder').text = 'images'
                    ET.SubElement(annotation, 'filename').text = img_info['file_name']

                    size = ET.SubElement(annotation, 'size')
                    ET.SubElement(size, 'width').text = str(img_info['width'])
                    ET.SubElement(size, 'height').text = str(img_info['height'])
                    ET.SubElement(size, 'depth').text = '3'

                    for ann in annotations:
                        obj = ET.SubElement(annotation, 'object')
                        ET.SubElement(obj, 'name').text = cat_map.get(ann['category_id'], 'unknown')
                        ET.SubElement(obj, 'difficult').text = '0'

                        bbox = ann['bbox']
                        bndbox = ET.SubElement(obj, 'bndbox')
                        ET.SubElement(bndbox, 'xmin').text = str(int(bbox[0]))
                        ET.SubElement(bndbox, 'ymin').text = str(int(bbox[1]))
                        ET.SubElement(bndbox, 'xmax').text = str(int(bbox[0] + bbox[2]))
                        ET.SubElement(bndbox, 'ymax').text = str(int(bbox[1] + bbox[3]))

                        results['converted_annotations'] += 1

                    # Write XML
                    xml_name = Path(img_info['file_name']).stem + '.xml'
                    tree = ET.ElementTree(annotation)
                    tree.write(annotations_dir / xml_name)
                    results['converted_images'] += 1

            except Exception as e:
                logger.error(f"Error converting {coco_file}: {e}")

        return results


# ============================================================================
# Dataset Splitting
# ============================================================================

class DatasetSplitter:
    """Split dataset into train/val/test sets."""

    def __init__(self, dataset_path: str, output_path: str = None):
        self.dataset_path = Path(dataset_path)
        self.output_path = Path(output_path) if output_path else self.dataset_path

    def split(self, train: float = 0.8, val: float = 0.1, test: float = 0.1,
              stratify: bool = True, seed: int = 42) -> Dict:
        """Split dataset with optional stratification."""

        if abs(train + val + test - 1.0) > 0.001:
            raise ValueError(f"Split ratios must sum to 1.0, got {train + val + test}")

        random.seed(seed)
        logger.info(f"Splitting dataset: train={train}, val={val}, test={test}")

        # Detect format and find images
        analyzer = DatasetAnalyzer(str(self.dataset_path))
        analyzer.analyze()
        detected_format = analyzer.stats.get('format', 'unknown')

        images = []
        for ext in SUPPORTED_IMAGE_EXTENSIONS:
            images.extend(self.dataset_path.rglob(f'*{ext}'))

        if not images:
            return {'error': 'No images found'}

        # Stratify if requested and we have class info
        if stratify and detected_format in ['coco', 'yolo']:
            splits = self._stratified_split(images, detected_format, train, val, test)
        else:
            splits = self._random_split(images, train, val, test)

        # Create output directories and copy/link files
        results = self._create_split_directories(splits, detected_format)

        return results

    def _random_split(self, images: List[Path], train: float, val: float, test: float) -> Dict:
        """Perform random split."""
        images = list(images)
        random.shuffle(images)

        n = len(images)
        train_end = int(n * train)
        val_end = train_end + int(n * val)

        return {
            'train': images[:train_end],
            'val': images[train_end:val_end],
            'test': images[val_end:]
        }

    def _stratified_split(self, images: List[Path], format: str,
                         train: float, val: float, test: float) -> Dict:
        """Perform stratified split based on class distribution."""

        # Group images by their primary class
        image_classes = {}

        for img in images:
            if format == 'yolo':
                label_path = img.with_suffix('.txt')
                if not label_path.exists():
                    label_path = img.parent.parent / 'labels' / (img.stem + '.txt')

                if label_path.exists():
                    with open(label_path) as f:
                        line = f.readline()
                    if line:
                        class_id = int(line.split()[0])
                        image_classes[img] = class_id
                else:
                    image_classes[img] = -1  # No annotation
            else:
                image_classes[img] = -1  # Default for other formats

        # Group by class
        class_images = defaultdict(list)
        for img, class_id in image_classes.items():
            class_images[class_id].append(img)

        # Split each class proportionally
        splits = {'train': [], 'val': [], 'test': []}

        for class_id, class_imgs in class_images.items():
            random.shuffle(class_imgs)
            n = len(class_imgs)
            train_end = int(n * train)
            val_end = train_end + int(n * val)

            splits['train'].extend(class_imgs[:train_end])
            splits['val'].extend(class_imgs[train_end:val_end])
            splits['test'].extend(class_imgs[val_end:])

        # Shuffle final splits
        for key in splits:
            random.shuffle(splits[key])

        return splits

    def _create_split_directories(self, splits: Dict, format: str) -> Dict:
        """Create split directories and organize files."""
        results = {
            'train_count': len(splits['train']),
            'val_count': len(splits['val']),
            'test_count': len(splits['test']),
            'output_path': str(self.output_path)
        }

        # Create directory structure
        for split_name in ['train', 'val', 'test']:
            images_dir = self.output_path / 'images' / split_name
            labels_dir = self.output_path / 'labels' / split_name
            images_dir.mkdir(parents=True, exist_ok=True)
            labels_dir.mkdir(parents=True, exist_ok=True)

            for img_path in splits[split_name]:
                # Create symlink for image
                dst_img = images_dir / img_path.name
                if not dst_img.exists():
                    try:
                        dst_img.symlink_to(img_path.absolute())
                    except OSError:
                        # Fall back to copy if symlink fails
                        shutil.copy2(img_path, dst_img)

                # Handle label file
                if format == 'yolo':
                    label_path = img_path.with_suffix('.txt')
                    if not label_path.exists():
                        label_path = img_path.parent.parent / 'labels' / (img_path.stem + '.txt')

                    if label_path.exists():
                        dst_label = labels_dir / (img_path.stem + '.txt')
                        if not dst_label.exists():
                            try:
                                dst_label.symlink_to(label_path.absolute())
                            except OSError:
                                shutil.copy2(label_path, dst_label)

        # Generate data.yaml for YOLO
        if format == 'yolo':
            # Read classes
            classes_file = self.dataset_path / 'classes.txt'
            class_names = []
            if classes_file.exists():
                with open(classes_file) as f:
                    class_names = [line.strip() for line in f.readlines()]

            yaml_content = YOLO_DATA_YAML_TEMPLATE.format(
                dataset_path=str(self.output_path.absolute()),
                train_path='images/train',
                val_path='images/val',
                test_path='images/test',
                num_classes=len(class_names),
                class_names=class_names
            )
            with open(self.output_path / 'data.yaml', 'w') as f:
                f.write(yaml_content)

        return results


# ============================================================================
# Augmentation Configuration
# ============================================================================

class AugmentationConfigGenerator:
    """Generate augmentation configurations for different CV tasks."""

    @staticmethod
    def generate(task: str, intensity: str = 'medium',
                 framework: str = 'albumentations') -> Dict:
        """Generate augmentation config for task and intensity."""

        if task not in AUGMENTATION_PRESETS:
            return {'error': f"Unknown task: {task}. Use: detection, segmentation, classification"}

        if intensity not in AUGMENTATION_PRESETS[task]:
            return {'error': f"Unknown intensity: {intensity}. Use: light, medium, heavy"}

        base_config = AUGMENTATION_PRESETS[task][intensity]

        if framework == 'albumentations':
            return AugmentationConfigGenerator._to_albumentations(base_config, task)
        elif framework == 'torchvision':
            return AugmentationConfigGenerator._to_torchvision(base_config, task)
        elif framework == 'ultralytics':
            return AugmentationConfigGenerator._to_ultralytics(base_config, task)
        else:
            return base_config

    @staticmethod
    def _to_albumentations(config: Dict, task: str) -> Dict:
        """Convert to Albumentations format."""
        transforms = []

        for aug_name, params in config.items():
            if aug_name == 'horizontal_flip':
                transforms.append({
                    'type': 'HorizontalFlip',
                    'p': params
                })
            elif aug_name == 'vertical_flip':
                transforms.append({
                    'type': 'VerticalFlip',
                    'p': params
                })
            elif aug_name == 'rotate':
                transforms.append({
                    'type': 'Rotate',
                    'limit': params.get('limit', 15),
                    'p': params.get('p', 0.5)
                })
            elif aug_name == 'scale':
                transforms.append({
                    'type': 'RandomScale',
                    'scale_limit': params.get('scale_limit', 0.2),
                    'p': params.get('p', 0.5)
                })
            elif aug_name == 'brightness_contrast':
                transforms.append({
                    'type': 'RandomBrightnessContrast',
                    'brightness_limit': params.get('brightness_limit', 0.2),
                    'contrast_limit': params.get('contrast_limit', 0.2),
                    'p': params.get('p', 0.5)
                })
            elif aug_name == 'hue_saturation':
                transforms.append({
                    'type': 'HueSaturationValue',
                    'hue_shift_limit': params.get('hue_shift_limit', 20),
                    'sat_shift_limit': params.get('sat_shift_limit', 30),
                    'p': params.get('p', 0.5)
                })
            elif aug_name == 'blur':
                transforms.append({
                    'type': 'Blur',
                    'blur_limit': params.get('blur_limit', 5),
                    'p': params.get('p', 0.3)
                })
            elif aug_name == 'noise':
                transforms.append({
                    'type': 'GaussNoise',
                    'var_limit': params.get('var_limit', (10, 50)),
                    'p': params.get('p', 0.3)
                })
            elif aug_name == 'elastic_transform':
                transforms.append({
                    'type': 'ElasticTransform',
                    'alpha': params.get('alpha', 100),
                    'sigma': params.get('sigma', 10),
                    'p': params.get('p', 0.3)
                })
            elif aug_name == 'cutout':
                transforms.append({
                    'type': 'CoarseDropout',
                    'max_holes': params.get('num_holes', 8),
                    'max_height': params.get('max_h_size', 32),
                    'max_width': params.get('max_w_size', 32),
                    'p': params.get('p', 0.3)
                })

        # Add bbox format for detection
        bbox_params = None
        if task == 'detection':
            bbox_params = {
                'format': 'pascal_voc',
                'label_fields': ['class_labels'],
                'min_visibility': 0.3
            }

        return {
            'framework': 'albumentations',
            'task': task,
            'transforms': transforms,
            'bbox_params': bbox_params,
            'code_example': AugmentationConfigGenerator._albumentations_code(transforms, task)
        }

    @staticmethod
    def _albumentations_code(transforms: List, task: str) -> str:
        """Generate Albumentations code example."""
        code = """import albumentations as A
from albumentations.pytorch import ToTensorV2

transform = A.Compose([
"""
        for t in transforms:
            params = ', '.join(f"{k}={v}" for k, v in t.items() if k != 'type')
            code += f"    A.{t['type']}({params}),\n"

        code += "    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n"
        code += "    ToTensorV2(),\n"
        code += "]"

        if task == 'detection':
            code += ", bbox_params=A.BboxParams(format='pascal_voc', label_fields=['class_labels']))"
        else:
            code += ")"

        return code

    @staticmethod
    def _to_torchvision(config: Dict, task: str) -> Dict:
        """Convert to torchvision transforms format."""
        transforms = []

        for aug_name, params in config.items():
            if aug_name == 'horizontal_flip':
                transforms.append({
                    'type': 'RandomHorizontalFlip',
                    'p': params
                })
            elif aug_name == 'vertical_flip':
                transforms.append({
                    'type': 'RandomVerticalFlip',
                    'p': params
                })
            elif aug_name == 'rotate':
                transforms.append({
                    'type': 'RandomRotation',
                    'degrees': params.get('limit', 15)
                })
            elif aug_name == 'color_jitter':
                transforms.append({
                    'type': 'ColorJitter',
                    'brightness': params.get('brightness', 0.2),
                    'contrast': params.get('contrast', 0.2),
                    'saturation': params.get('saturation', 0.2),
                    'hue': params.get('hue', 0.1)
                })

        return {
            'framework': 'torchvision',
            'task': task,
            'transforms': transforms
        }

    @staticmethod
    def _to_ultralytics(config: Dict, task: str) -> Dict:
        """Convert to Ultralytics YOLO format."""
        yolo_config = {
            'hsv_h': 0.015,
            'hsv_s': 0.7,
            'hsv_v': 0.4,
            'degrees': config.get('rotate', {}).get('limit', 0.0),
            'translate': 0.1,
            'scale': config.get('scale', {}).get('scale_limit', 0.5),
            'shear': 0.0,
            'perspective': 0.0,
            'flipud': config.get('vertical_flip', 0.0),
            'fliplr': config.get('horizontal_flip', 0.5),
            'mosaic': config.get('mosaic', {}).get('p', 1.0) if 'mosaic' in config else 0.0,
            'mixup': config.get('mixup', {}).get('p', 0.0) if 'mixup' in config else 0.0,
            'copy_paste': 0.0
        }

        return {
            'framework': 'ultralytics',
            'task': task,
            'config': yolo_config,
            'usage': "# Add to data.yaml or pass to Trainer\nmodel.train(data='data.yaml', augment=True, **aug_config)"
        }


# ============================================================================
# Dataset Validation
# ============================================================================

class DatasetValidator:
    """Validate dataset integrity and quality."""

    def __init__(self, dataset_path: str, format: str = None):
        self.dataset_path = Path(dataset_path)
        self.format = format

    def validate(self) -> Dict:
        """Run all validation checks."""
        results = {
            'valid': True,
            'errors': [],
            'warnings': [],
            'stats': {}
        }

        # Auto-detect format if not specified
        if self.format is None:
            analyzer = DatasetAnalyzer(str(self.dataset_path))
            analyzer.analyze()
            self.format = analyzer.stats.get('format', 'unknown')

        results['format'] = self.format

        # Run format-specific validation
        if self.format == 'coco':
            self._validate_coco(results)
        elif self.format == 'yolo':
            self._validate_yolo(results)
        elif self.format == 'voc':
            self._validate_voc(results)
        else:
            results['warnings'].append(f"Unknown format: {self.format}")

        # General checks
        self._validate_images(results)
        self._check_duplicates(results)

        # Set overall validity
        results['valid'] = len(results['errors']) == 0

        return results

    def _validate_coco(self, results: Dict):
        """Validate COCO format dataset."""
        for json_file in self.dataset_path.rglob('*.json'):
            try:
                with open(json_file) as f:
                    data = json.load(f)

                if 'annotations' not in data:
                    continue

                # Check required fields
                if 'images' not in data:
                    results['errors'].append(f"{json_file}: Missing 'images' field")
                if 'categories' not in data:
                    results['warnings'].append(f"{json_file}: Missing 'categories' field")

                # Validate annotations
                image_ids = {img['id'] for img in data.get('images', [])}
                category_ids = {cat['id'] for cat in data.get('categories', [])}

                for ann in data['annotations']:
                    if ann.get('image_id') not in image_ids:
                        results['errors'].append(
                            f"Annotation {ann.get('id')} references non-existent image {ann.get('image_id')}"
                        )
                    if ann.get('category_id') not in category_ids:
                        results['warnings'].append(
                            f"Annotation {ann.get('id')} references unknown category {ann.get('category_id')}"
                        )

                    # Validate bbox
                    if 'bbox' in ann:
                        bbox = ann['bbox']
                        if len(bbox) != 4:
                            results['errors'].append(
                                f"Annotation {ann.get('id')}: Invalid bbox format"
                            )
                        elif any(v < 0 for v in bbox[:2]) or any(v <= 0 for v in bbox[2:]):
                            results['warnings'].append(
                                f"Annotation {ann.get('id')}: Suspicious bbox values {bbox}"
                            )

                results['stats']['coco_images'] = len(data.get('images', []))
                results['stats']['coco_annotations'] = len(data['annotations'])
                results['stats']['coco_categories'] = len(data.get('categories', []))

            except json.JSONDecodeError as e:
                results['errors'].append(f"{json_file}: Invalid JSON - {e}")
            except Exception as e:
                results['errors'].append(f"{json_file}: Error - {e}")

    def _validate_yolo(self, results: Dict):
        """Validate YOLO format dataset."""
        label_files = list(self.dataset_path.rglob('*.txt'))
        valid_labels = 0
        invalid_labels = 0

        for txt_file in label_files:
            if txt_file.name == 'classes.txt':
                continue

            try:
                with open(txt_file) as f:
                    lines = f.readlines()

                for line_num, line in enumerate(lines, 1):
                    parts = line.strip().split()
                    if not parts:
                        continue

                    if len(parts) < 5:
                        results['errors'].append(
                            f"{txt_file}:{line_num}: Expected 5 values, got {len(parts)}"
                        )
                        invalid_labels += 1
                        continue

                    try:
                        class_id = int(parts[0])
                        x, y, w, h = map(float, parts[1:5])

                        # Check normalized coordinates
                        if not (0 <= x <= 1 and 0 <= y <= 1):
                            results['warnings'].append(
                                f"{txt_file}:{line_num}: Center coords outside [0,1]: ({x}, {y})"
                            )
                        if not (0 < w <= 1 and 0 < h <= 1):
                            results['warnings'].append(
                                f"{txt_file}:{line_num}: Size outside (0,1]: ({w}, {h})"
                            )

                        valid_labels += 1

                    except ValueError as e:
                        results['errors'].append(
                            f"{txt_file}:{line_num}: Invalid values - {e}"
                        )
                        invalid_labels += 1

            except Exception as e:
                results['errors'].append(f"{txt_file}: Error - {e}")

        results['stats']['yolo_valid_labels'] = valid_labels
        results['stats']['yolo_invalid_labels'] = invalid_labels

    def _validate_voc(self, results: Dict):
        """Validate Pascal VOC format dataset."""
        xml_files = list(self.dataset_path.rglob('*.xml'))
        valid_annotations = 0

        for xml_file in xml_files:
            try:
                tree = ET.parse(xml_file)
                root = tree.getroot()

                if root.tag != 'annotation':
                    continue

                # Check required fields
                filename = root.find('filename')
                if filename is None:
                    results['warnings'].append(f"{xml_file}: Missing filename")

                size = root.find('size')
                if size is None:
                    results['warnings'].append(f"{xml_file}: Missing size")
                else:
                    for dim in ['width', 'height']:
                        if size.find(dim) is None:
                            results['errors'].append(f"{xml_file}: Missing {dim}")

                # Validate objects
                for obj in root.findall('object'):
                    name = obj.find('name')
                    if name is None or not name.text:
                        results['errors'].append(f"{xml_file}: Object missing name")

                    bndbox = obj.find('bndbox')
                    if bndbox is None:
                        results['errors'].append(f"{xml_file}: Object missing bndbox")
                    else:
                        for coord in ['xmin', 'ymin', 'xmax', 'ymax']:
                            elem = bndbox.find(coord)
                            if elem is None:
                                results['errors'].append(f"{xml_file}: Missing {coord}")

                    valid_annotations += 1

            except ET.ParseError as e:
                results['errors'].append(f"{xml_file}: XML parse error - {e}")
            except Exception as e:
                results['errors'].append(f"{xml_file}: Error - {e}")

        results['stats']['voc_annotations'] = valid_annotations

    def _validate_images(self, results: Dict):
        """Check for image file issues."""
        images = []
        for ext in SUPPORTED_IMAGE_EXTENSIONS:
            images.extend(self.dataset_path.rglob(f'*{ext}'))

        results['stats']['total_images'] = len(images)

        # Check for empty images
        empty_images = [img for img in images if img.stat().st_size == 0]
        if empty_images:
            results['errors'].append(f"Found {len(empty_images)} empty image files")

        # Check for very small images
        small_images = [img for img in images if img.stat().st_size < 1000]
        if small_images:
            results['warnings'].append(f"Found {len(small_images)} very small images (<1KB)")

    def _check_duplicates(self, results: Dict):
        """Check for duplicate images by hash."""
        images = []
        for ext in SUPPORTED_IMAGE_EXTENSIONS:
            images.extend(self.dataset_path.rglob(f'*{ext}'))

        hashes = {}
        duplicates = []

        for img in images:
            try:
                with open(img, 'rb') as f:
                    file_hash = hashlib.md5(f.read()).hexdigest()

                if file_hash in hashes:
                    duplicates.append((img, hashes[file_hash]))
                else:
                    hashes[file_hash] = img
            except:
                pass

        if duplicates:
            results['warnings'].append(f"Found {len(duplicates)} duplicate images")
            results['stats']['duplicate_images'] = len(duplicates)


# ============================================================================
# Main CLI
# ============================================================================

def main():
    parser = argparse.ArgumentParser(
        description="Dataset Pipeline Builder for Computer Vision",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  Analyze dataset:
    python dataset_pipeline_builder.py analyze --input /path/to/dataset

  Convert COCO to YOLO:
    python dataset_pipeline_builder.py convert --input /path/to/coco --output /path/to/yolo --format yolo

  Split dataset:
    python dataset_pipeline_builder.py split --input /path/to/dataset --train 0.8 --val 0.1 --test 0.1

  Generate augmentation config:
    python dataset_pipeline_builder.py augment-config --task detection --intensity heavy

  Validate dataset:
    python dataset_pipeline_builder.py validate --input /path/to/dataset --format coco
        """
    )

    subparsers = parser.add_subparsers(dest='command', help='Command to run')

    # Analyze command
    analyze_parser = subparsers.add_parser('analyze', help='Analyze dataset structure and statistics')
    analyze_parser.add_argument('--input', '-i', required=True, help='Path to dataset')
    analyze_parser.add_argument('--json', action='store_true', help='Output as JSON')

    # Convert command
    convert_parser = subparsers.add_parser('convert', help='Convert between annotation formats')
    convert_parser.add_argument('--input', '-i', required=True, help='Input dataset path')
    convert_parser.add_argument('--output', '-o', required=True, help='Output dataset path')
    convert_parser.add_argument('--format', '-f', required=True,
                               choices=['yolo', 'coco', 'voc'],
                               help='Target format')
    convert_parser.add_argument('--source-format', '-s',
                               choices=['yolo', 'coco', 'voc'],
                               help='Source format (auto-detected if not specified)')

    # Split command
    split_parser = subparsers.add_parser('split', help='Split dataset into train/val/test')
    split_parser.add_argument('--input', '-i', required=True, help='Input dataset path')
    split_parser.add_argument('--output', '-o', help='Output path (default: same as input)')
    split_parser.add_argument('--train', type=float, default=0.8, help='Train split ratio')
    split_parser.add_argument('--val', type=float, default=0.1, help='Validation split ratio')
    split_parser.add_argument('--test', type=float, default=0.1, help='Test split ratio')
    split_parser.add_argument('--stratify', action='store_true', help='Stratify by class')
    split_parser.add_argument('--seed', type=int, default=42, help='Random seed')

    # Augmentation config command
    aug_parser = subparsers.add_parser('augment-config', help='Generate augmentation configuration')
    aug_parser.add_argument('--task', '-t', required=True,
                           choices=['detection', 'segmentation', 'classification'],
                           help='CV task type')
    aug_parser.add_argument('--intensity', '-n', default='medium',
                           choices=['light', 'medium', 'heavy'],
                           help='Augmentation intensity')
    aug_parser.add_argument('--framework', '-f', default='albumentations',
                           choices=['albumentations', 'torchvision', 'ultralytics'],
                           help='Target framework')
    aug_parser.add_argument('--output', '-o', help='Output file path')

    # Validate command
    validate_parser = subparsers.add_parser('validate', help='Validate dataset integrity')
    validate_parser.add_argument('--input', '-i', required=True, help='Path to dataset')
    validate_parser.add_argument('--format', '-f',
                                choices=['yolo', 'coco', 'voc'],
                                help='Dataset format (auto-detected if not specified)')
    validate_parser.add_argument('--json', action='store_true', help='Output as JSON')

    args = parser.parse_args()

    if args.command is None:
        parser.print_help()
        sys.exit(1)

    try:
        if args.command == 'analyze':
            analyzer = DatasetAnalyzer(args.input)
            results = analyzer.analyze()

            if args.json:
                print(json.dumps(results, indent=2, default=str))
            else:
                print("\n" + "="*60)
                print("DATASET ANALYSIS REPORT")
                print("="*60)
                print(f"\nFormat: {results.get('format', 'unknown')}")
                print(f"Total Images: {results.get('total_images', 0)}")

                if 'image_stats' in results:
                    stats = results['image_stats']
                    print(f"\nImage Statistics:")
                    print(f"  Total Size: {stats.get('total_size_mb', 0):.2f} MB")
                    print(f"  Extensions: {stats.get('extensions', {})}")
                    print(f"  Locations: {stats.get('locations', {})}")

                if 'annotations' in results:
                    ann = results['annotations']
                    print(f"\nAnnotations:")
                    print(f"  Total: {ann.get('total_annotations', 0)}")
                    print(f"  Images with annotations: {ann.get('images_with_annotations', 0)}")
                    if 'classes' in ann:
                        print(f"  Classes: {len(ann['classes'])}")
                        for cls, count in sorted(ann['classes'].items(), key=lambda x: -x[1])[:10]:
                            print(f"    - {cls}: {count}")

                if 'quality' in results:
                    q = results['quality']
                    if q.get('warnings'):
                        print(f"\nWarnings:")
                        for w in q['warnings']:
                            print(f"  ⚠ {w}")
                    if q.get('recommendations'):
                        print(f"\nRecommendations:")
                        for r in q['recommendations']:
                            print(f"  → {r}")

        elif args.command == 'convert':
            converter = FormatConverter(args.input, args.output)
            results = converter.convert(args.format, args.source_format)
            print(json.dumps(results, indent=2))

        elif args.command == 'split':
            output = args.output if args.output else args.input
            splitter = DatasetSplitter(args.input, output)
            results = splitter.split(
                train=args.train,
                val=args.val,
                test=args.test,
                stratify=args.stratify,
                seed=args.seed
            )
            print(json.dumps(results, indent=2))

        elif args.command == 'augment-config':
            config = AugmentationConfigGenerator.generate(
                args.task,
                args.intensity,
                args.framework
            )

            output = json.dumps(config, indent=2)

            if args.output:
                with open(args.output, 'w') as f:
                    f.write(output)
                print(f"Configuration saved to {args.output}")
            else:
                print(output)

        elif args.command == 'validate':
            validator = DatasetValidator(args.input, args.format)
            results = validator.validate()

            if args.json:
                print(json.dumps(results, indent=2))
            else:
                print("\n" + "="*60)
                print("DATASET VALIDATION REPORT")
                print("="*60)
                print(f"\nFormat: {results.get('format', 'unknown')}")
                print(f"Valid: {'✓' if results['valid'] else '✗'}")

                if results.get('errors'):
                    print(f"\nErrors ({len(results['errors'])}):")
                    for err in results['errors'][:10]:
                        print(f"  ✗ {err}")
                    if len(results['errors']) > 10:
                        print(f"  ... and {len(results['errors']) - 10} more")

                if results.get('warnings'):
                    print(f"\nWarnings ({len(results['warnings'])}):")
                    for warn in results['warnings'][:10]:
                        print(f"  ⚠ {warn}")
                    if len(results['warnings']) > 10:
                        print(f"  ... and {len(results['warnings']) - 10} more")

                if results.get('stats'):
                    print(f"\nStatistics:")
                    for key, value in results['stats'].items():
                        print(f"  {key}: {value}")

        sys.exit(0)

    except Exception as e:
        logger.error(f"Error: {e}")
        sys.exit(1)


if __name__ == '__main__':
    main()
