import streamlit as st
import open_clip
import torch
from PIL import Image
import numpy as np
from transformers import AutoImageProcessor, AutoModelForSemanticSegmentation
import chromadb
import logging
import io
import requests
from concurrent.futures import ThreadPoolExecutor
from chromadb.utils.embedding_functions import OpenCLIPEmbeddingFunction
from chromadb.utils.data_loaders import ImageLoader

# 로깅 설정
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class CustomFashionEmbeddingFunction:
    def __init__(self):
        self.model, _, self.preprocess = open_clip.create_model_and_transforms('hf-hub:Marqo/marqo-fashionSigLIP')
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = self.model.to(self.device)
    
    def __call__(self, input):
        try:
            # 입력이 URL이나 경로인 경우 처리
            processed_images = []
            for img in input:
                if isinstance(img, (str, bytes)):
                    if isinstance(img, str):
                        response = requests.get(img)
                        img = Image.open(io.BytesIO(response.content)).convert('RGB')
                    else:
                        img = Image.open(io.BytesIO(img)).convert('RGB')
                elif isinstance(img, np.ndarray):
                    img = Image.fromarray(img.astype('uint8')).convert('RGB')
                
                processed_img = self.preprocess(img).unsqueeze(0).to(self.device)
                processed_images.append(processed_img)
            
            # 배치 처리
            batch = torch.cat(processed_images)
            
            # CLIP 임베딩 추출
            with torch.no_grad():
                clip_features = self.model.encode_image(batch)
                clip_features = clip_features.cpu().numpy()
            
            # 색상 특징 추출
            color_features_list = []
            for img in input:
                if isinstance(img, (str, bytes)):
                    if isinstance(img, str):
                        response = requests.get(img)
                        img = Image.open(io.BytesIO(response.content)).convert('RGB')
                    else:
                        img = Image.open(io.BytesIO(img)).convert('RGB')
                elif isinstance(img, np.ndarray):
                    img = Image.fromarray(img.astype('uint8')).convert('RGB')
                
                color_features = self.extract_color_histogram(img)
                color_features_list.append(color_features)
            
            # 특징 결합
            combined_embeddings = []
            for clip_emb, color_feat in zip(clip_features, color_features_list):
                # CLIP 임베딩을 768차원으로 패딩
                if clip_emb.shape[0] < 768:
                    padding = np.zeros(768 - clip_emb.shape[0])
                    clip_emb = np.concatenate([clip_emb, padding])
                else:
                    clip_emb = clip_emb[:768]  # 768차원으로 자르기
                
                # 색상 특징을 768차원으로 확장
                color_features_expanded = np.repeat(color_feat, 32)  # 24 * 32 = 768
                
                # 정규화
                clip_emb = clip_emb / (np.linalg.norm(clip_emb) + 1e-8)
                color_features_expanded = color_features_expanded / (np.linalg.norm(color_features_expanded) + 1e-8)
                
                # 가중치 결합
                combined = clip_emb * 0.7 + color_features_expanded * 0.3
                combined = combined / (np.linalg.norm(combined) + 1e-8)
                
                combined_embeddings.append(combined)
            
            return np.array(combined_embeddings)
            
        except Exception as e:
            logger.error(f"Error in embedding function: {e}")
            raise

    def extract_color_histogram(self, image):
        """Extract color histogram from the image"""
        try:
            if isinstance(image, (str, bytes)):
                if isinstance(image, str):
                    response = requests.get(image)
                    image = Image.open(io.BytesIO(response.content))
                else:
                    image = Image.open(io.BytesIO(image))
            
            if not isinstance(image, np.ndarray):
                img_array = np.array(image)
            else:
                img_array = image
                
            # HSV 변환 및 히스토그램 계산
            img_hsv = Image.fromarray(img_array.astype('uint8')).convert('HSV')
            hsv_pixels = np.array(img_hsv)
            
            h_hist = np.histogram(hsv_pixels[:,:,0], bins=8, range=(0, 256))[0]
            s_hist = np.histogram(hsv_pixels[:,:,1], bins=8, range=(0, 256))[0]
            v_hist = np.histogram(hsv_pixels[:,:,2], bins=8, range=(0, 256))[0]
            
            # 정규화
            h_hist = h_hist / (h_hist.sum() + 1e-8)
            s_hist = s_hist / (s_hist.sum() + 1e-8)
            v_hist = v_hist / (v_hist.sum() + 1e-8)
            
            return np.concatenate([h_hist, s_hist, v_hist])
        except Exception as e:
            logger.error(f"Color histogram extraction error: {e}")
            return np.zeros(24)

