import { ArrowUpward, AttachFile } from "@mui/icons-material";
import {
  CircularProgress,
  Fab,
  IconButton,
  InputBase,
  Paper,
} from "@mui/material";
import { styled } from "@mui/material/styles";
import { Session } from "@supabase/supabase-js";
import { useSnackbar } from "notistack";
import React, { useState } from "react";
import { Edge, Node } from "reactflow";
import BackendApi from "../../BackendApi";
import { aiRequest } from "./ai_request";

// Styled components
const StyledPaper = styled(Paper)(({ theme }) => ({
  p: "2px 4px",
  display: "flex",
  alignItems: "center",
  backgroundColor: theme.palette.background.paper,
  borderRadius: 30,
  minWidth: "50vw",
  height: 50,
}));

interface OpenAIInputProps {
  placeholder?: string;
  maxRows?: number;
  nodes: Array<Node>;
  edges: Array<Edge>;
  backendApi: BackendApi;
  activeOrgId: string;
  session: Session;
  addAiResponse: Function;
  setAiMessages: Function;
}

const AIInput: React.FC<OpenAIInputProps> = ({
  placeholder = "Ask a question about your software...",
  maxRows = 1,
  nodes,
  edges,
  backendApi,
  activeOrgId,
  session,
  addAiResponse,
  setAiMessages,
}) => {
  const [prompt, setPrompt] = useState<string>("");
  const [loading, setLoading] = useState(false);

  const { enqueueSnackbar } = useSnackbar();

  const getNodeName = (node) => {
    if (node.name) {
      return `${node.data.name} (${node.data.label}) ${node.data.product_id}`;
    }
    return `${node.data.label} ${node.data.product_id}`;
  };

  const getNodeNameForEdge = (nodeId, nodeIdToName) => {
    if (nodeId in nodeIdToName) {
      return nodeIdToName[nodeId];
    }
    return "Not Connected";
  };

  const getEdgeName = (edge, nodeIdToName) => {
    return `${getNodeNameForEdge(
      edge.source,
      nodeIdToName
    )} -> ${getNodeNameForEdge(edge.target, nodeIdToName)}`;
  };

  const buildTextGraph = () => {
    const nodeIdToName = {};
    nodes.forEach((node) => {
      nodeIdToName[node.id] = getNodeName(node);
    });
    var retStr = "Graph Nodes";
    retStr += "\n" + nodes.map((node) => getNodeName(node)).join("\n");
    retStr += "\n\nGraph Edges";
    retStr +=
      "\n" + edges.map((edge) => getEdgeName(edge, nodeIdToName)).join("\n");
    return retStr;
  };

  const handleSubmitReq = async (prompt) => {
    setLoading(true);
    const textGraph = buildTextGraph();
    const { newNodes, newEdges } = await aiRequest(
      prompt,
      textGraph,
      backendApi,
      session,
      nodes,
      setAiMessages,
      enqueueSnackbar
    );
    addAiResponse(newNodes, newEdges);
    setLoading(false);
  };

  const handleSubmit = async (e: React.FormEvent) => {
    e.preventDefault();
    if (prompt.trim() && !loading) {
      await handleSubmitReq(prompt.trim());
      // Only clear if submission was successful
      setPrompt("");
    }
  };

  const handleClear = () => {
    setPrompt("");
  };

  return (
    <form onSubmit={handleSubmit}>
      <StyledPaper elevation={5}>
        <IconButton disableRipple disableFocusRipple>
          <AttachFile sx={{ mr: 1, ml: 1 }} />
        </IconButton>
        <InputBase
          fullWidth
          sx={{ ml: 1, flex: 1 }}
          placeholder={placeholder}
          disabled={loading}
          value={prompt}
          onChange={(e) => setPrompt(e.target.value)}
        />
        <Fab
          sx={{ boarderRadius: 10, mr: 1, width: 35, height: 35 }}
          disableRipple
          disableFocusRipple
          type="submit"
          disabled={!prompt.trim() || loading}
        >
          {loading ? (
            <CircularProgress size={24} color="inherit" />
          ) : (
            <>
              <ArrowUpward />
            </>
          )}
        </Fab>
      </StyledPaper>
    </form>
  );
};

export default AIInput;
