2011年9月13日 星期二

C++ Iterator 與 Lambda Function

今天寫程式碼的時候遇到一個問題:要怎麼讓一個存有指標的陣列依照指標指向的內容做排序呢?例如該怎麼作才可以讓下面的程式碼依序印出 1~7 呢?

int main() {
  int a[] = { 5, 4, 1, 2, 7, 3, 6 };

  vector<int *> b;
  b.push_back(&a[6]);
  b.push_back(&a[5]);
  b.push_back(&a[2]);
  b.push_back(&a[0]);
  b.push_back(&a[4]);
  b.push_back(&a[1]);
  b.push_back(&a[3]);

  for (size_t i = 0; i < b.size(); ++i) {
    cout << *b[i] << endl;
  }

  return EXIT_SUCCESS;
}

我們一開始可能會想要用 <algorithm> 裡面的 sort,不過很可惜,下面的程式碼一定是錯的:

#include <algorithm>

int main() {
  // ... 略 ...

  sort(b.begin(), b.end());

  for (size_t i = 0; i < b.size(); ++i) {
    cout << *b[i] << endl;
  }

  return EXIT_SUCCESS;
}

這是因為在這種情況下,sort 函式比較的是位址的大小,而不是每一個 Pointer 指向的 int 的值。而且令人遺憾的是我們不能重載 Pointer Type。當然最簡單的解決方法是:
struct int_ptr_less_than_comparator {
public:
  bool operator()(int const *pa, int const *pb) const {
    return *pa < *pb;
  }
};

int main() {
  // ... 略 ... 

  sort(b.begin(), b.end(), int_ptr_less_than_comparator());

  for (size_t i = 0; i < b.size(); ++i) {
    cout << *b[i] << endl;
  }
 
  return EXIT_SUCCESS;
}

我們直接寫一個比較函式 (函式物件),然後把它作為 sort 函式的第三個參數。這樣一來結果是對了,不過感覺可以寫得更為一般化。首先我們可以觀察到 int_ptr_less_than_comparator 是由二個部分組成:(1) 先對 pa 與 pb 取值 (dereference) (2) 比較二者的大小。其中,第二部分和 std::less 的功能是一樣的。而第一部分可以寫一個叫作 dereference_to 的函式物件來解決:

#include <functional>

template <typename T>
struct dereference_to : public std::unary_function<T *, T &> {
public:
  T &operator()(T *ptr) const {
    return *ptr;
  }
};


這個 dereference_to 看起來很複雜,不過不要被他嚇到了!它只是一個帶有 operator() 的 class template。舉例來說,如果我們把 int 代入 T:

struct dereference_to<int>
: public std::unary_function<int *, int &> {
public:
  int &operator()(int *ptr) const {
    return ptr;
  }
}

這樣如果我們忽略 std::unary_function,剩下的應該都不難理解,簡單的說,就是我們可以宣告一個 dereference_to<int> 物件,然後透過其 operator() 運算子幫我們 dereference operator() 的第一個參數。例如:

int main() {
  int c = 5;
  dereference_to<int> dref;
  cout << dref(&c) << endl; // 印出 5
}


所以我們現在有二種函式物件:std::less 與 dereference_to,可是我們還需要一個方法幫我們把 dereference_to 與 std::less 組裝成 int_ptr_less_than_comparator。如果你的編譯器附有非標準的 compose2,你就可以用以下程式碼「組合」出 int_ptr_less_than_comparator:

compose2(less<int>(), dereference_to<int>(), dereference_to<int>())

可是 compose2 畢竟不是標準的一部分,所以我們要自己寫一份:

template <typename F, typename G, typename H>
struct binary_composer
: public binary_function<typename G::argument_type,
                         typename H::argument_type,
                         typename F::result_type>
{
private:
  F f;
  G g;
  H h;

public:
  // 這些 typedef 是要讓 Type Traits 機制可以正常運作。一開始看不懂沒有關係。
  // 只要依樣畫葫蘆即可!
  typedef typename G::argument_type first_argument_type;
  typedef typename H::argument_type second_argument_type;
  typedef typename F::result_type result_type;

public:
  binary_composer(F const &f_, G const &g_, H const &h_)
    : f(f_), g(g_), h(h_) { }


  // 理論上這部分要 overload 四個版本 (不同的 const) 不過為了簡單起見我只寫了一個。
  result_type operator()(first_argument_type const &arg1,
                         second_argument_type const &arg2) const {
    return f(g(arg1), h(arg2));
  }
};

// 因為 class 的 constructor 是要在 class template 具現化之後才能被呼叫,所以我們
// 必須要寫一大串才可以產生 binary_composer 的 object。這很不方便,所以我們可以
// 再寫一個 function template,讓 C++ 的型別推導機制自動推導出來。
template <typename F, typename G, typename H>
binary_composer<F, G, H> compose2(F const &f, G const &g, H const &h) {
  return binary_composer<F, G, H>(f, g, h);
}
 
這樣萬事俱備,只要把它們組合起來就完成了!

int main() {
  // ... 略 ...

  sort(b.begin(), b.end(),
       compose2(less<int>(),
                dereference_to<int>(),
                dereference_to<int>()));

  for (size_t i = 0; i < b.size(); ++i) {
    cout << *b[i] << endl;
  }

  return EXIT_SUCCESS;
}

