import { Vector2, Vector3 } from 'three';
import type * as zod from 'zod';

import type { CartesianPose, Plane } from '@sb/geometry';
import {
  findRayPlaneIntersection,
  applyCompoundPose,
  getTooltipOrientationPerpendicularToPlane,
  cameraPoseFromWristPose,
  castCameraRay,
  ZERO_POSE,
} from '@sb/geometry';
import type { CameraIntrinsics } from '@sb/integrations/camera/types';
import { getCalibrationOffset } from '@sb/integrations/WristCamera/calibrationHelpers';
import type { CalibrationData } from '@sb/integrations/WristCamera/types/CalibrationData';
import type {
  WristCameraConfiguration,
  WristCameraAccuracyCalibrationEntry,
} from '@sb/integrations/WristCamera/types/Configuration';
import * as log from '@sb/log';
import type { ArmJointPositions } from '@sb/motion-planning';
import type { Blob2D } from '@sb/routine-runner';
import { FailureKind } from '@sb/routine-runner/FailureKind';
import type { ArmPosition } from '@sb/routine-runner/types';
import { six } from '@sb/utilities';

import Step from '../Step';
import type { StepPlayArguments } from '../Step';

import Arguments from './Arguments';
import Variables from './Variables';

const ns = log.namespace('locateStep');

type Arguments = zod.infer<typeof Arguments>;

type Variables = zod.infer<typeof Variables>;

/**
 * We save a dummy set of joint angles for each position. Calculation of the joint angles
 * to use to reach the position should be deferred until a move arm step is played
 */
const NULL_JOINT_ANGLES: ArmJointPositions = six(0);

export default class LocateStep extends Step<Arguments, Variables> {
  public static areSubstepsRequired = false;

  public static Arguments = Arguments;

  public static Variables = Variables;

  public initializeVariableState(): void {
    this.variables = {
      latestResult: null,
      resultCount: 0,
    };
  }

  private getPlaneFromSpaceItemID(planeID: string | undefined): Plane {
    if (!planeID) {
      throw new Error('No planeID provided');
    }

    const plane = this.routineContext.getSpaceItem(planeID);

    if (plane.kind !== 'plane') {
      throw new Error(`SpaceItem ${planeID} is not a plane`);
    }

    const planePositions = plane.positions.map(
      (position) =>
        new Vector3(position.pose.x, position.pose.y, position.pose.z),
    );

    return {
      origin: planePositions[0],
      plusX: planePositions[1],
      plusY: planePositions[2],
    };
  }

  /**
   * Takes 2D pixel coordinates from a camera image and returns
   * the 3D coordinates of the point in the base's coordinate system
   * */
  private deproject(
    blobResult: { x: number; y: number; rotation: number },
    cameraIntrinsics: CameraIntrinsics,
    cameraPose: CartesianPose,
    plane: Plane,
    cameraCalibrationCorrection: CartesianPose,
  ): CartesianPose | undefined {
    const rayOrigin = {
      x: cameraPose.x,
      y: cameraPose.y,
      z: cameraPose.z,
    };

    const rayDirection = castCameraRay(
      cameraIntrinsics,
      cameraPose,
      new Vector2(blobResult.x, blobResult.y),
    );

    // Now that we have ray and base in same coordinate system
    // we can find the intersection of the ray with the plane
    const intersectionPoint = findRayPlaneIntersection(
      rayOrigin,
      rayDirection,
      plane,
    );

    if (intersectionPoint === undefined) {
      return undefined;
    }

    const tooltipPose = getTooltipOrientationPerpendicularToPlane(
      plane,
      blobResult.rotation,
    );

    const intersection: CartesianPose = {
      ...intersectionPoint,
      i: tooltipPose.x,
      j: tooltipPose.y,
      k: tooltipPose.z,
      w: tooltipPose.w,
    };

    // Directly apply the calibration correction to the intersection point
    return applyCompoundPose(intersection, cameraCalibrationCorrection);
  }

