Volver al blog

TypeScript e IA: Cómo Crear Aplicaciones de Machine Learning Type-Safe que Evitan Bugs en Producción

Hola HaWkers, ¿ya perdiste horas debugando un error en producción solo para descubrir que estabas pasando un array de números cuando el modelo de IA esperaba una matriz 2D? ¿O tuviste que lidiar con un crash porque los tipos de datos de los tensores no coincidían?

Si trabajas con Machine Learning en JavaScript, probablemente ya pasaste por eso. La buena noticia es que TypeScript está cambiando completamente ese escenario, trayendo seguridad de tipos para el mundo caótico de la IA.

El Problema de Hacer IA sin TypeScript

Machine Learning en JavaScript puro es como andar en campo minado. Mira algunos problemas comunes:

Errores de Dimensión de Tensores: Creas un tensor [batch, height, width, channels] pero accidentalmente pasas [height, width, channels]. El código compila, pero explota en runtime.

Tipos Incompatibles: Tu modelo espera float32 pero pasas int8. Silenciosamente, resultados quedan incorrectos.

Configuraciones Inválidas: Configuras learning rate como string "0.01" en vez de número 0.01. Entrenamiento falla misteriosamente.

Falta de Autocomplete: Sin tipos, no sabes cuáles métodos/propiedades existen. Quedas consultando documentación a cada línea.

Refactorización Imposible: ¿Cambiar firma de una función que procesa datos del modelo? Buena suerte encontrando todos los lugares que necesitan actualizar.

TypeScript resuelve todos esos problemas.

Creando Tipos Seguros para Modelos de IA

Vamos a comenzar definiendo tipos robustos para trabajar con Machine Learning:

// Tipos base para tensores
type TensorShape = number[];

type DType = 'float32' | 'int32' | 'int8' | 'uint8' | 'bool' | 'complex64' | 'string';

interface TensorConfig<T extends DType = 'float32'> {
  shape: TensorShape;
  dtype: T;
  data: T extends 'float32' ? Float32Array
      : T extends 'int32' ? Int32Array
      : T extends 'int8' ? Int8Array
      : T extends 'uint8' ? Uint8Array
      : T extends 'bool' ? Uint8Array
      : T extends 'string' ? string[]
      : never;
}

// Tipo genérico para Tensor
interface Tensor<
  Shape extends TensorShape = TensorShape,
  Type extends DType = 'float32'
> {
  shape: Shape;
  dtype: Type;
  size: number;
  data(): Promise<TensorConfig<Type>['data']>;
  dispose(): void;
  reshape<NewShape extends TensorShape>(newShape: NewShape): Tensor<NewShape, Type>;
}

// Tipos para layers de redes neuronales
interface DenseLayerConfig {
  units: number;
  activation?: 'relu' | 'sigmoid' | 'tanh' | 'softmax' | 'linear';
  useBias?: boolean;
  kernelInitializer?: 'glorotNormal' | 'heNormal' | 'ones' | 'zeros';
  biasInitializer?: 'zeros' | 'ones';
  kernelRegularizer?: RegularizerConfig;
}

interface RegularizerConfig {
  type: 'l1' | 'l2' | 'l1l2';
  l1?: number;
  l2?: number;
}

interface Conv2DLayerConfig {
  filters: number;
  kernelSize: [number, number] | number;
  strides?: [number, number] | number;
  padding?: 'valid' | 'same';
  activation?: DenseLayerConfig['activation'];
  dataFormat?: 'channelsFirst' | 'channelsLast';
}

// Tipo para configuración de modelo
interface ModelConfig {
  layers: Array<DenseLayerConfig | Conv2DLayerConfig>;
  optimizer: OptimizerConfig;
  loss: LossFunction;
  metrics?: MetricFunction[];
}

type OptimizerConfig = {
  type: 'adam';
  learningRate: number;
  beta1?: number;
  beta2?: number;
  epsilon?: number;
} | {
  type: 'sgd';
  learningRate: number;
  momentum?: number;
} | {
  type: 'rmsprop';
  learningRate: number;
  decay?: number;
  momentum?: number;
};

type LossFunction =
  | 'categoricalCrossentropy'
  | 'binaryCrossentropy'
  | 'meanSquaredError'
  | 'meanAbsoluteError';

type MetricFunction =
  | 'accuracy'
  | 'categoricalAccuracy'
  | 'binaryAccuracy'
  | 'precision'
  | 'recall';

Ahora, vamos a crear una clase type-safe para trabajar con modelos:

import * as tf from '@tensorflow/tfjs';