接下來我們把焦點放到用來印出數字的 for 迴圈,有沒有辦法寫得更為一般化呢?一般我們要印數字的時候,我們可以用 copy 加上 ostream_iterator 達成目標。不過這次 dereference_to 又為我們帶來一些麻煩!這次我們要寫 Output Iterator 的 Decorator:

template <typename I, typename D>
struct output_iterator_decorator
: public iterator<output_iterator_tag, typename D::argument_type> {
private:
  typedef output_iterator_decorator<I, D> THISCLASS;

private:
  I iterator;
  D decorator;

public:
  output_iterator_decorator(I const &iterator_, D const &decorator_)
  : iterator(iterator_), decorator(decorator_) {
  }

  THISCLASS &operator=(typename D::argument_type const &arg) {
    // 先呼叫 decorator 再把回傳值 assign 給 iterator
    *iterator = decorator(arg);
    return *this;
  }

  // dereference 運算子
  THISCLASS &operator*() {
    return *this;
  }

  // 前置遞增運算子
  THISCLASS &operator++() {
    ++iterator;
    return *this;
  }

  // 後置遞增運算子
  THISCLASS &operator++(int) {
    iterator++;
    return *this;
  }
};

// 和前面一樣,利用 C++ 的型別推導具現化 output_iterator_decorator
template <typename I, typename D>
output_iterator_decorator<I, D> decorate(I iter, D decorator) {
  return output_iterator_decorator<I, D>(iter, decorator);
}

有了以上的準備我們就可以把 for 迴圈變成:

  copy(b.begin(), b.end(),
       decorate(ostream_iterator<int>(cout, "\n"),
                dereference_to<int>()));

到此我們再回頭看一下完整的 main 函式:

int main() {
  int a[] = { 5, 4, 1, 2, 7, 3, 6 };

  vector<int *> b;
  b.push_back(&a[6]);
  b.push_back(&a[5]);
  b.push_back(&a[2]);
  b.push_back(&a[0]);
  b.push_back(&a[4]);
  b.push_back(&a[1]);
  b.push_back(&a[3]);

  sort(b.begin(), b.end(),
       compose2(less<int>(), dereference_to<int>(),
                             dereference_to<int>()));

  copy(b.begin(), b.end(),
       decorate(ostream_iterator<int>(cout, "\n"),
                dereference_to<int>()));

  return EXIT_SUCCESS;
}

看起來還是很不直觀,而且為了省幾行程式碼,結果寫了更多難懂的程式碼,好像有一點本末倒置。所以我們再更進一步!如果你可以用 Boost.Lambda 函式庫,你可以寫以下的程式碼:

#include <iostream>

#include <algorithm>
#include <vector>

#include <boost/lambda/lambda.hpp>

using namespace boost::lambda;
using namespace std;

int main() {
  int a[] = { 5, 4, 1, 2, 7, 3, 6 };

  vector<int *> b;
  b.push_back(&a[6]);
  b.push_back(&a[5]);
  b.push_back(&a[2]);
  b.push_back(&a[0]);
  b.push_back(&a[4]);
  b.push_back(&a[1]);
  b.push_back(&a[3]);

  // 排序 (以 int * 指向的值)
  sort(b.begin(), b.end(), ((*_1) < (*_2)));

  // 一個一個印出來
  for_each(b.begin(), b.end(), cout << *_1 << "\n");

  return EXIT_SUCCESS;
}

對!你沒有看錯!這樣就沒有了!這裡的 _1 與 _2 是神秘的 Function Object,Boost.Lambda 函式庫會自動幫你組合成一個合適的 Function Object。(*_1) 就像是我們的 dereference_to,而 ((*_1) < (*_2)) 就像是我們的 int_ptr_less_than_comparator。至於印出數字的部分,我另外做了一些改動。所以不能使用 copy,而是要改成 for_each。

Boost.Lambda 函式庫雖然看起來很漂亮,不過它本身是用很複雜的技巧寫出來的。正常運作的時候很漂亮,不過寫出 bug 的時候,編譯器會噴很多錯誤訊息,一般人根本不可能看得懂。某種意義上來說這已經是函式庫的極限了!

最新的 C++11 加入了 Lambda function (就是下面粗體字的部分),我們再也不需要 Boost.Lambda 也可以達到類似的效果:

#include <iostream>

#include <algorithm>
#include <vector>

using namespace std;

int main() {
  int a[] = { 5, 4, 1, 2, 7, 3, 6 };
 
  vector<int *> b; 
  b.push_back(&a[6]);
  b.push_back(&a[5]);
  b.push_back(&a[2]);
  b.push_back(&a[0]);
  b.push_back(&a[4]);
  b.push_back(&a[1]);
  b.push_back(&a[3]);

  // 排序 (以 int * 指向的值)
  sort(b.begin(), b.end(), [](int *pa, int *pb) { return *pa < *pb; });

  // 一個一個印出來
  for_each(b.begin(), b.end(), [](int *p) { cout << *p << "\n"; });

  return EXIT_SUCCESS;
}

讓我們期待 C++11 的到來吧!

--
備註:範例程式碼 http://www.csie.ntu.edu.tw/~b97073/B/iterator.tgz