  protected getCalibrationCorrection(
    jointAngles: ArmJointPositions,
    accuracyCalibration: WristCameraAccuracyCalibrationEntry[],
  ): CartesianPose {
    if (accuracyCalibration.length === 0) {
      log.info(ns`locate.play`, 'No calibration data found');

      return ZERO_POSE;
    }

    const calibrationData: CalibrationData = {
      jointPositionsList: accuracyCalibration.map((entry) => entry.jointAngles),
      offsets: accuracyCalibration.map((entry) => entry.offset),
    };

    const calibrationOffset = getCalibrationOffset(
      jointAngles,
      calibrationData,
    );

    log.info(ns`locate.play`, 'Calibration offset', calibrationOffset);

    return calibrationOffset;
  }

  private getCameraConfig(): WristCameraConfiguration | undefined {
    // TODO(bpatmiller) - this approach does not seem to work with the existing wrist cam implementation
    const wristCameras = this.routineContext.equipment
      .getEquipmentList()
      .filter((equipment) => equipment.kind === 'WristCamera');

    if (wristCameras.length === 0) {
      return undefined;
    }

    return wristCameras[0] as WristCameraConfiguration;
  }

  private async get2DBlobs(): Promise<Blob2D[]> {
    const { method, regionOfInterest, camera } = this.args;
    const image = await this.routineContext.camera.getColorFrame(camera);

    switch (method.kind) {
      case 'BlobDetection2D': {
        return this.routineContext.vision.detect2DBlobs(
          image,
          regionOfInterest,
          method.settings,
        );
      }
      case 'ShapeDetection2D': {
        return this.routineContext.vision.detect2DShapes(
          image,
          method.templateImage,
          regionOfInterest,
          method.settings,
        );
      }
      default:
        throw new Error(`Unsupported method ${method}`);
    }
  }

  public async _play({ fail }: StepPlayArguments): Promise<void> {
    try {
      log.info(ns`locate.play`, 'Playing Locate step');

      const { transform, filters, planeID } = this.args;

      const blobs2D = await this.get2DBlobs();

      const plane = this.getPlaneFromSpaceItemID(planeID);

      const cameraConfig = this.getCameraConfig();

      let cameraCorrection = ZERO_POSE;

      if (cameraConfig) {
        cameraCorrection = this.getCalibrationCorrection(
          this.routineContext.getRoutineRunnerState().kinematicState
            .jointAngles,
          cameraConfig.accuracyCalibration,
        );
      }

      let intrinsics: CameraIntrinsics;

      if (cameraConfig?.intrinsics) {
        // if intrinsics calibration has been done, use the saved intrinsics
        intrinsics = cameraConfig.intrinsics;
      } else {
        // otherwise fetch the intrinsics that shipped with the camera
        intrinsics = await this.routineContext.camera.getIntrinsics();
      }

      const { wristPose } =
        this.routineContext.getRoutineRunnerState().kinematicState;

      const cameraPose = cameraPoseFromWristPose(wristPose);

      log.info(ns`locate.play`, 'Blobs detected', blobs2D);

      blobs2D.sort((a, b) => b.score - a.score); // note score is currently meaningless

      const filteredBlobs2D = blobs2D.slice(
        0,
        filters.resultsLimit ?? blobs2D.length,
      );

      const blobs6D = filteredBlobs2D
        .map((blob) =>
          this.deproject(blob, intrinsics, cameraPose, plane, cameraCorrection),
        )
        .filter((blob) => blob !== undefined) as CartesianPose[];

      // apply user-defined offset and rotation
      const transformedBlobs6D: CartesianPose[] = blobs6D.map((blob) => {
        return applyCompoundPose(blob, transform);
      });

      log.info(ns`locate.play`, 'Transformed blobs', transformedBlobs6D);

      const positions: ArmPosition[] = transformedBlobs6D.map((blob) => {
        return { pose: blob, jointAngles: NULL_JOINT_ANGLES };
      });

      this.setVariable('latestResult', positions);
      this.setVariable('resultCount', positions.length);

      this.routineContext.setPositionListEntries(
        this.args.positionListID,
        positions,
      );
    } catch (error) {
      log.error(ns`locate.play`, 'Locate failed', error);

      return fail({
        failure: {
          kind: FailureKind.StepPlayFailure,
          stepKind: 'Locate',
        },
        failureReason: `${error.message}`,
        error,
      });
    }
  }
}
