import { message, type FormInstance } from 'antd'
import { useRef } from 'react'
import { LLMChannels } from '@apis/llm/model'
import {
  LLMMessageStructType,
  getModelMessageStructTypeByChannel,
  transformLLMMessage,
} from '@/features/nodes/utils/llm'
import type { LLMNodeData, LLMNodeDataForm } from '@/features/nodes/llm'
import { useLLMStore } from '@/store/llm'

export interface Message {
  role: 'system' | 'user' | 'assistant'
  content: string
}

export function useModelChange(
  form: FormInstance<LLMNodeDataForm>,
  propsChannel: LLMChannels,
) {
  const { llmModelList } = useLLMStore()

  const lastModelChannel = useRef<LLMChannels | undefined>(propsChannel)
  const model = form.getFieldValue(['inputs', 'modelSetting'])
  const channel = model?.channel
  const modelMessageStructType = getModelMessageStructTypeByChannel(channel!)

  // 检查模型切换是否需要做数据的兼容转换
  const checkMessageTransform = (
    newChannel: LLMChannels,
    oldChannel: LLMChannels,
  ) => {
    const newModelMessageStructType =
      getModelMessageStructTypeByChannel(newChannel)
    const oldModelMessageStructType =
      getModelMessageStructTypeByChannel(oldChannel)
    return {
      needTransform: newModelMessageStructType !== oldModelMessageStructType,
      newModelMessageStructType,
      oldModelMessageStructType,
    }
  }

  const checkAndTransformMessage = (
    messages: Message[],
    newChannel: LLMChannels,
    oldChannel: LLMChannels,
  ) => {
    const {
      needTransform,
      newModelMessageStructType,
      oldModelMessageStructType,
    } = checkMessageTransform(newChannel, oldChannel)
    if (needTransform) {
      const newMessage = transformLLMMessage(
        messages,
        oldModelMessageStructType,
        newModelMessageStructType,
      )
      form.setFieldValue(['inputs', 'messages'], newMessage)
      if (newModelMessageStructType === LLMMessageStructType.LIKE_BAIDU) {
        // 类GPT切类百度的时候给用户提示，因为GPT结构兼容百度结构，反过来不兼容
        message.success(
          '该模型多轮对话结构与上一个模型不兼容，已自动转换，请检查',
        )
      }
      return newMessage
    }
    return messages
  }

  const checkAndSafetyStream = (channel: LLMChannels, stream: boolean) => {
    if (channel === LLMChannels.XunFei && !stream) {
      message.success('讯飞模型不支持非流式输出，已自动切换')
      form.setFieldValue(['inputs', 'stream'], true)
      return true
    }
    return stream
  }

  const beforeChange = (inputs: {
    modelSetting: LLMNodeDataForm['inputs']['modelSetting']
    stream: boolean
    messages?: Message[]
  }) => {
    const { modelSetting, stream, messages } = inputs
    const {
      channel,
      model,
      temperature,
      top_p,
      presence_penalty,
      frequency_penalty,
      outputType,
    } = modelSetting

    const modelDetail = llmModelList.find(
      each => each.model === modelSetting.model,
    )

    let surePresencePenalty = presence_penalty
    if (modelDetail?.feature?.param_config) {
      surePresencePenalty =
        surePresencePenalty ?? modelDetail.feature.param_config.default!

      surePresencePenalty = Math.min(
        surePresencePenalty,
        modelDetail.feature.param_config.max!,
      )

      surePresencePenalty = Math.max(
        surePresencePenalty,
        modelDetail.feature.param_config.min!,
      )
    }

    form.setFieldValue(
      ['inputs', 'modelSetting', 'presence_penalty'],
      surePresencePenalty,
    )

    const newInputs: LLMNodeData['inputs'] = {
      model,
      channel,
      temperature: temperature || 0.7,
      top_p: top_p || 0.6,
      presence_penalty: surePresencePenalty,
      frequency_penalty: frequency_penalty || 0.0,
      plugin: { json_mode: outputType === 'json' },
      stream: checkAndSafetyStream(channel as LLMChannels, stream),
      ...(messages
        ? {
            messages: checkAndTransformMessage(
              messages,
              channel as LLMChannels,
              lastModelChannel.current!,
            ),
          }
        : {}),
    }

    lastModelChannel.current = channel as LLMChannels
    return newInputs
  }

  return {
    beforeChange,
    modelMessageStructType,
  }
}
