import io
import json
import logging
import os
from uuid import uuid4
from typing import List, Dict, Optional
from label_studio_ml.model import LabelStudioMLBase
from label_studio_ml.response import ModelResponse
from label_studio_ml.utils import get_image_size, DATA_UNDEFINED_NAME
from urllib.parse import urlparse
from ocr import get_ocr_output

logger = logging.getLogger(__name__)


class OCIOCR(LabelStudioMLBase):
    """Custom ML Backend model
    """
    
    def _lazy_init(self):
        self.model = None

    def setup(self):
        """Configure any paramaters of your model here
        """
        self.set("model_version", f'{self.__class__.__name__}-v0.0.1')

    def _get_image_path(self, task, value='pages'):
        data = task.get('data', {})
        document = data.get(value, [])
        document_root = data.get("ls_document_root", "")

        def extract_path(image_url):
            prefix = '/data/local-files/?d='
            if image_url and image_url.startswith(prefix):
                image_name = image_url.replace(prefix, "")
                return os.path.join(document_root, image_name)
            return None

        if isinstance(document, list):
            image_paths = [
                extract_path(page.get('page')) 
                for page in document 
                if extract_path(page.get('page'))
            ]
            return image_paths
        else:
            #image_url = document.get('page') if isinstance(document, dict) else None
            return extract_path(document) or document
    

    def predict_single(self, task):
        logger.debug('Task data: %s', task['data'])
        from_name_poly, to_name, value = self.get_first_tag_occurence('Rectangle', 'Image')
        from_name_trans, _, _ = self.get_first_tag_occurence('TextArea', 'Image')
        
        labels = self.label_interface.labels
        labels = sum([list(l) for l in labels], [])
        
        if len(labels) > 1:
            logger.warning('More than one label in the tag. Only the first one will be used: %s', labels[0])
        label = labels[0]

        value="pages"
        #def process_task(self, task, value='document'):
        image_paths = self._get_image_path(task, value)

        # Normalize to list and detect single-page vs multi-page
        single_page = False
        if isinstance(image_paths, str):
            image_paths = [image_paths]
            single_page = True

        print(single_page)
        result = []
        all_scores = []

        for index, image_path in enumerate(image_paths):
            oci_result = get_ocr_output(image_path)
            if not oci_result:
                return  
            oci_result = json.loads(str(oci_result))
            page_data = oci_result["pages"][0]
            img_width = page_data["dimensions"]["width"]
            img_height = page_data["dimensions"]["height"]
            model_results = page_data["words"]
            
            # Prepare from_name/to_name identifiers
            suffix = "" if single_page else f"_{index}"
            to_name = "page" + suffix

            for res in model_results:
                if not res:
                    logger.warning('Empty result from the model')
                    continue

                score = res['confidence']
                rel_pnt = [
                    [v['x'] * 100, v['y'] * 100]
                    for v in res['bounding_polygon']['normalized_vertices']
                    if v['x'] <= 1 and v['y'] <= 1
                ]

                id_gen = str(uuid4())[:4]

                result.append({
                    'original_width': img_width,
                    'original_height': img_height,
                    'image_rotation': 0,
                    'value': {
                        'points': rel_pnt,
                    },
                    'id': id_gen,
                    'from_name': from_name_poly if single_page else f"bbox_{index}",
                    'to_name': "image" if single_page else to_name,
                    'type': 'rectangle', 
                    'origin': 'manual',
                    'score': score,
                })

                result.append({
                    'original_width': img_width,
                    'original_height': img_height,
                    'image_rotation': 0,
                    'value': {
                        "text": [res['text']]
                    },
                    'id': id_gen,
                    'from_name': from_name_trans if single_page else f"transcription_{index}",
                    'to_name': "image" if single_page else to_name,
                    'type': 'textarea',
                    'origin': 'manual',
                    'score': score,
                })

                result.append({
                    'original_width': img_width,
                    'original_height': img_height,
                    'image_rotation': 0,
                    'value': {
                        'labels': ["ignore"]
                    },
                    'id': id_gen,
                    'from_name': "label" if single_page else f"labels_{index}",
                    'to_name': "image" if single_page else to_name,
                    'type': 'labels',
                    'origin': 'manual',
                    'score': score,
                })

                all_scores.append(score)

        return {
            'result': result,
            'score': sum(all_scores) / max(len(all_scores), 1),
            'model_version': self.get('model_version'),
        }

  
    def predict(self, tasks: List[Dict], context: Optional[Dict] = None, **kwargs) -> ModelResponse:
        self._lazy_init()
        predictions = []
        for task in tasks:
            print(task)
            prediction = self.predict_single(task)
            if prediction:
                predictions.append(prediction)
            
        return ModelResponse(predictions=predictions, model_versions=self.get('model_version'))