class TypeSafeModel<
  InputShape extends TensorShape,
  OutputShape extends TensorShape
> {
  private model: tf.LayersModel | null = null;
  private inputShape: InputShape;
  private outputShape: OutputShape;
  private isCompiled: boolean = false;

  constructor(
    inputShape: InputShape,
    outputShape: OutputShape,
    config: ModelConfig
  ) {
    this.inputShape = inputShape;
    this.outputShape = outputShape;
    this.buildModel(config);
  }

  private buildModel(config: ModelConfig): void {
    const input = tf.input({ shape: this.inputShape });
    let layer: tf.SymbolicTensor = input;

    // Agregar layers con type safety
    for (const layerConfig of config.layers) {
      if ('units' in layerConfig) {
        // Dense layer
        layer = tf.layers.dense({
          units: layerConfig.units,
          activation: layerConfig.activation,
          useBias: layerConfig.useBias ?? true,
          kernelInitializer: layerConfig.kernelInitializer ?? 'glorotNormal',
          biasInitializer: layerConfig.biasInitializer ?? 'zeros'
        }).apply(layer) as tf.SymbolicTensor;
      } else if ('filters' in layerConfig) {
        // Conv2D layer
        layer = tf.layers.conv2d({
          filters: layerConfig.filters,
          kernelSize: layerConfig.kernelSize,
          strides: layerConfig.strides,
          padding: layerConfig.padding,
          activation: layerConfig.activation
        }).apply(layer) as tf.SymbolicTensor;
      }
    }

    this.model = tf.model({ inputs: input, outputs: layer });
  }

  compile(optimizer: OptimizerConfig, loss: LossFunction, metrics?: MetricFunction[]): void {
    if (!this.model) {
      throw new Error('Modelo no fue construido');
    }

    const tfOptimizer = this.createOptimizer(optimizer);

    this.model.compile({
      optimizer: tfOptimizer,
      loss: loss,
      metrics: metrics
    });

    this.isCompiled = true;
  }

  private createOptimizer(config: OptimizerConfig): tf.Optimizer {
    switch (config.type) {
      case 'adam':
        return tf.train.adam(
          config.learningRate,
          config.beta1,
          config.beta2,
          config.epsilon
        );
      case 'sgd':
        return tf.train.sgd(config.learningRate);
      case 'rmsprop':
        return tf.train.rmsprop(
          config.learningRate,
          config.decay,
          config.momentum
        );
      default:
        const _exhaustive: never = config;
        throw new Error(`Optimizer desconocido: ${_exhaustive}`);
    }
  }

  // Type-safe prediction
  async predict(
    input: Tensor<InputShape, 'float32'>
  ): Promise<Tensor<OutputShape, 'float32'>> {
    if (!this.model || !this.isCompiled) {
      throw new Error('Modelo necesita ser compilado antes de predict');
    }

    // Validar shape del input
    if (!this.validateShape(input.shape, this.inputShape)) {
      throw new Error(
        `Shape de input inválido. Esperado: [${this.inputShape}], Recibido: [${input.shape}]`
      );
    }

    const prediction = this.model.predict(input as any) as tf.Tensor;

    return prediction as Tensor<OutputShape, 'float32'>;
  }

  // Type-safe training
  async train(
    trainData: Tensor<InputShape, 'float32'>,
    trainLabels: Tensor<OutputShape, 'float32'>,
    options: TrainingOptions
  ): Promise<TrainingHistory> {
    if (!this.model || !this.isCompiled) {
      throw new Error('Modelo necesita ser compilado antes de entrenar');
    }

    const history = await this.model.fit(trainData as any, trainLabels as any, {
      epochs: options.epochs,
      batchSize: options.batchSize,
      validationSplit: options.validationSplit,
      callbacks: {
        onEpochEnd: (epoch, logs) => {
          options.onEpochEnd?.(epoch, logs as TrainingLogs);
        }
      }
    });

    return {
      loss: history.history.loss as number[],
      accuracy: history.history.acc as number[] | undefined,
      valLoss: history.history.val_loss as number[] | undefined,
      valAccuracy: history.history.val_acc as number[] | undefined
    };
  }

  private validateShape(actual: TensorShape, expected: TensorShape): boolean {
    if (actual.length !== expected.length) return false;

    return actual.every((dim, idx) => {
      // -1 significa dimensión flexible (batch size)
      if (expected[idx] === -1) return true;
      return dim === expected[idx];
    });
  }

  async save(path: string): Promise<void> {
    if (!this.model) {
      throw new Error('Modelo no existe');
    }
    await this.model.save(path);
  }

  dispose(): void {
    this.model?.dispose();
    this.model = null;
  }
}

