import { BaseAPI } from './base_api';
import {
  CursorPaginatedResponse,
  ModelClassConfig,
  JobsListResponse,
  JobInfo,
  TBServerInfo,
  ModelsListResponse,
  ModelInfo,
  ApiResponse,
  SingleModelReportSummaryData,
  JobInfoParsed,
} from '@clef/shared/types';
import { API_GATEWAY_URL } from '../constants';

export const maglevPrefix = 'train';

class MaglevApi extends BaseAPI {
  async getModelArchConfigs(): Promise<ModelClassConfig[]> {
    return this.get(`model_archs`, /* params = */ {}, /* dataOnly = */ true, {
      credentials: 'include',
    });
  }

  async getTrainedModels(
    projectId: number,
    includeEval: boolean = false,
    limit: number = 50,
  ): Promise<ModelsListResponse> {
    return this.get(
      `models`,
      { projectId, includeEval, limit } /* params = */,
      true /* dataOnly = */,
    );
  }

  async getTrainedModelsPage(
    params: {
      sortOrder?: 'asc' | 'desc';
      projectId: number;
      favorite?: boolean;
      sortBy?: string;
      limit?: number;
    } & ({ startingAfter?: string } | { endingBefore?: string }),
  ): Promise<CursorPaginatedResponse<ModelInfo>> {
    return this.get(`v2/models`, params, true /* dataOnly = */);
  }

  async postCreateTrainingJob(params: object): Promise<{}> {
    return this.postJSON(`create_training_run`, params);
  }

  async postCreateFastTrainingJob(params: object): Promise<ApiResponse<{ runId: string }>> {
    return this.postJSON(`v2/create_fast_training_run/`, params);
  }

  async postRetryFastTrainingJob(params: {
    jobId: string;
    projectId: number;
    orgId: number;
  }): Promise<ApiResponse<{ jobId: string }>> {
    return this.postJSON(`retry_fast_training_run/`, params);
  }

  async postSaveFastModel(params: {
    projectId: number;
    jobId: string;
    name: string;
    description?: string;
    testSplitKey: string;
    scoreThreshold: number;
    exportedDatasetId: number;
  }): Promise<{}> {
    return this.postJSON(`save_fast_model`, params);
  }

  async postCreateEvalJob(params: object): Promise<{}> {
    return this.postJSON(`create_eval_run`, params);
  }

  async postFavoriteEvalJob(params: {
    isFavorited: boolean;
    projectId: number;
    jobId: string;
  }): Promise<{}> {
    return this.postJSON(`job/favorite`, params);
  }

  async postFavoriteModel(params: {
    isFavorited: boolean;
    trainJobId: string;
    projectId: number;
  }): Promise<{}> {
    return this.postJSON(`model/favorite`, params);
  }

  async postBestModel(params: { projectId: number; trainJobId: string }): Promise<{}> {
    return this.postJSON(`model/best`, params);
  }

  async postStartTensorboard(projectId: number, jobId: string): Promise<ApiResponse<TBServerInfo>> {
    return this.postJSON(
      `kickstart_tb_server`,
      /* body = */ { projectId, jobId },
      {
        credentials: 'include',
      },
    );
  }

  async getJobs(projectId: number, limit: number = 50): Promise<JobsListResponse> {
    return this.get(`jobs`, /* params = */ { projectId, limit }, /* dataOnly = */ true, {
      credentials: 'include',
    });
  }

  async getEvalJobsPage(
    params: {
      trainJobId: ModelInfo['trainJobId'];
      sortOrder?: 'asc' | 'desc';
      projectId: number;
      favorite?: boolean;
      sortBy?: string;
      limit: number;
    } & ({ startingAfter?: string } | { endingBefore?: string }),
  ): Promise<CursorPaginatedResponse<JobInfoParsed & { summary?: SingleModelReportSummaryData }>> {
    return this.get(`v3/jobs/eval`, params, true /* dataOnly = */);
  }

  async getJobsPage(
    params: {
      sortOrder?: 'asc' | 'desc';
      projectId: number;
      sortBy?: string;
      limit: number;
    } & ({ startingAfter?: string } | { endingBefore?: string }),
  ): Promise<CursorPaginatedResponse<JobInfoParsed>> {
    return this.get(`v2/jobs`, params, true /* dataOnly = */);
  }

  async getJobDetails(projectId: number, jobId: string, orgId?: number): Promise<JobInfoParsed> {
    const res: JobInfo[] = await this.get(
      `job`,
      { projectId, jobId, orgId },
      /* dataOnly = */ true,
      { credentials: 'include' },
    );
    if (res && res[0]) {
      const jobInfo = res[0];
      let hyperParams = undefined;
      try {
        hyperParams = jobInfo.hyperParams ? JSON.parse(jobInfo.hyperParams) : undefined;
      } catch (err) {
        hyperParams = undefined;
      }
      return {
        ...jobInfo,
        hyperParams,
      } as JobInfoParsed;
    }
    return undefined as never;
  }

  async getJobLogs(projectId: number, jobId: string): Promise<string> {
    return this.get(
      `job/${jobId}/truncated_log`,
      /* params = */ { projectId },
      /* dataOnly = */ true,
      {
        credentials: 'include',
      },
    );
  }

  async postStopTraining(params: { projectId: number; jobId: string }): Promise<{}> {
    return this.postJSON(`stop_fast_training_run`, params);
  }
}

export default new MaglevApi(maglevPrefix, API_GATEWAY_URL);
