program(1.0) [buildInfo = dict, tensor>({{"coremlc-component-MIL", "3401.3.1"}, {"coremlc-version", "3401.4.1"}, {"coremltools-component-torch", "2.6.0"}, {"coremltools-source-dialect", "TorchScript"}, {"coremltools-version", "8.2"}})] { func main(tensor decoder_input_ids, tensor state_1, tensor state_2) { tensor input_1_axis_0 = const()[name = tensor("input_1_axis_0"), val = tensor(0)]; tensor input_1_batch_dims_0 = const()[name = tensor("input_1_batch_dims_0"), val = tensor(0)]; tensor input_1_validate_indices_0 = const()[name = tensor("input_1_validate_indices_0"), val = tensor(false)]; tensor prediction_embed_weight_to_fp16 = const()[name = tensor("prediction_embed_weight_to_fp16"), val = tensor(BLOBFILE(path = tensor("@model_path/weights/weight.bin"), offset = tensor(64)))]; tensor decoder_input_ids_to_int16_dtype_0 = const()[name = tensor("decoder_input_ids_to_int16_dtype_0"), val = tensor("int16")]; tensor decoder_input_ids_to_int16 = cast(dtype = decoder_input_ids_to_int16_dtype_0, x = decoder_input_ids)[name = tensor("cast_6")]; tensor input_1_cast_fp16_cast_uint16 = gather(axis = input_1_axis_0, batch_dims = input_1_batch_dims_0, indices = decoder_input_ids_to_int16, validate_indices = input_1_validate_indices_0, x = prediction_embed_weight_to_fp16)[name = tensor("input_1_cast_fp16_cast_uint16")]; tensor input_3_axes_0 = const()[name = tensor("input_3_axes_0"), val = tensor([1])]; tensor input_3_cast_fp16 = expand_dims(axes = input_3_axes_0, x = input_1_cast_fp16_cast_uint16)[name = tensor("input_3_cast_fp16")]; tensor hx_1_axes_0 = const()[name = tensor("hx_1_axes_0"), val = tensor([1])]; tensor hx_1_cast_fp16 = expand_dims(axes = hx_1_axes_0, x = state_1)[name = tensor("hx_1_cast_fp16")]; tensor hx_axes_0 = const()[name = tensor("hx_axes_0"), val = tensor([1])]; tensor hx_cast_fp16 = expand_dims(axes = hx_axes_0, x = state_2)[name = tensor("hx_cast_fp16")]; tensor split_0_num_splits_0 = const()[name = tensor("split_0_num_splits_0"), val = tensor(2)]; tensor split_0_axis_0 = const()[name = tensor("split_0_axis_0"), val = tensor(0)]; tensor split_0_cast_fp16_0, tensor split_0_cast_fp16_1 = split(axis = split_0_axis_0, num_splits = split_0_num_splits_0, x = hx_1_cast_fp16)[name = tensor("split_0_cast_fp16")]; tensor split_1_num_splits_0 = const()[name = tensor("split_1_num_splits_0"), val = tensor(2)]; tensor split_1_axis_0 = const()[name = tensor("split_1_axis_0"), val = tensor(0)]; tensor split_1_cast_fp16_0, tensor split_1_cast_fp16_1 = split(axis = split_1_axis_0, num_splits = split_1_num_splits_0, x = hx_cast_fp16)[name = tensor("split_1_cast_fp16")]; tensor output_lstm_layer_0_lstm_h0_squeeze_axes_0 = const()[name = tensor("output_lstm_layer_0_lstm_h0_squeeze_axes_0"), val = tensor([0])]; tensor output_lstm_layer_0_lstm_h0_squeeze_cast_fp16 = squeeze(axes = output_lstm_layer_0_lstm_h0_squeeze_axes_0, x = split_0_cast_fp16_0)[name = tensor("output_lstm_layer_0_lstm_h0_squeeze_cast_fp16")]; tensor output_lstm_layer_0_lstm_c0_squeeze_axes_0 = const()[name = tensor("output_lstm_layer_0_lstm_c0_squeeze_axes_0"), val = tensor([0])]; tensor output_lstm_layer_0_lstm_c0_squeeze_cast_fp16 = squeeze(axes = output_lstm_layer_0_lstm_c0_squeeze_axes_0, x = split_1_cast_fp16_0)[name = tensor("output_lstm_layer_0_lstm_c0_squeeze_cast_fp16")]; tensor output_lstm_layer_0_direction_0 = const()[name = tensor("output_lstm_layer_0_direction_0"), val = tensor("forward")]; tensor output_lstm_layer_0_output_sequence_0 = const()[name = tensor("output_lstm_layer_0_output_sequence_0"), val = tensor(true)]; tensor output_lstm_layer_0_recurrent_activation_0 = const()[name = tensor("output_lstm_layer_0_recurrent_activation_0"), val = tensor("sigmoid")]; tensor output_lstm_layer_0_cell_activation_0 = const()[name = tensor("output_lstm_layer_0_cell_activation_0"), val = tensor("tanh")]; tensor output_lstm_layer_0_activation_0 = const()[name = tensor("output_lstm_layer_0_activation_0"), val = tensor("tanh")]; tensor concat_1_to_fp16 = const()[name = tensor("concat_1_to_fp16"), val = tensor(BLOBFILE(path = tensor("@model_path/weights/weight.bin"), offset = tensor(1312128)))]; tensor concat_2_to_fp16 = const()[name = tensor("concat_2_to_fp16"), val = tensor(BLOBFILE(path = tensor("@model_path/weights/weight.bin"), offset = tensor(4588992)))]; tensor concat_0_to_fp16 = const()[name = tensor("concat_0_to_fp16"), val = tensor(BLOBFILE(path = tensor("@model_path/weights/weight.bin"), offset = tensor(7865856)))]; tensor output_lstm_layer_0_cast_fp16_0, tensor output_lstm_layer_0_cast_fp16_1, tensor output_lstm_layer_0_cast_fp16_2 = lstm(activation = output_lstm_layer_0_activation_0, bias = concat_0_to_fp16, cell_activation = output_lstm_layer_0_cell_activation_0, direction = output_lstm_layer_0_direction_0, initial_c = output_lstm_layer_0_lstm_c0_squeeze_cast_fp16, initial_h = output_lstm_layer_0_lstm_h0_squeeze_cast_fp16, output_sequence = output_lstm_layer_0_output_sequence_0, recurrent_activation = output_lstm_layer_0_recurrent_activation_0, weight_hh = concat_2_to_fp16, weight_ih = concat_1_to_fp16, x = input_3_cast_fp16)[name = tensor("output_lstm_layer_0_cast_fp16")]; tensor output_lstm_h0_squeeze_axes_0 = const()[name = tensor("output_lstm_h0_squeeze_axes_0"), val = tensor([0])]; tensor output_lstm_h0_squeeze_cast_fp16 = squeeze(axes = output_lstm_h0_squeeze_axes_0, x = split_0_cast_fp16_1)[name = tensor("output_lstm_h0_squeeze_cast_fp16")]; tensor output_lstm_c0_squeeze_axes_0 = const()[name = tensor("output_lstm_c0_squeeze_axes_0"), val = tensor([0])]; tensor output_lstm_c0_squeeze_cast_fp16 = squeeze(axes = output_lstm_c0_squeeze_axes_0, x = split_1_cast_fp16_1)[name = tensor("output_lstm_c0_squeeze_cast_fp16")]; tensor output_direction_0 = const()[name = tensor("output_direction_0"), val = tensor("forward")]; tensor output_output_sequence_0 = const()[name = tensor("output_output_sequence_0"), val = tensor(true)]; tensor output_recurrent_activation_0 = const()[name = tensor("output_recurrent_activation_0"), val = tensor("sigmoid")]; tensor output_cell_activation_0 = const()[name = tensor("output_cell_activation_0"), val = tensor("tanh")]; tensor output_activation_0 = const()[name = tensor("output_activation_0"), val = tensor("tanh")]; tensor concat_4_to_fp16 = const()[name = tensor("concat_4_to_fp16"), val = tensor(BLOBFILE(path = tensor("@model_path/weights/weight.bin"), offset = tensor(7871040)))]; tensor concat_5_to_fp16 = const()[name = tensor("concat_5_to_fp16"), val = tensor(BLOBFILE(path = tensor("@model_path/weights/weight.bin"), offset = tensor(11147904)))]; tensor concat_3_to_fp16 = const()[name = tensor("concat_3_to_fp16"), val = tensor(BLOBFILE(path = tensor("@model_path/weights/weight.bin"), offset = tensor(14424768)))]; tensor output_cast_fp16_0, tensor output_cast_fp16_1, tensor output_cast_fp16_2 = lstm(activation = output_activation_0, bias = concat_3_to_fp16, cell_activation = output_cell_activation_0, direction = output_direction_0, initial_c = output_lstm_c0_squeeze_cast_fp16, initial_h = output_lstm_h0_squeeze_cast_fp16, output_sequence = output_output_sequence_0, recurrent_activation = output_recurrent_activation_0, weight_hh = concat_5_to_fp16, weight_ih = concat_4_to_fp16, x = output_lstm_layer_0_cast_fp16_0)[name = tensor("output_cast_fp16")]; tensor var_32_axis_0 = const()[name = tensor("op_32_axis_0"), val = tensor(0)]; tensor var_32_cast_fp16 = stack(axis = var_32_axis_0, values = (output_lstm_layer_0_cast_fp16_1, output_cast_fp16_1))[name = tensor("op_32_cast_fp16")]; tensor var_33_axis_0 = const()[name = tensor("op_33_axis_0"), val = tensor(0)]; tensor var_33_cast_fp16 = stack(axis = var_33_axis_0, values = (output_lstm_layer_0_cast_fp16_2, output_cast_fp16_2))[name = tensor("op_33_cast_fp16")]; tensor input_axes_0 = const()[name = tensor("input_axes_0"), val = tensor([1])]; tensor input_cast_fp16 = squeeze(axes = input_axes_0, x = output_cast_fp16_0)[name = tensor("input_cast_fp16")]; tensor var_35_axes_0 = const()[name = tensor("op_35_axes_0"), val = tensor([1])]; tensor new_state_1 = squeeze(axes = var_35_axes_0, x = var_32_cast_fp16)[name = tensor("op_35_cast_fp16")]; tensor var_36_axes_0 = const()[name = tensor("op_36_axes_0"), val = tensor([1])]; tensor new_state_2 = squeeze(axes = var_36_axes_0, x = var_33_cast_fp16)[name = tensor("op_36_cast_fp16")]; tensor joint_projection_weight_to_fp16 = const()[name = tensor("joint_projection_weight_to_fp16"), val = tensor(BLOBFILE(path = tensor("@model_path/weights/weight.bin"), offset = tensor(14429952)))]; tensor joint_projection_bias_to_fp16 = const()[name = tensor("joint_projection_bias_to_fp16"), val = tensor(BLOBFILE(path = tensor("@model_path/weights/weight.bin"), offset = tensor(15249216)))]; tensor decoder_output_projected = linear(bias = joint_projection_bias_to_fp16, weight = joint_projection_weight_to_fp16, x = input_cast_fp16)[name = tensor("linear_0_cast_fp16")]; } -> (decoder_output_projected, new_state_1, new_state_2); }