...
421. Maximum XOR of Two Numbers in an Array

421. Maximum XOR of Two Numbers in an Array #

Approach #

I. Naive(O(n^2))[TLE❌] #

We can try to find the maximum xor by brute force, i.e., for every pair (a, b) we calculate x and maximize it.

Code #

class Solution {
public:
    int findMaximumXOR(vector<int>& nums) {
        int maxi = 0;
        for(int i = 0; i < nums.size() - 1; i++){
            for(int j = i + 1; j < nums.size(); j++){
                maxi = max(maxi, nums[i] ^ nums[j]);
            }
        }
        return maxi;
    }
};

II.Optimized(O(32n))[Passed✅] #

We can solve this problem using bit manipulation technique. The idea is to use the bits of numbers from least significant to most significant and keep track. So to keep track we store them in a Trie with our root storing MSB. Steps

  • Store each num in trie.
  • To find maximum xor value for each num, we traverse the num in our trie st:
    1. If our current bit of num is 1, try to visit child node 0 in the trie.
    2. If our current bit of num is 0, try to visit child node 1 in the trie.
    3. Else we just visit whatever child we have of the root.
  • This would give us the maximum xor as we traversed the path when bit result was 1 and 0 when our try failed(i.e we can’t find 0 for 1 or vice versa).

Code #

struct TrieNode {
    TrieNode *child[2] = {};
    void increase(int number) {
        TrieNode *cur = this;
        for (int i = 31; i >= 0; --i) {
            int bit = (number >> i) & 1;
            if (cur->child[bit] == nullptr) cur->child[bit] = new TrieNode();
            cur = cur->child[bit];
        }
    }
    int findMax(int number) {
        TrieNode *cur = this;
        int ans = 0;
        for (int i = 31; i >= 0; --i) {
            int bit = (number >> i) & 1;
            if (cur->child[1 - bit] != nullptr) {
                cur = cur->child[1 - bit];
                ans |= (1 << i);
            } else cur = cur->child[bit];
        }
        return ans;
    }
};

class Solution {
public:
    int findMaximumXOR(vector<int> &nums) {
        TrieNode *trieNode = new TrieNode();
        for (int x : nums) trieNode->increase(x);
        
        int ans = 0;
        
        for (int x : nums) ans = max(ans, trieNode->findMax(x));
        return ans;
    }
};

credits to @hiepit for his clear explaination and elegant c++ code, this code taught me alot!

Thanks🙋‍♂️