import React, { useState, useEffect } from 'react';
import { llm_inference, sqwish } from './api';
import llamaImg from './images/llama3-70b.webp';
import sqwishImg from './images/sqwish-llama3-70b.webp';
import { CSSTransition } from 'react-transition-group';

interface TextProps {
    text: string | null;
}

interface TokenCounts {
    input: number;
    llmOutput: number;
    sqwishInput: number;
    finalOutput: number;
}

const parseTextToBold = (text: string | null) => {
    if (!text) return null;
    const parts = text.split(/(\*\*[^*]+\*\*)/g);
    return parts.map((part, index) => {
        if (part.startsWith('**') && part.endsWith('**')) {
            return <strong key={index}>{part.slice(2, -2)}</strong>;
        }
        return part;
    });
};

const CustomizedParagraph: React.FC<TextProps> = ({ text }) => {
    return (
        <p className="custom-paragraph whitespace-pre-wrap">
            {parseTextToBold(text)}
        </p>
    );
};

const countTokens = async (text: string, model: string = "gpt-4") => {
    const response = await fetch('/api/v1/count_tokens', {
        method: 'POST',
        headers: {
            'Content-Type': 'application/json',
        },
        body: JSON.stringify({ text, model }),
    });
    const data = await response.json();
    return data.token_count;
};