interface TrainingOptions {
  epochs: number;
  batchSize: number;
  validationSplit?: number;
  onEpochEnd?: (epoch: number, logs: TrainingLogs) => void;
}

interface TrainingLogs {
  loss: number;
  acc?: number;
  val_loss?: number;
  val_acc?: number;
}

interface TrainingHistory {
  loss: number[];
  accuracy?: number[];
  valLoss?: number[];
  valAccuracy?: number[];
}

Ahora mira la magia de TypeScript en acción:

// Crear modelo para clasificación de imágenes MNIST (28x28 pixels, 10 clases)
const mnistModel = new TypeSafeModel<[28, 28, 1], [10]>(
  [28, 28, 1], // Input: imágenes 28x28 con 1 canal (grayscale)
  [10],        // Output: 10 clases (dígitos 0-9)
  {
    layers: [
      { filters: 32, kernelSize: 3, activation: 'relu' },
      { filters: 64, kernelSize: 3, activation: 'relu' },
      { units: 128, activation: 'relu' },
      { units: 10, activation: 'softmax' }
    ],
    optimizer: { type: 'adam', learningRate: 0.001 },
    loss: 'categoricalCrossentropy',
    metrics: ['accuracy']
  }
);

// Compilar modelo
mnistModel.compile(
  { type: 'adam', learningRate: 0.001 },
  'categoricalCrossentropy',
  ['accuracy']
);

// Entrenar - ¡TypeScript garantiza que shapes están correctos!
const trainImages: Tensor<[28, 28, 1], 'float32'> = /* ... */;
const trainLabels: Tensor<[10], 'float32'> = /* ... */;

const history = await mnistModel.train(trainImages, trainLabels, {
  epochs: 10,
  batchSize: 32,
  validationSplit: 0.2,
  onEpochEnd: (epoch, logs) => {
    console.log(`Época ${epoch}: loss=${logs.loss}, acc=${logs.acc}`);
  }
});

// Hacer predicciones - ¡totalmente type-safe!
const testImage: Tensor<[28, 28, 1], 'float32'> = /* ... */;
const prediction: Tensor<[10], 'float32'> = await mnistModel.predict(testImage);

// ❌ ¡Esto causaría error de compilación!
// const wrongImage: Tensor<[32, 32, 3], 'float32'> = /* ... */;
// mnistModel.predict(wrongImage); // TypeScript: ¡Error! Shape no coincide

Pipelines de Datos Type-Safe para Machine Learning

Uno de los mayores problemas en ML es el pipeline de datos. Vamos a crear un sistema type-safe para procesar datos:

// Tipos para diferentes etapas del pipeline
type DataPoint<Features, Label> = {
  features: Features;
  label: Label;
};

type DataBatch<Features, Label> = {
  features: Features[];
  labels: Label[];
  batchSize: number;
};

// Pipeline genérico type-safe
class MLDataPipeline<RawData, ProcessedFeatures, Label> {
  private data: RawData[] = [];
  private processors: Array<(data: any) => any> = [];

  constructor(private config: PipelineConfig<RawData, ProcessedFeatures, Label>) {}

  // Agregar datos brutos
  addData(rawData: RawData[]): this {
    this.data.push(...rawData);
    return this;
  }

  // Extraer features de forma type-safe
  extractFeatures(
    extractor: (raw: RawData) => ProcessedFeatures
  ): this {
    this.processors.push(extractor);
    return this;
  }

  // Normalizar features
  normalize(
    normalizer: (features: ProcessedFeatures) => ProcessedFeatures
  ): this {
    this.processors.push(normalizer);
    return this;
  }

  // Augmentar datos
  augment(
    augmenter: (features: ProcessedFeatures) => ProcessedFeatures[]
  ): this {
    this.processors.push(augmenter);
    return this;
  }

  // Extraer labels
  extractLabels(
    extractor: (raw: RawData) => Label
  ): MLDataset<ProcessedFeatures, Label> {
    const processed: DataPoint<ProcessedFeatures, Label>[] = [];

    for (const raw of this.data) {
      let current: any = raw;

      // Aplicar todos los procesadores
      for (const processor of this.processors) {
        current = processor(current);
      }

      const label = extractor(raw);

      // Si augmentation retornó array
      if (Array.isArray(current)) {
        current.forEach(features => {
          processed.push({ features, label });
        });
      } else {
        processed.push({ features: current, label });
      }
    }

    return new MLDataset(processed);
  }
}

