Scala の練習

Base64 encoder, decoder を wikiアルゴリズム解説をもとに書いてみる。Decoder から始めた。いろんな試行錯誤のうちに collection 周りの復習ができたし、できない事も分かってきた。REPL で試しながら書けるのはやはり便利。

object Base64Decoder { 
  def main(args: Array[String]) {
    import scala.io.Source
    val source = Source.fromFile(args(0))
    for (line <- source.getLines)
      System.out.write(Base64Decoder.decode(line))
  }
  def decode(str: String) = {
    val s1 = str.replace("\r\n","").replace("\n","").replace("\r","") 
    val c4seq = for (i <- 0 until(s1.length, 4)) yield s1.substring(i, i+4)
    val t4vec = c4seq.map(for (c <- _) yield if (c=='=') 0 else c2vmap(c) )
    val i3sec = for (v <- t4vec) yield v.reduceLeft((a,b) => a << 6 | b)
    val a = i3sec.flatMap(x=>List(x>>16,(x>>8)&0xff,x&0xff)).toList.toArray
    a.map(_.asInstanceOf[Byte]).filter(_ != 0)                    
  }
  def ctable = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"
  def c2vmap = Map() ++ (for (i <- 0 until ctable.length) yield ctable(i) -> i)
}

Encoder はなかなか上手く書けない。試行錯誤の後で、全体を見直せてないので、まだうまくまとめる余地はあると思うけど...

object Base64Encoder { 
  def main(args: Array[String]) {
    import java.io.{File,FileInputStream}
    val f = new File(args(0))
    val length = f.length.asInstanceOf[Int]
    println("file len " + length)
    val ba = new Array[Byte](length)
    val fis = new FileInputStream(f)
    val r = fis.read(ba)
    println("read return " + r)
    fis.close()
    val s = Base64Encoder.encode(ba)
    println(s)
  }
  def ctable = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"
  def encode(indata: Array[Byte]) = {
    import scala.collection.mutable.ListBuffer
    // convert byte to nibbles 00000000 -> 0000 0000
    val b4at = for (v <- indata) yield List(v>>4, v&0xf)
    val b4ltmp = b4at.reduceLeft((a,b) => a:::b)

    val n6bm = roundup(b4ltmp.length, 6)
    val b4l = b4ltmp

    val b4a = b4l.toArray
    val lb = ListBuffer() ++ b4a
    for (i <- 0 until n6bm - b4ltmp.length) lb += 0
    
    val bl = lb.toList
    val ivv = for (i <- 0 until(b4ltmp.length,3)) 
      yield List(0x3f& (bl(i)<<2 | bl(i+1)>>2), 0x3f&(bl(i+1)<<4 | bl(i+2)))
    val ia = ivv.reduceLeft((a,b) => a:::b)
    val ca = ia.map(ctable(_))
    val neqpad = roundup(ca.length, 4) - ca.length
    val lb2 = ListBuffer() ++ ca
    for (i <- 0 until neqpad) lb2 += '='
    val s1 = new String(lb2.toList.toArray)
    val sa = for (i <- 0 to s1.length / 64) 
      yield s1.substring(i*64, Math.min(s1.length, roundup(i*64+1, 64)))
    sa.mkString("\n")
  }

  def roundup(v: Int, base: Int) = ((v + base - 1)/base)*base
}