const App: React.FC = () => {
    const [inputText, setInputText] = useState('');
    const [submitText, setSubmitText] = useState<string | null>(null);
    const [llmResponse, setLlmResponse] = useState<string | null>(null);
    const [sqwishResponse, setSqwishResponse] = useState<string | null>(null);
    const [finalResponse, setFinalResponse] = useState<string | null>(null);
    const [tokenCounts, setTokenCounts] = useState<TokenCounts>({
        input: 0,
        llmOutput: 0,
        sqwishInput: 0,
        finalOutput: 0
    });
    const [loading, setLoading] = useState({
        llm: false,
        sqwish: false,
        final: false
    });

    const showCompression = !!(submitText && finalResponse && sqwishResponse && llmResponse && (tokenCounts.input > tokenCounts.sqwishInput));

    useEffect(() => {
        const updateTokenCounts = async () => {
            if (submitText) {
                const inputTokens = await countTokens(submitText);
                setTokenCounts(prev => ({ ...prev, input: inputTokens }));
            }
            if (llmResponse) {
                const llmOutputTokens = await countTokens(llmResponse);
                setTokenCounts(prev => ({ ...prev, llmOutput: llmOutputTokens }));
            }
            if (sqwishResponse) {
                const sqwishInputTokens = await countTokens(sqwishResponse);
                setTokenCounts(prev => ({ ...prev, sqwishInput: sqwishInputTokens }));
            }
            if (finalResponse) {
                const finalOutputTokens = await countTokens(finalResponse);
                setTokenCounts(prev => ({ ...prev, finalOutput: finalOutputTokens }));
            }
        };

        updateTokenCounts();
    }, [submitText, llmResponse, sqwishResponse, finalResponse]);

    const llm_response = async (input_text: string) => {
        const llmResult = await llm_inference(input_text);
        setLlmResponse(llmResult);
        setLoading((prev) => ({ ...prev, llm: false }));
    };

    const sqwish_response = async (input_text: string) => {
        const sqwishResult = await sqwish(input_text);
        setSqwishResponse(sqwishResult);
        setLoading((prev) => ({ ...prev, sqwish: false, final: true }));

        const finalResult = await llm_inference(sqwishResult);
        setFinalResponse(finalResult);
        setLoading((prev) => ({ ...prev, final: false }));
    };

    const handleSubmit = async () => {
        setSubmitText(inputText);
        setLlmResponse(null);
        setSqwishResponse(null);
        setFinalResponse(null);
        setLoading((prev) => ({ ...prev, llm: true, sqwish: true }));
        await Promise.all([llm_response(inputText), sqwish_response(inputText)]);
    };

    const handleKeyDown = (e: React.KeyboardEvent<HTMLTextAreaElement>) => {
        if (e.key === 'Enter') {
            e.preventDefault();
            handleSubmit();
        }
    };

    return (
        <div className="min-h-screen flex flex-col items-center justify-center bg-gray-100">
            <div className="w-full flex flex-col items-center justify-center">
                <div className="grid grid-cols-2 gap-2 relative">
                    <div className="p-4 bg-[#eabca8] shadow rounded w-[600px] h-[660px] rounded-[20px] relative">
                        <img src={llamaImg} alt="Logo Llama" className="absolute top-1/2 h-[300px] transform -translate-y-1/2 left-20"/>
                        <div className="absolute right-4 h-full">
                            <div className="grid grid-cols-1 gap-6">
                                <div className="p-4 bg-white shadow rounded w-96 h-48 relative rounded-[20px]">
                                    <h2 className="font-bold text-center mb-3">Original Input</h2>
                                    <textarea
                                        className="w-full p-2 border rounded h-32"
                                        placeholder="Enter text here"
                                        value={inputText}
                                        onChange={(e) => setInputText(e.target.value)}
                                        onKeyDown={handleKeyDown}
                                    />
                                    <button
                                        className="absolute bottom-2 right-2 px-3 py-1 bg-[#e4491e] text-white rounded rounded-[20px]"
                                        onClick={handleSubmit}
                                    >
                                        ➔
                                    </button>
                                </div>
                                <div className="p-4 bg-white shadow rounded w-96 h-[380px] rounded-[20px] relative">
                                    <h2 className="font-bold text-center mb-3 italic">Response to Original</h2>
                                    <div className="h-[310px] w-full overflow-y-auto">
                                        {loading.llm ? <div className="h-full flex items-center justify-center"><div className="loader"></div></div> : <CustomizedParagraph text={llmResponse} />}
                                    </div>             
                                </div>  
                            </div>
                            {/* {(submitText && llmResponse) && (
                                <p className='mt-3 ml-2 whitespace-pre-wrap'>
                                    <strong>Input tokens: </strong>{`${tokenCounts.input} tkns / `}
                                    <strong>Output tokens: </strong>{`${tokenCounts.llmOutput} tkns \n`}
                                    <strong>LLM cost: </strong>{`$${((tokenCounts.input + tokenCounts.llmOutput) * 0.03 / 1000).toFixed(4)}`}
                                </p>
                            )} */}
                        </div>            
                    </div>
                    <div className="p-4 bg-[#ccd5ae] shadow rounded w-[600px] h-[660px] rounded-[20px] relative">
                        <img src={sqwishImg} alt="Logo Sqwish" className="absolute top-1/2 h-[300px] transform -translate-y-1/2 right-8"/>
                        <div className="grid grid-cols-1 gap-6">
                            <div className="p-4 bg-white shadow rounded w-96 h-48 relative rounded-[20px]">
                                <h2 className="font-bold text-center mb-3 italic">Sqwished Input</h2>
                                <div className="h-32 w-full overflow-y-auto">
                                    {loading.sqwish ? <div className="h-full flex items-center justify-center"><div className="loader"></div></div> : <CustomizedParagraph text={sqwishResponse} />}
                                </div>
                            </div>
                            
                            <div className="p-4 bg-white shadow rounded w-96 h-[380px] rounded-[20px] relative">
                                <h2 className="font-bold text-center mb-3 italic">Response to Sqwished</h2>
                                <div className="h-[310px] w-full overflow-y-auto">
                                    {loading.final ? <div className="h-full flex items-center justify-center"><div className="loader"></div></div> : <CustomizedParagraph text={finalResponse} />}
                                </div>
                            </div>
                        </div>
                        {/* {(submitText && finalResponse && sqwishResponse) && (
                            <p className='mt-3 ml-2 whitespace-pre-wrap'>
                                <strong>Input tokens: </strong>{`${tokenCounts.sqwishInput} tkns / `}
                                <strong>Output tokens: </strong>{`${tokenCounts.finalOutput} tkns \n`}
                                <strong>LLM cost: </strong>{`$${((tokenCounts.sqwishInput + tokenCounts.finalOutput) * 0.03 / 1000).toFixed(4)} `}
                                <strong>Sqwish cost: </strong>{`$${(tokenCounts.input * 0.0015 / 1000).toFixed(4)} `}
                                <strong>Total cost: </strong>{`$${((tokenCounts.sqwishInput + tokenCounts.finalOutput) * 0.03 / 1000 + tokenCounts.input * 0.0015 / 1000).toFixed(4)}`}
                            </p>
                        )} */}
                    </div>
                    <div className='absolute bottom-0 left-1/2 -translate-x-1/2 translate-y-1/2'>
                        <CSSTransition
                            in={showCompression}
                            timeout={300}
                            classNames="fade"
                            unmountOnExit
                        >
                            <div className="w-56 h-16 rounded rounded-[20px] bg-white border border-black flex">
                                <div className="w-32 flex items-center justify-center text-4xl font-bold ml-3">
                                    {(tokenCounts.input && tokenCounts.sqwishInput) ? `${((100 * (tokenCounts.input - tokenCounts.sqwishInput)) / tokenCounts.input).toFixed(0)}%` : null}
                                </div>
                                <div className="w-28 flex flex-col justify-center items-center text-center text-xs">
                                    <span>input tokens</span>
                                    <span>removed</span>
                                </div>
                            </div>
                        </CSSTransition>
                    </div>
                </div>
            </div>
        </div>
    );
};

export default App;