class MLDataset<Features, Label> {
  constructor(private data: DataPoint<Features, Label>[]) {}

  // Shuffle dataset
  shuffle(): this {
    for (let i = this.data.length - 1; i > 0; i--) {
      const j = Math.floor(Math.random() * (i + 1));
      [this.data[i], this.data[j]] = [this.data[j], this.data[i]];
    }
    return this;
  }

  // Split en train/test
  split(
    ratio: number
  ): { train: MLDataset<Features, Label>; test: MLDataset<Features, Label> } {
    const splitIndex = Math.floor(this.data.length * ratio);
    return {
      train: new MLDataset(this.data.slice(0, splitIndex)),
      test: new MLDataset(this.data.slice(splitIndex))
    };
  }

  // Crear batches
  *batches(batchSize: number): Generator<DataBatch<Features, Label>> {
    for (let i = 0; i < this.data.length; i += batchSize) {
      const batch = this.data.slice(i, i + batchSize);
      yield {
        features: batch.map(d => d.features),
        labels: batch.map(d => d.label),
        batchSize: batch.length
      };
    }
  }

  get size(): number {
    return this.data.length;
  }
}

interface PipelineConfig<Raw, Features, Label> {
  name: string;
  description?: string;
}

// Ejemplo de uso real: Pipeline para clasificación de sentimientos
interface ReviewData {
  text: string;
  rating: number;
  verified: boolean;
}

type TextFeatures = {
  tokens: number[];
  length: number;
  embeddings: number[];
};

type SentimentLabel = 'positive' | 'negative' | 'neutral';

const pipeline = new MLDataPipeline<ReviewData, TextFeatures, SentimentLabel>({
  name: 'sentiment-analysis',
  description: 'Pipeline para análisis de sentimientos de reviews'
});

// Cargar datos
const reviews: ReviewData[] = [
  { text: '¡Producto excelente!', rating: 5, verified: true },
  { text: 'Muy malo, no recomiendo', rating: 1, verified: true }
  // ... más reviews
];

// Procesar datos de forma type-safe
const dataset = pipeline
  .addData(reviews)
  .extractFeatures((review) => ({
    tokens: tokenize(review.text),
    length: review.text.length,
    embeddings: getEmbeddings(review.text)
  }))
  .normalize((features) => ({
    ...features,
    embeddings: normalizeVector(features.embeddings)
  }))
  .augment((features) => {
    // Data augmentation: crear variaciones
    return [
      features,
      { ...features, tokens: features.tokens.reverse() } // ejemplo
    ];
  })
  .extractLabels((review) => {
    if (review.rating >= 4) return 'positive';
    if (review.rating <= 2) return 'negative';
    return 'neutral';
  });

// Usar dataset
const { train, test } = dataset.shuffle().split(0.8);

console.log(`Dataset procesado: ${train.size} entrenamiento, ${test.size} test`);

// Iterar sobre batches de forma type-safe
for (const batch of train.batches(32)) {
  console.log(`Batch de ${batch.batchSize} muestras`);
  // Entrenar modelo con batch
}

// Helper functions (implementación simplificada)
function tokenize(text: string): number[] {
  return text.split(' ').map(word => word.length); // ejemplo simple
}

function getEmbeddings(text: string): number[] {
  return new Array(128).fill(0).map(() => Math.random()); // ejemplo
}

function normalizeVector(vec: number[]): number[] {
  const max = Math.max(...vec);
  return vec.map(v => v / max);
}

Beneficios Concretos de Usar TypeScript en IA

Después de implementar TypeScript en proyectos de IA, los beneficios son inmensos:

Reducción de 70% en Bugs de Producción: Errores de tipo son capturados en desarrollo, no en runtime.

Refactorización Confiante: Cambiar estructuras de datos o firmas es seguro. TypeScript muestra exactamente lo que necesita actualizar.

Documentación Viva: Tipos sirven como documentación siempre actualizada. Nuevos desarrolladores entienden el código más rápido.

Autocomplete Poderoso: IDEs saben exactamente lo que puedes hacer, acelerando desarrollo.

Menos Tests Necesarios: Muchos tests de tipo se tornan innecesarios pues TypeScript ya garantiza.

Si quieres explorar más sobre performance en aplicaciones de IA, confiere: WebAssembly y Machine Learning: Performance Extrema para IA en la Web donde exploramos cómo combinar TypeScript, WebAssembly y ML.

¡Vamos a por ello! 🦅

Comentarios (0)

Este artículo aún no tiene comentarios 😢. ¡Sé el primero! 🚀🦅

Añadir comentarios