# Initialize session state
if 'image' not in st.session_state:
    st.session_state.image = None
if 'detected_items' not in st.session_state:
    st.session_state.detected_items = None
if 'selected_item_index' not in st.session_state:
    st.session_state.selected_item_index = None
if 'upload_state' not in st.session_state:
    st.session_state.upload_state = 'initial'
if 'search_clicked' not in st.session_state:
    st.session_state.search_clicked = False

# Load segmentation model
@st.cache_resource
def load_segmentation_model():
    try:
        model_name = "mattmdjaga/segformer_b2_clothes"
        image_processor = AutoImageProcessor.from_pretrained(model_name)
        model = AutoModelForSemanticSegmentation.from_pretrained(model_name)
        
        if torch.cuda.is_available():
            model = model.to('cuda')
            
        return model, image_processor
    except Exception as e:
        logger.error(f"Error loading segmentation model: {e}")
        raise

# ChromaDB 설정
def setup_multimodal_collection():
    """멀티모달 컬렉션 설정"""
    try:
        client = chromadb.PersistentClient(path="./fashion_multimodal_db_original")
        embedding_function = CustomFashionEmbeddingFunction()
        data_loader = ImageLoader()
        
        # 기존 컬렉션 가져오기
        try:
            collection = client.get_collection(
                name="fashion_multimodal",
                embedding_function=embedding_function,
                data_loader=data_loader
            )
            logger.info("Successfully connected to existing clothes_multimodal collection")
            return collection
            
        except Exception as e:
            logger.error(f"Error getting existing collection: {e}")
            # 컬렉션이 없는 경우에만 새로 생성
            collection = client.create_collection(
                name="clothes_multimodal",
                embedding_function=embedding_function,
                data_loader=data_loader
            )
            logger.info("Created new clothes_multimodal collection")
            return collection
        
    except Exception as e:
        logger.error(f"Error setting up multimodal collection: {e}")
        raise

def process_segmentation(image):
    """Segmentation processing"""
    try:
        model, image_processor = load_segmentation_model()
        
        # 이미지 전처리
        inputs = image_processor(image, return_tensors="pt")
        
        if torch.cuda.is_available():
            inputs = {k: v.to('cuda') for k, v in inputs.items()}
        
        # 추론
        with torch.no_grad():
            outputs = model(**inputs)
        
        # 로직 및 후처리
        logits = outputs.logits.cpu()
        upsampled_logits = torch.nn.functional.interpolate(
            logits,
            size=image.size[::-1],  # (height, width)
            mode="bilinear",
            align_corners=False,
        )
        
        # 세그멘테이션 마스크 생성
        seg_masks = upsampled_logits.argmax(dim=1).numpy()
        
        processed_items = []
        unique_labels = np.unique(seg_masks)
        
        for label_idx in unique_labels:
            if label_idx == 0:  # background
                continue
                
            mask = (seg_masks[0] == label_idx).astype(float)
            
            processed_segment = {
                'label': f"Item_{label_idx}",  # 라벨 매핑이 필요하다면 여기서 처리
                'score': 1.0,  # confidence score 계산이 필요하다면 추가
                'mask': mask
            }
            
            processed_items.append(processed_segment)
            
        logger.info(f"Successfully processed {len(processed_items)} segments")
        return processed_items
        
    except Exception as e:
        logger.error(f"Segmentation error: {str(e)}")
        import traceback
        logger.error(traceback.format_exc())
        return []
    
