import React, { useState, useLayoutEffect, useRef, useEffect } from 'react';
import { Link } from 'react-router-dom';
import { llm_inference, sqwish, tokenize } from './api';
import sqwishLogo from '../assets/logo/sqwish-white-off-thick-high-res.png';
import backgroundImage from '../assets/backgrounds/noise_overlay-7.png';

import ChatBubble from './ChatBubble';
import SearchBar from './SearchBar';
import { CSSTransition } from 'react-transition-group';

const INFERENCE_MODEL = "gpt4o-mini"

const Playground: React.FC = () => {
    const nodeRef = useRef(null);

    const deltasComputedRef = useRef(false);
    const [inputText, setInputText] = useState('');
    const [originalText, setOriginalText] = useState<string | null>(null);
    const [originalResponse, setOriginalResponse] = useState<string | null>(null);
    const [sqwishText, setSqwishText] = useState<string | null>(null);
    const [sqwishResponse, setSqwishResponse] = useState<string | null>(null);
    const [loading, setLoading] = useState({ original_reponse: false, sqwish: false, sqwish_reponse: false });
    const [errorMessage, setErrorMessage] = useState<string | null>(null);
    const [isSwitchOn, setIsSwitchOn] = useState(true);
    const [hasCompressed, setHasCompressed] = useState(false);

    // Refs for measuring heights
    const originalPromptRef = useRef<HTMLDivElement | null>(null);
    const compressedPromptRef = useRef<HTMLDivElement | null>(null);

    const [originalTextStats, setOriginalTextStats] = useState<{ number_tokens: number | null } | null>(null);
    const [originalResponseStats, setOriginalResponseStats] = useState<{ throughput: number; TTFT: number } | null>(null);
    const [sqwishTextStats, setSqwishTextStats] = useState<{ number_tokens: number; throughput: number; total_latency: number } | null>(null);
    const [sqwishResponseStats, setSqwishResponseStats] = useState<{ throughput: number; throughput_delta: number | null; TTFT: number, TTFT_delta: number | null } | null>(null);

    const tokenize_original = async (inputText: string) => {
        try {
            const tokensCountResponse = await tokenize(inputText);
            setOriginalTextStats({ number_tokens: tokensCountResponse.tokens_count });
        } catch (error) {
            setErrorMessage('Error fetching tokenizer response.');
        }
    }

    const llm_original_response = async (inputText: string) => {
        try {
            const originalResponseResult = await llm_inference(inputText, INFERENCE_MODEL);
            setOriginalResponse(originalResponseResult.response);
            // Update the stats for the LLM response
            setOriginalResponseStats({
                throughput: originalResponseResult.throughput,
                TTFT: originalResponseResult.TTFT
            });
        } catch (error) {
            setErrorMessage('Error fetching LLM response.');
        } finally {
            setLoading((prev) => ({ ...prev, original_reponse: false }));
        }
    };

    const llm_sqwish_response = async (inputText: string) => {
        try {
            const sqwishResponseResult = await llm_inference(inputText, INFERENCE_MODEL);
            setSqwishResponse(sqwishResponseResult.response);
            // Update the stats for the final response
            setSqwishResponseStats({
                throughput: sqwishResponseResult.throughput,
                throughput_delta: null,
                TTFT: sqwishResponseResult.TTFT,
                TTFT_delta: null
            });
        } catch (error) {
            setErrorMessage('Error fetching LLM response.');
        } finally {
            setLoading((prev) => ({ ...prev, sqwish: false, sqwish_reponse: false }));
        }
    };

    const sqwish_response = async (inputText: string) => {
        try {
            const sqwishTextResult = await sqwish(inputText);
            setSqwishText(sqwishTextResult.response);
            // Update the stats for the Sqwish response
            setSqwishTextStats({
                number_tokens: sqwishTextResult.number_tokens,
                throughput: sqwishTextResult.throughput,
                total_latency: sqwishTextResult.total_latency
            });
            setLoading((prev) => ({ ...prev, sqwish: false, sqwish_reponse: true }));
            await Promise.all([llm_original_response(inputText), llm_sqwish_response(sqwishTextResult.response)]);
        } catch (error) {
            setErrorMessage('Error fetching Sqwish response.');
            setLoading((prev) => ({ ...prev, sqwish: false, sqwish_reponse: false }));
        }
    };

    useEffect(() => {
        if (
            sqwishResponseStats &&
            originalResponseStats &&
            !deltasComputedRef.current
        ) {
            setSqwishResponseStats((prevStats) => ({
                ...prevStats,
                TTFT_delta: prevStats!.TTFT - originalResponseStats.TTFT,
                throughput_delta: prevStats!.throughput - originalResponseStats.throughput,
                throughput: prevStats!.throughput,
                TTFT: prevStats!.TTFT,
            }));
            deltasComputedRef.current = true;
        }
    }, [sqwishResponseStats, originalResponseStats]);

    const handleSubmit = async () => {
        if (!inputText.trim()) {
            return;
        }
        setOriginalText(inputText);
        setOriginalTextStats(null)
        setOriginalResponse(null);
        setOriginalResponseStats(null);
        setSqwishText(null);
        setSqwishResponseStats(null);
        setSqwishResponse(null);
        setSqwishResponseStats(null);
        setErrorMessage(null);
        setLoading({ original_reponse: true, sqwish: true, sqwish_reponse: false });
        deltasComputedRef.current = false; // Reset the ref here
        await Promise.all([tokenize_original(inputText), sqwish_response(inputText)]);
    };

    const calculateTokensCompression = () => {
        if (originalText && sqwishText) {
            return `${((100 * (originalText.length - sqwishText.length)) / originalText.length).toFixed(0)}%`;
        }
        return null;
    };

    const format_value_with_arrows = (value: number) => {
        return `${value > 0 ? '↑' : '↓'}${Math.abs(value)}`;
    }

    const format_original_text_stats = (originalTextStats: { number_tokens: number | null } | null) => {
        return originalTextStats ? (
            <div>
                {Math.floor(originalTextStats.number_tokens || 0)} tokens
            </div>
        ) : <div></div>;
    }

    const format_original_response_stats = (originalResponseStats: { throughput: number; TTFT: number } | null) => {
        return originalResponseStats ? (
            <div>
                {Math.floor(originalResponseStats.throughput || 0)} tokens/s | {Math.floor(originalResponseStats.TTFT || 0)} ms
            </div>
        ) : <div></div>;
    }

    const format_sqwish_response_stats = (sqwishResponseStats: { throughput: number; throughput_delta: number | null; TTFT: number, TTFT_delta: number | null } | null) => {
        return sqwishResponseStats ? (
            <div>
                {sqwishResponseStats.throughput_delta && sqwishResponseStats.TTFT_delta ? (
                    <>{Math.floor(sqwishResponseStats.throughput)} <b style={{ color: sqwishResponseStats.throughput_delta > 0 ? '#008000' : '#FF0000' }}>({format_value_with_arrows(Math.floor(sqwishResponseStats.throughput_delta))})</b> tokens/s | {Math.floor(sqwishResponseStats.TTFT)} <b style={{ color: sqwishResponseStats.TTFT_delta > 0 ? '#FF0000' : '#008000' }}>({format_value_with_arrows(Math.floor(sqwishResponseStats.TTFT_delta))})</b> ms</>
                ) : (
                    <>{Math.floor(sqwishResponseStats.throughput)} tokens/s | {Math.floor(sqwishResponseStats.TTFT) || 0} ms</>
                )}
            </div>
        ) : <div></div>;
    }

    const format_sqwish_text_stats = (sqwishTextStats: { number_tokens: number; throughput: number; total_latency: number } | null) => {
        return sqwishTextStats ? (
            <div>
                {Math.floor(sqwishTextStats.number_tokens)} tokens | {Math.floor(sqwishTextStats.throughput)} tokens/s | {Math.floor(sqwishTextStats.total_latency)} ms
            </div>
        ) : <div></div>;
    }

    return (
        <div className="relative w-full h-full bg-gray-100">
            <div className="absolute inset-0 overflow-hidden z-0">
                <img src={backgroundImage} alt="background" className="w-full h-full" />
            </div>
            <div className="absolute top-4 right-4 flex justify-end">
                <Link
                    to="/access"
                    target="_blank"
                    rel="noopener noreferrer"
                    className="shadow-md hover:shadow-lg cursor-pointer bg-black text-white py-2 px-4 s-body rounded-[62.5rem] hover:bg-black/70 transition-colors duration-300 h-10 flex items-center justify-center z-40"
                >
                    Get access
                </Link>
            </div>
            <div className={`pt-14 lg:pt-0 h-full flex flex-col items-center lg:justify-center w-full overflow-x-auto overflow-y-auto ${!hasCompressed ? 'justify-center' : ''}`}>
                <div className={`relative w-full flex flex-col items-center justify-center transition-transform duration-500 z-10`}>
                    {/* Input/Search Bar */}
                    {!hasCompressed && (
                        <Link to="/" className="-mt-20 relative z-20">
                            <img src={sqwishLogo} alt="Sqwish Logo" className="h-[62px] w-auto mb-4" />
                        </Link>
                    )}
                    <SearchBar
                        inputText={inputText}
                        setInputText={setInputText}
                        handleSubmit={handleSubmit}
                        errorMessage={errorMessage}
                        hasCompressed={hasCompressed}
                        setHasCompressed={setHasCompressed}
                    />
                    {/* Transitioned Content */}
                    {hasCompressed && (
                        <div
                            className={`relative transition-all duration-500 w-full z-0`}
                        >
                            <div className="flex flex-col lg:flex-row gap-4 w-full items-start justify-center px-8">

                                {/* Left Column */}
                                <div className="relative flex flex-col items-start lg:w-[50%] max-w-xl h-[80vh] min-h-[450px] bg-pink px-6 rounded-[20px]">
                                    {((isSwitchOn && loading.original_reponse) || !(originalResponse && sqwishResponse && originalResponseStats)) && <div className="self-end pr-6 lg:pr-0 lg:self-center absolute top-[26px]">
                                        <div className="dot-typing-loader">
                                        </div>
                                    </div>}
                                    <div className="p-2 text-4xl h-14 font-bold" />
                                    <div className='w-full h-full flex flex-col'>
                                        {originalText &&
                                            <div ref={originalPromptRef} className="w-full flex flex-col">
                                                <ChatBubble
                                                    key={isSwitchOn ? "on" : "off"}
                                                    maxHeight={isSwitchOn ? "28vh" : "70vh"}
                                                    text={originalText}
                                                    role="user"
                                                    targetRGBA="rgba(250, 234, 247, 1)"
                                                    stats={format_original_text_stats(originalTextStats)}
                                                />
                                            </div>
                                        }
                                        {isSwitchOn && originalResponse && sqwishResponse && originalResponseStats && (
                                            <div className="absolute w-full flex flex-col bottom-0">
                                                <ChatBubble maxHeight="45vh" text={originalResponse} role={INFERENCE_MODEL} targetRGBA="rgba(244, 204, 236, 1)" stats={format_original_response_stats(originalResponseStats)} />
                                            </div>
                                        )}
                                    </div>

                                </div>

                                {/* Right Column */}
                                <div className="relative flex flex-col items-start lg:w-[50%] max-w-xl h-[80vh] min-h-[450px] bg-yellow px-6 rounded-[20px]">
                                    {((loading.sqwish || (isSwitchOn && loading.sqwish_reponse)) || !(sqwishResponse && sqwishResponseStats)) && <div className="self-end pr-6 lg:pr-0 lg:self-center absolute top-[26px]">
                                        <div className="dot-typing-loader">
                                        </div>
                                    </div>}
                                    <div className="p-2 text-4xl font-bold">
                                        <a href="/" target="_blank" rel="noopener noreferrer">
                                            <img src={sqwishLogo} alt="Sqwish Logo" className="h-10 w-auto" />
                                        </a>
                                    </div>
                                    <div className='w-full h-full flex flex-col'>
                                        {sqwishText && sqwishTextStats &&
                                            <>
                                                <div ref={compressedPromptRef} className="w-full flex flex-col">
                                                    <ChatBubble
                                                        key={isSwitchOn ? "on" : "off"}
                                                        maxHeight={isSwitchOn ? "20vh" : "50vh"}
                                                        text={sqwishText}
                                                        role="sqwish"
                                                        targetRGBA="rgba(255, 243, 215, 1)"
                                                        stats={format_sqwish_text_stats(sqwishTextStats)}
                                                    />
                                                </div>
                                            </>
                                        }
                                        {!(sqwishText && sqwishTextStats) &&
                                            <>
                                                <div className="w-full self-end flex flex-col h-[20vh] items-center justify-center mt-4 min-w-[280px]">
                                                    <div className="italic">
                                                        You are in a short queue…
                                                    </div>
                                                </div>
                                            </>
                                        }
                                        {isSwitchOn && sqwishResponse && sqwishResponseStats && (
                                            <div className="absolute w-full flex flex-col bottom-0">
                                                <ChatBubble
                                                    maxHeight="45vh" text={sqwishResponse} role={INFERENCE_MODEL} targetRGBA="rgba(255, 226, 155, 1)" stats={format_sqwish_response_stats(sqwishResponseStats)} />
                                            </div>
                                        )}
                                    </div>
                                </div>
                            </div>
                            <div className='absolute bottom-0 left-1/2 -translate-x-1/2 translate-y-1/2'>
                                <CSSTransition
                                    in={!!(originalText && sqwishText)}
                                    nodeRef={nodeRef}
                                    timeout={300}
                                    classNames="fade"
                                    unmountOnExit
                                >
                                    <div ref={nodeRef} className="w-56 h-16 rounded-[20px] bg-[#bfe1a0] border-2 border-black flex flex-col items-center justify-center mt-12 lg:mt-0">
                                        <div className="w-full flex flex-row">
                                            <div className="w-32 flex items-center justify-center text-4xl font-bold ml-3">
                                                {calculateTokensCompression()}
                                            </div>
                                            <div className="w-28 flex flex-col justify-center items-center text-center text-xs">
                                                <span>input tokens</span>
                                                <span>removed</span>
                                            </div>
                                        </div>
                                    </div>
                                </CSSTransition>
                            </div>
                        </div>
                    )}
                </div>
            </div >
        </div>
    );
};

export default Playground;
