import React, { useState } from 'react';
import {
  useGetSelectedProjectQuery,
  useGetProjectVersionedDefectsQuery,
} from '@/serverStore/projects';
import { RegisteredModel } from '@clef/shared/types';
import { useDatasetExportedWithVersionsQuery } from '@/serverStore/dataset';
import { Predict } from '@/components/Predict';
import { DeployComponent } from '../../FirstRunLiveExperienceDialog/DeployComponent';
import { makeStyles } from '@material-ui/core';
import { ReachLimitModelDialog } from '../../TrainModelButtonGroup/ReachLimitModelDialog';
import { useCheckCreditReachLimit } from '@/hooks/useSubscriptions';
import { queryClient } from '@/serverStore';
import { projectModelQueryKeys } from '@/serverStore/projectModels';

const useStyles = makeStyles(theme => ({
  modalBanner: {
    padding: theme.spacing(3),
  },
}));

export interface TryModelDialogProps {
  model: RegisteredModel;
  onClose: () => void;
}

export const TryModelDialog: React.FC<TryModelDialogProps> = ({
  model,
  onClose,
}: TryModelDialogProps) => {
  const {
    name,
    id: projectId,
    datasetId,
    labelType,
  } = useGetSelectedProjectQuery().data ?? {
    name: '',
    id: undefined,
    datasetId: undefined,
  };
  const styles = useStyles();
  const { datasetVersionId } = model;
  const { data: datasetExported } = useDatasetExportedWithVersionsQuery({
    includeNotCompleted: true,
    includeFastEasy: true,
  });
  const versionedDatasetContentId = datasetExported?.datasetVersions?.find(
    e => e.id === datasetVersionId,
  )?.version;

  const { hasReachLimit } = useCheckCreditReachLimit();
  const shouldOpenReachLimitDialog = hasReachLimit;
  const [openReachLimitModelDialog, setOpenReachLimitModelDialog] = useState(false);
  const allDefects = useGetProjectVersionedDefectsQuery(versionedDatasetContentId).data!;

  const [thresholdForPredict, setThresholdForPredict] = useState(model.confidence);
  if (!projectId || !datasetId) return null;

  const modelInfo = {
    id: model.id,
    threshold: thresholdForPredict ?? undefined,
  };

  if (shouldOpenReachLimitDialog) {
    return (
      <>
        {datasetId && projectId && (
          <ReachLimitModelDialog
            projectId={projectId}
            datasetId={datasetId}
            open={openReachLimitModelDialog}
            onClose={() => {
              setOpenReachLimitModelDialog(false);
            }}
          />
        )}
      </>
    );
  }
  return (
    <Predict
      selectedProject={{
        projectId,
        datasetId,
        name,
        labelType,
      }}
      modelInfo={modelInfo}
      updateThreshold={newThreshold => {
        setThresholdForPredict(newThreshold);
      }}
      defectMap={allDefects}
      title="Try this model"
      onCloseModal={() => {
        onClose?.();
        queryClient.invalidateQueries(projectModelQueryKeys.list(projectId));
      }}
      classes={{
        bannerContainer: styles.modalBanner,
      }}
      bannerComponent={null}
      deployComponent={<DeployComponent modelInfo={modelInfo} labelType={labelType} />}
    />
  );
};

export default TryModelDialog;