def search_similar_items(image, mask=None, top_k=10):
    """멀티모달 검색 수행"""
    try:
        collection = setup_multimodal_collection()
        
        # 마스크 적용
        if mask is not None:
            mask_3d = np.stack([mask] * 3, axis=-1)
            masked_image = np.array(image) * mask_3d
            query_image = Image.fromarray(masked_image.astype(np.uint8))
        else:
            query_image = image
            
        # 검색 수행
        results = collection.query(
            query_images=[np.array(query_image)],
            n_results=top_k,
            include=['metadatas', 'distances']
        )
        
        if not results or 'metadatas' not in results:
            return []
            
        similar_items = []
        for metadata, distance in zip(results['metadatas'][0], results['distances'][0]):
            # L2 거리를 코사인 유사도로 변환
            # 정규화된 벡터 간의 L2 거리(d)와 코사인 유사도(cos_sim) 관계:
            # d^2 = 2(1 - cos_sim)
            # cos_sim = 1 - (d^2/2)
            cosine_similarity = 1 - (distance ** 2 / 2)
            
            # -1~1 범위의 코사인 유사도를 0~100 범위로 변환
            similarity_score = ((cosine_similarity + 1) / 2) * 100
            
            item_data = metadata.copy()
            item_data['similarity_score'] = similarity_score
            similar_items.append(item_data)
            
        similar_items.sort(key=lambda x: x['similarity_score'], reverse=True)
        return similar_items
        
    except Exception as e:
        logger.error(f"Multimodal search error: {e}")
        return []

def update_db_with_multimodal():
    """DB를 멀티모달 방식으로 업데이트"""
    try:
        # 새 컬렉션 생성
        collection = setup_multimodal_collection()
        
        # 기존 컬렉션에서 데이터 가져오기
        client = chromadb.PersistentClient(path="./clothesDB_11GmarketMusinsa")
        old_collection = client.get_collection("clothes")
        old_data = old_collection.get(include=['metadatas'])
        
        total_items = len(old_data['metadatas'])
        progress_bar = st.progress(0)
        status_text = st.empty()
        
        batch_size = 100
        successful_updates = 0
        failed_updates = 0
        
        for i in range(0, total_items, batch_size):
            batch = old_data['metadatas'][i:i + batch_size]
            
            images = []
            valid_metadatas = []
            valid_ids = []
            
            for metadata in batch:
                try:
                    if 'image_url' in metadata:
                        response = requests.get(metadata['image_url'])
                        img = Image.open(io.BytesIO(response.content)).convert('RGB')
                        images.append(np.array(img))
                        valid_metadatas.append(metadata)
                        valid_ids.append(metadata.get('id', str(hash(metadata['image_url']))))
                        successful_updates += 1
                except Exception as e:
                    logger.error(f"Error processing image: {e}")
                    failed_updates += 1
                    continue
            
            if images:
                collection.add(
                    ids=valid_ids,
                    images=images,
                    metadatas=valid_metadatas
                )
            
            # Update progress
            progress = (i + len(batch)) / total_items
            progress_bar.progress(progress)
            status_text.text(f"Processing: {i + len(batch)}/{total_items} items. "
                           f"Success: {successful_updates}, Failed: {failed_updates}")
            
        status_text.text(f"Update completed. Successfully processed: {successful_updates}, "
                        f"Failed: {failed_updates}")
        return True
        
    except Exception as e:
        logger.error(f"Multimodal DB update error: {e}")
        return False

def show_similar_items(similar_items):
    """Display similar items in a structured format with similarity scores"""
    if not similar_items:
        st.warning("No similar items found.")
        return
        
    st.subheader("Similar Items:")
    
    items_per_row = 2
    for i in range(0, len(similar_items), items_per_row):
        cols = st.columns(items_per_row)
        for j, col in enumerate(cols):
            if i + j < len(similar_items):
                item = similar_items[i + j]
                with col:
                    try:
                        if 'image_url' in item:
                            st.image(item['image_url'], use_column_width=True)
                        
                        st.markdown(f"**Similarity: {item['similarity_score']:.1f}%**")
                        
                        st.write(f"Brand: {item.get('brand', 'Unknown')}")
                        name = item.get('name', 'Unknown')
                        if len(name) > 50:
                            name = name[:47] + "..."
                        st.write(f"Name: {name}")
                        
                        price = item.get('price', 0)
                        if isinstance(price, (int, float)):
                            st.write(f"Price: {price:,}원")
                        else:
                            st.write(f"Price: {price}")
                        
                        if 'discount' in item and item['discount']:
                            st.write(f"Discount: {item['discount']}%")
                            if 'original_price' in item:
                                st.write(f"Original: {item['original_price']:,}원")
                        
                        st.divider()
                        
                    except Exception as e:
                        logger.error(f"Error displaying item: {e}")
                        st.error("Error displaying this item")

def process_search(image, mask, num_results):
    """유사 아이템 검색 처리"""
    try:
        with st.spinner("Finding similar items..."):
            similar_items = search_similar_items(image, mask, num_results)
            
        return similar_items
    except Exception as e:
        logger.error(f"Search processing error: {e}")
        return None

def handle_file_upload():
    if st.session_state.uploaded_file is not None:
        image = Image.open(st.session_state.uploaded_file).convert('RGB')
        st.session_state.image = image
        st.session_state.upload_state = 'image_uploaded'
        st.rerun()

def handle_detection():
    if st.session_state.image is not None:
        detected_items = process_segmentation(st.session_state.image)
        st.session_state.detected_items = detected_items
        st.session_state.upload_state = 'items_detected'
        st.rerun()

def handle_search():
    st.session_state.search_clicked = True

def main():
    st.title("Fashion Search App")

    # Admin controls in sidebar
    st.sidebar.title("Admin Controls")
    if st.sidebar.checkbox("Show Admin Interface"):
        if st.sidebar.button("Update Database (Multimodal)"):
            with st.spinner("Updating database with multimodal support..."):
                success = update_db_with_multimodal()
                if success:
                    st.sidebar.success("Database updated successfully!")
                else:
                    st.sidebar.error("Failed to update database")
        st.divider()

    # 파일 업로더
    if st.session_state.upload_state == 'initial':
        uploaded_file = st.file_uploader("Upload an image", type=['png', 'jpg', 'jpeg'], 
                                       key='uploaded_file', on_change=handle_file_upload)

    # 이미지가 업로드된 상태
    if st.session_state.image is not None:
        st.image(st.session_state.image, caption="Uploaded Image", use_column_width=True)
        
        if st.session_state.detected_items is None:
            if st.button("Detect Items", key='detect_button', on_click=handle_detection):
                pass
        
        # 검출된 아이템 표시 및 검색
        if st.session_state.detected_items is not None and len(st.session_state.detected_items) > 0:
            cols = st.columns(2)
            for idx, item in enumerate(st.session_state.detected_items):
                with cols[idx % 2]:
                    try:
                        if item.get('mask') is not None:
                            masked_img = np.array(st.session_state.image) * np.expand_dims(item['mask'], axis=2)
                            st.image(masked_img.astype(np.uint8), caption=f"Detected {item.get('label', 'Unknown')}")
                            
                        st.write(f"Item {idx + 1}: {item.get('label', 'Unknown')}")
                        score = item.get('score')
                        if score is not None and isinstance(score, (int, float)):
                            st.write(f"Confidence: {score*100:.1f}%")
                        else:
                            st.write("Confidence: N/A")
                    except Exception as e:
                        logger.error(f"Error displaying item {idx}: {str(e)}")
                        st.error(f"Error displaying item {idx}")
            
            valid_items = [i for i in range(len(st.session_state.detected_items)) 
                          if st.session_state.detected_items[i].get('mask') is not None]
            
            if not valid_items:
                st.warning("No valid items detected for search.")
                return
                
            selected_idx = st.selectbox(
                "Select item to search:",
                valid_items,
                format_func=lambda i: f"{st.session_state.detected_items[i].get('label', 'Unknown')}",
                key='item_selector'
            )
            
            search_col1, search_col2 = st.columns([1, 2])
            with search_col1:
                search_clicked = st.button("Search Similar Items", 
                                         key='search_button',
                                         type="primary")
            with search_col2:
                num_results = st.slider("Number of results:", 
                                      min_value=1, 
                                      max_value=20, 
                                      value=5,
                                      key='num_results')

            if search_clicked or st.session_state.get('search_clicked', False):
                st.session_state.search_clicked = True
                selected_item = st.session_state.detected_items[selected_idx]
                
                if selected_item.get('mask') is None:
                    st.error("Selected item has no valid mask for search.")
                    return
                
                if 'search_results' not in st.session_state:
                    similar_items = process_search(st.session_state.image, 
                                                selected_item['mask'], 
                                                num_results)
                    st.session_state.search_results = similar_items
                
                if st.session_state.search_results:
                    show_similar_items(st.session_state.search_results)
                else:
                    st.warning("No similar items found.")

    # 새 검색 버튼
    if st.button("Start New Search", key='new_search'):
        for key in list(st.session_state.keys()):
            del st.session_state[key]
        st.rerun()

if __name__ == "__main__":
    print('시작')
    main()